Skip to content

Commit

Permalink
Specify compiled patterns using annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Jan 10, 2025
1 parent b757626 commit 70b1e59
Show file tree
Hide file tree
Showing 18 changed files with 199 additions and 942 deletions.
3 changes: 0 additions & 3 deletions hkmc2/jvm/src/test/scala/hkmc2/MLsDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ abstract class MLsDiffMaker extends DiffMaker:
val showUCS = Command("ucs"): ln =>
ln.split(" ").iterator.map(x => "ucs:" + x.trim).toSet

val compilePatterns = NullaryCommand("cp")

given Elaborator.State = new Elaborator.State:
override def dbg: Bool =
dbgParsing.isSet
|| dbgElab.isSet
|| debug.isSet
override def shouldCompilePatterns: Bool = compilePatterns.isSet

val etl = new TraceLogger:
override def doTrace = dbgElab.isSet || scope.exists:
Expand Down
41 changes: 24 additions & 17 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Desugarer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import hkmc2.syntax.Literal
import Keyword.{as, and, `do`, `else`, is, let, `then`}
import collection.mutable.{HashMap, SortedSet}
import Elaborator.{ctx, Ctxl}
import ucs.DesugaringBase
import ucs.{DesugaringBase, warn, error}

object Desugarer:
extension (op: Keyword.Infix)
Expand Down Expand Up @@ -66,6 +66,8 @@ class Desugarer(val elaborator: Elaborator)
// represents the context with bindings in the current match.

type Sequel = Ctx => Split

type Ctor = SynthSel | Sel | Ident

