Skip to content

Commit

Permalink
start extracting meta-related stuff to datapipe.meta
Browse files Browse the repository at this point in the history
  • Loading branch information
elephantum committed Aug 27, 2024
1 parent c4a1444 commit 9979018
Show file tree
Hide file tree
Showing 7 changed files with 144 additions and 101 deletions.
8 changes: 2 additions & 6 deletions datapipe/datatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
cast,
Expand All @@ -22,12 +21,9 @@
from sqlalchemy.sql.expression import and_, func, or_, select

from datapipe.event_logger import EventLogger
from datapipe.meta.sql_meta import sql_apply_filters_idx_to_subquery
from datapipe.run_config import RunConfig
from datapipe.sql_util import (
sql_apply_filters_idx_to_subquery,
sql_apply_idx_filter_to_table,
sql_apply_runconfig_filter,
)
from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter
from datapipe.store.database import DBConn, MetaKey
from datapipe.store.table_store import TableStore
from datapipe.types import (
Expand Down
Empty file added datapipe/meta/__init__.py
Empty file.
85 changes: 85 additions & 0 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

import pandas as pd
import sqlalchemy as sa

if TYPE_CHECKING:
from datapipe.datatable import DataStore


def sql_apply_filters_idx_to_subquery(
sql: Any,
keys: List[str],
filters_idx: Optional[pd.DataFrame],
) -> Any:
if filters_idx is None:
return sql

applicable_filter_keys = [i for i in filters_idx.columns if i in keys]
if len(applicable_filter_keys) > 0:
sql = sql.where(
sa.tuple_(*[sa.column(i) for i in applicable_filter_keys]).in_(
[
sa.tuple_(*[r[k] for k in applicable_filter_keys])
for r in filters_idx.to_dict(orient="records")
]
)
)

return sql


def make_agg_of_agg(
ds: "DataStore",
transform_keys: List[str],
common_transform_keys: List[str],
agg_col: str,
ctes: List[Tuple[List[str], Any]],
) -> Any:
assert len(ctes) > 0

if len(ctes) == 1:
return ctes[0][1]

coalesce_keys = []

for key in transform_keys:
ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys]

if len(ctes_with_key) == 0:
raise ValueError(f"Key {key} not found in any of the input tables")

if len(ctes_with_key) == 1:
coalesce_keys.append(ctes_with_key[0].c[key])
else:
coalesce_keys.append(
sa.func.coalesce(*[cte.c[key] for cte in ctes_with_key]).label(key)
)

agg = sa.func.max(
ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes])
).label(agg_col)

_, first_cte = ctes[0]

sql = sa.select(*coalesce_keys + [agg]).select_from(first_cte)

for _, cte in ctes[1:]:
if len(common_transform_keys) > 0:
sql = sql.outerjoin(
cte,
onclause=sa.and_(
*[first_cte.c[key] == cte.c[key] for key in common_transform_keys]
),
full=True,
)
else:
sql = sql.outerjoin(
cte,
onclause=sa.literal(True),
full=True,
)

sql = sql.group_by(*coalesce_keys)

return sql.cte(name=f"all__{agg_col}")
22 changes: 0 additions & 22 deletions datapipe/sql_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,6 @@
from datapipe.types import IndexDF


def sql_apply_filters_idx_to_subquery(
sql: Any,
keys: List[str],
filters_idx: Optional[pd.DataFrame],
) -> Any:
if filters_idx is None:
return sql

applicable_filter_keys = [i for i in filters_idx.columns if i in keys]
if len(applicable_filter_keys) > 0:
sql = sql.where(
tuple_(*[column(i) for i in applicable_filter_keys]).in_(
[
tuple_(*[r[k] for k in applicable_filter_keys])
for r in filters_idx.to_dict(orient="records")
]
)
)

return sql


def sql_apply_idx_filter_to_table(
sql: Any,
table: Table,
Expand Down
109 changes: 40 additions & 69 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,20 @@
)
from datapipe.datatable import DataStore, DataTable, MetaTable
from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor
from datapipe.meta.sql_meta import make_agg_of_agg, sql_apply_filters_idx_to_subquery
from datapipe.run_config import LabelDict, RunConfig
from datapipe.sql_util import (
sql_apply_filters_idx_to_subquery,
sql_apply_runconfig_filter,
)
from datapipe.sql_util import sql_apply_runconfig_filter
from datapipe.store.database import DBConn
from datapipe.types import (
ChangeList,
DataDF,
DataSchema,
IndexDF,
JoinSpec,
Labels,
MetaSchema,
PipelineInput,
Required,
TableOrName,
TransformResult,
data_to_index,
Expand Down Expand Up @@ -398,64 +398,6 @@ def _build_changed_idx_sql(
for col in itertools.chain(*[inp.dt.primary_schema for inp in self.input_dts]):
all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1

common_keys = [
k for k, v in all_input_keys_counts.items() if v == len(self.input_dts)
]

common_transform_keys = [k for k in self.transform_keys if k in common_keys]

def _make_agg_of_agg(ctes, agg_col):
assert len(ctes) > 0

if len(ctes) == 1:
return ctes[0][1]

coalesce_keys = []

for key in self.transform_keys:
ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys]

if len(ctes_with_key) == 0:
raise ValueError(f"Key {key} not found in any of the input tables")

if len(ctes_with_key) == 1:
coalesce_keys.append(ctes_with_key[0].c[key])
else:
coalesce_keys.append(
func.coalesce(*[cte.c[key] for cte in ctes_with_key]).label(key)
)

agg = func.max(
ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes])
).label(agg_col)

