From b91b83bd09fa66721d33e4c8b285c051778eacb4 Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Sat, 31 Aug 2024 14:54:39 +0000 Subject: [PATCH 1/2] added new unpleasent tests --- tests/test_complex_pipeline.py | 152 ++++++++++++++++++++++++++++++++- 1 file changed, 150 insertions(+), 2 deletions(-) diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index e0ea30ab..a599c974 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -1,14 +1,17 @@ +from typing import cast + import pandas as pd from sqlalchemy import Column from sqlalchemy.sql.sqltypes import Integer, String from datapipe.compute import Catalog, Pipeline, Table, build_compute, run_steps from datapipe.datatable import DataStore +from datapipe.step.batch_generate import BatchGenerate from datapipe.step.batch_transform import BatchTransform from datapipe.store.database import TableStoreDB from datapipe.types import IndexDF -from .util import assert_datatable_equal +from .util import assert_datatable_equal, assert_df_equal TEST__ITEM = pd.DataFrame( { @@ -152,7 +155,7 @@ def complex_function( TEST__PIPELINE, TEST__PREDICTION, TEST__KEYPOINT, - idx=pd.DataFrame(columns=["item_id", "pipeline_id"]), + idx=cast(IndexDF, pd.DataFrame(columns=["item_id", "pipeline_id"])), ) run_steps(ds, steps) assert_datatable_equal(ds.get_table("output"), TEST_RESULT) @@ -287,3 +290,148 @@ def train( assert len( ds.get_table("pipeline__is_trained_on__frozen_dataset").get_data() ) == len(TEST__FROZEN_DATASET) * len(TEST__TRAIN_CONFIG) + + +def complex_transform_with_many_recordings_but_take_only_val_items(dbconn, train_N: int, val_N: int): + ds = DataStore(dbconn, create_meta_table=True) + catalog = Catalog( + { + "tbl_image": Table( + store=TableStoreDB( + dbconn, + "tbl_image", + [ + Column("image_id", Integer, primary_key=True), + Column("image__attribute", Integer), + ], + True, + ) + ), + "tbl_subset__has__image": Table( + store=TableStoreDB( + dbconn, + "tbl_subset__has__image", + [ + Column("image_id", Integer, primary_key=True), + Column("subset_id", Integer, primary_key=True), + ], + True, + ) + ), + "tbl_model": Table( + store=TableStoreDB( + dbconn, + "tbl_best_model", + [ + Column("model_id", Integer, primary_key=True), + Column("model__attribute", Integer), + ], + True, + ) + ), + "tbl_prediction": Table( + store=TableStoreDB( + dbconn, + "tbl_prediction", + [ + Column("image_id", Integer, primary_key=True), + Column("model_id", Integer, primary_key=True), + Column("prediction__attribite", Integer), + ], + True, + ) + ), + } + ) + + def gen_tbls(df1, df2, df3): + yield df1, df2, df3 + + test_df__image = pd.DataFrame({ + "image_id": range(train_N+val_N), + "image__attribute": range(train_N+val_N) + }) + test_df__subset__has__image = pd.DataFrame( + { + "image_id": list(range(train_N)) + list(range(train_N, train_N+val_N)), + "subset_id": ["train"] * train_N + ["val"] * val_N, + } + ) + test_df__model = pd.DataFrame( + { + "model_id": [0, 1, 2, 3, 4, 5, 6, 7, 8, 9], + "model__attribute": [100 * model_id for model_id in range(10)] + } + ) + + def inference_model( + df__image: pd.DataFrame, + df__image__has__subset: pd.DataFrame, + df__model: pd.DataFrame, + ): + df__image__has__subset = df__image__has__subset[ + df__image__has__subset["subset_id"] == "val" # FIXME: хочется, чтобы это было снаружи функции + ] + df__image = pd.merge(df__image, df__image__has__subset) + df__prediction = pd.merge(df__image, df__model, how="cross") + df__prediction["prediction__attribite"] = df__prediction["image__attribute"] + df__prediction["model__attribute"] + return df__prediction[["image_id", "model_id", "prediction__attribite"]] + + pipeline = Pipeline( + [ + BatchGenerate( + func=gen_tbls, + outputs=[ + "tbl_image", + "tbl_subset__has__image", + "tbl_model", + ], + kwargs=dict( + df1=test_df__image, + df2=test_df__subset__has__image, + df3=test_df__model, + ), + ), + BatchTransform( + func=inference_model, + inputs=[ + "tbl_image", # ["image_id"] + "tbl_subset__has__image", # ["image_id", "subset_id"] + "tbl_model" # ["model_id"] + ], + outputs=["tbl_prediction"], # ["image_id", "model_id", "prediction__attribite"] + transform_keys=["image_id", "model_id"], + chunk_size=100 + ), + ] + ) + steps = build_compute(ds, catalog, pipeline) + run_steps(ds, steps) + test__df_prediction = pd.DataFrame( + { + "image_id": list(range(train_N, train_N+val_N)) * 10, + "model_id": [model_id for model_id in range(10) for _ in range(val_N)], + "prediction__attribite": [ + (model_id * 100) + x + for model_id in range(10) + for x in range(train_N, train_N+val_N) + ] + } + ) + assert_df_equal( + ds.get_table("tbl_prediction").get_data(), + test__df_prediction, + index_cols=["image_id", "model_id"], + ) + + +def test_complex_transform_with_many_recordings_but_take_only_val_items_trainN100_valN100(dbconn): + complex_transform_with_many_recordings_but_take_only_val_items(dbconn, train_N=100, val_N=100) + + +def test_complex_transform_with_many_recordings_but_take_only_val_items_trainN10000_valN1000(dbconn): + complex_transform_with_many_recordings_but_take_only_val_items(dbconn, train_N=10000, val_N=1000) + + +def test_complex_transform_with_many_recordings_but_take_only_val_items_trainN100000_valN1000(dbconn): + complex_transform_with_many_recordings_but_take_only_val_items(dbconn, train_N=100000, val_N=1000) From 706a02b3878dfdade502932c4143936503ab2671 Mon Sep 17 00:00:00 2001 From: Alexander Kozlov Date: Sat, 31 Aug 2024 14:59:10 +0000 Subject: [PATCH 2/2] fix subset_id string type --- tests/test_complex_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_complex_pipeline.py b/tests/test_complex_pipeline.py index a599c974..2db09190 100644 --- a/tests/test_complex_pipeline.py +++ b/tests/test_complex_pipeline.py @@ -313,7 +313,7 @@ def complex_transform_with_many_recordings_but_take_only_val_items(dbconn, train "tbl_subset__has__image", [ Column("image_id", Integer, primary_key=True), - Column("subset_id", Integer, primary_key=True), + Column("subset_id", String, primary_key=True), ], True, )