From 600019f1117c66c1b14267cd727fc5cb3904db56 Mon Sep 17 00:00:00 2001 From: Martin Bernstorff Date: Wed, 22 Nov 2023 09:07:28 +0000 Subject: [PATCH] feat: refactor groupby --- functionalpy/_sequence.py | 21 +++++++++++++-------- functionalpy/_sequence.pyi | 4 ++-- functionalpy/test_sequence.py | 23 +++++++++++++++++++++++ 3 files changed, 38 insertions(+), 10 deletions(-) diff --git a/functionalpy/_sequence.py b/functionalpy/_sequence.py index 9d38e35..ba63dc0 100644 --- a/functionalpy/_sequence.py +++ b/functionalpy/_sequence.py @@ -11,7 +11,7 @@ @dataclass(frozen=True) class Group(Generic[_T0]): key: str - value: "Seq[_T0]" + value: _T0 class Seq(Generic[_T0]): @@ -48,14 +48,19 @@ def filter(self, func: Callable[[_T0], bool]) -> "Seq[_T0]": # noqa: A003 def reduce(self, func: Callable[[_T0, _T0], _T0]) -> _T0: return reduce(func, self._seq) - def group_by( + def groupby( self, func: Callable[[_T0], str] - ) -> "Seq[Group[_T0]]": - result = ( - Group(key=key, value=Seq(value)) - for key, value in itertools.groupby(self._seq, key=func) - ) - return Seq(result) + ) -> dict[str, tuple[_T0, ...]]: + sorted_input = sorted(self._seq, key=func) + + result = { + key: tuple(value) + for key, value in itertools.groupby( + sorted_input, key=func + ) + } + + return result def flatten(self) -> "Seq[_T0]": return Seq( diff --git a/functionalpy/_sequence.pyi b/functionalpy/_sequence.pyi index 82f194e..75c77e2 100644 --- a/functionalpy/_sequence.pyi +++ b/functionalpy/_sequence.pyi @@ -28,9 +28,9 @@ class Seq(Generic[_T0]): def filter(self, func: Callable[[_T0], bool]) -> Seq[_T0]: # noqa: A003 ... def reduce(self, func: Callable[[_T0, _T0], _T0]) -> _T0: ... - def group_by( + def groupby( self, func: Callable[[_T0], str] - ) -> Seq[Group[_T0]]: ... + ) -> dict[str, tuple[_T0, ...]]: ... @overload def flatten(self: Seq[list[_S]]) -> Seq[_S]: ... @overload diff --git a/functionalpy/test_sequence.py b/functionalpy/test_sequence.py index 52b7d13..4543677 100644 --- a/functionalpy/test_sequence.py +++ b/functionalpy/test_sequence.py @@ -1,3 +1,5 @@ +from dataclasses import dataclass + from functionalpy._sequence import Seq @@ -36,6 +38,27 @@ def test_flatten(): assert result.to_list() == [1, 2, 3, 4] +@dataclass(frozen=True) +class GroupbyInput: + key: str + value: int + + +def test_groupby(): + groupby_inputs = [ + GroupbyInput(key="a", value=1), + GroupbyInput(key="a", value=2), + GroupbyInput(key="b", value=3), + ] + sequence = Seq(groupby_inputs) + result = sequence.groupby(lambda x: x.key) + + assert len(result) == 2 + assert list(result.keys()) == ["a", "b"] + assert result["a"] == (groupby_inputs[0], groupby_inputs[1]) + assert result["b"] == (groupby_inputs[2],) + + class TestFlattenTypes: def test_flatten_tuple(self): test_input = ((1, 2), (3, 4))