Skip to content

Commit

Permalink
introduce JoinType wrapper for input_dts
Browse files Browse the repository at this point in the history
  • Loading branch information
elephantum committed Aug 25, 2024
1 parent 6f55b3c commit feb7b9b
Show file tree
Hide file tree
Showing 10 changed files with 196 additions and 143 deletions.
24 changes: 15 additions & 9 deletions datapipe/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Dict, Iterable, List, Literal, Optional, Tuple

from opentelemetry import trace

Expand Down Expand Up @@ -87,6 +87,12 @@ class StepStatus:
changed_idx_count: int


@dataclass
class JoinType:
dt: DataTable
join_type: Literal["inner", "full"] = "full"


class ComputeStep:
"""
Шаг вычислений в графе вычислений.
Expand All @@ -106,7 +112,7 @@ class ComputeStep:
def __init__(
self,
name: str,
input_dts: List[DataTable],
input_dts: List[JoinType],
output_dts: List[DataTable],
labels: Optional[Labels] = None,
executor_config: Optional[ExecutorConfig] = None,
Expand All @@ -121,7 +127,7 @@ def get_name(self) -> str:
ss = [
self.__class__.__name__,
self._name,
*[i.name for i in self.input_dts],
*[i.dt.name for i in self.input_dts],
*[o.name for o in self.output_dts],
]

Expand All @@ -143,7 +149,7 @@ def get_status(self, ds: DataStore) -> StepStatus:

# TODO: move to lints
def validate(self) -> None:
inp_p_keys_arr = [set(inp.primary_keys) for inp in self.input_dts if inp]
inp_p_keys_arr = [set(inp.dt.primary_keys) for inp in self.input_dts if inp]
out_p_keys_arr = [set(out.primary_keys) for out in self.output_dts if out]

inp_p_keys = set.intersection(*inp_p_keys_arr) if len(inp_p_keys_arr) else set()
Expand All @@ -153,7 +159,7 @@ def validate(self) -> None:
key_to_column_type_inp = {
column.name: type(column.type)
for inp in self.input_dts
for column in inp.primary_schema
for column in inp.dt.primary_schema
if column.name in join_keys
}
key_to_column_type_out = {
Expand Down Expand Up @@ -271,11 +277,11 @@ def run_steps(
) -> None:
for step in steps:
with tracer.start_as_current_span(
f"{step.get_name()} {[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
f"{step.get_name()} {[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
):
logger.info(
f"Running {step.get_name()} "
f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
)

step.run_full(ds=ds, run_config=run_config, executor=executor)
Expand Down Expand Up @@ -323,11 +329,11 @@ def run_steps_changelist(
for step in steps:
with tracer.start_as_current_span(
f"{step.get_name()} "
f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
):
logger.info(
f"Running {step.get_name()} "
f"{[i.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
f"{[i.dt.name for i in step.input_dts]} -> {[i.name for i in step.output_dts]}"
)

if isinstance(step, BaseBatchTransformStep):
Expand Down
13 changes: 11 additions & 2 deletions datapipe/datatable.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,18 @@
import copy
import logging
import math
import time
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, cast
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterator,
List,
Literal,
Optional,
Tuple,
cast,
)

import pandas as pd
from cityhash import CityHash32
Expand Down
32 changes: 17 additions & 15 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
from sqlalchemy.sql.expression import select
from tqdm_loggable.auto import tqdm

from datapipe.compute import Catalog, ComputeStep, PipelineStep, StepStatus
from datapipe.compute import Catalog, ComputeStep, JoinType, PipelineStep, StepStatus
from datapipe.datatable import DataStore, DataTable, MetaTable
from datapipe.executor import Executor, ExecutorConfig, SingleThreadExecutor
from datapipe.run_config import LabelDict, RunConfig
Expand Down Expand Up @@ -299,7 +299,7 @@ def __init__(
self,
ds: DataStore,
name: str,
input_dts: List[DataTable],
input_dts: List[JoinType],
output_dts: List[DataTable],
transform_keys: Optional[List[str]] = None,
chunk_size: int = 1000,
Expand All @@ -325,8 +325,8 @@ def __init__(
transform_keys = list(transform_keys)

self.transform_keys, self.transform_schema = self.compute_transform_schema(
[i.meta_table for i in input_dts],
[i.meta_table for i in output_dts],
[inp.dt.meta_table for inp in input_dts],
[out.meta_table for out in output_dts],
transform_keys,
)

Expand Down Expand Up @@ -388,7 +388,7 @@ def _build_changed_idx_sql(
run_config: Optional[RunConfig] = None, # TODO remove
) -> Tuple[Iterable[str], Any]:
all_input_keys_counts: Dict[str, int] = {}
for col in itertools.chain(*[dt.primary_schema for dt in self.input_dts]):
for col in itertools.chain(*[inp.dt.primary_schema for inp in self.input_dts]):
all_input_keys_counts[col.name] = all_input_keys_counts.get(col.name, 0) + 1

common_keys = [
Expand Down Expand Up @@ -450,12 +450,12 @@ def _make_agg_of_agg(ctes, agg_col):
return sql.cte(name=f"all__{agg_col}")

inp_ctes = [
tbl.get_agg_cte(
inp.dt.get_agg_cte(
transform_keys=self.transform_keys,
filters_idx=filters_idx,
run_config=run_config,
)
for tbl in self.input_dts
for inp in self.input_dts
]

inp = _make_agg_of_agg(inp_ctes, "update_ts")
Expand Down Expand Up @@ -639,8 +639,8 @@ def get_change_list_process_ids(
changes = [pd.DataFrame(columns=self.transform_keys)]

for inp in self.input_dts:
if inp.name in change_list.changes:
idx = change_list.changes[inp.name]
if inp.dt.name in change_list.changes:
idx = change_list.changes[inp.dt.name]
if any([key not in idx.columns for key in self.transform_keys]):
_, sql = self._build_changed_idx_sql(
ds=ds,
Expand Down Expand Up @@ -761,7 +761,7 @@ def get_batch_input_dfs(
idx: IndexDF,
run_config: Optional[RunConfig] = None,
) -> List[DataDF]:
return [inp.get_data(idx) for inp in self.input_dts]
return [inp.dt.get_data(idx) for inp in self.input_dts]

def process_batch_dfs(
self,
Expand Down Expand Up @@ -922,7 +922,7 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]:
ds=ds,
name=f"{self.func.__name__}",
func=self.func,
input_dts=input_dts,
input_dts=[JoinType(dt=inp, join_type="full") for inp in input_dts],
output_dts=output_dts,
kwargs=self.kwargs,
transform_keys=self.transform_keys,
Expand All @@ -938,7 +938,7 @@ def __init__(
ds: DataStore,
name: str,
func: DatatableBatchTransformFunc,
input_dts: List[DataTable],
input_dts: List[JoinType],
output_dts: List[DataTable],
kwargs: Optional[Dict] = None,
transform_keys: Optional[List[str]] = None,
Expand Down Expand Up @@ -967,7 +967,7 @@ def process_batch_dts(
return self.func(
ds=ds,
idx=idx,
input_dts=self.input_dts,
input_dts=[inp.dt for inp in self.input_dts],
run_config=run_config,
kwargs=self.kwargs,
)
Expand Down Expand Up @@ -995,7 +995,9 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List[ComputeStep]:
BatchTransformStep(
ds=ds,
name=f"{self.func.__name__}", # type: ignore # mypy bug: https://github.com/python/mypy/issues/10976
input_dts=input_dts,
input_dts=[
JoinType(dt=input_dts, join_type="full") for input_dts in input_dts
],
output_dts=output_dts,
func=self.func,
kwargs=self.kwargs,
Expand All @@ -1016,7 +1018,7 @@ def __init__(
ds: DataStore,
name: str,
func: BatchTransformFunc,
input_dts: List[DataTable],
input_dts: List[JoinType],
output_dts: List[DataTable],
kwargs: Optional[Dict[str, Any]] = None,
transform_keys: Optional[List[str]] = None,
Expand Down
11 changes: 7 additions & 4 deletions datapipe/step/datatable_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from opentelemetry import trace

from datapipe.compute import Catalog, ComputeStep, PipelineStep
from datapipe.compute import Catalog, ComputeStep, JoinType, PipelineStep
from datapipe.datatable import DataStore, DataTable
from datapipe.executor import Executor
from datapipe.run_config import RunConfig
Expand Down Expand Up @@ -32,7 +32,7 @@ class DatatableTransformStep(ComputeStep):
def __init__(
self,
name: str,
input_dts: List[DataTable],
input_dts: List[JoinType],
output_dts: List[DataTable],
func: DatatableTransformFunc,
kwargs: Optional[Dict] = None,
Expand Down Expand Up @@ -76,7 +76,7 @@ def run_full(
try:
self.func(
ds=ds,
input_dts=self.input_dts,
input_dts=[inp.dt for inp in self.input_dts],
output_dts=self.output_dts,
run_config=run_config,
kwargs=self.kwargs,
Expand All @@ -101,7 +101,10 @@ def build_compute(self, ds: DataStore, catalog: Catalog) -> List["ComputeStep"]:
return [
DatatableTransformStep(
name=self.func.__name__,
input_dts=[catalog.get_datatable(ds, i) for i in self.inputs],
input_dts=[
JoinType(dt=catalog.get_datatable(ds, i), join_type="full")
for i in self.inputs
],
output_dts=[catalog.get_datatable(ds, i) for i in self.outputs],
func=self.func,
kwargs=self.kwargs,
Expand Down
2 changes: 2 additions & 0 deletions docs/source/SUMMARY.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
# Introduction

- [Introduction](./introduction.md)
- TODO: Concepts
- TODO: How merging works

# Command Line Interface

Expand Down
11 changes: 9 additions & 2 deletions tests/test_batch_transform_scheduling.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pandas as pd
from sqlalchemy import Column, Integer

from datapipe.compute import JoinType
from datapipe.datatable import DataStore
from datapipe.step.batch_transform import BatchTransformStep
from datapipe.store.database import TableStoreDB
Expand Down Expand Up @@ -42,7 +43,9 @@ def id_func(df):
ds=ds,
name="step",
func=id_func,
input_dts=[tbl1],
input_dts=[
JoinType(dt=tbl1, join_type="full"),
],
output_dts=[tbl2],
)

Expand Down Expand Up @@ -148,7 +151,11 @@ def test_aux_input(dbconn) -> None:
ds=ds,
name="step",
func=lambda items1, items2, aux: items1,
input_dts=[tbl_items1, tbl_items2, tbl_aux],
input_dts=[
JoinType(dt=tbl_items1, join_type="full"),
JoinType(dt=tbl_items2, join_type="full"),
JoinType(dt=tbl_aux, join_type="full"),
],
output_dts=[tbl_out],
transform_keys=["id"],
)
Expand Down
Loading

0 comments on commit feb7b9b

Please sign in to comment.