Skip to content

Commit

Permalink
add tests, part 1
Browse files Browse the repository at this point in the history
  • Loading branch information
bobokvsky committed Aug 16, 2024
1 parent 50c8f9f commit 0008bf9
Show file tree
Hide file tree
Showing 4 changed files with 223 additions and 18 deletions.
31 changes: 18 additions & 13 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,15 @@ def _apply_filters_to_run_config(
return run_config
else:
filters: List[LabelDict]
if isinstance(self.filters, list) and all([isinstance(x, dict) for x in self.filters]):
if isinstance(self.filters, str):
dt = ds.get_table(self.filters)
df = dt.get_data()
filters = cast(List[LabelDict], df[dt.primary_keys].to_dict(orient="records"))
elif isinstance(self.filters, DataTable):
filters = cast(List[LabelDict], self.filters.get_data().to_dict(orient="records"))
elif isinstance(self.filters, pd.DataFrame):
filters = cast(List[LabelDict], self.filters.to_dict(orient="records"))
elif isinstance(self.filters, list) and all([isinstance(x, dict) for x in self.filters]):
filters = self.filters
elif isinstance(self.filters, Callable): # type: ignore
filters_func = cast(Callable[..., Union[List[LabelDict], IndexDF]], self.filters)
Expand All @@ -556,10 +564,6 @@ def _apply_filters_to_run_config(
filters = cast(List[LabelDict], filters_res.to_dict(orient="records"))
else:
filters = filters_res
elif isinstance(self.filters, str):
dt = ds.get_table(self.filters)
df = dt.get_data()
filters = cast(List[LabelDict], df[dt.primary_keys].to_dict(orient="records"))

if run_config is None:
return RunConfig(filters=filters)
Expand Down Expand Up @@ -627,24 +631,25 @@ def get_full_process_ids(
)

# Список ключей из фильтров, которые нужно добавить в результат
extra_filters: LabelDict
extra_filters: Optional[List[str, Dict]] = None
if run_config is not None:
extra_filters = {
extra_filters = [{
k: v
for filter in run_config.filters
for k, v in filter.items()
if k not in join_keys
}
else:
extra_filters = {}
} for filter in run_config.filters]

def alter_res_df():
with ds.meta_dbconn.con.begin() as con:
for df in pd.read_sql_query(u1, con=con, chunksize=chunk_size):
df = df[self.transform_keys]

for k, v in extra_filters.items():
df[k] = v
if extra_filters is not None:
df__extra_filters = pd.DataFrame(extra_filters)
if set(df__extra_filters.columns).intersection(df.columns):
df = pd.merge(df, df__extra_filters)
else:
df = pd.merge(df, df__extra_filters, how="cross")

yield df

Expand Down
4 changes: 1 addition & 3 deletions datapipe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@
Dict,
List,
NewType,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
)

import pandas as pd
from sqlalchemy import Column

Expand All @@ -44,7 +42,7 @@
TransformResult = Union[DataDF, List[DataDF], Tuple[DataDF, ...]]

LabelDict = Dict[str, Any]
Filters = Union[str, IndexDF, List[LabelDict], Callable[..., List[LabelDict]], Callable[..., IndexDF]]
Filters = Union[str, "DataTable", IndexDF, List[LabelDict], Callable[..., List[LabelDict]], Callable[..., IndexDF]]
try:
from sqlalchemy.orm import DeclarativeBase

Expand Down
91 changes: 90 additions & 1 deletion tests/test_complex_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps
from datapipe.datatable import DataStore
from datapipe.step.batch_generate import BatchGenerate
from datapipe.step.batch_transform import BatchTransform
from datapipe.store.database import TableStoreDB
from datapipe.types import IndexDF

from .util import assert_datatable_equal
from .util import assert_datatable_equal, assert_df_equal

TEST__ITEM = pd.DataFrame(
{
Expand Down Expand Up @@ -287,3 +288,91 @@ def train(
assert len(
ds.get_table("pipeline__is_trained_on__frozen_dataset").get_data()
) == len(TEST__FROZEN_DATASET) * len(TEST__TRAIN_CONFIG)


def test_complex_transform_with_filters(dbconn):
ds = DataStore(dbconn, create_meta_table=True)
catalog = Catalog({
"tbl_image": Table(
store=TableStoreDB(
dbconn,
"tbl_image",
[
Column("image_id", Integer, primary_key=True),
],
True
)
),
"tbl_prediction": Table(
store=TableStoreDB(
dbconn,
"tbl_prediction",
[
Column("image_id", Integer, primary_key=True),
Column("model_id", Integer, primary_key=True),
],
True
)
),
"tbl_output": Table(
store=TableStoreDB(
dbconn,
"tbl_output",
[
Column("model_id", Integer, primary_key=True),
Column("count", Integer),
],
True
)
)
})

def gen_tbl(df):
yield df

test_df__image = pd.DataFrame(
{"image_id": [0, 1, 2, 3]}
)
test_df__prediction = pd.DataFrame({
"image_id": [0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3],
"model_id": [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
})

def count_func(df__image: pd.DataFrame, df__prediction: pd.DataFrame):
df__image = pd.merge(df__image, df__prediction, on=["image_id"])
print(f"{df__image=}")
print(f"{df__prediction=}")
df__output = df__image.groupby("model_id").agg(len).reset_index().rename(columns={"image_id": "count"})
print(f"{df__output=}")
return df__output

pipeline = Pipeline(
[
BatchGenerate(
func=gen_tbl,
outputs=["tbl_image"],
kwargs=dict(df=test_df__image),
),
BatchGenerate(
func=gen_tbl,
outputs=["tbl_prediction"],
kwargs=dict(df=test_df__prediction),
),
BatchTransform(
func=count_func,
inputs=["tbl_image", "tbl_prediction"],
outputs=["tbl_output"],
transform_keys=["model_id"],
chunk_size=6,
# filters=[{"image_id": 0}, {"image_id": 1}, {"image_id": 2}]
),
]
)
steps = build_compute(ds, catalog, pipeline)
run_steps(ds, steps)
print(ds.get_table("tbl_output").get_data())
# assert_df_equal(
# ds.get_table("tbl_output").get_data(),
# count_func(test_df__image, test_df__prediction),
# index_cols=["model_id"]
# )
115 changes: 114 additions & 1 deletion tests/test_core_steps2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# import pytest

import time
from typing import Optional, cast

import pandas as pd
from sqlalchemy import Column, String
Expand All @@ -14,7 +15,7 @@
from datapipe.step.batch_generate import do_batch_generate
from datapipe.step.batch_transform import BatchTransformStep
from datapipe.store.database import MetaKey, TableStoreDB
from datapipe.types import ChangeList, IndexDF
from datapipe.types import ChangeList, Filters, IndexDF

from .util import assert_datatable_equal, assert_df_equal

Expand Down Expand Up @@ -387,3 +388,115 @@ def update_df(products: pd.DataFrame, items: pd.DataFrame):
items2_df = merged_df[["item_id", "pipeline_id", "product_id", "a"]]

assert_df_equal(items2.get_data(), items2_df, index_cols=["item_id", "pipeline_id"])


PRODUCTS_DF = pd.DataFrame(
{
"product_id": list(range(2)),
"pipeline_id": list(range(2)),
"b": range(10, 12),
}
)

ITEMS_DF = pd.DataFrame(
{
"item_id": list(range(5)) * 2,
"pipeline_id": list(range(2)) * 5,
"product_id": list(range(2)) * 5,
"a": range(10),
}
)


def batch_transform_with_filters(dbconn, filters: Filters, ds: Optional[DataStore] = None):
if ds is None:
ds = DataStore(dbconn, create_meta_table=True)

products = ds.create_table(
"products",
table_store=TableStoreDB(dbconn, "products_data", PRODUCTS_SCHEMA, True),
)

items = ds.create_table(
"items", table_store=TableStoreDB(dbconn, "items_data", ITEMS_SCHEMA, True)
)

items2 = ds.create_table(
"items2", table_store=TableStoreDB(dbconn, "items2_data", ITEMS_SCHEMA, True)
)

products.store_chunk(PRODUCTS_DF, now=0)
items.store_chunk(ITEMS_DF, now=0)

def update_df(products: pd.DataFrame, items: pd.DataFrame):
merged_df = pd.merge(items, products, on=["product_id", "pipeline_id"])
merged_df["a"] = merged_df.apply(lambda x: x["a"] + x["b"], axis=1)

return merged_df[["item_id", "pipeline_id", "product_id", "a"]]

step = BatchTransformStep(
ds=ds,
name="test",
func=update_df,
input_dts=[products, items],
output_dts=[items2],
filters=filters
)

step.run_full(ds)

merged_df = pd.merge(ITEMS_DF, PRODUCTS_DF, on=["product_id", "pipeline_id"])
merged_df["a"] = merged_df.apply(lambda x: x["a"] + x["b"], axis=1)

items2_df = merged_df[["item_id", "pipeline_id", "product_id", "a"]]
items2_df = items2_df[items2_df["item_id"].isin([0, 1, 2])]

assert_df_equal(items2.get_data(), items2_df, index_cols=["item_id", "pipeline_id"])


def test_batch_transform_with_filters_as_str(dbconn):
ds = DataStore(dbconn, create_meta_table=True)
filters_data = pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}])
filters = ds.create_table(
"filters_data", table_store=TableStoreDB(
dbconn, "filters_data", [Column("item_id", Integer, primary_key=True)], True
)
)
filters.store_chunk(filters_data, now=0)
batch_transform_with_filters(dbconn, filters="filters_data", ds=ds)


def test_batch_transform_with_filters_as_datatable(dbconn):
ds = DataStore(dbconn, create_meta_table=True)
filters_data = pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}])
filters = ds.create_table(
"filters_data", table_store=TableStoreDB(
dbconn, "filters_data", [Column("item_id", Integer, primary_key=True)], True
)
)
filters.store_chunk(filters_data, now=0)
batch_transform_with_filters(dbconn, filters=filters)


def test_batch_transform_with_filters_as_IndexDF(dbconn):
batch_transform_with_filters(
dbconn, filters=cast(IndexDF, pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}]))
)


def test_batch_transform_with_filters_as_list_of_dict(dbconn):
batch_transform_with_filters(dbconn, filters=[{"item_id": 0}, {"item_id": 1}, {"item_id": 2}])


def test_batch_transform_with_filters_as_callable_IndexDF(dbconn):
def callable(ds: DataStore, run_config: Optional[RunConfig]):
return cast(IndexDF, pd.DataFrame([{"item_id": 0}, {"item_id": 1}, {"item_id": 2}]))

batch_transform_with_filters(dbconn, filters=callable)


def test_batch_transform_with_filters_as_callable_list_of_dict(dbconn):
def callable(ds: DataStore, run_config: Optional[RunConfig]):
return [{"item_id": 0}, {"item_id": 1}, {"item_id": 2}]

batch_transform_with_filters(dbconn, filters=callable)

0 comments on commit 0008bf9

Please sign in to comment.