Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add defunctionalizer prototype #185

Merged
merged 22 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
302 changes: 192 additions & 110 deletions compiler/shared/main/scala/mlscript/compiler/ClassLifter.scala

Large diffs are not rendered by default.

36 changes: 36 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/DataType.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package mlscript.compiler

abstract class DataType

object DataType:
sealed class Singleton(value: Expr.Literal, dataType: DataType) extends DataType:
override def toString(): String = value.toString()

enum Primitive(name: String) extends DataType:
case Integer extends Primitive("int")
case Decimal extends Primitive("real")
case Boolean extends Primitive("bool")
case String extends Primitive("str")
override def toString(): String = this.name
end Primitive

sealed case class Tuple(elementTypes: List[DataType]) extends DataType:
override def toString(): String = elementTypes.mkString("(", ", ", ")")

sealed case class Class(declaration: Item.TypeDecl) extends DataType:
override def toString(): String = s"class ${declaration.name.name}"

sealed case class Function(parameterTypes: List[DataType], returnType: DataType) extends DataType:
def this(returnType: DataType, parameterTypes: DataType*) =
this(parameterTypes.toList, returnType)
override def toString(): String =
val parameterList = parameterTypes.mkString("(", ", ", ")")
s"$parameterList -> $returnType"

sealed case class Record(fields: Map[String, DataType]) extends DataType:
def this(fields: (String, DataType)*) = this(Map.from(fields))
override def toString(): String =
fields.iterator.map { (name, ty) => s"$name: $ty" }.mkString("{", ", ", "}")

case object Unknown extends DataType:
override def toString(): String = "unknown"
18 changes: 18 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/DataTypeInferer.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package mlscript.compiler
import mlscript.compiler.mono.MonomorphError

trait DataTypeInferer:
import DataType._

def findClassByName(name: String): Option[Item.TypeDecl]

def infer(expr: Expr, compatiableType: Option[DataType]): DataType =
expr match
case Expr.Tuple(elements) => DataType.Tuple(elements.map(infer(_, None)))
case lit @ Expr.Literal(value: BigInt) => Singleton(lit, Primitive.Integer)
case lit @ Expr.Literal(value: BigDecimal) => Singleton(lit, Primitive.Decimal)
case lit @ Expr.Literal(value: String) => Singleton(lit, Primitive.String)
case lit @ Expr.Literal(value: Boolean) => Singleton(lit, Primitive.Boolean)
case Expr.Apply(Expr.Ref(name), args) =>
findClassByName(name).fold(DataType.Unknown)(DataType.Class(_))
case _ => throw MonomorphError(s"I can't infer the type of $expr now")
196 changes: 196 additions & 0 deletions compiler/shared/main/scala/mlscript/compiler/Helpers.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,196 @@
package mlscript.compiler

import mlscript.{App, Asc, Assign, Bind, Blk, Bra, CaseOf, Lam, Let, Lit,
New, Rcd, Sel, Subs, Term, Test, Tup, With, Var, Fld, FldFlags, If, PolyType}
import mlscript.{IfBody, IfThen, IfElse, IfLet, IfOpApp, IfOpsApp, IfBlock}
import mlscript.UnitLit
import mlscript.codegen.Helpers.inspect as showStructure
import mlscript.compiler.mono.MonomorphError
import mlscript.NuTypeDef
import mlscript.NuFunDef
import scala.collection.mutable.ArrayBuffer
import mlscript.CaseBranches
import mlscript.Case
import mlscript.NoCases
import mlscript.Wildcard
import mlscript.DecLit
import mlscript.IntLit
import mlscript.StrLit
import mlscript.AppliedType
import mlscript.TypeName
import mlscript.TypeDefKind
import mlscript.compiler.mono.Monomorph

object Helpers:
/**
* Extract parameters for monomorphization from a `Tup`.
*/
def toFuncParams(term: Term): Iterator[Parameter] = term match
case Tup(fields) => fields.iterator.flatMap {
// The new parser emits `Tup(_: UnitLit(true))` from `fun f() = x`.
case (_, Fld(FldFlags(_, _, _), UnitLit(true))) => None
case (None, Fld(FldFlags(_, spec, _), Var(name))) => Some((spec, Expr.Ref(name)))
case (Some(Var(name)), Fld(FldFlags(_, spec, _), _)) => Some((spec, Expr.Ref(name)))
case _ => throw new MonomorphError(
s"only `Var` can be parameters but we meet ${showStructure(term)}"
)
}
case _ => throw MonomorphError("expect the list of parameters to be a `Tup`")

