From feb7b9b6ba85a3cae2e83cf8bcff6894be6fb9c3 Mon Sep 17 00:00:00 2001 From: Andrey Tatarinov Date: Sun, 25 Aug 2024 16:38:17 +0400 Subject: [PATCH] introduce `JoinType` wrapper for `input_dts` --- datapipe/compute.py | 24 +-- datapipe/datatable.py | 13 +- datapipe/step/batch_transform.py | 32 ++-- datapipe/step/datatable_transform.py | 11 +- docs/source/SUMMARY.md | 2 + tests/test_batch_transform_scheduling.py | 11 +- tests/test_complex_pipeline.py | 192 ++++++++++++----------- tests/test_core_steps1.py | 30 ++-- tests/test_core_steps2.py | 21 ++- tests/test_image_pipeline.py | 3 +- 10 files changed, 196 insertions(+), 143 deletions(-) diff --git a/datapipe/compute.py b/datapipe/compute.py index b3196c0b..1ed52526 100644 --- a/datapipe/compute.py +++ b/datapipe/compute.py @@ -2,7 +2,7 @@ import logging from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Dict, Iterable, List, Literal, Optional, Tuple from opentelemetry import trace @@ -87,6 +87,12 @@ class StepStatus: changed_idx_count: int +@dataclass +class JoinType: + dt: DataTable + join_type: Literal["inner", "full"] = "full" + + class ComputeStep: """ Шаг вычислений в графе вычислений. @@ -106,7 +112,7 @@ class ComputeStep: def __init__( self, name: str, - input_dts: List[DataTable], + input_dts: List[JoinType], output_dts: List[DataTable], labels: Optional[Labels] = None, executor_config: Optional[ExecutorConfig] = None, @@ -121,7 +127,7 @@ def get_name(self) -> str: ss = [ self.__class__.__name__, self._name, - *[i.name for i in self.input_dts], + *[i.dt.name for i in self.input_dts], *[o.name for o in self.output_dts], ] @@ -143,7 +149,7 @@ def get_status(self, ds: DataStore) -> StepStatus: # TODO: move to lints def validate(self) -> None: - inp_p_keys_arr = [set(inp.primary_keys) for inp in self.input_dts if inp] + inp_p_keys_arr = [set(inp.dt.primary_keys) for inp in self.input_dts if inp] out_p_keys_arr = [set(out.primary_keys) for out in self.output_dts if out] inp_p_keys = set.intersection(*inp_p_keys_arr) if len(inp_p_keys_arr) else set() @@ -153,7 +159,7 @@ def validate(self) -> None: key_to_column_type_inp = { column.name: type(column.type) for inp in self.input_dts - for column in inp.primary_schema + for column in inp.dt.primary_schema if column.name in join_keys } key_to_column_type_out = { @@ -271,11 +277,11 @@ def run_steps( ) -> None: for step in steps: with tracer.start_as_current_span( - f"{step.get_name()} {[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" + f"{step.get_name()} {[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" ): logger.info( f"Running {step.get_name()} " - f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" + f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" ) step.run_full(ds=ds, run_config=run_config, executor=executor) @@ -323,11 +329,11 @@ def run_steps_changelist( for step in steps: with tracer.start_as_current_span( f"{step.get_name()} " - f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" + f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" ): logger.info( f"Running {step.get_name()} " - f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" + f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}" ) if isinstance(step, BaseBatchTransformStep): diff --git a/datapipe/datatable.py b/datapipe/datatable.py index d7a19c7a..db2febbd 100644 --- a/datapipe/datatable.py +++ b/datapipe/datatable.py @@ -1,9 +1,18 @@ -import copy import logging import math import time from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterator, + List, + Literal, + Optional, + Tuple, + cast, +) import pandas as pd from cityhash import CityHash32 diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index 27c6b610..a4ddce4c 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -42,7 +42,7 @@ from sqlalchemy.sql.expression import select from tqdm_loggable.auto import tqdm -from datapipe.compute import Catalog, ComputeStep, PipelineStep, StepStatus +from datapipe.compute import Catalog, ComputeStep, JoinType, PipelineStep, StepStatus from datapipe.datatable import DataStore, DataTable, MetaTable from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor from datapipe.run_config import LabelDict, RunConfig @@ -299,7 +299,7 @@ def __init__( self, ds: DataStore, name: str, - input_dts: List[DataTable], + input_dts: List[JoinType], output_dts: List[DataTable], transform_keys: Optional[List[str]] = None, chunk_size: int = 1000, @@ -325,8 +325,8 @@ def __init__( transform_keys = list(transform_keys) self.transform_keys, self.transform_schema = self.compute_transform_schema( - [i.meta_table for i in input_dts], - [i.meta_table for i in output_dts], + [inp.dt.meta_table for inp in input_dts], + [out.meta_table for out in output_dts], transform_keys, ) @@ -388,7 +388,7 @@ def _build_changed_idx_sql( run_config: Optional[RunConfig] = None, # TODO remove ) -> Tuple[Iterable[str], Any]: all_input_keys_counts: Dict[str, int] = {} - for col in itertools.chain(*[dt.primary_schema for dt in self.input_dts]): + for col in itertools.chain(*[inp.dt.primary_schema for inp in self.input_dts]): all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1 common_keys = [ @@ -450,12 +450,12 @@ def _make_agg_of_agg(ctes, agg_col): return sql.cte(name=f"all__{agg_col}") inp_ctes = [ - tbl.get_agg_cte( + inp.dt.get_agg_cte( transform_keys=self.transform_keys, filters_idx=filters_idx, run_config=run_config, ) - for tbl in self.input_dts + for inp in self.input_dts ] inp = _make_agg_of_agg(inp_ctes, "update_ts") @@ -639,8 +639,8 @@ def get_change_list_process_ids( changes = [pd.DataFrame(columns=self.transform_keys)] for inp in self.input_dts: - if inp.name in change_list.changes: - idx = change_list.changes[inp.name] + if inp.dt.name in change_list.changes: + idx = change_list.changes[inp.dt.name] if any([key not in idx.columns for key in self.transform_keys]): _, sql = self._build_changed_idx_sql( ds=ds, @@ -761,7 +761,7 @@ def get_batch_input_dfs( idx: IndexDF, run_config: Optional[RunConfig] = None, ) -> List[DataDF]: - return [inp.get_data(idx) for inp in self.input_dts] + return [inp.dt.get_data(idx) for inp in self.input_dts] def process_batch_dfs( self, @@ -922,7 +922,7 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]: ds=ds, name=f"{self.func.__name__}", func=self.func, - input_dts=input_dts, + input_dts=[JoinType(dt=inp, join_type="full") for inp in input_dts], output_dts=output_dts, kwargs=self.kwargs, transform_keys=self.transform_keys, @@ -938,7 +938,7 @@ def __init__( ds: DataStore, name: str, func: DatatableBatchTransformFunc, - input_dts: List[DataTable], + input_dts: List[JoinType], output_dts: List[DataTable], kwargs: Optional[Dict] = None, transform_keys: Optional[List[str]] = None, @@ -967,7 +967,7 @@ def process_batch_dts( return self.func( ds=ds, idx=idx, - input_dts=self.input_dts, + input_dts=[inp.dt for inp in self.input_dts], run_config=run_config, kwargs=self.kwargs, ) @@ -995,7 +995,9 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]: BatchTransformStep( ds=ds, name=f"{self.func.__name__}", # type: ignore # mypy bug: https://github.com/python/mypy/issues/10976 - input_dts=input_dts, + input_dts=[ + JoinType(dt=input_dts, join_type="full") for input_dts in input_dts + ], output_dts=output_dts, func=self.func, kwargs=self.kwargs, @@ -1016,7 +1018,7 @@ def __init__( ds: DataStore, name: str, func: BatchTransformFunc, - input_dts: List[DataTable], + input_dts: List[JoinType], output_dts: List[DataTable], kwargs: Optional[Dict[str, Any]] = None, transform_keys: Optional[List[str]] = None, diff --git a/datapipe/step/datatable_transform.py b/datapipe/step/datatable_transform.py index d40400a8..7b37467e 100644 --- a/datapipe/step/datatable_transform.py +++ b/datapipe/step/datatable_transform.py @@ -4,7 +4,7 @@ from opentelemetry import trace -from datapipe.compute import Catalog, ComputeStep, PipelineStep +from datapipe.compute import Catalog, ComputeStep, JoinType, PipelineStep from datapipe.datatable import DataStore, DataTable from datapipe.executor import Executor from datapipe.run_config import RunConfig @@ -32,7 +32,7 @@ class DatatableTransformStep(ComputeStep): def __init__( self, name: str, - input_dts: List[DataTable], + input_dts: List[JoinType], output_dts: List[DataTable], func: DatatableTransformFunc, kwargs: Optional[Dict] = None, @@ -76,7 +76,7 @@ def run_full( try: self.func( ds=ds, - input_dts=self.input_dts, + input_dts=[inp.dt for inp in self.input_dts], output_dts=self.output_dts, run_config=run_config, kwargs=self.kwargs, @@ -101,7 +101,10 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List["ComputeStep"]: return [ DatatableTransformStep( name=self.func.__name__, - input_dts=[catalog.get_datatable(ds, i) for i in self.inputs], + input_dts=[ + JoinType(dt=catalog.get_datatable(ds, i), join_type="full") + for i in self.inputs + ], output_dts=[catalog.get_datatable(ds, i) for i in self.outputs], func=self.func, kwargs=self.kwargs, diff --git a/docs/source/SUMMARY.md b/docs/source/SUMMARY.md index 8ec0adb7..6f206870 100644 --- a/docs/source/SUMMARY.md +++ b/docs/source/SUMMARY.md @@ -3,6 +3,8 @@ # Introduction - [Introduction](./introduction.md) +- TODO: Concepts +- TODO: How merging works # Command Line Interface diff --git a/tests/test_batch_transform_scheduling.py b/tests/test_batch_transform_scheduling.py index c66afe05..a4dfd85c 100644 --- a/tests/test_batch_transform_scheduling.py +++ b/tests/test_batch_transform_scheduling.py @@ -3,6 +3,7 @@ import pandas as pd from sqlalchemy import Column, Integer +from datapipe.compute import JoinType from datapipe.datatable import DataStore from datapipe.step.batch_transform import BatchTransformStep from datapipe.store.database import TableStoreDB @@ -42,7 +43,9 @@ def id_func(df): ds=ds, name="step", func=id_func, - input_dts=[tbl1], + input_dts=[ + JoinType(dt=tbl1, join_type="full"), + ], output_dts=[tbl2], ) @@ -148,7 +151,11 @@ def test_aux_input(dbconn) -> None: ds=ds, name="step", func=lambda items1, items2, aux: items1, - input_dts=[tbl_items1, tbl_items2, tbl_aux], + input_dts=[ + JoinType(dt=tbl_items1, join_type="full"), + JoinType(dt=tbl_items2, join_type="full"), + JoinType(dt=tbl_aux, join_type="full"), + ], output_dts=[tbl_out], transform_keys=["id"], ) diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index e5b45c1c..bbc80056 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -1,10 +1,11 @@ import pandas as pd +import pytest from sqlalchemy import Column from sqlalchemy.sql.sqltypes import Integer, String -from datapipe.step.batch_generate import BatchGenerate from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps from datapipe.datatable import DataStore +from datapipe.step.batch_generate import BatchGenerate from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB from datapipe.types import IndexDF @@ -292,116 +293,127 @@ def train( def complex_transform_with_many_recordings(dbconn, N: int): ds = DataStore(dbconn, create_meta_table=True) - catalog = Catalog({ - "tbl_image": Table( - store=TableStoreDB( - dbconn, - "tbl_image", - [ - Column("image_id", Integer, primary_key=True), - ], - True - ) - ), - "tbl_image__attribute": Table( - store=TableStoreDB( - dbconn, - "tbl_image__attribute", - [ - Column("image_id", Integer, primary_key=True), - Column("attribute", Integer), - ], - True - ) - ), - "tbl_prediction": Table( - store=TableStoreDB( - dbconn, - "tbl_prediction", - [ - Column("image_id", Integer, primary_key=True), - Column("model_id", Integer, primary_key=True), - Column("prediction__attribite", Integer), - ], - True - ) - ), - "tbl_best_model": Table( - store=TableStoreDB( - dbconn, - "tbl_best_model", - [ - Column("model_id", Integer, primary_key=True), - ], - True - ) - ), - "tbl_output": Table( - store=TableStoreDB( - dbconn, - "tbl_output", - [ - Column("image_id", Integer, primary_key=True), - Column("model_id", Integer, primary_key=True), - Column("result", Integer), - ], - True - ) - ), - }) + catalog = Catalog( + { + "tbl_image": Table( + store=TableStoreDB( + dbconn, + "tbl_image", + [ + Column("image_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_image__attribute": Table( + store=TableStoreDB( + dbconn, + "tbl_image__attribute", + [ + Column("image_id", Integer, primary_key=True), + Column("attribute", Integer), + ], + True, + ) + ), + "tbl_prediction": Table( + store=TableStoreDB( + dbconn, + "tbl_prediction", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + Column("prediction__attribite", Integer), + ], + True, + ) + ), + "tbl_best_model": Table( + store=TableStoreDB( + dbconn, + "tbl_best_model", + [ + Column("model_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_output": Table( + store=TableStoreDB( + dbconn, + "tbl_output", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + Column("result", Integer), + ], + True, + ) + ), + } + ) def gen_tbls(df1, df2, df3, df4): yield df1, df2, df3, df4 - test_df__image = pd.DataFrame({ - "image_id": range(N) - }) - test_df__image__attribute = pd.DataFrame({ - "image_id": range(N), - "attribute": [5*x for x in range(N)] - }) - test_df__prediction = pd.DataFrame({ - "image_id": list(range(N)) * 5, - "model_id": [0] * N + [1] * N + [2] * N + [3] * N + [4] * N, - "prediction__attribite": ( - [1*x for x in range(N)] + # model_id=0 - [2*x for x in range(N)] + # model_id=1 - [3*x for x in range(N)] + # model_id=2 - [4*x for x in range(N)] + # model_id=3 - [5*x for x in range(N)] # model_id=4 - ) - }) - test_df__best_model = pd.DataFrame({ - "model_id": [4] - }) + test_df__image = pd.DataFrame({"image_id": range(N)}) + test_df__image__attribute = pd.DataFrame( + {"image_id": range(N), "attribute": [5 * x for x in range(N)]} + ) + test_df__prediction = pd.DataFrame( + { + "image_id": list(range(N)) * 5, + "model_id": [0] * N + [1] * N + [2] * N + [3] * N + [4] * N, + "prediction__attribite": ( + [1 * x for x in range(N)] # model_id=0 + + [2 * x for x in range(N)] # model_id=1 + + [3 * x for x in range(N)] # model_id=2 + + [4 * x for x in range(N)] # model_id=3 + + [5 * x for x in range(N)] # model_id=4 + ), + } + ) + test_df__best_model = pd.DataFrame({"model_id": [4]}) def get_some_prediction_only_on_best_model( df__image: pd.DataFrame, df__image__attribute: pd.DataFrame, df__prediction: pd.DataFrame, - df__best_model: pd.DataFrame + df__best_model: pd.DataFrame, ): df__prediction = pd.merge(df__prediction, df__best_model, on=["model_id"]) df__image = pd.merge(df__image, df__image__attribute, on=["image_id"]) df__result = pd.merge(df__image, df__prediction, on=["image_id"]) - df__result["result"] = df__result["attribute"] - df__result["prediction__attribite"] + df__result["result"] = ( + df__result["attribute"] - df__result["prediction__attribite"] + ) return df__result[["image_id", "model_id", "result"]] pipeline = Pipeline( [ BatchGenerate( func=gen_tbls, - outputs=["tbl_image", "tbl_image__attribute", "tbl_prediction", "tbl_best_model"], + outputs=[ + "tbl_image", + "tbl_image__attribute", + "tbl_prediction", + "tbl_best_model", + ], kwargs=dict( df1=test_df__image, df2=test_df__image__attribute, df3=test_df__prediction, - df4=test_df__best_model + df4=test_df__best_model, ), ), BatchTransform( func=get_some_prediction_only_on_best_model, - inputs=["tbl_image", "tbl_image__attribute", "tbl_prediction", "tbl_best_model"], + inputs=[ + "tbl_image", + "tbl_image__attribute", + "tbl_prediction", + "tbl_best_model", + ], outputs=["tbl_output"], transform_keys=["image_id", "model_id"], ), @@ -409,15 +421,13 @@ def get_some_prediction_only_on_best_model( ) steps = build_compute(ds, catalog, pipeline) run_steps(ds, steps) - test__df_output = pd.DataFrame({ - "image_id": range(N), - "model_id": [4] * N, - "result": [0] * N - }) + test__df_output = pd.DataFrame( + {"image_id": range(N), "model_id": [4] * N, "result": [0] * N} + ) assert_df_equal( ds.get_table("tbl_output").get_data(), test__df_output, - index_cols=["image_id", "model_id"] + index_cols=["image_id", "model_id"], ) @@ -425,9 +435,11 @@ def test_complex_transform_with_many_recordings_N100(dbconn): complex_transform_with_many_recordings(dbconn, N=100) +@pytest.mark.skip(reason="This test is slow") def test_complex_transform_with_many_recordings_N1000(dbconn): complex_transform_with_many_recordings(dbconn, N=1000) +@pytest.mark.skip(reason="This test is slow") def test_complex_transform_with_many_recordings_N10000(dbconn): complex_transform_with_many_recordings(dbconn, N=10000) diff --git a/tests/test_core_steps1.py b/tests/test_core_steps1.py index de1b8684..a15fa63a 100644 --- a/tests/test_core_steps1.py +++ b/tests/test_core_steps1.py @@ -8,6 +8,7 @@ from sqlalchemy import Column from sqlalchemy.sql.sqltypes import JSON, Integer +from datapipe.compute import JoinType from datapipe.datatable import DataStore from datapipe.step.batch_generate import do_batch_generate from datapipe.step.batch_transform import BatchTransformStep @@ -104,7 +105,7 @@ def id_func(df): ds=ds, name="test", func=id_func, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -139,7 +140,7 @@ def id_func(df): ds=ds, name="test", func=id_func, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -180,7 +181,7 @@ def id_func(df): ds=ds, name="test", func=id_func, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -255,7 +256,7 @@ def inc_func(df): ds=ds, name="step_inc", func=inc_func, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl1, tbl2, tbl3], ) @@ -281,7 +282,7 @@ def inc_func_inv(df): ds=ds, name="step_inc_inv", func=inc_func_inv, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl3, tbl2, tbl1], ) @@ -338,7 +339,10 @@ def inc_func(df1, df2): ds=ds, name="test", func=inc_func, - input_dts=[tbl1, tbl2], + input_dts=[ + JoinType(dt=tbl1, join_type="full"), + JoinType(dt=tbl2, join_type="full"), + ], output_dts=[tbl], ) @@ -432,7 +436,7 @@ def inc_func(df): ds=ds, name="test", func=inc_func, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl_good, tbl_bad], ) @@ -484,14 +488,14 @@ def inc_func_pack(df): ds=ds, name="unpack", func=inc_func_unpack, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl_rel], ) step_pack = BatchTransformStep( ds=ds, name="pack", func=inc_func_pack, - input_dts=[tbl_rel], + input_dts=[JoinType(dt=tbl_rel, join_type="full")], output_dts=[tbl2], ) @@ -556,14 +560,14 @@ def inc_func_pack(df): ds=ds, name="unpack", func=inc_func_unpack, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl_rel], ) step_pack = BatchTransformStep( ds=ds, name="pack", func=inc_func_pack, - input_dts=[tbl_rel], + input_dts=[JoinType(dt=tbl_rel, join_type="full")], output_dts=[tbl2], ) @@ -669,7 +673,7 @@ def inc_func_good(df): ds=ds, name="bad", func=inc_func_bad, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl_good], chunk_size=1, ) @@ -681,7 +685,7 @@ def inc_func_good(df): ds=ds, name="good", func=inc_func_good, - input_dts=[tbl], + input_dts=[JoinType(dt=tbl, join_type="full")], output_dts=[tbl_good], chunk_size=CHUNKSIZE, ) diff --git a/tests/test_core_steps2.py b/tests/test_core_steps2.py index d2bdcaca..f90c0199 100644 --- a/tests/test_core_steps2.py +++ b/tests/test_core_steps2.py @@ -9,6 +9,7 @@ from sqlalchemy import Column, String from sqlalchemy.sql.sqltypes import Integer +from datapipe.compute import JoinType from datapipe.datatable import DataStore from datapipe.run_config import RunConfig from datapipe.step.batch_generate import do_batch_generate @@ -94,7 +95,7 @@ def test_batch_transform(dbconn): ds=ds, name="test", func=lambda df: df, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -132,7 +133,7 @@ def test_batch_transform_with_filter(dbconn): ds=ds, name="test", func=lambda df: df, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) step.run_full( @@ -162,7 +163,7 @@ def test_batch_transform_with_filter_not_in_transform_index(dbconn): ds=ds, name="test", func=lambda df: df[["item_id", "a"]], - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -203,7 +204,10 @@ def update_df(df1: pd.DataFrame, df2: pd.DataFrame): ds=ds, name="test", func=update_df, - input_dts=[tbl1, tbl2], + input_dts=[ + JoinType(dt=tbl1, join_type="full"), + JoinType(dt=tbl2, join_type="full"), + ], output_dts=[tbl2], ) @@ -258,7 +262,7 @@ def transform_func(df, context=context): ds=ds, name="step1", func=transform_func, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -325,7 +329,7 @@ def func(df): ds=ds, name="test", func=func, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], ) @@ -375,7 +379,10 @@ def update_df(products: pd.DataFrame, items: pd.DataFrame): ds=ds, name="test", func=update_df, - input_dts=[products, items], + input_dts=[ + JoinType(dt=products, join_type="full"), + JoinType(dt=items, join_type="full"), + ], output_dts=[items2], ) diff --git a/tests/test_image_pipeline.py b/tests/test_image_pipeline.py index 31c9b26a..05d1dc04 100644 --- a/tests/test_image_pipeline.py +++ b/tests/test_image_pipeline.py @@ -6,6 +6,7 @@ from datapipe.compute import ( Catalog, + JoinType, Pipeline, Table, build_compute, @@ -71,7 +72,7 @@ def test_image_datatables(dbconn, tmp_dir): ds=ds, name="resize_images", func=resize_images, - input_dts=[tbl1], + input_dts=[JoinType(dt=tbl1, join_type="full")], output_dts=[tbl2], )