Skip to content

Commit

Permalink
Merge pull request #333 from epoch8/elephantum/issue332
Browse files Browse the repository at this point in the history
Добавить возможность напрямую указать SQLA таблицу вместо схемы в TableStoreDB
  • Loading branch information
elephantum authored Aug 11, 2024
2 parents 4ce1688 + 0a44cde commit 40589ef
Show file tree
Hide file tree
Showing 18 changed files with 450 additions and 265 deletions.
2 changes: 1 addition & 1 deletion .vscode/tasks.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
{
"label": "mypy-whole-project",
"type": "shell",
"command": "source .venv/bin/activate; mypy -p datapipe --show-column-numbers --show-error-codes --ignore-missing-imports --namespace-packages",
"command": "poetry run mypy -p datapipe --show-column-numbers --show-error-codes --ignore-missing-imports --namespace-packages",
"presentation": {
"echo": true,
"reveal": "never",
Expand Down
46 changes: 43 additions & 3 deletions datapipe/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from datapipe.datatable import DataStore, DataTable
from datapipe.executor import Executor, ExecutorConfig
from datapipe.run_config import RunConfig
from datapipe.store.database import TableStoreDB
from datapipe.store.table_store import TableStore
from datapipe.types import ChangeList, IndexDF, Labels
from datapipe.types import ChangeList, IndexDF, Labels, TableOrName

logger = logging.getLogger("datapipe.compute")
tracer = trace.get_tracer("datapipe.compute")
Expand All @@ -19,6 +20,7 @@
@dataclass
class Table:
store: TableStore
name: Optional[str] = None


class Catalog:
Expand All @@ -36,8 +38,46 @@ def init_all_tables(self, ds: DataStore):
for name in self.catalog.keys():
self.get_datatable(ds, name)

def get_datatable(self, ds: DataStore, name: str) -> DataTable:
return ds.get_or_create_table(name=name, table_store=self.catalog[name].store)
def get_datatable(self, ds: DataStore, table: TableOrName) -> DataTable:
if isinstance(table, str):
assert table in self.catalog, f"Table {table} not found in catalog"
return ds.get_or_create_table(
name=table, table_store=self.catalog[table].store
)

elif isinstance(table, Table):
assert table.name is not None, f"Table name must be specified for {table}"

if table.name not in self.catalog:
self.add_datatable(table.name, table)
else:
existing_table = self.catalog[table.name]
assert existing_table.store == table.store, (
f"Table {table.name} already exists in catalog "
f"with different store {existing_table.store}"
)

return ds.get_or_create_table(name=table.name, table_store=table.store)

else:
table_store = TableStoreDB(ds.meta_dbconn, orm_table=table)
if table_store.name not in self.catalog:
self.add_datatable(table_store.name, Table(store=table_store))
else:
existing_table_store = self.catalog[table_store.name].store
assert isinstance(existing_table_store, TableStoreDB), (
f"Table {table_store.name} already exists in catalog "
f"with different store {existing_table_store}"
)

assert existing_table_store.data_table == table.__table__, ( # type: ignore
f"Table {table_store.name} already exists in catalog "
f"with different orm_table {existing_table_store.data_table}"
)

return ds.get_or_create_table(
name=table_store.name, table_store=table_store
)


@dataclass
Expand Down
11 changes: 10 additions & 1 deletion datapipe/datatable.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,16 @@ def store_chunk(

with tracer.start_as_current_span("store metadata"):
self.meta_table.update_rows(
cast(MetadataDF, pd.concat([new_meta_df, changed_meta_df]))
cast(
MetadataDF,
pd.concat(
[
df
for df in [new_meta_df, changed_meta_df]
if not df.empty
]
),
)
)

if not new_df.empty:
Expand Down
4 changes: 2 additions & 2 deletions datapipe/step/batch_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
DatatableTransformFunc,
DatatableTransformStep,
)
from datapipe.types import Labels, TransformResult, cast
from datapipe.types import Labels, TableOrName, TransformResult, cast

logger = logging.getLogger("datapipe.step.batch_generate")
tracer = trace.get_tracer("datapipe.step.batch_generate")
Expand Down Expand Up @@ -88,7 +88,7 @@ def do_batch_generate(
@dataclass
class BatchGenerate(PipelineStep):
func: BatchGenerateFunc
outputs: List[str]
outputs: List[TableOrName]
kwargs: Optional[Dict] = None
labels: Optional[Labels] = None
delete_stale: bool = True
Expand Down
17 changes: 11 additions & 6 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
IndexDF,
Labels,
MetaSchema,
TableOrName,
TransformResult,
data_to_index,
)
Expand Down Expand Up @@ -725,12 +726,16 @@ def store_batch_err(
) -> None:
run_config = self._apply_filters_to_run_config(run_config)

logger.error(f"Process batch failed: {str(e)}")
idx_records = idx.to_dict(orient="records")

logger.error(
f"Process batch in transform {self.name} on idx {idx_records} failed: {str(e)}"
)
ds.event_logger.log_exception(
e,
run_config=RunConfig.add_labels(
run_config,
{"idx": idx.to_dict(orient="records"), "process_ts": process_ts},
{"idx": idx_records, "process_ts": process_ts},
),
)

Expand Down Expand Up @@ -901,8 +906,8 @@ def run_idx(
@dataclass
class DatatableBatchTransform(PipelineStep):
func: DatatableBatchTransformFunc
inputs: List[str]
outputs: List[str]
inputs: List[TableOrName]
outputs: List[TableOrName]
chunk_size: int = 1000
transform_keys: Optional[List[str]] = None
kwargs: Optional[Dict] = None
Expand Down Expand Up @@ -971,8 +976,8 @@ def process_batch_dts(
@dataclass
class BatchTransform(PipelineStep):
func: BatchTransformFunc
inputs: List[str]
outputs: List[str]
inputs: List[TableOrName]
outputs: List[TableOrName]
chunk_size: int = 1000
kwargs: Optional[Dict[str, Any]] = None
transform_keys: Optional[List[str]] = None
Expand Down
9 changes: 4 additions & 5 deletions datapipe/step/datatable_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from datapipe.datatable import DataStore, DataTable
from datapipe.executor import Executor
from datapipe.run_config import RunConfig
from datapipe.types import Labels
from datapipe.types import Labels, TableOrName

logger = logging.getLogger("datapipe.step.datatable_transform")
tracer = trace.get_tracer("datapipe.step.datatable_transform")
Expand All @@ -25,8 +25,7 @@ def __call__(
run_config: Optional[RunConfig],
# Возможно, лучше передавать как переменную, а не **
**kwargs,
) -> None:
...
) -> None: ...


class DatatableTransformStep(ComputeStep):
Expand Down Expand Up @@ -92,8 +91,8 @@ def run_full(
@dataclass
class DatatableTransform(PipelineStep):
func: DatatableTransformFunc
inputs: List[str]
outputs: List[str]
inputs: List[TableOrName]
outputs: List[TableOrName]
check_for_changes: bool = True
kwargs: Optional[Dict[str, Any]] = None
labels: Optional[Labels] = None
Expand Down
9 changes: 6 additions & 3 deletions datapipe/step/update_external_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
DatatableTransformFunc,
DatatableTransformStep,
)
from datapipe.types import Labels, MetadataDF, cast
from datapipe.types import Labels, MetadataDF, TableOrName, cast

logger = logging.getLogger("datapipe.step.update_external_table")

Expand All @@ -35,7 +35,10 @@ def update_external_table(
# TODO switch to iterative store_chunk and table.sync_meta_by_process_ts

table.meta_table.update_rows(
cast(MetadataDF, pd.concat([new_meta_df, changed_meta_df]))
cast(
MetadataDF,
pd.concat(df for df in [new_meta_df, changed_meta_df] if not df.empty),
),
)

for stale_idx in table.meta_table.get_stale_idx(now, run_config=run_config):
Expand All @@ -55,7 +58,7 @@ def update_external_table(
class UpdateExternalTable(PipelineStep):
def __init__(
self,
output: str,
output: TableOrName,
labels: Optional[Labels] = None,
) -> None:
self.output_table_name = output
Expand Down
77 changes: 59 additions & 18 deletions datapipe/store/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from datapipe.run_config import RunConfig
from datapipe.sql_util import sql_apply_idx_filter_to_table, sql_apply_runconfig_filter
from datapipe.store.table_store import TableStore
from datapipe.types import DataDF, DataSchema, IndexDF, MetaSchema, TAnyDF
from datapipe.types import DataDF, DataSchema, IndexDF, MetaSchema, OrmTable, TAnyDF

logger = logging.getLogger("datapipe.store.database")
tracer = trace.get_tracer("datapipe.store.database")
Expand All @@ -27,15 +27,17 @@ def __init__(
connstr: str,
schema: Optional[str] = None,
create_engine_kwargs: Optional[Dict[str, Any]] = None,
sqla_metadata: Optional[MetaData] = None,
):
create_engine_kwargs = create_engine_kwargs or {}
self._init(connstr, schema, create_engine_kwargs)
self._init(connstr, schema, create_engine_kwargs, sqla_metadata)

def _init(
self,
connstr: str,
schema: Optional[str],
create_engine_kwargs: Dict[str, Any],
sqla_metadata: Optional[MetaData] = None,
) -> None:
self.connstr = connstr
self.schema = schema
Expand Down Expand Up @@ -75,11 +77,14 @@ def _init(
poolclass=QueuePool,
pool_pre_ping=True,
pool_recycle=3600,
**create_engine_kwargs
**create_engine_kwargs,
# pool_size=25,
)

self.sqla_metadata = MetaData(schema=schema)
if sqla_metadata is None:
self.sqla_metadata = MetaData(schema=schema)
else:
self.sqla_metadata = sqla_metadata

def __reduce__(self) -> Tuple[Any, ...]:
return self.__class__, (
Expand Down Expand Up @@ -119,28 +124,64 @@ class TableStoreDB(TableStore):
def __init__(
self,
dbconn: Union["DBConn", str],
name: str,
data_sql_schema: List[Column],
name: Optional[str] = None,
data_sql_schema: Optional[List[Column]] = None,
create_table: bool = False,
orm_table: Optional[OrmTable] = None,
) -> None:
if isinstance(dbconn, str):
self.dbconn = DBConn(dbconn)
else:
self.dbconn = dbconn
self.name = name

self.data_sql_schema = data_sql_schema

self.data_keys = [
column.name for column in self.data_sql_schema if not column.primary_key
]
if orm_table is not None:
assert name is None, "name should be None if orm_table is provided"
assert (
data_sql_schema is None
), "data_sql_schema should be None if orm_table is provided"

orm_table__table = orm_table.__table__ # type: ignore
self.data_table = cast(Table, orm_table__table)

self.name = self.data_table.name

self.data_sql_schema = [
Column(
column.name,
column.type,
primary_key=column.primary_key,
nullable=column.nullable,
unique=column.unique,
*column.constraints,
)
for column in self.data_table.columns
]
self.data_keys = [
column.name for column in self.data_sql_schema if not column.primary_key
]

self.data_table = Table(
self.name,
self.dbconn.sqla_metadata,
*[copy.copy(i) for i in self.data_sql_schema],
extend_existing=True,
)
else:
assert (
name is not None
), "name should be provided if data_table is not provided"
assert (
data_sql_schema is not None
), "data_sql_schema should be provided if data_table is not provided"

self.name = name

self.data_sql_schema = data_sql_schema

self.data_keys = [
column.name for column in self.data_sql_schema if not column.primary_key
]

self.data_table = Table(
self.name,
self.dbconn.sqla_metadata,
*[copy.copy(i) for i in self.data_sql_schema],
extend_existing=True,
)

if create_table:
self.data_table.create(self.dbconn.con, checkfirst=True)
Expand Down
2 changes: 1 addition & 1 deletion datapipe/store/filedir.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _pattern_to_match(pat: str) -> str:
class Replacer:
def __init__(self, values: List[str]):
self.counter = -1
self.values = values
self.values = list(values)

def __call__(self, matchobj):
self.counter += 1
Expand Down
Loading

0 comments on commit 40589ef

Please sign in to comment.