Skip to content

Commit

Permalink
mypy fixs + add IndexDF support
Browse files Browse the repository at this point in the history
  • Loading branch information
bobokvsky committed Aug 16, 2024
1 parent 526b1ab commit 50c8f9f
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
11 changes: 7 additions & 4 deletions datapipe/step/batch_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,15 +545,18 @@ def _apply_filters_to_run_config(
if isinstance(self.filters, list) and all([isinstance(x, dict) for x in self.filters]):
filters = self.filters
elif isinstance(self.filters, Callable): # type: ignore
filters_func = cast(Callable[..., List[LabelDict]], self.filters)
filters_func = cast(Callable[..., Union[List[LabelDict], IndexDF]], self.filters)
parameters = inspect.signature(filters_func).parameters
kwargs = {
**({"ds": ds} if "ds" in parameters else {}),
**({"run_config": run_config} if "run_config" in parameters else {})
}
filters = filters_func(**kwargs)

if isinstance(self.filters, str):
filters_res = filters_func(**kwargs)
if isinstance(filters_res, pd.DataFrame):
filters = cast(List[LabelDict], filters_res.to_dict(orient="records"))
else:
filters = filters_res
elif isinstance(self.filters, str):
dt = ds.get_table(self.filters)
df = dt.get_data()
filters = cast(List[LabelDict], df[dt.primary_keys].to_dict(orient="records"))
Expand Down
4 changes: 2 additions & 2 deletions datapipe/store/filedir.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import re
from abc import ABC
from pathlib import Path
from typing import IO, Any, Dict, Iterator, List, Optional, Union, cast
from typing import IO, Any, Dict, Iterator, List, Optional, Union, cast, Set

import fsspec
import numpy as np
Expand Down Expand Up @@ -492,7 +492,7 @@ def read_rows_meta_pseudo_df(
ids: Dict[str, List[str]] = {attrname: [] for attrname in self.attrnames}
ukeys = []
filepaths = []
looked_keys = set()
looked_keys: Set[Any] = set()

for f in files:
for filemath_match_suffix in self.filename_match_suffixes:
Expand Down
3 changes: 1 addition & 2 deletions datapipe/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@
TransformResult = Union[DataDF, List[DataDF], Tuple[DataDF, ...]]

LabelDict = Dict[str, Any]

Filters = Union[str, List[LabelDict], Callable[..., List[LabelDict]]]
Filters = Union[str, IndexDF, List[LabelDict], Callable[..., List[LabelDict]], Callable[..., IndexDF]]
try:
from sqlalchemy.orm import DeclarativeBase

Expand Down

0 comments on commit 50c8f9f

Please sign in to comment.