Skip to content

Commit

Permalink
feat: refactor groupby
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff committed Nov 22, 2023
1 parent 5e8bef2 commit 600019f
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 10 deletions.
21 changes: 13 additions & 8 deletions functionalpy/_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
@dataclass(frozen=True)
class Group(Generic[_T0]):
key: str
value: "Seq[_T0]"
value: _T0


class Seq(Generic[_T0]):
Expand Down Expand Up @@ -48,14 +48,19 @@ def filter(self, func: Callable[[_T0], bool]) -> "Seq[_T0]": # noqa: A003
def reduce(self, func: Callable[[_T0, _T0], _T0]) -> _T0:
return reduce(func, self._seq)

def group_by(
def groupby(
self, func: Callable[[_T0], str]
) -> "Seq[Group[_T0]]":
result = (
Group(key=key, value=Seq(value))
for key, value in itertools.groupby(self._seq, key=func)
)
return Seq(result)
) -> dict[str, tuple[_T0, ...]]:
sorted_input = sorted(self._seq, key=func)

result = {
key: tuple(value)
for key, value in itertools.groupby(
sorted_input, key=func
)
}

return result

def flatten(self) -> "Seq[_T0]":
return Seq(
Expand Down
4 changes: 2 additions & 2 deletions functionalpy/_sequence.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,9 @@ class Seq(Generic[_T0]):
def filter(self, func: Callable[[_T0], bool]) -> Seq[_T0]: # noqa: A003
...
def reduce(self, func: Callable[[_T0, _T0], _T0]) -> _T0: ...
def group_by(
def groupby(
self, func: Callable[[_T0], str]
) -> Seq[Group[_T0]]: ...
) -> dict[str, tuple[_T0, ...]]: ...
@overload
def flatten(self: Seq[list[_S]]) -> Seq[_S]: ...
@overload
Expand Down
23 changes: 23 additions & 0 deletions functionalpy/test_sequence.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from dataclasses import dataclass

from functionalpy._sequence import Seq


Expand Down Expand Up @@ -36,6 +38,27 @@ def test_flatten():
assert result.to_list() == [1, 2, 3, 4]


@dataclass(frozen=True)
class GroupbyInput:
key: str
value: int


def test_groupby():
groupby_inputs = [
GroupbyInput(key="a", value=1),
GroupbyInput(key="a", value=2),
GroupbyInput(key="b", value=3),
]
sequence = Seq(groupby_inputs)
result = sequence.groupby(lambda x: x.key)

assert len(result) == 2
assert list(result.keys()) == ["a", "b"]
assert result["a"] == (groupby_inputs[0], groupby_inputs[1])
assert result["b"] == (groupby_inputs[2],)


class TestFlattenTypes:
def test_flatten_tuple(self):
test_input = ((1, 2), (3, 4))
Expand Down

0 comments on commit 600019f

Please sign in to comment.