diff --git a/datapipe/datatable.py b/datapipe/datatable.py index db2febbd..835beae3 100644 --- a/datapipe/datatable.py +++ b/datapipe/datatable.py @@ -8,7 +8,6 @@ Dict, Iterator, List, - Literal, Optional, Tuple, cast, @@ -22,12 +21,9 @@ from sqlalchemy.sql.expression import and_, func, or_, select from datapipe.event_logger import EventLogger +from datapipe.meta.sql_meta import sql_apply_filters_idx_to_subquery from datapipe.run_config import RunConfig -from datapipe.sql_util import ( - sql_apply_filters_idx_to_subquery, - sql_apply_idx_filter_to_table, - sql_apply_runconfig_filter, -) +from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter from datapipe.store.database import DBConn, MetaKey from datapipe.store.table_store import TableStore from datapipe.types import ( diff --git a/datapipe/meta/__init__.py b/datapipe/meta/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/datapipe/meta/sql_meta.py b/datapipe/meta/sql_meta.py new file mode 100644 index 00000000..0e3dc4fb --- /dev/null +++ b/datapipe/meta/sql_meta.py @@ -0,0 +1,85 @@ +from typing import TYPE_CHECKING, Any, List, Optional, Tuple + +import pandas as pd +import sqlalchemy as sa + +if TYPE_CHECKING: + from datapipe.datatable import DataStore + + +def sql_apply_filters_idx_to_subquery( + sql: Any, + keys: List[str], + filters_idx: Optional[pd.DataFrame], +) -> Any: + if filters_idx is None: + return sql + + applicable_filter_keys = [i for i in filters_idx.columns if i in keys] + if len(applicable_filter_keys) > 0: + sql = sql.where( + sa.tuple_(*[sa.column(i) for i in applicable_filter_keys]).in_( + [ + sa.tuple_(*[r[k] for k in applicable_filter_keys]) + for r in filters_idx.to_dict(orient="records") + ] + ) + ) + + return sql + + +def make_agg_of_agg( + ds: "DataStore", + transform_keys: List[str], + common_transform_keys: List[str], + agg_col: str, + ctes: List[Tuple[List[str], Any]], +) -> Any: + assert len(ctes) > 0 + + if len(ctes) == 1: + return ctes[0][1] + + coalesce_keys = [] + + for key in transform_keys: + ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys] + + if len(ctes_with_key) == 0: + raise ValueError(f"Key {key} not found in any of the input tables") + + if len(ctes_with_key) == 1: + coalesce_keys.append(ctes_with_key[0].c[key]) + else: + coalesce_keys.append( + sa.func.coalesce(*[cte.c[key] for cte in ctes_with_key]).label(key) + ) + + agg = sa.func.max( + ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes]) + ).label(agg_col) + + _, first_cte = ctes[0] + + sql = sa.select(*coalesce_keys + [agg]).select_from(first_cte) + + for _, cte in ctes[1:]: + if len(common_transform_keys) > 0: + sql = sql.outerjoin( + cte, + onclause=sa.and_( + *[first_cte.c[key] == cte.c[key] for key in common_transform_keys] + ), + full=True, + ) + else: + sql = sql.outerjoin( + cte, + onclause=sa.literal(True), + full=True, + ) + + sql = sql.group_by(*coalesce_keys) + + return sql.cte(name=f"all__{agg_col}") diff --git a/datapipe/sql_util.py b/datapipe/sql_util.py index a326ef66..0311b2fc 100644 --- a/datapipe/sql_util.py +++ b/datapipe/sql_util.py @@ -7,28 +7,6 @@ from datapipe.types import IndexDF -def sql_apply_filters_idx_to_subquery( - sql: Any, - keys: List[str], - filters_idx: Optional[pd.DataFrame], -) -> Any: - if filters_idx is None: - return sql - - applicable_filter_keys = [i for i in filters_idx.columns if i in keys] - if len(applicable_filter_keys) > 0: - sql = sql.where( - tuple_(*[column(i) for i in applicable_filter_keys]).in_( - [ - tuple_(*[r[k] for k in applicable_filter_keys]) - for r in filters_idx.to_dict(orient="records") - ] - ) - ) - - return sql - - def sql_apply_idx_filter_to_table( sql: Any, table: Table, diff --git a/datapipe/step/batch_transform.py b/datapipe/step/batch_transform.py index 69154890..36deb150 100644 --- a/datapipe/step/batch_transform.py +++ b/datapipe/step/batch_transform.py @@ -51,20 +51,20 @@ ) from datapipe.datatable import DataStore, DataTable, MetaTable from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor +from datapipe.meta.sql_meta import make_agg_of_agg, sql_apply_filters_idx_to_subquery from datapipe.run_config import LabelDict, RunConfig -from datapipe.sql_util import ( - sql_apply_filters_idx_to_subquery, - sql_apply_runconfig_filter, -) +from datapipe.sql_util import sql_apply_runconfig_filter from datapipe.store.database import DBConn from datapipe.types import ( ChangeList, DataDF, DataSchema, IndexDF, + JoinSpec, Labels, MetaSchema, PipelineInput, + Required, TableOrName, TransformResult, data_to_index, @@ -398,64 +398,6 @@ def _build_changed_idx_sql( for col in itertools.chain(*[inp.dt.primary_schema for inp in self.input_dts]): all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1 - common_keys = [ - k for k, v in all_input_keys_counts.items() if v == len(self.input_dts) - ] - - common_transform_keys = [k for k in self.transform_keys if k in common_keys] - - def _make_agg_of_agg(ctes, agg_col): - assert len(ctes) > 0 - - if len(ctes) == 1: - return ctes[0][1] - - coalesce_keys = [] - - for key in self.transform_keys: - ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys] - - if len(ctes_with_key) == 0: - raise ValueError(f"Key {key} not found in any of the input tables") - - if len(ctes_with_key) == 1: - coalesce_keys.append(ctes_with_key[0].c[key]) - else: - coalesce_keys.append( - func.coalesce(*[cte.c[key] for cte in ctes_with_key]).label(key) - ) - - agg = func.max( - ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes]) - ).label(agg_col) - - _, first_cte = ctes[0] - - sql = select(*coalesce_keys + [agg]).select_from(first_cte) - - for _, cte in ctes[1:]: - if len(common_transform_keys) > 0: - sql = sql.outerjoin( - cte, - onclause=and_( - *[ - first_cte.c[key] == cte.c[key] - for key in common_transform_keys - ] - ), - full=True, - ) - else: - sql = sql.outerjoin( - cte, - onclause=literal(True), - full=True, - ) - - sql = sql.group_by(*coalesce_keys) - - return sql.cte(name=f"all__{agg_col}") - inp_ctes = [ inp.dt.get_agg_cte( transform_keys=self.transform_keys, @@ -465,7 +407,19 @@ def _make_agg_of_agg(ctes, agg_col): for inp in self.input_dts ] - inp = _make_agg_of_agg(inp_ctes, "update_ts") + common_keys = [ + k for k, v in all_input_keys_counts.items() if v == len(self.input_dts) + ] + + common_transform_keys = [k for k in self.transform_keys if k in common_keys] + + inp = make_agg_of_agg( + ds=ds, + transform_keys=self.transform_keys, + common_transform_keys=common_transform_keys, + ctes=inp_ctes, + agg_col="update_ts", + ) tr_tbl = self.meta_table.sql_table out: Any = ( @@ -983,7 +937,7 @@ def process_batch_dts( @dataclass class BatchTransform(PipelineStep): func: BatchTransformFunc - inputs: List[TableOrName] + inputs: List[PipelineInput] outputs: List[TableOrName] chunk_size: int = 1000 kwargs: Optional[Dict[str, Any]] = None @@ -994,18 +948,35 @@ class BatchTransform(PipelineStep): order_by: Optional[List[str]] = None order: Literal["asc", "desc"] = "asc" + def pipeline_input_to_compute_input( + self, ds: DataStore, catalog: Catalog, input: PipelineInput + ) -> ComputeInput: + if isinstance(input, Required): + return ComputeInput( + dt=catalog.get_datatable(ds, input.table), + join_type="inner", + ) + elif isinstance(input, JoinSpec): + # This should not happen, but just in case + return ComputeInput( + dt=catalog.get_datatable(ds, input.table), + join_type="full", + ) + else: + return ComputeInput(dt=catalog.get_datatable(ds, input), join_type="full") + def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]: - input_dts = [catalog.get_datatable(ds, name) for name in self.inputs] + input_dts = [ + self.pipeline_input_to_compute_input(ds, catalog, input) + for input in self.inputs + ] output_dts = [catalog.get_datatable(ds, name) for name in self.outputs] return [ BatchTransformStep( ds=ds, name=f"{self.func.__name__}", - input_dts=[ - ComputeInput(dt=input_dts, join_type="full") - for input_dts in input_dts - ], + input_dts=input_dts, output_dts=output_dts, func=self.func, kwargs=self.kwargs, diff --git a/datapipe/types.py b/datapipe/types.py index c0585c68..45c8ad2c 100644 --- a/datapipe/types.py +++ b/datapipe/types.py @@ -54,6 +54,19 @@ TableOrName = Union[str, OrmTable, "Table"] +@dataclass +class JoinSpec: + table: TableOrName + + +@dataclass +class Required(JoinSpec): + pass + + +PipelineInput = Union[TableOrName, JoinSpec] + + @dataclass class ChangeList: changes: Dict[str, IndexDF] = field( diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index bbc80056..b839438c 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -409,10 +409,10 @@ def get_some_prediction_only_on_best_model( BatchTransform( func=get_some_prediction_only_on_best_model, inputs=[ - "tbl_image", - "tbl_image__attribute", - "tbl_prediction", - "tbl_best_model", + "tbl_image", # image_id + "tbl_image__attribute", # image_id, attribute + "tbl_prediction", # image_id, model_id, prediction__attribite + "tbl_best_model", # model_id ], outputs=["tbl_output"], transform_keys=["image_id", "model_id"],