Skip to content

Commit

Permalink
removed Keys from filters must be in transform_keys error
Browse files Browse the repository at this point in the history
  • Loading branch information
bobokvsky committed Sep 6, 2024
1 parent b1de4b8 commit 8055860
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 5 deletions.
2 changes: 0 additions & 2 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,6 @@ def _get_filters(
keys = set([key for keys in filters for key in keys])
if not all(len(filter) == len(keys) for filter in filters):
raise ValueError("Size of keys from filters must have same length")
if not all([key in self.transform_keys for key in keys]):
raise ValueError(f"Keys from filters must be in transform_keys={self.transform_keys}.")
return filters

def get_status(self, ds: DataStore) -> StepStatus:
Expand Down
3 changes: 1 addition & 2 deletions tests/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,11 +458,10 @@ def gen_tbl(df):
test_df__model = pd.DataFrame({
"model_id": [0, 1, 2, 3, 4]
})

def filters_images():
return [{"image_id": i} for i in range(N // 2)]


def make_prediction(
df__image: pd.DataFrame,
df__model: pd.DataFrame,
Expand Down
31 changes: 30 additions & 1 deletion tests/test_core_steps2.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def test_batch_transform_with_filter_in_run_config(dbconn):
assert_datatable_equal(tbl2, TEST_DF1_1.query("pipeline_id == 0"))


def test_batch_transform_with_filter_not_in_transform_index(dbconn):
def test_batch_transform_with_filter_in_run_config_not_in_transform_index(dbconn):
ds = DataStore(dbconn, create_meta_table=True)

tbl1 = ds.create_table(
Expand Down Expand Up @@ -176,6 +176,35 @@ def test_batch_transform_with_filter_not_in_transform_index(dbconn):
assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]])


def test_batch_transform_with_filter_not_in_transform_index(dbconn):
ds = DataStore(dbconn, create_meta_table=True)

tbl1 = ds.create_table(
"tbl1", table_store=TableStoreDB(dbconn, "tbl1_data", TEST_SCHEMA1, True)
)

tbl2 = ds.create_table(
"tbl2", table_store=TableStoreDB(dbconn, "tbl2_data", TEST_SCHEMA2, True)
)

tbl1.store_chunk(TEST_DF1_2, now=0)

step = BatchTransformStep(
ds=ds,
name="test",
func=lambda df: df[["item_id", "a"]],
input_dts=[ComputeInput(dt=tbl1, join_type="full")],
output_dts=[tbl2],
filters=[{"pipeline_id": 0}]
)

step.run_full(
ds,
)

assert_datatable_equal(tbl2, TEST_DF1_2.query("pipeline_id == 0")[["item_id", "a"]])


def test_batch_transform_with_dt_on_input_and_output(dbconn):
ds = DataStore(dbconn, create_meta_table=True)

Expand Down

0 comments on commit 8055860

Please sign in to comment.