diff --git a/functionalpy/_sequence.py b/functionalpy/_sequence.py index 950da52..0aa0c8f 100644 --- a/functionalpy/_sequence.py +++ b/functionalpy/_sequence.py @@ -4,18 +4,18 @@ from functools import reduce from typing import Generic, TypeVar -_T0 = TypeVar("_T0") -_T1 = TypeVar("_T1") +_S = TypeVar("_S") +_T = TypeVar("_T") @dataclass(frozen=True) -class Group(Generic[_T0]): +class Group(Generic[_T]): key: str - value: _T0 + value: _T -class Seq(Generic[_T0]): - def __init__(self, iterable: Iterable[_T0]): +class Seq(Generic[_T]): + def __init__(self, iterable: Iterable[_T]): self._seq = iterable ### Reductions @@ -23,34 +23,34 @@ def count(self) -> int: return sum(1 for _ in self._seq) ### Output - def to_list(self) -> list[_T0]: + def to_list(self) -> list[_T]: return list(self._seq) - def to_tuple(self) -> tuple[_T0, ...]: + def to_tuple(self) -> tuple[_T, ...]: return tuple(self._seq) - def to_iter(self) -> Iterator[_T0]: + def to_iter(self) -> Iterator[_T]: return iter(self._seq) - def to_set(self) -> set[_T0]: + def to_set(self) -> set[_T]: return set(self._seq) ### Transformations def map( # noqa: A003 # Ignore that it's shadowing a python built-in self, - func: Callable[[_T0], _T1], - ) -> "Seq[_T1]": + func: Callable[[_T], _S], + ) -> "Seq[_S]": return Seq(map(func, self._seq)) - def filter(self, func: Callable[[_T0], bool]) -> "Seq[_T0]": # noqa: A003 + def filter(self, func: Callable[[_T], bool]) -> "Seq[_T]": # noqa: A003 return Seq(filter(func, self._seq)) - def reduce(self, func: Callable[[_T0, _T0], _T0]) -> _T0: + def reduce(self, func: Callable[[_T, _T], _T]) -> _T: return reduce(func, self._seq) def groupby( - self, func: Callable[[_T0], str] - ) -> "Seq[Group[tuple[_T0, ...]]]": + self, func: Callable[[_T], str] + ) -> "Seq[dict[str, tuple[_T, ...]]]": # Itertools.groupby requires the input to be sorted sorted_input = sorted(self._seq, key=func) @@ -60,15 +60,11 @@ def groupby( sorted_input, key=func ) } + items = [{k: v} for k, v in result.items()] - groups = ( - Group(key=key, value=value) - for key, value in result.items() - ) - - return Seq(groups) + return Seq(items) - def flatten(self) -> "Seq[_T0]": + def flatten(self) -> "Seq[_T]": return Seq( item for sublist in self._seq diff --git a/functionalpy/benchmark/query_1/iterators_q1.py b/functionalpy/benchmark/query_1/iterators_q1.py deleted file mode 100644 index d20c611..0000000 --- a/functionalpy/benchmark/query_1/iterators_q1.py +++ /dev/null @@ -1,116 +0,0 @@ -import datetime as dt -import statistics as stats -from collections.abc import Mapping, Sequence -from dataclasses import dataclass -from enum import Enum -from typing import Any - -from functionalpy import Group, Seq -from functionalpy.benchmark.query_1.input_data import Q1_DATA -from functionalpy.benchmark.utils import benchmark_method - - -class LineStatus(Enum): - SHIPPED = 0 - PENDING = 1 - CANCELLED = 2 - BACKORDERED = 3 - - -@dataclass(frozen=True) -class Item: - ship_date: dt.datetime - quantity: int - extended_price: float - discount: float - tax: float - returned: bool - cancelled: bool - line_status: LineStatus - - -@dataclass(frozen=True) -class CategorySummary: - category_name: str - - sum_quantity: int - sum_base_price: float - sum_discount_price: float - sum_charge: float - - avg_quantity: float - avg_price: float - avg_discount: float - num_orders: int - - -def summarise_category(input_data: Group[Item]) -> CategorySummary: - group_id = input_data.key - rows = input_data.value.to_list() - - return CategorySummary( - category_name=group_id, - sum_quantity=sum(r.quantity for r in rows), - sum_base_price=sum(r.extended_price for r in rows), - sum_discount_price=sum( - calculate_discounted_price(r) for r in rows - ), - sum_charge=sum(calculate_charge(r) for r in rows), - avg_quantity=stats.mean(r.quantity for r in rows), - avg_price=stats.mean(r.extended_price for r in rows), - avg_discount=stats.mean(r.discount for r in rows), - num_orders=input_data.value.count(), - ) - - -def calculate_discounted_price(item: Item) -> float: - return item.extended_price * (1 - item.discount) - - -def calculate_charge(item: Item) -> float: - return item.extended_price * (1 - item.discount) * (1 - item.tax) - - -def parse_input_data( - input_data: Sequence[Mapping[str, Any]] -) -> Sequence[Item]: - parsed_data = ( - Seq(input_data) - .map( - lambda row: Item( - ship_date=row["ship_date"], - quantity=row["quantity"], - extended_price=row["extended_price"], - discount=row["discount"], - tax=row["tax"], - returned=row["returned"], - line_status=LineStatus[row["line_status"].upper()], - cancelled=row["cancelled"], - ), - ) - .to_list() - ) - - return parsed_data - - -def main_iterator(data: Sequence[Item]) -> Sequence[CategorySummary]: - sequence = ( - Seq(data) - .filter(lambda i: i.ship_date <= dt.datetime(2000, 1, 1)) - .groupby( - lambda i: f"status_{i.cancelled}_returned_{i.returned}" - ) - .map(summarise_category) - .to_list() - ) - - return sequence - - -if __name__ == "__main__": - benchmark = benchmark_method( - data_ingest=lambda: parse_input_data(input_data=Q1_DATA), - query=main_iterator, - method_title="iterators_q1", - ) diff --git a/functionalpy/test_sequence.py b/functionalpy/test_sequence.py index 4543677..ab9b7fc 100644 --- a/functionalpy/test_sequence.py +++ b/functionalpy/test_sequence.py @@ -44,6 +44,10 @@ class GroupbyInput: value: int +# TODO: Properly specify groupby +# Some desirable properties: +# - Direct indexing by group key +# - Still a sequence we can use .map etc. on def test_groupby(): groupby_inputs = [ GroupbyInput(key="a", value=1), @@ -51,12 +55,16 @@ def test_groupby(): GroupbyInput(key="b", value=3), ] sequence = Seq(groupby_inputs) - result = sequence.groupby(lambda x: x.key) + result = sequence.groupby(lambda x: x.key).to_tuple() assert len(result) == 2 - assert list(result.keys()) == ["a", "b"] - assert result["a"] == (groupby_inputs[0], groupby_inputs[1]) - assert result["b"] == (groupby_inputs[2],) + assert [x.key for x in result] == ["a", "b"] + assert [x.value.to_list() for x in result if x.key == "a"] == [ + (1, 2) + ] + assert [x.value.to_list() for x in result if x.key == "b"] == [ + (3,) + ] class TestFlattenTypes: