diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index 1db7d3e6..88b3e82b 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -207,8 +207,6 @@ def _get_filters( keys = set([key for keys in filters for key in keys]) if not all(len(filter) == len(keys) for filter in filters): raise ValueError("Size of keys from filters must have same length") - if not all([key in self.transform_keys for key in keys]): - raise ValueError(f"Keys from filters must be in transform_keys={self.transform_keys}.") return filters def get_status(self, ds: DataStore) -> StepStatus: diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index 7e3321c2..f9dc079b 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -458,11 +458,10 @@ def gen_tbl(df): test_df__model = pd.DataFrame({ "model_id": [0, 1, 2, 3, 4] }) - + def filters_images(): return [{"image_id": i} for i in range(N // 2)] - def make_prediction( df__image: pd.DataFrame, df__model: pd.DataFrame, diff --git a/tests/test_core_steps2.py b/tests/test_core_steps2.py index 604d8145..90237339 100644 --- a/tests/test_core_steps2.py +++ b/tests/test_core_steps2.py @@ -147,7 +147,7 @@ def test_batch_transform_with_filter_in_run_config(dbconn): assert_datatable_equal(tbl2, TEST_DF1_1.query("pipeline_id == 0")) -def test_batch_transform_with_filter_not_in_transform_index(dbconn): +def test_batch_transform_with_filter_in_run_config_not_in_transform_index(dbconn): ds = DataStore(dbconn, create_meta_table=True) tbl1 = ds.create_table( @@ -176,6 +176,35 @@ def test_batch_transform_with_filter_not_in_transform_index(dbconn): assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]]) +def test_batch_transform_with_filter_not_in_transform_index(dbconn): + ds = DataStore(dbconn, create_meta_table=True) + + tbl1 = ds.create_table( + "tbl1", table_store=TableStoreDB(dbconn, "tbl1_data", TEST_SCHEMA1, True) + ) + + tbl2 = ds.create_table( + "tbl2", table_store=TableStoreDB(dbconn, "tbl2_data", TEST_SCHEMA2, True) + ) + + tbl1.store_chunk(TEST_DF1_2, now=0) + + step = BatchTransformStep( + ds=ds, + name="test", + func=lambda df: df[["item_id", "a"]], + input_dts=[ComputeInput(dt=tbl1, join_type="full")], + output_dts=[tbl2], + filters=[{"pipeline_id": 0}] + ) + + step.run_full( + ds, + ) + + assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]]) + + def test_batch_transform_with_dt_on_input_and_output(dbconn): ds = DataStore(dbconn, create_meta_table=True)