def toFuncArgs(term: Term): IterableOnce[Term] = term match
// The new parser generates `(undefined, )` when no arguments.
// Let's do this temporary fix.
case Tup((_, Fld(FldFlags(_, _, _), UnitLit(true))) :: Nil) => Iterable.empty
case Tup(fields) => fields.iterator.map(_._2.value)
case _ => Some(term)

def term2Expr(term: Term): Expr = {
term match
case Var(name) => Expr.Ref(name)
case Lam(lhs, rhs) =>
val params = toFuncParams(lhs).toList
Expr.Lambda(params, term2Expr(rhs))
case App(App(Var("=>"), Bra(false, args: Tup)), body) =>
val params = toFuncParams(args).toList
Expr.Lambda(params, term2Expr(body))
case App(App(Var("."), self), App(Var(method), args: Tup)) =>
Expr.Apply(Expr.Select(term2Expr(self), Expr.Ref(method)), List.from(toFuncArgs(args).map(term2Expr)))
case App(lhs, rhs) =>
val callee = term2Expr(lhs)
val arguments = toFuncArgs(rhs).map(term2Expr).toList
Expr.Apply(callee, arguments)
case Tup(fields) =>
Expr.Tuple(fields.map {
case (_, Fld(FldFlags(mut, spec, genGetter), value)) => term2Expr(value)
})
case Rcd(fields) =>
Expr.Record(fields.map {
case (name, Fld(FldFlags(mut, spec, genGetter), value)) => (Expr.Ref(name.name), term2Expr(value))
})
case Sel(receiver, fieldName) =>
Expr.Select(term2Expr(receiver), Expr.Ref(fieldName.name))
case Let(rec, Var(name), rhs, body) =>
val exprRhs = term2Expr(rhs)
val exprBody = term2Expr(body)
Expr.LetIn(rec, Expr.Ref(name), exprRhs, exprBody)
case Blk(stmts) => Expr.Block(stmts.flatMap[Expr | Item.FuncDecl | Item.FuncDefn] {
case term: Term => Some(term2Expr(term))
case tyDef: NuTypeDef => ???
case funDef: NuFunDef =>
val NuFunDef(_, nme, sn, targs, rhs) = funDef
val ret: Item.FuncDecl | Item.FuncDefn = rhs match
case Left(Lam(params, body)) =>
Item.FuncDecl(Expr.Ref(nme.name), toFuncParams(params).toList, term2Expr(body))
case Left(body: Term) => Item.FuncDecl(Expr.Ref(nme.name), Nil, term2Expr(body))
case Right(tp) => Item.FuncDefn(Expr.Ref(nme.name), targs, PolyType(Nil, tp)) //TODO: Check correctness in Type -> Polytype conversion
Some(ret)
case mlscript.DataDefn(_) => throw MonomorphError("unsupported DataDefn")
case mlscript.DatatypeDefn(_, _) => throw MonomorphError("unsupported DatatypeDefn")
case mlscript.TypeDef(_, _, _, _, _, _, _, _) => throw MonomorphError("unsupported TypeDef")
case mlscript.Def(_, _, _, _) => throw MonomorphError("unsupported Def")
case mlscript.LetS(_, _, _) => throw MonomorphError("unsupported LetS")
case mlscript.Constructor(_, _) => throw MonomorphError("unsupported Constructor")
})
case Bra(rcd, term) => term2Expr(term)
case Asc(term, ty) => Expr.As(term2Expr(term), ty)
case _: Bind => throw MonomorphError("cannot monomorphize `Bind`")
case _: Test => throw MonomorphError("cannot monomorphize `Test`")
case With(term, Rcd(fields)) =>
Expr.With(term2Expr(term), Expr.Record(fields.map {
case (name, Fld(FldFlags(mut, spec, getGetter), value)) => (Expr.Ref(name.name), term2Expr(term))
}))
case CaseOf(term, cases) =>
def rec(bra: CaseBranches)(using buffer: ArrayBuffer[CaseBranch]): Unit = bra match
case Case(pat, body, rest) =>
val newCase = pat match
case Var(name) => CaseBranch.Instance(Expr.Ref(name), Expr.Ref("_"), term2Expr(body))
case DecLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case IntLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case StrLit(value) => CaseBranch.Constant(Expr.Literal(value), term2Expr(body))
case UnitLit(undefinedOrNull) => CaseBranch.Constant(Expr.Literal(UnitValue.Undefined), term2Expr(body))
buffer.addOne(newCase)
rec(rest)
case NoCases => ()
case Wildcard(body) =>
buffer.addOne(CaseBranch.Wildcard(term2Expr(body)))
val branchBuffer = ArrayBuffer[CaseBranch]()
rec(cases)(using branchBuffer)
Expr.Match(term2Expr(term), branchBuffer)

case Subs(array, index) =>
Expr.Subscript(term2Expr(array), term2Expr(index))
case Assign(lhs, rhs) =>
Expr.Assign(term2Expr(lhs), term2Expr(rhs))
case New(None, body) =>
???
case New(Some((constructor, args)), body) =>
val typeName = constructor match
case AppliedType(TypeName(name), _) => name
case TypeName(name) => name
Expr.New(TypeName(typeName), toFuncArgs(args).map(term2Expr).toList)
// case Blk(unit) => Expr.Isolated(trans2Expr(TypingUnit(unit)))
case If(body, alternate) => body match
case IfThen(condition, consequent) =>
Expr.IfThenElse(
term2Expr(condition),
term2Expr(consequent),
alternate.map(term2Expr)
)
case term: IfElse => throw MonomorphError("unsupported IfElse")
case term: IfLet => throw MonomorphError("unsupported IfLet")
case term: IfOpApp => throw MonomorphError("unsupported IfOpApp")
case term: IfOpsApp => throw MonomorphError("unsupported IfOpsApp")
case term: IfBlock => throw MonomorphError("unsupported IfBlock")
case IntLit(value) => Expr.Literal(value)
case DecLit(value) => Expr.Literal(value)
case StrLit(value) => Expr.Literal(value)
case UnitLit(undefinedOrNull) =>
Expr.Literal(if undefinedOrNull
then UnitValue.Undefined
else UnitValue.Null)
case _ => throw MonomorphError("unsupported term"+ term.toString)
}

