Skip to content

Commit

Permalink
seems to be working
Browse files Browse the repository at this point in the history
  • Loading branch information
elephantum committed Aug 31, 2024
1 parent e5a01ca commit 0e32abb
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 33 deletions.
69 changes: 42 additions & 27 deletions datapipe/meta/sql_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
primary_schema: DataSchema,
meta_schema: MetaSchema = [],
create_table: bool = False,
):
) -> None:
self.dbconn = dbconn
self.name = name

Expand Down Expand Up @@ -691,22 +691,28 @@ def sql_apply_filters_idx_to_subquery(
return sql


@dataclass
class ComputeInputCTE:
cte: Any
keys: List[str]
join_type: Literal["inner", "full"]


def _make_agg_of_agg(
ds: "DataStore",
transform_keys: List[str],
common_transform_keys: List[str],
agg_col: str,
ctes: List[Tuple[List[str], Any]],
ctes: List[ComputeInputCTE],
) -> Any:
assert len(ctes) > 0

if len(ctes) == 1:
return ctes[0][1]
return ctes[0].cte

coalesce_keys = []

for key in transform_keys:
ctes_with_key = [subq for (subq_keys, subq) in ctes if key in subq_keys]
ctes_with_key = [cte.cte for cte in ctes if key in cte.keys]

if len(ctes_with_key) == 0:
raise ValueError(f"Key {key} not found in any of the input tables")
Expand All @@ -719,29 +725,41 @@ def _make_agg_of_agg(
)

agg = sa.func.max(
ds.meta_dbconn.func_greatest(*[subq.c[agg_col] for (_, subq) in ctes])
ds.meta_dbconn.func_greatest(*[cte.cte.c[agg_col] for cte in ctes])
).label(agg_col)

_, first_cte = ctes[0]
first_cte = ctes[0].cte

sql = sa.select(*coalesce_keys + [agg]).select_from(first_cte)

for _, cte in ctes[1:]:
if len(common_transform_keys) > 0:
prev_ctes = [ctes[0]]

for cte in ctes[1:]:
onclause = []

for prev_cte in prev_ctes:
for key in cte.keys:
if key in prev_cte.keys:
onclause.append(prev_cte.cte.c[key] == cte.cte.c[key])

if len(onclause) > 0:
sql = sql.outerjoin(
cte,
onclause=sa.and_(
*[first_cte.c[key] == cte.c[key] for key in common_transform_keys]
),
cte.cte,
onclause=sa.and_(*onclause),
full=True,
)
else:
sql = sql.outerjoin(
cte,
cte.cte,
onclause=sa.literal(True),
full=True,
)

if cte.join_type == "inner":
sql = sql.where(sa.and_(*[cte.cte.c[key].isnot(None) for key in cte.keys]))

prev_ctes.append(cte)

sql = sql.group_by(*coalesce_keys)

return sql.cte(name=f"all__{agg_col}")
Expand All @@ -761,23 +779,20 @@ def build_changed_idx_sql(
for col in itertools.chain(*[inp.dt.primary_schema for inp in input_dts]):
all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1

inp_ctes = [
inp.dt.meta_table.get_agg_cte(
inp_ctes = []
for inp in input_dts:
keys, cte = inp.dt.meta_table.get_agg_cte(
transform_keys=transform_keys,
filters_idx=filters_idx,
run_config=run_config,
)
for inp in input_dts
]
inp_ctes.append(ComputeInputCTE(cte=cte, keys=keys, join_type=inp.join_type))

common_keys = [k for k, v in all_input_keys_counts.items() if v == len(input_dts)]

common_transform_keys = [k for k in transform_keys if k in common_keys]

inp = _make_agg_of_agg(
agg_of_aggs = _make_agg_of_agg(
ds=ds,
transform_keys=transform_keys,
common_transform_keys=common_transform_keys,
ctes=inp_ctes,
agg_col="update_ts",
)
Expand All @@ -799,10 +814,10 @@ def build_changed_idx_sql(
if len(transform_keys) == 0:
join_onclause_sql: Any = sa.literal(True)
elif len(transform_keys) == 1:
join_onclause_sql = inp.c[transform_keys[0]] == out.c[transform_keys[0]]
join_onclause_sql = agg_of_aggs.c[transform_keys[0]] == out.c[transform_keys[0]]
else: # len(transform_keys) > 1:
join_onclause_sql = sa.and_(
*[inp.c[key] == out.c[key] for key in transform_keys]
*[agg_of_aggs.c[key] == out.c[key] for key in transform_keys]
)

sql = (
Expand All @@ -811,11 +826,11 @@ def build_changed_idx_sql(
# пустом transform_keys
sa.literal(1).label("_datapipe_dummy"),
*[
sa.func.coalesce(inp.c[key], out.c[key]).label(key)
sa.func.coalesce(agg_of_aggs.c[key], out.c[key]).label(key)
for key in transform_keys
],
)
.select_from(inp)
.select_from(agg_of_aggs)
.outerjoin(
out,
onclause=join_onclause_sql,
Expand All @@ -825,7 +840,7 @@ def build_changed_idx_sql(
sa.or_(
sa.and_(
out.c.is_success == True, # noqa
inp.c.update_ts > out.c.process_ts,
agg_of_aggs.c.update_ts > out.c.process_ts,
),
out.c.is_success != True, # noqa
out.c.process_ts == None, # noqa
Expand Down
12 changes: 6 additions & 6 deletions tests/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from datapipe.step.batch_generate import BatchGenerate
from datapipe.step.batch_transform import BatchTransform
from datapipe.store.database import TableStoreDB
from datapipe.types import IndexDF
from datapipe.types import IndexDF, Required

from .util import assert_datatable_equal, assert_df_equal

Expand Down Expand Up @@ -411,10 +411,12 @@ def get_some_prediction_only_on_best_model(
BatchTransform(
func=get_some_prediction_only_on_best_model,
inputs=[
"tbl_image", # image_id
Required("tbl_image"), # image_id
"tbl_image__attribute", # image_id, attribute
"tbl_prediction", # image_id, model_id, prediction__attribite
"tbl_best_model", # model_id
Required(
"tbl_prediction"
), # image_id, model_id, prediction__attribite
Required("tbl_best_model"), # model_id
],
outputs=["tbl_output"],
transform_keys=["image_id", "model_id"],
Expand All @@ -437,11 +439,9 @@ def test_complex_transform_with_many_recordings_N100(dbconn):
complex_transform_with_many_recordings(dbconn, N=100)


@pytest.mark.skip(reason="This test is slow")
def test_complex_transform_with_many_recordings_N1000(dbconn):
complex_transform_with_many_recordings(dbconn, N=1000)


@pytest.mark.skip(reason="This test is slow")
def test_complex_transform_with_many_recordings_N10000(dbconn):
complex_transform_with_many_recordings(dbconn, N=10000)

0 comments on commit 0e32abb

Please sign in to comment.