Skip to content

Commit

Permalink
Linting
Browse files Browse the repository at this point in the history
  • Loading branch information
bruno-f-cruz committed Jul 24, 2024
1 parent e4b8365 commit 265ce73
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 58 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from setuptools import setup

if __name__ == "__main__":
setup()
setup()
66 changes: 44 additions & 22 deletions src/aind_behavior_core_analysis/io/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,22 @@
Any,
Generic,
List,
NamedTuple,
Optional,
Protocol,
Sequence,
Tuple,
Type,
TypeVar,
overload,
runtime_checkable,
)

from aind_behavior_core_analysis.io._utils import StrPattern, validate_str_pattern
from aind_behavior_core_analysis.io._utils import StrPattern

TData = TypeVar("TData", bound=Any)


class DataStream(abc.ABC, Generic[TData]):

def __init__(
self,
/,
Expand All @@ -35,7 +34,6 @@ def __init__(
_data: Optional[TData] = None,
**kwargs,
) -> None:

self._auto_load = auto_load
self._data = _data

Expand Down Expand Up @@ -90,7 +88,6 @@ def data(self) -> TData:
return self._data

def load(self, /, path: Optional[PathLike] = None, *, force_reload: bool = False, **kwargs) -> TData:

if force_reload is False and self._data:
pass
else:
Expand All @@ -108,11 +105,13 @@ def __str__(self) -> str:

@runtime_checkable
class _DataStreamSourceBuilder(Protocol):
def build(self, /, source: Optional[DataStreamSource] = None, **kwargs) -> StreamCollection:
...

def build(self, /, source: Optional[DataStreamSource] = None, **kwargs) -> StreamCollection: ...


_SequenceDataStreamBuilderPattern = Sequence[Tuple[Type[DataStream], StrPattern]]
class DataStreamBuilderPattern(NamedTuple):
stream_type: Type[DataStream]
pattern: StrPattern


class DataStreamSource:
Expand All @@ -129,19 +128,34 @@ def __init__(
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None: ...
) -> None:
...

@overload
def __init__(
self,
/,
path: PathLike,
builder: DataStreamBuilderPattern,
*,
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None:
...

@overload
def __init__(
self,
/,
path: PathLike,
builder: _SequenceDataStreamBuilderPattern,
builder: Sequence[DataStreamBuilderPattern],
*,
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None: ...
) -> None:
...

@overload
def __init__(
Expand All @@ -153,7 +167,8 @@ def __init__(
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None: ...
) -> None:
...

@overload
def __init__(
Expand All @@ -165,19 +180,23 @@ def __init__(
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None: ...
) -> None:
...

def __init__(
self,
/,
path: PathLike,
builder: None | Type[DataStream] | _SequenceDataStreamBuilderPattern | _DataStreamSourceBuilder = None,
builder: None
| Type[DataStream]
| DataStreamBuilderPattern
| Sequence[DataStreamBuilderPattern]
| _DataStreamSourceBuilder = None,
*,
name: Optional[str] = None,
auto_load: bool = False,
**kwargs,
) -> None:

self._streams: StreamCollection
self._path = Path(path)

Expand Down Expand Up @@ -206,15 +225,18 @@ def __init__(
self.reload_streams()

@staticmethod
def _normalize_builder_from_data_stream(builder: Type[DataStream] | Sequence) -> _SequenceDataStreamBuilderPattern:
def _normalize_builder_from_data_stream(
builder: Type[DataStream] | DataStreamBuilderPattern | Sequence
) -> Sequence[DataStreamBuilderPattern]:
_builder: Sequence
if isinstance(builder, type(DataStream)): # If only a single data stream class is provided
_builder = ((builder, "*"),)
_builder = (DataStreamBuilderPattern(stream_type=builder, pattern="*"),)
if isinstance(builder, DataStreamBuilderPattern): # If only a single data stream class is provided
_builder = (builder,)

for _tuple in _builder:
if not isinstance(_tuple[0], type(DataStream)):
if not isinstance(_tuple.stream_type, type(DataStream)):
raise ValueError("builder must be a DataStream type")
validate_str_pattern(_tuple[1])
return _builder

@property
Expand Down Expand Up @@ -246,10 +268,10 @@ def _get_data_streams_helper(
return streams

@classmethod
def _build_from_data_stream(cls, path: PathLike, builder: _SequenceDataStreamBuilderPattern) -> StreamCollection:
def _build_from_data_stream(cls, path: PathLike, builder: Sequence[DataStreamBuilderPattern]) -> StreamCollection:
streams = StreamCollection()
for stream_type, pattern in builder:
_this_type_stream = cls._get_data_streams_helper(path, stream_type, pattern)
for stream_builder in builder:
_this_type_stream = cls._get_data_streams_helper(path, stream_builder.stream_type, stream_builder.pattern)
for stream in _this_type_stream:
if stream.name is None:
raise ValueError(f"Stream {stream} does not have a name")
Expand Down
28 changes: 2 additions & 26 deletions src/aind_behavior_core_analysis/io/_utils.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,3 @@
import re
from typing import Sequence, Union
from typing import List, Union

StrPattern = Union[str, Sequence[str]]


def validate_str_pattern(pattern: StrPattern) -> None:
"""
Validates a string pattern or a sequence of string patterns.
Args:
pattern (StrPattern): The string pattern or sequence of string patterns to validate.
Raises:
re.error: If any of the patterns is not a valid regex pattern.
Returns:
None
"""
if isinstance(pattern, Sequence):
for pat in pattern:
validate_str_pattern(pat)
else:
try:
re.compile(pattern)
except re.error as err:
raise re.error(f"Pattern {pattern} is not a valid regex pattern") from err
StrPattern = Union[str, List[str]]
11 changes: 2 additions & 9 deletions src/aind_behavior_core_analysis/io/data_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,16 +121,13 @@ def _reader(
col_names: Optional[List[str]] = None,
**kwargs,
) -> DataFrameOrSeries:

has_header = csv.Sniffer().has_header(value)
_header = 0 if has_header is True else None
df = pd.read_csv(io.StringIO(value), header=_header, index_col=infer_index_col, names=col_names)
return df


class SingletonStream(DataStream[str | BaseModel]):
"""Represents a generic Software event."""

def __init__(
self,
/,
Expand Down Expand Up @@ -174,7 +171,6 @@ def _apply_inner_parser(self, value: Optional[str | BaseModel]) -> str | BaseMod


class HarpDataStream(DataStream[DataFrameOrSeries]):

def __init__(
self,
/,
Expand Down Expand Up @@ -221,7 +217,6 @@ def load(self, /, path: Optional[PathLike] = None, *, force_reload: bool = False
def _bin_file_inference_helper(
root_path: PathLike, register_reader: harp.reader.RegisterReader, name_hint: Optional[str] = None
) -> Path:

root_path = Path(root_path)
candidate_files = list(root_path.glob(f"*_{register_reader.register.address}.bin"))

Expand All @@ -244,7 +239,6 @@ def _bin_file_inference_helper(


class HarpDataStreamSourceBuilder(_DataStreamSourceBuilder):

_reader_default_params = {
"include_common_registers": True,
"keep_type": True,
Expand All @@ -258,17 +252,16 @@ def __init__(
device_hint: Optional[DeviceReader | WhoAmI | PathLike] = None,
default_inference_mode: _available_inference_modes = "yml",
) -> None:

self.device_hint = device_hint
self.default_inference_mode = default_inference_mode

@overload
def build(self, /, source: Optional[DataStreamSource] = None, **kwargs) -> StreamCollection: ...
def build(self, /, source: Optional[DataStreamSource] = None, **kwargs) -> StreamCollection:
...

def build(
self, /, source: Optional[DataStreamSource] = None, *, path: Optional[PathLike] = None, **kwargs
) -> StreamCollection:

# Leaving this undocumented here for now...
device_hint = kwargs.get("device_hint", self.device_hint)
default_inference_mode = kwargs.get("default_inference_mode", self.default_inference_mode)
Expand Down

0 comments on commit 265ce73

Please sign in to comment.