diff --git a/iterpy/__init__.py b/iterpy/__init__.py index bb8d745..53627ed 100644 --- a/iterpy/__init__.py +++ b/iterpy/__init__.py @@ -1 +1,2 @@ -from .iter import Iter # noqa: F401 # type: ignore +from .iter import Iter # type: ignore +from .arr import Arr # type: ignore diff --git a/iterpy/arr.py b/iterpy/arr.py new file mode 100644 index 0000000..0ecf741 --- /dev/null +++ b/iterpy/arr.py @@ -0,0 +1,229 @@ +from __future__ import annotations + +import copy +import multiprocessing +from collections import defaultdict +from functools import reduce +from itertools import islice +from typing import TYPE_CHECKING, Any, Generator, Generic, Sequence, TypeVar, overload + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator + +T = TypeVar("T") +S = TypeVar("S") + + +class Arr(Generic[T]): + def __init__(self, iterable: Iterable[T]) -> None: + self._iter = list(iterable) + self._current_index: int = 0 + + @property + def _iterator(self) -> Iterator[T]: + return iter(self._iter) + + def __bool__(self) -> bool: + return bool(self._iter) + + def __iter__(self) -> Arr[T]: + return self + + def __next__(self) -> T: + try: + item = self._iter[self._current_index] + except IndexError: + raise StopIteration # noqa: B904 + self._current_index += 1 + return item + + @overload + def __getitem__(self, index: int) -> T: ... + @overload + def __getitem__(self, index: slice) -> Arr[T]: ... + + def __getitem__(self, index: int | slice) -> T | Arr[T]: + if isinstance(index, int) and index >= 0: + try: + return next(islice(self._iter, index, index + 1)) + except StopIteration: + raise IndexError("Index out of range") from None + elif isinstance(index, slice): + return Arr(islice(self._iterator, index.start, index.stop, index.step)) + else: + raise KeyError(f"Key must be non-negative integer or slice, not {index}") + + def __repr__(self) -> str: + return f"Arr{self._iter}" + + def __eq__(self, other: object) -> bool: + if not isinstance(other, Arr): + return False + return self._iter == other._iter # type: ignore + + ### Reductions + def reduce(self, func: Callable[[T, T], T]) -> T: + return reduce(func, self._iterator) + + def count(self) -> int: + return sum(1 for _ in self._iterator) + + ### Output + def to_list(self) -> list[T]: + return list(self._iter) + + def to_tuple(self) -> tuple[T, ...]: + return tuple(self._iterator) # pragma: no cover + + def to_consumable(self) -> Iterator[T]: + return iter(self._iterator) # pragma: no cover + + def to_set(self) -> set[T]: + return set(self._iterator) # pragma: no cover + + ### Transformations + def map( # Ignore that it's shadowing a python built-in + self, func: Callable[[T], S] + ) -> Arr[S]: + return Arr(map(func, self._iterator)) + + def pmap(self, func: Callable[[T], S]) -> Arr[S]: + """Parallel map using multiprocessing.Pool + + Not that lambdas are not supported by multiprocessing.Pool.map. + """ + with multiprocessing.Pool() as pool: + return Arr(pool.map(func, self._iterator)) + + def filter(self, func: Callable[[T], bool]) -> Arr[T]: + return Arr(filter(func, self._iterator)) # type: ignore + + def groupby(self, func: Callable[[T], str]) -> Arr[tuple[str, list[T]]]: + groups_with_values: defaultdict[str, list[T]] = defaultdict(list) + + for value in self._iterator: + value_key = func(value) + groups_with_values[value_key].append(value) + + tuples = list(groups_with_values.items()) + return Arr(tuples) + + def take(self, n: int = 1) -> Arr[T]: + return Arr(islice(self._iter, n)) + + def any(self, func: Callable[[T], bool]) -> bool: + return any(func(i) for i in self._iterator) + + def all(self, func: Callable[[T], bool]) -> bool: + return all(func(i) for i in self._iterator) + + def unique(self) -> Arr[T]: + return Arr(set(self._iterator)) + + def unique_by(self, func: Callable[[T], S]) -> Arr[T]: + seen: set[S] = set() + values: list[T] = [] + + for value in self._iterator: + key = func(value) + if key not in seen: + seen.add(key) + values.append(value) + + return Arr(values) + + def enumerate(self) -> Arr[tuple[int, T]]: + return Arr(enumerate(self._iterator)) + + def find(self, func: Callable[[T], bool]) -> T | None: + for value in self._iterator: + if func(value): + return value + return None + + def clone(self) -> Arr[T]: + return copy.deepcopy(self) + + def zip(self, other: Arr[S]) -> Arr[tuple[T, S]]: + return Arr(zip(self, other)) + + ############################################################ + # Auto-generated overloads for flatten # + # Code for generating the following is in _generate_pyi.py # + ############################################################ + # Overloads are technically incompatible, because they use generic S instead of T. However, this is required for the flattening logic to work. + + @overload + def flatten(self: Arr[Iterable[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[Iterable[S] | S]) -> Arr[S]: ... + + # Iterator[S] # noqa: ERA001 + @overload + def flatten(self: Arr[Iterator[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[Iterator[S] | S]) -> Arr[S]: ... + + # tuple[S, ...] # noqa: ERA001 + @overload + def flatten(self: Arr[tuple[S, ...]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[tuple[S, ...] | S]) -> Arr[S]: ... + + # Sequence[S] # noqa: ERA001 + @overload + def flatten(self: Arr[Sequence[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[Sequence[S] | S]) -> Arr[S]: ... + + # list[S] # noqa: ERA001 + @overload + def flatten(self: Arr[list[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[list[S] | S]) -> Arr[S]: ... + + # set[S] # noqa: ERA001 + @overload + def flatten(self: Arr[set[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[set[S] | S]) -> Arr[S]: ... + + # frozenset[S] # noqa: ERA001 + @overload + def flatten(self: Arr[frozenset[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[frozenset[S] | S]) -> Arr[S]: ... + + # Arr[S] # noqa: ERA001 + @overload + def flatten(self: Arr[Arr[S]]) -> Arr[S]: ... + @overload + def flatten(self: Arr[Arr[S] | S]) -> Arr[S]: ... + + # str + @overload + def flatten(self: Arr[str]) -> Arr[str]: ... + @overload + def flatten(self: Arr[str | S]) -> Arr[S]: ... + + # Generic + @overload + def flatten(self: Arr[S]) -> Arr[S]: ... + + def flatten(self) -> Arr[T]: # type: ignore - + depth = 1 + + def walk(node: Any, level: int) -> Generator[T, None, None]: + if (level > depth) or isinstance(node, str): + yield node # type: ignore + return + try: + tree = iter(node) + except TypeError: + yield node + return + else: + for child in tree: + yield from walk(child, level + 1) + + return Arr(walk(self, level=0)) # type: ignore diff --git a/iterpy/test_arr.py b/iterpy/test_arr.py new file mode 100644 index 0000000..b2fb4e7 --- /dev/null +++ b/iterpy/test_arr.py @@ -0,0 +1,192 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from iterpy import Arr + +if TYPE_CHECKING: + from collections.abc import Sequence + + +def test_inexhaustable(): + test_input = [1, 2, 3] + test_iterator = Arr(test_input) + assert test_iterator.to_list() == test_input + assert test_iterator.to_list() == test_input + + +def test_chaining(): + iterator = Arr([1, 2]) + result: list[int] = iterator.filter(lambda x: x % 2 == 0).map(lambda x: x * 2).to_list() + assert result == [4] + + +def test_map(): + iterator = Arr([1, 2]) + result: list[int] = iterator.map(lambda x: x * 2).to_list() + assert result == [2, 4] + + +def multiple_by_2(num: int) -> int: + return num * 2 # pragma: no cover + + +def test_pmap(): + iterator = Arr([1, 2]) + result: list[int] = iterator.pmap(multiple_by_2).to_list() + assert result == [2, 4] + + +def test_filter(): + iterator = Arr([1, 2]) + result: list[int] = iterator.filter(lambda x: x % 2 == 0).to_list() + assert result == [2] + + +def test_reduce(): + iterator = Arr([1, 2]) + result: int = iterator.reduce(lambda x, y: x + y) + assert result == 3 + + +def test_count(): + iterator = Arr([1, 2]) + result: int = iterator.count() + assert result == 2 + + +def test_grouped_filter(): + iterator = Arr([1, 2, 3, 4]) + + def is_even(num: int) -> str: + if num % 2 == 0: + return "even" + return "odd" + + grouped: list[tuple[str, list[int]]] = iterator.groupby(is_even).to_list() + assert grouped == [("odd", [1, 3]), ("even", [2, 4])] + + +def test_getitem(): + test_input = [1, 2, 3] + test_iterator = Arr(test_input) + assert test_iterator[0] == 1 + test_iterator = Arr(test_input) + arr_slice = test_iterator[0:2] + assert arr_slice.to_list() == [1, 2] + + +def test_iteration(): + test_iterator = Arr([1, 2, 3]) + for i in test_iterator.to_consumable(): + assert i in [1, 2, 3] + + +def test_take(): + test_iterator = Arr([1, 2, 3]) + assert test_iterator.take(2).to_list() == [1, 2] + + +def test_any(): + test_iterator = Arr([1, 2, 3]) + assert test_iterator.any(lambda x: x == 2) is True + assert test_iterator.any(lambda x: x == 4) is False + + +def test_all(): + test_iterator = Arr([1, 2, 3]) + assert test_iterator.all(lambda x: x < 4) is True + + test_iterator = Arr([1, 2, 3]) + assert test_iterator.all(lambda x: x < 3) is False + + +def test_unique(): + test_iterator = Arr([1, 2, 2, 3]) + assert test_iterator.unique().to_list() == [1, 2, 3] + + +def test_unique_by(): + test_iterator = Arr([1, 2, 2, 3]) + assert test_iterator.unique_by(lambda x: x % 2).to_list() == [1, 2] + + +def test_enumerate(): + test_iterator = Arr([1, 2, 3]) + assert test_iterator.enumerate().to_list() == [(0, 1), (1, 2), (2, 3)] + + +def test_find(): + test_iterator = Arr([1, 2, 3]) + assert test_iterator.find(lambda x: x == 2) == 2 + assert test_iterator.find(lambda x: x == 4) is None + + +def test_zip(): + iter1 = Arr([1, 2, 3]) + iter2 = Arr(["a", "b", "c"]) + result: Arr[tuple[int, str]] = iter1.zip(iter2) + assert result.to_list() == [(1, "a"), (2, "b"), (3, "c")] + + +def test_flatten(): + test_input: list[list[int]] = [[1, 2], [3, 4]] + iterator = Arr(test_input) + result: Arr[int] = iterator.flatten() + assert result.to_list() == [1, 2, 3, 4] + + +@pytest.mark.benchmark() +def test_benchmark_large_flattening(): + test_input = Arr(range(100_000)).map(lambda x: Arr([x])) + assert test_input.flatten().to_list() == list(range(100_000)) + + +class TestFlattenTypes: + """Intentionally not parametrised so we can manually specify the return type hints and they can be checked by pyright""" + + def test_flatten_generator(self): + iterator: Arr[tuple[int, int]] = Arr((i, i + 1) for i in range(1, 3)) + result: Arr[int] = iterator.flatten() + assert result.to_list() == [1, 2, 2, 3] + + def test_flatten_tuple(self): + iterator: Arr[tuple[int, int]] = Arr((i, i + 1) for i in range(1, 3)) + result: Arr[int] = iterator.flatten() + assert result.to_list() == [1, 2, 2, 3] + + def test_flatten_list(self): + test_input: list[list[int]] = [[1, 2], [3, 4]] + iterator = Arr(test_input) + result: Arr[int] = iterator.flatten() + assert result.to_list() == [1, 2, 3, 4] + + def test_flatten_iterator(self): + test_input: Sequence[Sequence[int]] = [[1, 2], [3, 4]] + iterator = Arr(test_input) + result: Arr[int] = iterator.flatten() + assert result.to_list() == [1, 2, 3, 4] + + def test_flatten_iter_iter(self): + iterator: Arr[int] = Arr([1, 2]) + nested_iter: Arr[Arr[int]] = Arr([iterator]) + unnested_iter: Arr[int] = nested_iter.flatten() + assert unnested_iter.to_list() == [1, 2] + + def test_flatten_str(self): + test_input: list[str] = ["abcd"] + iterator = Arr(test_input) + result: Arr[str] = iterator.flatten() + assert result.to_list() == ["abcd"] + + def test_flatten_includes_primitives(self): + test_input: list[str | list[int] | None] = ["first", [2], None] + result: Arr[int | str | None] = Arr(test_input).flatten() + assert result.to_list() == ["first", 2, None] + + def test_flatten_removes_empty_iterators(self): + test_input: list[list[int]] = [[1], []] + result: Arr[int] = Arr(test_input).flatten() + assert result.to_list() == [1] diff --git a/iterpy/test_iter.py b/iterpy/test_iter.py index a6d34ec..e7efaf7 100644 --- a/iterpy/test_iter.py +++ b/iterpy/test_iter.py @@ -16,6 +16,13 @@ def test_chaining(): assert result == [4] +def test_exhaustable(): + test_input = [1, 2, 3] + test_iterator = Iter(test_input) + assert test_iterator.to_list() == test_input + assert test_iterator.to_list() == [] + + def test_map(): iterator = Iter([1, 2]) result: list[int] = iterator.map(lambda x: x * 2).to_list()