def func2Item(funDef: NuFunDef): Item.FuncDecl | Item.FuncDefn =
val NuFunDef(_, nme, sn, targs, rhs) = funDef
rhs match
case Left(Lam(params, body)) =>
Item.FuncDecl(Expr.Ref(nme.name), toFuncParams(params).toList, term2Expr(body))
case Left(body: Term) => Item.FuncDecl(Expr.Ref(nme.name), Nil, term2Expr(body))
case Right(tp) => Item.FuncDefn(Expr.Ref(nme.name), targs, PolyType(Nil, tp)) //TODO: Check correctness in Type -> Polytype conversion

def type2Item(tyDef: NuTypeDef): Item.TypeDecl =
val NuTypeDef(kind, className, tparams, params, _, _, parents, _, _, body) = tyDef
val isolation = Isolation(body.entities.flatMap {
// Question: Will there be pure terms in class body?
case term: Term =>
Some(term2Expr(term))
case subTypeDef: NuTypeDef => ???
case subFunDef: NuFunDef =>
Some(func2Item(subFunDef))
case term => throw MonomorphError(term.toString)
})
val typeDecl: Item.TypeDecl = Item.TypeDecl(
Expr.Ref(className.name), // name
kind, // kind
tparams.map(_._2), // typeParams
toFuncParams(params.getOrElse(Tup(Nil))).toList, // params
parents.map {
case Var(name) => (TypeName(name), Nil)
case App(Var(name), args) => (TypeName(name), term2Expr(args) match{
case Expr.Tuple(fields) => fields
case _ => Nil
})
case _ => throw MonomorphError("unsupported parent term")
}, // parents
isolation // body
)
typeDecl

private given Conversion[TypeDefKind, TypeDeclKind] with
import mlscript.{Als, Cls, Trt}
def apply(kind: TypeDefKind): TypeDeclKind = kind match
case Als => TypeDeclKind.Alias
case Cls => TypeDeclKind.Class
case Trt => TypeDeclKind.Trait
case _ => ???
Loading