_, first_cte = ctes[0]

sql = select(*coalesce_keys + [agg]).select_from(first_cte)

for _, cte in ctes[1:]:
if len(common_transform_keys) > 0:
sql = sql.outerjoin(
cte,
onclause=and_(
*[
first_cte.c[key] == cte.c[key]
for key in common_transform_keys
]
),
full=True,
)
else:
sql = sql.outerjoin(
cte,
onclause=literal(True),
full=True,
)

sql = sql.group_by(*coalesce_keys)

return sql.cte(name=f"all__{agg_col}")

inp_ctes = [
inp.dt.get_agg_cte(
transform_keys=self.transform_keys,
Expand All @@ -465,7 +407,19 @@ def _make_agg_of_agg(ctes, agg_col):
for inp in self.input_dts
]

inp = _make_agg_of_agg(inp_ctes, "update_ts")
common_keys = [
k for k, v in all_input_keys_counts.items() if v == len(self.input_dts)
]

common_transform_keys = [k for k in self.transform_keys if k in common_keys]

inp = make_agg_of_agg(
ds=ds,
transform_keys=self.transform_keys,
common_transform_keys=common_transform_keys,
ctes=inp_ctes,
agg_col="update_ts",
)

tr_tbl = self.meta_table.sql_table
out: Any = (
Expand Down Expand Up @@ -983,7 +937,7 @@ def process_batch_dts(
@dataclass
class BatchTransform(PipelineStep):
func: BatchTransformFunc
inputs: List[TableOrName]
inputs: List[PipelineInput]
outputs: List[TableOrName]
chunk_size: int = 1000
kwargs: Optional[Dict[str, Any]] = None
Expand All @@ -994,18 +948,35 @@ class BatchTransform(PipelineStep):
order_by: Optional[List[str]] = None
order: Literal["asc", "desc"] = "asc"

def pipeline_input_to_compute_input(
self, ds: DataStore, catalog: Catalog, input: PipelineInput
) -> ComputeInput:
if isinstance(input, Required):
return ComputeInput(
dt=catalog.get_datatable(ds, input.table),
join_type="inner",
)
elif isinstance(input, JoinSpec):
# This should not happen, but just in case
return ComputeInput(
dt=catalog.get_datatable(ds, input.table),
join_type="full",
)
else:
return ComputeInput(dt=catalog.get_datatable(ds, input), join_type="full")

def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]:
input_dts = [catalog.get_datatable(ds, name) for name in self.inputs]
input_dts = [
self.pipeline_input_to_compute_input(ds, catalog, input)
for input in self.inputs
]
output_dts = [catalog.get_datatable(ds, name) for name in self.outputs]

return [
BatchTransformStep(
ds=ds,
name=f"{self.func.__name__}",
input_dts=[
ComputeInput(dt=input_dts, join_type="full")
for input_dts in input_dts
],
input_dts=input_dts,
output_dts=output_dts,
func=self.func,
kwargs=self.kwargs,
Expand Down
13 changes: 13 additions & 0 deletions datapipe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,19 @@
TableOrName = Union[str, OrmTable, "Table"]


@dataclass
class JoinSpec:
table: TableOrName


@dataclass
class Required(JoinSpec):
pass


PipelineInput = Union[TableOrName, JoinSpec]


@dataclass
class ChangeList:
changes: Dict[str, IndexDF] = field(
Expand Down
8 changes: 4 additions & 4 deletions tests/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,10 +409,10 @@ def get_some_prediction_only_on_best_model(
BatchTransform(
func=get_some_prediction_only_on_best_model,
inputs=[
"tbl_image",
"tbl_image__attribute",
"tbl_prediction",
"tbl_best_model",
"tbl_image", # image_id
"tbl_image__attribute", # image_id, attribute
"tbl_prediction", # image_id, model_id, prediction__attribite
"tbl_best_model", # model_id
],
outputs=["tbl_output"],
transform_keys=["image_id", "model_id"],
Expand Down

0 comments on commit 9979018

Please sign in to comment.