Skip to content

Commit

Permalink
Implement rudimentary PreTyper and a prototype of new UCS desugarer
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu committed Nov 23, 2023
1 parent 8894b74 commit c51e5f7
Show file tree
Hide file tree
Showing 41 changed files with 1,771 additions and 33 deletions.
15 changes: 15 additions & 0 deletions shared/src/main/scala/mlscript/JSBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ class JSBackend(allowUnresolvedSymbols: Boolean) {
JSRecord(fields map { case (key, Fld(_, value)) =>
key.name -> translateTerm(value)
})
case Sel(receiver, JSBackend.TupleIndex(n)) =>
JSField(translateTerm(receiver), n.toString)
case Sel(receiver, fieldName) =>
JSField(translateTerm(receiver), fieldName.name)
// Turn let into an IIFE.
Expand Down Expand Up @@ -1576,4 +1578,17 @@ object JSBackend {

def isSafeInteger(value: BigInt): Boolean =
MinimalSafeInteger <= value && value <= MaximalSafeInteger

// Temporary measurement until we adopt the new tuple index.
object TupleIndex {
def unapply(fieldName: Var): Opt[Int] = {
val name = fieldName.name
if (name.startsWith("_") && name.forall(_.isDigit))
name.drop(1).toIntOption match {
case S(n) if n > 0 => S(n - 1)
case _ => N
}
else N
}
}
}
7 changes: 5 additions & 2 deletions shared/src/main/scala/mlscript/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1207,8 +1207,11 @@ class Typer(var dbg: Boolean, var verbose: Bool, var explainErrors: Bool, val ne
}
con(s_ty, req, cs_ty)
case elf: If =>
try typeTerm(desugarIf(elf)) catch {
case e: ucs.DesugaringException => err(e.messages)
elf.desugaredTerm match {
case S(desugared) => typeTerm(desugared)
case N => try typeTerm(desugarIf(elf)) catch {
case e: ucs.DesugaringException => err(e.messages)
}
}
case AdtMatchWith(cond, arms) =>
println(s"typed condition term ${cond}")
Expand Down
4 changes: 2 additions & 2 deletions shared/src/main/scala/mlscript/codegen/Helpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ object Helpers {
case App(lhs, rhs) => s"App(${inspect(lhs)}, ${inspect(rhs)})"
case Tup(fields) =>
val entries = fields map {
case (S(name), Fld(_, value)) => s"$name: ${inspect(value)}"
case (N, Fld(_, value)) => s"_: ${inspect(value)}"
case (S(name), Fld(_, value)) => s"(S(${inspect(name)}), ${inspect(value)})"
case (N, Fld(_, value)) => s"(N, ${inspect(value)})"
}
s"Tup(${entries mkString ", "})"
case Rcd(fields) =>
Expand Down
24 changes: 24 additions & 0 deletions shared/src/main/scala/mlscript/helpers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,14 @@ trait TypingUnitImpl extends Located { self: TypingUnit =>
}.mkString("{", "; ", "}")
override def toString: String = s"${entities.mkString("; ")}"
lazy val children: List[Located] = entities
def describe: Str = entities.iterator.map {
case term: Term => term.describe
case NuFunDef(S(rec), nme, _, _, _) =>
s"let ${if (rec) "rec " else ""}$nme"
case NuFunDef(N, nme, _, _, _) => s"fun $nme"
case typ: NuTypeDef => typ.describe
case other => "?"
}.mkString("{", "; ", "}")
}

trait ConstructorImpl { self: Constructor =>
Expand All @@ -493,6 +501,7 @@ trait TypeNameImpl extends Ordered[TypeName] { self: TypeName =>
def targs: Ls[Type] = Nil
def compare(that: TypeName): Int = this.name compare that.name
lazy val toVar: Var = Var(name).withLocOf(this)
var symbol: Opt[pretyper.TypeSymbol] = N
}

trait FldImpl extends Located { self: Fld =>
Expand Down Expand Up @@ -727,6 +736,21 @@ trait VarImpl { self: Var =>
(name.head.isLetter && name.head.isLower || name.head === '_' || name.head === '$') && name =/= "true" && name =/= "false"
def toVar: Var = this
var uid: Opt[Int] = N

// PreTyper additions
import pretyper.{Symbol}

private var _symbol: Opt[Symbol] = N
def symbolOption: Opt[Symbol] = _symbol
def symbol: Symbol = _symbol.getOrElse(???)
def symbol_=(symbol: Symbol): Unit =
_symbol match {
case N => _symbol = S(symbol)
case S(_) => ???
}
// TODO: Remove this methods if they are useless.
// def withSymbol: Var = { symbol = S(new ValueSymbol(this, false)); this }
// def withSymbol(s: TermSymbol): Var = { symbol = S(s); this }
}

trait TupImpl { self: Tup =>
Expand Down
186 changes: 186 additions & 0 deletions shared/src/main/scala/mlscript/pretyper/PreTyper.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
package mlscript.pretyper

import mlscript.ucs.DesugarUCS
import mlscript._, utils._, shorthands._
import mlscript.codegen.Helpers.inspect

class PreTyper(override val debugLevel: Opt[Int], useNewDefs: Bool) extends Traceable with DesugarUCS {
private def extractParameters(fields: Term): Ls[ValueSymbol] = fields match {
case Tup(arguments) =>
if (useNewDefs) {
arguments.map {
case (S(nme: Var), Fld(_, _)) => new ValueSymbol(nme, false)
case (_, Fld(_, nme: Var)) => new ValueSymbol(nme, false)
case (_, Fld(_, x)) => println(x.toString); ???
}
} else {
arguments.map {
case (_, Fld(_, nme: Var)) => new ValueSymbol(nme, false)
case (_, Fld(_, x)) => println(x.toString); ???
}
}
case PlainTup(arguments @ _*) =>
arguments.map {
case nme: Var => new ValueSymbol(nme, false)
case other => println("Unknown parameters: " + inspect(other)); ??? // TODO: bad
}.toList
case other => println("Unknown parameters: " + inspect(other)); ??? // TODO: bad
}

// `visitIf` is meaningless because it represents patterns with terms.

protected def resolveVar(v: Var)(implicit scope: Scope): Unit =
trace(s"resolveVar(name = \"$v\")") {
scope.get(v.name) match {
case Some(sym: ValueSymbol) =>
println(s"Resolve variable $v to a value.", 2)
v.symbol = sym
case Some(sym: SubValueSymbol) =>
println(s"Resolve variable $v to a value.", 2)
v.symbol = sym
case Some(sym: FunctionSymbol) =>
println(s"Resolve variable $v to a function.", 2)
v.symbol = sym
case Some(sym: TypeSymbol) =>
if (sym.defn.kind == Cls) {
println(s"Resolve variable $v to a class.", 2)
v.symbol = sym
} else {
throw new Exception(s"Name $v refers to a type")
}
case None => throw new Exception(s"Variable $v not found in scope")
}
}()

protected def visitVar(v: Var)(implicit scope: Scope): Unit =
trace(s"visitVar(name = \"$v\")") {
v.symbolOption match {
case N => resolveVar(v)
case S(symbol) => scope.get(v.name) match {
case S(other) if other === symbol => ()
case S(other) => throw new Exception(s"Variable $v refers to a different symbol")
case N => throw new Exception(s"Variable $v not found in scope. It is possibly a free variable.")
}
}
}()

protected def visitTerm(term: Term)(implicit scope: Scope): Unit =
trace(s"visitTerm <== ${shortName(term)}") {
term match {
case Assign(lhs, rhs) => visitTerm(lhs); visitTerm(rhs)
case Bra(_, trm) => visitTerm(trm)
case Lam(lhs, rhs) =>
visitTerm(rhs)(scope ++ extractParameters(lhs))
case Sel(receiver, fieldName) => visitTerm(receiver)
case Let(isRec, nme, rhs, body) =>
visitTerm(rhs)
visitTerm(body)(scope + new ValueSymbol(nme, false))
case New(head, body) =>
case Tup(fields) => fields.foreach { case (_, Fld(_, t)) => visitTerm(t) }
case Asc(trm, ty) => visitTerm(trm)
case ef @ If(_, _) => visitIf(ef)(scope)
case TyApp(lhs, targs) => // TODO: When?
case Eqn(lhs, rhs) => ??? // TODO: How?
case Blk(stmts) => stmts.foreach {
case t: Term => visitTerm(t)
case _ => ??? // TODO: When?
}
case Subs(arr, idx) => visitTerm(arr); visitTerm(idx)
case Bind(lhs, rhs) => visitTerm(lhs); visitTerm(rhs)
case Splc(fields) => fields.foreach {
case L(t) => visitTerm(t)
case R(Fld(_, t)) => visitTerm(t)
}
case Forall(params, body) => ??? // TODO: When?
case Rcd(fields) => fields.foreach { case (_, Fld(_, t)) => visitTerm(t) }
case CaseOf(trm, cases) =>
case With(trm, fields) => visitTerm(trm); visitTerm(fields)
case Where(body, where) => ??? // TODO: When?
case App(lhs, rhs) => visitTerm(lhs); visitTerm(rhs)
case Test(trm, ty) => visitTerm(trm)
case _: Lit | _: Super => ()
case v: Var => visitVar(v)
case AdtMatchWith(cond, arms) => ??? // TODO: How?
case Inst(body) => visitTerm(body)
}
}(_ => s"visitTerm ==> ${shortName(term)}")

private def visitNuTypeDef(symbol: TypeSymbol, defn: NuTypeDef)(implicit scope: Scope): Unit =
trace(s"visitNuTypeDef <== ${defn.kind} ${defn.nme.name}") {
visitTypingUnit(defn.body, defn.nme.name, scope)
()
}(_ => s"visitNuTypeDef <== ${defn.kind} ${defn.nme.name}")

private def visitFunction(symbol: FunctionSymbol, defn: NuFunDef)(implicit scope: Scope): Unit =
trace(s"visitFunction <== ${defn.nme.name}") {
defn.rhs match {
case Left(term) =>
val subScope = if (defn.isLetRec == S(false)) scope else scope + symbol
visitTerm(term)(subScope)
case Right(value) => ()
}
}(_ => s"visitFunction ==> ${defn.nme.name}")

private def visitLetBinding(symbol: ValueSymbol, rec: Bool, rhs: Term)(implicit scope: Scope): Unit =
trace(s"visitLetBinding(rec = $rec, ${symbol.name})") {

}()

private def visitTypingUnit(typingUnit: TypingUnit, name: Str, parentScope: Scope): (Scope, TypeContents) =
trace(s"visitTypingUnit <== $name: ${typingUnit.describe}") {
import mlscript.{Cls, Trt, Mxn, Als, Mod}
// Pass 1: Build a scope with hoisted symbols.
val hoistedScope = typingUnit.entities.foldLeft(parentScope.derive) {
case (acc, _: Term) => acc // Skip
case (acc, defn: NuTypeDef) =>
val `var` = Var(defn.nme.name).withLoc(defn.nme.toLoc)
// Create a type symbol but do not visit its inner members
acc ++ (new TypeSymbol(defn.nme, defn) ::
(defn.kind match {
case Mod => new ValueSymbol(`var`, true) :: Nil
case Als | Cls | Mxn | Trt => Nil
}))
case (acc, defn: NuFunDef) if defn.isLetRec.isEmpty =>
acc + new FunctionSymbol(defn.nme, defn)
case (acc, _: NuFunDef) => acc
case (acc, _: Constructor | _: DataDefn | _: DatatypeDefn | _: Def | _: LetS | _: TypeDef) => ??? // TODO: When?
}
println(hoistedScope.symbols.map(_.name).mkString("1. scope = {", ", ", "}"))
// Pass 2: Visit non-hoisted and build a complete scope.
val completeScope = typingUnit.entities.foldLeft[Scope](hoistedScope) {
case (acc, term: Term) => visitTerm(term)(acc); acc
case (acc, defn: NuTypeDef) => acc
case (acc, defn @ NuFunDef(Some(rec), nme, _, _, L(rhs))) =>
val symbol = new ValueSymbol(defn.nme, true)
val scopeWithVar = acc + symbol
visitLetBinding(symbol, rec, rhs)(if (rec) { scopeWithVar } else { acc })
scopeWithVar
case (acc, _: NuFunDef) => acc
case (acc, _: Constructor | _: DataDefn | _: DatatypeDefn | _: Def | _: LetS | _: TypeDef) => ??? // TODO: When?
}
println(hoistedScope.symbols.map(_.name).mkString("2. scope = {", ", ", "}"))
import pretyper.TypeSymbol
// Pass 3: Visit hoisted symbols.
completeScope.symbols.foreach {
case symbol: TypeSymbol =>
val innerScope = symbol.defn.kind match {
case Cls =>
completeScope.derive ++ (symbol.defn.params match {
case N => Nil
case S(fields) => extractParameters(fields)
})
case Als | Mod | Mxn | Trt => completeScope
}
visitNuTypeDef(symbol, symbol.defn)(innerScope)
case symbol: FunctionSymbol => visitFunction(symbol, symbol.defn)(completeScope)
case _: ValueSymbol => ()
case _: SubValueSymbol => ()
}
(completeScope, new TypeContents)
}({ case (scope, contents) => s"visitTypingUnit ==> ${scope.showLocalSymbols}" })

def process(typingUnit: TypingUnit, scope: Scope, name: Str): (Scope, TypeContents) =
trace(s"process <== $name: ${typingUnit.describe}") {
visitTypingUnit(typingUnit, name, scope)
}({ case (scope, contents) => s"process ==> ${scope.showLocalSymbols}" })
}
47 changes: 47 additions & 0 deletions shared/src/main/scala/mlscript/pretyper/Scope.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
package mlscript.pretyper

import collection.immutable.Map
import mlscript.utils._, shorthands._
import mlscript.Var

class Scope(val enclosing: Opt[Scope], val entries: Map[String, Symbol]) {
@inline
def get(name: String): Opt[Symbol] = entries.get(name) match {
case Some(sym) => S(sym)
case None => enclosing.fold(N: Opt[Symbol])(_.get(name))
}

@inline
def +(sym: Symbol): Scope = new Scope(S(this), entries + (sym.name -> sym))

@inline
def ++(syms: IterableOnce[Symbol]): Scope =
new Scope(S(this), entries ++ syms.iterator.map(sym => sym.name -> sym))

def withEntries(syms: IterableOnce[Var -> Symbol]): Scope =
new Scope(S(this), entries ++ syms.iterator.map {
case (nme, sym) => nme.name -> sym
})

@inline
def symbols: Iterable[Symbol] = entries.values

def derive: Scope = new Scope(S(this), Map.empty)

def showLocalSymbols: Str = entries.iterator.map(_._1).mkString(", ")
}

object Scope {
def from(symbols: IterableOnce[Symbol]): Scope =
new Scope(N, Map.from(symbols.iterator.map(sym => sym.name -> sym)))

val global: Scope = Scope.from(
"""true,false,document,window,typeof,toString,not,succ,log,discard,negate,
|round,add,sub,mul,div,sqrt,lt,le,gt,ge,slt,sle,sgt,sge,length,concat,eq,
|ne,error,id,if,emptyArray,+,-,*,%,/,<,>,<=,>=,==,===,<>,&&,||"""
.stripMargin
.split(",")
.iterator
.map(name => new ValueSymbol(Var(name), false))
)
}
Loading

0 comments on commit c51e5f7

Please sign in to comment.