Skip to content

Commit

Permalink
Add support for product types
Browse files Browse the repository at this point in the history
  • Loading branch information
taig committed Dec 2, 2024
1 parent 7a34b56 commit 779035e
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
package io.taig.enumeration.ext

import cats.data.NonEmptyList

import scala.deriving.Mirror
Expand All @@ -22,27 +21,37 @@ object EnumerationValues:
override type Out = B
override def toNonEmptyList: NonEmptyList[Out] = values

inline given [A](using
mirror: Mirror.SumOf[A],
values: EnumerationValues.Aux[mirror.MirroredElemTypes, A]
inline given sum[A, B <: Tuple](using
mirror: Mirror.SumOf[A] { type MirroredElemTypes = B },
values: EnumerationValues.Aux[B, A]
): EnumerationValues.Aux[A, A] = EnumerationValues(values = values.toNonEmptyList)

inline given singleton[A <: Singleton, B <: Tuple, C >: A](using
values: EnumerationValues.Aux[B, C]
): EnumerationValues.Aux[A *: B, C] = EnumerationValues(values = valueOf[A] :: values.toNonEmptyList)
inline given product[A](using
mirror: Mirror.ProductOf[A],
values: EnumerationValues.Aux[mirror.MirroredElemTypes, mirror.MirroredElemTypes]
): EnumerationValues.Aux[A, A] = EnumerationValues(values = values.toNonEmptyList.map(mirror.fromTuple))

inline given nested[A, B <: Tuple, C >: A](using
mirror: Mirror.SumOf[A],
head: EnumerationValues.Aux[mirror.MirroredElemTypes, C],
tail: EnumerationValues.Aux[B, C]
): EnumerationValues.Aux[A *: B, C] = EnumerationValues(values = head.toNonEmptyList.concatNel(tail.toNonEmptyList))
inline given sum1[A, B >: A](using values: EnumerationValues.Aux[A, A]): EnumerationValues.Aux[A *: EmptyTuple, B] =
EnumerationValues(values = values.toNonEmptyList)

inline given nestedOne[A, B >: A](using
mirror: Mirror.SumOf[A],
head: EnumerationValues.Aux[mirror.MirroredElemTypes, B]
): EnumerationValues.Aux[A *: EmptyTuple, B] = EnumerationValues(values = head.toNonEmptyList)
inline given sumN[A, B <: Tuple, C >: A](using
head: EnumerationValues.Aux[A, A],
tail: EnumerationValues.Aux[B, C]
): EnumerationValues.Aux[A *: B, C] =
EnumerationValues(values = head.toNonEmptyList.concatNel(tail.toNonEmptyList))

inline given last[A <: Singleton, B >: A]: EnumerationValues.Aux[A *: EmptyTuple, B] =
inline given singleton[A <: Singleton]: EnumerationValues.Aux[A, A] =
EnumerationValues(values = NonEmptyList.one(valueOf[A]))

inline given product1[A](using
values: EnumerationValues.Aux[A, A]
): EnumerationValues.Aux[A *: EmptyTuple, A *: EmptyTuple] =
EnumerationValues(values = values.toNonEmptyList.map(_ *: EmptyTuple))

inline given productN[A, B <: Tuple](using
head: EnumerationValues.Aux[A, A],
tail: EnumerationValues.Aux[B, B]
): EnumerationValues.Aux[A *: B, A *: B] =
EnumerationValues(values = head.toNonEmptyList.flatMap(head => tail.toNonEmptyList.map(head *: _)))

def valuesOf[A](using values: EnumerationValues.Aux[A, A]): NonEmptyList[A] = values.toNonEmptyList
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,15 @@ package io.taig.enumeration.ext
import cats.data.NonEmptyList
import munit.FunSuite

import scala.deriving.Mirror

final class EnumerationValuesTest extends FunSuite:
test("singleton"):
assertEquals(
obtained = valuesOf["foo"],
expected = NonEmptyList.of("foo")
)

test("enum"):
enum Animal:
case Bird
Expand Down Expand Up @@ -31,20 +39,61 @@ final class EnumerationValuesTest extends FunSuite:
assertEquals(obtained = valuesOf[Animal], expected = NonEmptyList.of(Animal.Bird, Animal.Cat, Animal.Dog))

test("nested"):
sealed abstract class Foobar
object Foobar:
enum Foo extends Foobar:
sealed abstract class Nested
object Nested:
enum Foo extends Nested:
case A
case B
case C

sealed abstract class Bar extends Foobar
sealed abstract class Bar extends Nested
object Bar:
case object X extends Bar
case object Y extends Bar
case object Z extends Bar

assertEquals(
obtained = valuesOf[Nested],
expected = NonEmptyList.of(Nested.Foo.A, Nested.Foo.B, Nested.Foo.C, Nested.Bar.X, Nested.Bar.Y, Nested.Bar.Z)
)

test("tuple"):
enum Foo:
case A
case B

enum Bar:
case A
case B

assertEquals(
obtained = valuesOf[(Foo, Bar)],
expected = NonEmptyList.of(
(Foo.A, Bar.A),
(Foo.A, Bar.B),
(Foo.B, Bar.A),
(Foo.B, Bar.B)
)
)

test("case class"):
final case class Foobar(foo: Foobar.Foo, bar: Foobar.Bar)

object Foobar:
enum Foo:
case A
case B

enum Bar:
case A
case B

assertEquals(
obtained = valuesOf[Foobar],
expected = NonEmptyList.of(Foobar.Foo.A, Foobar.Foo.B, Foobar.Foo.C, Foobar.Bar.X, Foobar.Bar.Y, Foobar.Bar.Z)
expected = NonEmptyList.of(
Foobar(Foobar.Foo.A, Foobar.Bar.A),
Foobar(Foobar.Foo.A, Foobar.Bar.B),
Foobar(Foobar.Foo.B, Foobar.Bar.A),
Foobar(Foobar.Foo.B, Foobar.Bar.B)
)
)

0 comments on commit 779035e

Please sign in to comment.