diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py index edc29230..c56e2015 100644 --- a/datapipe/meta/sql_meta.py +++ b/datapipe/meta/sql_meta.py @@ -60,7 +60,7 @@ def __init__( primary_schema: DataSchema, meta_schema: MetaSchema = [], create_table: bool = False, - ): + ) -> None: self.dbconn = dbconn self.name = name @@ -691,22 +691,28 @@ def sql_apply_filters_idx_to_subquery( return sql +@dataclass +class ComputeInputCTE: + cte: Any + keys: List[str] + join_type: Literal["inner", "full"] + + def _make_agg_of_agg( ds: "DataStore", transform_keys: List[str], - common_transform_keys: List[str], agg_col: str, - ctes: List[Tuple[List[str], Any]], + ctes: List[ComputeInputCTE], ) -> Any: assert len(ctes) > 0 if len(ctes) == 1: - return ctes[0][1] + return ctes[0].cte coalesce_keys = [] for key in transform_keys: - ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys] + ctes_with_key = [cte.cte for cte in ctes if key in cte.keys] if len(ctes_with_key) == 0: raise ValueError(f"Key {key} not found in any of the input tables") @@ -719,29 +725,41 @@ def _make_agg_of_agg( ) agg = sa.func.max( - ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes]) + ds.meta_dbconn.func_greatest(*[cte.cte.c[agg_col] for cte in ctes]) ).label(agg_col) - _, first_cte = ctes[0] + first_cte = ctes[0].cte sql = sa.select(*coalesce_keys + [agg]).select_from(first_cte) - for _, cte in ctes[1:]: - if len(common_transform_keys) > 0: + prev_ctes = [ctes[0]] + + for cte in ctes[1:]: + onclause = [] + + for prev_cte in prev_ctes: + for key in cte.keys: + if key in prev_cte.keys: + onclause.append(prev_cte.cte.c[key] == cte.cte.c[key]) + + if len(onclause) > 0: sql = sql.outerjoin( - cte, - onclause=sa.and_( - *[first_cte.c[key] == cte.c[key] for key in common_transform_keys] - ), + cte.cte, + onclause=sa.and_(*onclause), full=True, ) else: sql = sql.outerjoin( - cte, + cte.cte, onclause=sa.literal(True), full=True, ) + if cte.join_type == "inner": + sql = sql.where(sa.and_(*[cte.cte.c[key].isnot(None) for key in cte.keys])) + + prev_ctes.append(cte) + sql = sql.group_by(*coalesce_keys) return sql.cte(name=f"all__{agg_col}") @@ -761,23 +779,20 @@ def build_changed_idx_sql( for col in itertools.chain(*[inp.dt.primary_schema for inp in input_dts]): all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1 - inp_ctes = [ - inp.dt.meta_table.get_agg_cte( + inp_ctes = [] + for inp in input_dts: + keys, cte = inp.dt.meta_table.get_agg_cte( transform_keys=transform_keys, filters_idx=filters_idx, run_config=run_config, ) - for inp in input_dts - ] + inp_ctes.append(ComputeInputCTE(cte=cte, keys=keys, join_type=inp.join_type)) common_keys = [k for k, v in all_input_keys_counts.items() if v == len(input_dts)] - common_transform_keys = [k for k in transform_keys if k in common_keys] - - inp = _make_agg_of_agg( + agg_of_aggs = _make_agg_of_agg( ds=ds, transform_keys=transform_keys, - common_transform_keys=common_transform_keys, ctes=inp_ctes, agg_col="update_ts", ) @@ -799,10 +814,10 @@ def build_changed_idx_sql( if len(transform_keys) == 0: join_onclause_sql: Any = sa.literal(True) elif len(transform_keys) == 1: - join_onclause_sql = inp.c[transform_keys[0]] == out.c[transform_keys[0]] + join_onclause_sql = agg_of_aggs.c[transform_keys[0]] == out.c[transform_keys[0]] else: # len(transform_keys) > 1: join_onclause_sql = sa.and_( - *[inp.c[key] == out.c[key] for key in transform_keys] + *[agg_of_aggs.c[key] == out.c[key] for key in transform_keys] ) sql = ( @@ -811,11 +826,11 @@ def build_changed_idx_sql( # пустом transform_keys sa.literal(1).label("_datapipe_dummy"), *[ - sa.func.coalesce(inp.c[key], out.c[key]).label(key) + sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key) for key in transform_keys ], ) - .select_from(inp) + .select_from(agg_of_aggs) .outerjoin( out, onclause=join_onclause_sql, @@ -825,7 +840,7 @@ def build_changed_idx_sql( sa.or_( sa.and_( out.c.is_success == True, # noqa - inp.c.update_ts > out.c.process_ts, + agg_of_aggs.c.update_ts > out.c.process_ts, ), out.c.is_success != True, # noqa out.c.process_ts == None, # noqa diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index 995d98d3..cd03394c 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -10,7 +10,7 @@ from datapipe.step.batch_generate import BatchGenerate from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB -from datapipe.types import IndexDF +from datapipe.types import IndexDF, Required from .util import assert_datatable_equal, assert_df_equal @@ -411,10 +411,12 @@ def get_some_prediction_only_on_best_model( BatchTransform( func=get_some_prediction_only_on_best_model, inputs=[ - "tbl_image", # image_id + Required("tbl_image"), # image_id "tbl_image__attribute", # image_id, attribute - "tbl_prediction", # image_id, model_id, prediction__attribite - "tbl_best_model", # model_id + Required( + "tbl_prediction" + ), # image_id, model_id, prediction__attribite + Required("tbl_best_model"), # model_id ], outputs=["tbl_output"], transform_keys=["image_id", "model_id"], @@ -437,11 +439,9 @@ def test_complex_transform_with_many_recordings_N100(dbconn): complex_transform_with_many_recordings(dbconn, N=100) -@pytest.mark.skip(reason="This test is slow") def test_complex_transform_with_many_recordings_N1000(dbconn): complex_transform_with_many_recordings(dbconn, N=1000) -@pytest.mark.skip(reason="This test is slow") def test_complex_transform_with_many_recordings_N10000(dbconn): complex_transform_with_many_recordings(dbconn, N=10000)