diff --git a/perhaps/maybe.py b/perhaps/maybe.py index aae3aab..2518594 100644 --- a/perhaps/maybe.py +++ b/perhaps/maybe.py @@ -13,6 +13,7 @@ T = TypeVar("T") R = TypeVar("R") +U = TypeVar("U") class Maybe(Generic[T], ABC): @@ -34,6 +35,27 @@ def map(self, f: Callable[[T], R]) -> "Maybe[R]": """ ... + @abstractmethod + def lift2(self, f: Callable[[T, R], U], other: "Maybe[R]") -> "Maybe[U]": + """ + Given a function that takes two arguments, and another Maybe, apply the + function to the value of this Maybe and the other Maybe, if they both + have values. + + Analogous to the `Option::zip_wth` in Rust and the `liftA2` in Haskell. + + >>> Just(1).lift2(lambda x, y: x + y, Just(2)) + Just(3) + + >>> Just(1).lift2(lambda x, y: x + y, Nothing()) + Nothing() + + >>> Nothing().lift2(lambda x, y: x + y, Just(2)) + Nothing() + + """ + ... + @abstractmethod def bind(self, f: Callable[[T], "Maybe[R]"]) -> "Maybe[R]": """ @@ -168,10 +190,6 @@ def from_try( except exc: return Nothing() - @abstractmethod - def __and__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: - ... - @abstractmethod def __or__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: ... @@ -200,10 +218,25 @@ def __init__(self, value: T): def map(self, f: Callable[[T], R]) -> "Just[R]": return Just(f(self.value)) - def bind(self, f: Callable[[T], "Maybe[R]"]) -> "Maybe[R]": + @overload + def lift2(self, f: Callable[[T, R], U], other: "Just[R]") -> "Just[U]": + ... + + @overload + def lift2(self, f: Callable[[T, R], U], other: "Nothing[R]") -> "Nothing[U]": + ... + + @overload + def lift2(self, f: Callable[[T, R], U], other: "Maybe[R]") -> "Maybe[U]": + ... + + def lift2(self, f: Callable[[T, R], U], other: Maybe[R]) -> Maybe[U]: + return other.map(lambda x: f(self.value, x)) + + def bind(self, f: Callable[[T], Maybe[R]]) -> Maybe[R]: return f(self.value) - def and_then(self, f: Callable[[T], "Maybe[R]"]) -> "Maybe[R]": + def and_then(self, f: Callable[[T], Maybe[R]]) -> Maybe[R]: return self.bind(f) def unwrap( @@ -232,10 +265,10 @@ def __and__(self, other: "Nothing[R]") -> "Nothing[R]": ... @overload - def __and__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: + def __and__(self, other: Maybe[R]) -> Union[Maybe[T], Maybe[R]]: ... - def __and__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: + def __and__(self, other: Maybe[R]) -> Union[Maybe[T], Maybe[R]]: return other @overload @@ -247,10 +280,10 @@ def __or__(self, other: "Nothing") -> "Just[T]": ... @overload - def __or__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: + def __or__(self, other: Maybe[R]) -> Union[Maybe[T], Maybe[R]]: ... - def __or__(self, other: "Maybe[R]") -> Union["Maybe[T]", "Maybe[R]"]: + def __or__(self, other: Maybe[R]) -> Union[Maybe[T], Maybe[R]]: return self def __eq__(self, other: object) -> bool: @@ -276,10 +309,13 @@ def __new__(cls) -> "Nothing[T]": def map(self, f: Callable[[T], R]) -> "Nothing[R]": return Nothing() - def bind(self, f: Callable[[T], "Maybe[R]"]) -> "Nothing[R]": + def lift2(self, f: Callable[[T, R], U], other: Maybe[R]) -> Maybe[U]: + return Nothing() + + def bind(self, f: Callable[[T], Maybe[R]]) -> "Nothing[R]": return Nothing() - def and_then(self, f: Callable[[T], "Maybe[R]"]) -> "Nothing[R]": + def and_then(self, f: Callable[[T], Maybe[R]]) -> "Nothing[R]": return self.bind(f) def unwrap( @@ -309,7 +345,7 @@ def __and__(self, other: "Just") -> "Nothing[T]": def __and__(self, other: "Nothing[R]") -> "Nothing[R]": ... - def __and__(self, other: "Maybe[R]") -> Union["Nothing[T]", "Maybe[R]"]: + def __and__(self, other: Maybe[R]) -> Union["Nothing[T]", Maybe[R]]: return other if isinstance(other, Nothing) else self @overload @@ -320,7 +356,7 @@ def __or__(self, other: "Just[R]") -> "Just[R]": def __or__(self, other: "Nothing[R]") -> "Nothing[R]": ... - def __or__(self, other: "Maybe[R]") -> Union["Nothing[T]", "Maybe[R]"]: + def __or__(self, other: Maybe[R]) -> Union["Nothing[T]", Maybe[R]]: return other def __eq__(self, other: object) -> bool: diff --git a/tests/test_maybe.py b/tests/test_maybe.py index d12b4aa..7d3cf46 100644 --- a/tests/test_maybe.py +++ b/tests/test_maybe.py @@ -1,3 +1,5 @@ +from typing import cast + import pytest from perhaps import Just, Maybe, Nothing @@ -24,6 +26,19 @@ def test_map(): assert Nothing().map(lambda x: x + 1) == Nothing() +def test_lift2(): + assert Just(1).lift2(lambda x, y: x + y, Just(2)) == Just(3) + assert ( + Just(1).lift2(lambda x, y: x + y, Nothing[int]()) == Nothing() + ) # annotation for Nothing is required here, though most often it can be inferred + assert Nothing().lift2(lambda x, y: x + y, Just(2)) == Nothing() + maybe1 = cast(Maybe[int], Nothing()) + assert maybe1.lift2(lambda x, y: x + y, Just(2)) == Nothing() + maybe2 = cast(Maybe[int], Just(3)) + assert maybe1.lift2(lambda x, y: x + y, maybe2) == Nothing() + assert maybe2.lift2(lambda x, y: x * y, maybe2) == Just(9) + + def test_bind(): assert Just(1).bind(lambda x: Just(x + 1)) == Just(2) assert Nothing().bind(lambda x: Just(x + 1)) == Nothing()