extension (sequel: Sequel)
def traced(pre: Str, post: Split => Str): Sequel =
Expand Down Expand Up @@ -409,6 +411,22 @@ class Desugarer(val elaborator: Elaborator)
*/
def expandMatch(scrutSymbol: BlockLocalSymbol, pattern: Tree, sequel: Sequel): Split => Sequel =
def ref = scrutSymbol.ref(/* FIXME ident? */)
def dealWithCtorCase(ctor: Ctor, compile: Bool)(fallback: Split): Sequel = ctx =>
val clsTrm = elaborator.cls(ctor, inAppPrefix = false)
clsTrm.symbol.flatMap(_.asClsLike) match
case S(cls: ClassSymbol) =>
if compile then warn(msg"Cannot compile the class `${cls.name}`" -> ctor.toLoc)
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(mod: ModuleSymbol) =>
if compile then warn(msg"Cannot compile the module `${mod.name}`" -> ctor.toLoc)
Branch(ref, Pattern.ClassLike(mod, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(pat: PatternSymbol) =>
if compile then Branch(ref, Pattern.Synonym(pat, N), sequel(ctx)) ~: fallback
else makeUnapplyBranch(ref, clsTrm, sequel(ctx))(fallback)
case N =>
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Cannot use this ${ctor.describe} as a pattern" -> ctor.toLoc :: Nil))
fallback
pattern match
// A single wildcard pattern.
case Under() => _ => ctx => sequel(ctx)
Expand All @@ -423,22 +441,11 @@ class Desugarer(val elaborator: Elaborator)
val aliasSymbol = VarSymbol(id)
val ctxWithAlias = ctx + (nme -> aliasSymbol)
Split.Let(aliasSymbol, ref, sequel(ctxWithAlias) ++ fallback)
case ctor @ (_: Ident | _: SynthSel | _: Sel) => fallback => ctx =>
val clsTrm = elaborator.cls(ctor, inAppPrefix = false)
clsTrm.symbol.flatMap(_.asClsLike) match
case S(cls: ClassSymbol) =>
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(cls: ModuleSymbol) =>
Branch(ref, Pattern.ClassLike(cls, clsTrm, N, false)(ctor), sequel(ctx)) ~: fallback
case S(psym: PatternSymbol) =>
if state.shouldCompilePatterns then
Branch(ref, Pattern.Synonym(psym, N), sequel(ctx)) ~: fallback
else
makeUnapplyBranch(ref, clsTrm, sequel(ctx))(fallback)
case N =>
// Raise an error and discard `sequel`. Use `fallback` instead.
raise(ErrorReport(msg"Cannot use this ${ctor.describe} as a pattern" -> ctor.toLoc :: Nil))
fallback
case ctor: Ctor => dealWithCtorCase(ctor, false)
case Annotated(Ident("compile"), ctor: Ctor) => dealWithCtorCase(ctor, true)
case Annotated(annotation, ctor: Ctor) =>
error(msg"Unrecognized annotation on patterns" -> annotation.toLoc)
dealWithCtorCase(ctor, false)
case Tree.Tup(args) => fallback => ctx => trace(
pre = s"expandMatch <<< ${args.mkString(", ")}",
post = (r: Split) => s"expandMatch >>> ${r.showDbg}"
Expand Down
13 changes: 5 additions & 8 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,6 @@ object Elaborator:
))
def dbg: Bool = false
def dbgUid(uid: Uid[Symbol]): Str = if dbg then s"$uid" else ""
def shouldCompilePatterns: Bool = false // TODO: remove after annotations introduced
transparent inline def State(using state: State): State = state

end Elaborator
Expand Down Expand Up @@ -810,13 +809,11 @@ extends Importer:
val owner = ctx.outer
newCtx.nest(S(patSym)).givenIn:
assert(body.isEmpty)
patSym.split = if state.shouldCompilePatterns then
td.extension.map: tree =>
val split = ucs.DeBrujinSplit.elaborate(tree, this)
scoped("ucs:rp:elaborated"):
log(s"elaborated nameless split:\n${split.display}")
split
else None
patSym.split = td.extension.map: tree =>
val split = ucs.DeBrujinSplit.elaborate(tree, this)
scoped("ucs:rp:elaborated"):
log(s"elaborated nameless split:\n${split.display}")
split
log(s"pattern body is ${td.extension}")
val translate = new ucs.Translator(this)
val bod = translate(ps.map(_.params).getOrElse(Nil), td.extension.getOrElse(die))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,16 @@ object DeBrujinSplit:
Branch(_, Literal(IntLit(-n)), _, _)
case App(Ident("-"), Tup(DecLit(n) :: Nil)) =>
Branch(_, Literal(DecLit(-n)), _, _)
// BEGIN TODO: Support range patterns. This is just to suppress the errors.
case (lo: StrLit) to (incl, hi: StrLit) =>
(_, _, alternative) => alternative
case (lo: IntLit) to (incl, hi: IntLit) =>
(_, _, alternative) => alternative
case (lo: DecLit) to (incl, hi: DecLit) =>
(_, _, alternative) => alternative
case (lo: syntax.Literal) to (_, hi: syntax.Literal) =>
(_, _, alternative) => alternative
// END TODO: Support range patterns
case App(ctor: (Ident | Sel), Tup(params)) => cls(ctor, params)
case literal: syntax.Literal => Branch(_, Literal(literal), _, _)
scoped("ucs:rp:elaborate"):
Expand Down Expand Up @@ -149,7 +159,6 @@ extension (branch: DeBrujinSplit.Branch)
case Branch(`scrutinee`, ClassLike(symbol: PatternSymbol), consequence, alternative) =>
val patternSplit = symbol.split.getOrElse:
lastWords(s"found unelaborated pattern: ${symbol.nme}")
// val consequence2 = consequence // TODO: why can't we expand the consequence?
val consequence2 = go(consequence, scrutinee)
val alternative2 = go(alternative, scrutinee)
patternSplit.expand(scrutinee :: Nil, consequence2) ++ alternative2
Expand Down
23 changes: 13 additions & 10 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/Translator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,19 @@ import Split.display, ucs.Normalization
import syntax.{Fun, Keyword, Literal, ParamBind, Tree}, Tree.*, Keyword.`as`
import scala.collection.mutable.{Buffer, Set as MutSet}

object Translator:
/** String range bounds must be single characters. */
def isInvalidStringBounds(lo: StrLit, hi: StrLit)(using Raise): Bool =
val ds = Buffer.empty[(Message, Option[Loc])]
if lo.value.length != 1 then
ds += msg"String range bounds must have only one character." -> lo.toLoc
if hi.value.length != 1 then
ds += msg"String range bounds must have only one character." -> hi.toLoc
if ds.nonEmpty then error(ds.toSeq*)
ds.nonEmpty

import Translator.*

/** This class translates a tree describing a pattern into functions that can
* perform pattern matching on terms described by the pattern.
*/
Expand Down Expand Up @@ -42,16 +55,6 @@ class Translator(val elaborator: Elaborator)
val test2 = app(upperOp.ref(), tup(scrutFld, fld(Term.Lit(hi))), "ltHi")
plainTest(test1, "gtLo")(plainTest(test2, "ltHi")(inner(Map.empty)))

/** String range bounds must be single characters. */
private def isInvalidStringBounds(lo: StrLit, hi: StrLit)(using Raise): Bool =
val ds = Buffer.empty[(Message, Option[Loc])]
if lo.value.length != 1 then
ds += msg"String range bounds must have only one character." -> lo.toLoc
if hi.value.length != 1 then
ds += msg"String range bounds must have only one character." -> hi.toLoc
if ds.nonEmpty then error(ds.toSeq*)
ds.nonEmpty

/** Generate a split that consumes the entire scrutinee. */
private def full(scrut: Scrut, pat: Tree, inner: Inner)(using Raise): Split = trace(
pre = s"full <<< $pat",
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/ucs/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@ package hkmc2
package semantics

package object ucs:
private[ucs] def error(msgs: (Message, Option[Loc])*)(using Raise): Unit =
def error(msgs: (Message, Option[Loc])*)(using Raise): Unit =
raise(ErrorReport(msgs.toList))

private[ucs] def warn(msgs: (Message, Option[Loc])*)(using Raise): Unit =
def warn(msgs: (Message, Option[Loc])*)(using Raise): Unit =
raise(WarningReport(msgs.toList))
end ucs
33 changes: 13 additions & 20 deletions hkmc2/shared/src/test/mlscript/rp/Future.mls
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@ pattern Email(name, domain) =
:todo // View patterns
pattern GreaterThan(value) = case
n and n > value then n
//│ ╔══[ERROR] Unrecognized pattern.
//│ ║ l.29: n and n > value then n
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^
//│ /!!!\ Uncaught error: java.lang.AssertionError: assertion failed
//│ /!!!\ Uncaught error: scala.MatchError: Case(None,Block(List(InfixApp(InfixApp(Ident(n),keyword 'and',App(Ident(>),Tup(List(Ident(n), Ident(value))))),keyword 'then',Ident(n))))) (of class hkmc2.syntax.Tree$Case)

:todo
// Normal view pattern
Expand All @@ -42,11 +39,11 @@ fun foo(x) = if x is
Unit then ....
Arrow(...) then ....
//│ ╔══[ERROR] Unrecognized pattern split.
//│ ║ l.41: view as
//│ ║ l.38: view as
//│ ║ ^^^^^^^
//│ ║ l.42: Unit then ....
//│ ║ l.39: Unit then ....
//│ ║ ^^^^^^^^^^^^^^^^^^
//│ ║ l.43: Arrow(...) then ....
//│ ║ l.40: Arrow(...) then ....
//│ ╙── ^^^^^^^^^^^^^^^^^^^^^^^^


Expand All @@ -64,16 +61,14 @@ pattern Star(pattern T) = "" | Many(T)
pattern Email(name, domains) =
Rep(Char | ".") as name ~ "@" ~ Rep(Rep(Char) ~ ) as domain
//│ ╔══[PARSE ERROR] Expected end of input; found literal instead
//│ ║ l.61: pattern Char = "a" to "z" | "A" to "Z" | "0" to "9" | "_" | "-"
//│ ║ l.58: pattern Char = "a" to "z" | "A" to "Z" | "0" to "9" | "_" | "-"
//│ ╙── ^^^
//│ ╔══[ERROR] Unrecognized pattern.
//│ ║ l.61: pattern Char = "a" to "z" | "A" to "Z" | "0" to "9" | "_" | "-"
//│ ╙── ^^^^^^
//│ /!!!\ Uncaught error: scala.MatchError: Jux(StrLit(a),Ident(to)) (of class hkmc2.syntax.Tree$Jux)

:todo
pattern Test(foo, bar) = ("foo" as foo) ~ ("bar" as bar)
//│ ╔══[ERROR] Unrecognized pattern.
//│ ║ l.74: pattern Test(foo, bar) = ("foo" as foo) ~ ("bar" as bar)
//│ ║ l.69: pattern Test(foo, bar) = ("foo" as foo) ~ ("bar" as bar)
//│ ╙── ^^^^^^^^^^^^
//│ /!!!\ Uncaught error: java.lang.AssertionError: assertion failed

Expand All @@ -96,7 +91,7 @@ pattern Lines(pattern L) = case
:todo
if input is Lines of Email then
//│ ╔══[PARSE ERROR] Expected start of statement in this position; found end of input instead
//│ ║ l.97: if input is Lines of Email then
//│ ║ l.92: if input is Lines of Email then
//│ ╙── ^
//│ /!!!\ Uncaught error: scala.MatchError: TypeDef(Pat,Ident(L),None,None) (of class hkmc2.syntax.Tree$TypeDef)

Expand All @@ -112,13 +107,13 @@ pattern Email(name, domain) = ...
:todo
if input is Opt(Email, Some((n, d))) then ...
//│ ╔══[ERROR] Name not found: input
//│ ║ l.113: if input is Opt(Email, Some((n, d))) then ...
//│ ║ l.108: if input is Opt(Email, Some((n, d))) then ...
//│ ╙── ^^^^^
//│ ╔══[ERROR] Name not found: Opt
//│ ║ l.113: if input is Opt(Email, Some((n, d))) then ...
//│ ║ l.108: if input is Opt(Email, Some((n, d))) then ...
//│ ╙── ^^^
//│ ╔══[ERROR] Cannot use this identifier as an extractor
//│ ║ l.113: if input is Opt(Email, Some((n, d))) then ...
//│ ║ l.108: if input is Opt(Email, Some((n, d))) then ...
//│ ╙── ^^^

:todo
Expand All @@ -136,8 +131,6 @@ pattern Opt(pattern P) = case
:todo
pattern Digits = "0" to "9" ~ (Digits | "")
//│ ╔══[PARSE ERROR] Expected end of input; found literal instead
//│ ║ l.137: pattern Digits = "0" to "9" ~ (Digits | "")
//│ ║ l.132: pattern Digits = "0" to "9" ~ (Digits | "")
//│ ╙── ^^^
//│ ╔══[ERROR] Unrecognized pattern.
//│ ║ l.137: pattern Digits = "0" to "9" ~ (Digits | "")
//│ ╙── ^^^^^^
//│ /!!!\ Uncaught error: scala.MatchError: Jux(StrLit(0),Ident(to)) (of class hkmc2.syntax.Tree$Jux)
11 changes: 7 additions & 4 deletions hkmc2/shared/src/test/mlscript/rp/RangePatterns.mls
Original file line number Diff line number Diff line change
Expand Up @@ -55,14 +55,17 @@ pattern UnsignedByte = 0..< 256
//│ ╔══[ERROR] Name not found: .<
//│ ║ l.51: pattern UnsignedByte = 0..< 256
//│ ╙── ^^
//│ ╔══[ERROR] Name not found: .<
//│ ║ l.51: pattern UnsignedByte = 0..< 256
//│ ╙── ^^
//│ ╔══[ERROR] Cannot use this identifier as an extractor
//│ ║ l.51: pattern UnsignedByte = 0..< 256
//│ ╙── ^^

:e
pattern BadRange = "s"..=0
//│ ╔══[ERROR] Incompatible range types: string literal to integer literal
//│ ║ l.63: pattern BadRange = "s"..=0
//│ ║ l.66: pattern BadRange = "s"..=0
//│ ╙── ^^^^^^^

// It becomes an absurd pattern.
Expand All @@ -72,14 +75,14 @@ pattern BadRange = "s"..=0
:e
pattern BadRange = 0 ..= "s"
//│ ╔══[ERROR] Incompatible range types: integer literal to string literal
//│ ║ l.73: pattern BadRange = 0 ..= "s"
//│ ║ l.76: pattern BadRange = 0 ..= "s"
//│ ╙── ^^^^^^^^^

:e
pattern BadRange = "yolo" ..= "swag"
//│ ╔══[ERROR] String range bounds must have only one character.
//│ ║ l.79: pattern BadRange = "yolo" ..= "swag"
//│ ║ l.82: pattern BadRange = "yolo" ..= "swag"
//│ ║ ^^^^^^
//│ ╟── String range bounds must have only one character.
//│ ║ l.79: pattern BadRange = "yolo" ..= "swag"
//│ ║ l.82: pattern BadRange = "yolo" ..= "swag"
//│ ╙── ^^^^^^
3 changes: 3 additions & 0 deletions hkmc2/shared/src/test/mlscript/rp/examples/Identifier.mls
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ pattern Digit = "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9"

:ucs rp:elaborated
pattern Lower = "a"..="z"
//│ elaborated nameless split:
//│ > a =>
//│ > reject

:expect true
"a" is Lower
Expand Down
22 changes: 2 additions & 20 deletions hkmc2/shared/src/test/mlscript/rp/nondeterminism/EvenOddTree.mls
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,11 @@ class Pair[A, B](first: A, second: B)
[A, B]
//│ = [ A { class: [class A] }, B { class: [class B] } ]

:cp
pattern OddTree = A | Pair(EvenTree, OddTree) | Pair(OddTree, EvenTree)
pattern EvenTree = B | Pair(EvenTree, EvenTree) | Pair(OddTree, OddTree)

// This does not work for now.

:cp
:ucs desugared rp:normalize rp:expand rp:memo
:todo
A is EvenTree
//│ Desugared:
//│ > if
//│ > let $scrut = globalThis:block#1#666(.)A‹member:A›
//│ > $scrut is EvenTree then true
//│ > else false
//│ FAILURE: Unexpected exception
//│ /!!!\ Uncaught error: java.lang.Exception: Internal Error: found unelaborated pattern: EvenTree
//│ at: mlscript.utils.package$.lastWords(package.scala:230)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.$anonfun$10(DeBrujinSplit.scala:292)
//│ at: scala.Option.getOrElse(Option.scala:201)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.go$12$$anonfun$1(DeBrujinSplit.scala:292)
//│ at: hkmc2.utils.TraceLogger.scoped(TraceLogger.scala:37)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.go$12(DeBrujinSplit.scala:350)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.go$12$$anonfun$1(DeBrujinSplit.scala:285)
//│ at: hkmc2.utils.TraceLogger.scoped(TraceLogger.scala:37)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.go$12(DeBrujinSplit.scala:350)
//│ at: hkmc2.semantics.ucs.DeBrujinSplit$package$.normalize(DeBrujinSplit.scala:352)
//│ = false
19 changes: 11 additions & 8 deletions hkmc2/shared/src/test/mlscript/rp/recursion/BitSeq.mls
Original file line number Diff line number Diff line change
@@ -1,32 +1,35 @@
:js
:cp

class Pair[A, B](val first: A, val second: B)

pattern Bit = 0 | 1

:expect true
0 is Bit
0 is @compile Bit
//│ = true

:expect true
1 is Bit
1 is @compile Bit
//│ = true

:expect false
42 is Bit
42 is @compile Bit
//│ = false

pattern BitSeq = null | Pair(Bit, BitSeq)

null is BitSeq
:expect true
null is @compile BitSeq
//│ = true

Pair(0, null) is BitSeq
:expect true
Pair(0, null) is @compile BitSeq
//│ = true

Pair(1, Pair(0, null)) is BitSeq
:expect true
Pair(1, Pair(0, null)) is @compile BitSeq
//│ = true

Pair(2, null) is BitSeq
:expect false
Pair(2, null) is @compile BitSeq
//│ = false
Loading

0 comments on commit 70b1e59

Please sign in to comment.