Skip to content

Commit

Permalink
Support for annotations (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
chengluyu authored Jan 9, 2025
1 parent e9e1354 commit 1b5d920
Show file tree
Hide file tree
Showing 14 changed files with 433 additions and 81 deletions.
10 changes: 5 additions & 5 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL, scope: Scope):
val cr = freshVar(new TempSymbol(S(unq), "ctx"))
constrain(tryMkMono(ty, body), BbCtx.codeTy(tv, cr))
(tv, cr, eff)
case blk @ Term.Blk(LetDecl(sym) :: DefineVar(sym2, rhs) :: Nil, body) if sym2 is sym => // TODO: more than one!!
case blk @ Term.Blk(LetDecl(sym, _) :: DefineVar(sym2, rhs) :: Nil, body) if sym2 is sym => // TODO: more than one!!
val (rhsTy, rhsCtx, rhsEff) = typeCode(rhs)(using ctx)
val nestCtx = ctx.nextLevel
given BbCtx = nestCtx
Expand Down Expand Up @@ -414,19 +414,19 @@ class BBTyper(using elState: Elaborator.State, tl: TL, scope: Scope):
case (term: Term) :: stats =>
effBuff += typeCheck(term)._2
goStats(stats)
case LetDecl(sym) :: DefineVar(sym2, rhs) :: stats =>
case LetDecl(sym, _) :: DefineVar(sym2, rhs) :: stats =>
require(sym2 is sym)
val (rhsTy, eff) = typeCheck(rhs)
effBuff += eff
ctx += sym -> rhsTy
goStats(stats)
case TermDefinition(_, Fun, sym, ps :: Nil, sig, Some(body), _, _) :: stats =>
case TermDefinition(_, Fun, sym, ps :: Nil, sig, Some(body), _, _, _) :: stats =>
typeFunDef(sym, Term.Lam(ps, body), sig, ctx)
goStats(stats)
case TermDefinition(_, Fun, sym, Nil, sig, Some(body), _, _) :: stats =>
case TermDefinition(_, Fun, sym, Nil, sig, Some(body), _, _, _) :: stats =>
typeFunDef(sym, body, sig, ctx) // * may be a case expressions
goStats(stats)
case TermDefinition(_, Fun, sym, _, S(sig), None, _, _) :: stats =>
case TermDefinition(_, Fun, sym, _, S(sig), None, _, _, _) :: stats =>
ctx += sym -> typeType(sig)
goStats(stats)
case (clsDef: ClassDef) :: stats =>
Expand Down
23 changes: 21 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ class Lowering(using TL, Raise, Elaborator.State):
case st.Blk((d: Declaration) :: stats, res) =>
d match
case td: TermDefinition =>
reportAnnotations(td, td.annotations)
td.body match
case N => // abstract declarations have no lowering
term(st.Blk(stats, res))(k)
Expand All @@ -120,12 +121,15 @@ class Lowering(using TL, Raise, Elaborator.State):
term(st.Blk(stats, res))(k))
// case cls: ClassDef =>
case cls: ClassLikeDef =>
reportAnnotations(cls, cls.annotations)
val bodBlk = cls.body.blk
val (mtds, rest1) = bodBlk.stats.partitionMap:
case td: TermDefinition if td.k is syntax.Fun => L(td)
case s => R(s)
val (privateFlds, rest2) = rest1.partitionMap:
case LetDecl(sym: TermSymbol) => L(sym)
case decl @ LetDecl(sym: TermSymbol, annotations) =>
reportAnnotations(decl, annotations)
L(sym)
case s => R(s)
val publicFlds = rest2.collect:
case td @ TermDefinition(k = (_: syntax.Val)) => td
Expand All @@ -148,7 +152,8 @@ class Lowering(using TL, Raise, Elaborator.State):
case _ =>
// TODO handle
term(st.Blk(stats, res))(k)
case st.Blk((LetDecl(sym)) :: stats, res) =>
case st.Blk((decl @ LetDecl(sym, annotations)) :: stats, res) =>
reportAnnotations(decl, annotations)
term(st.Blk(stats, res))(k)
case st.Blk((DefineVar(sym, rhs)) :: stats, res) =>
subTerm(rhs): r =>
Expand Down Expand Up @@ -339,6 +344,12 @@ class Lowering(using TL, Raise, Elaborator.State):
subTerm(rhs): value =>
AssignField(ref, Tree.Ident("value"), value, k(value))(N)

case Annotated(prefix, receiver) =>
raise(WarningReport(
msg"This annotation has no effect." -> prefix.toLoc ::
msg"Annotations are not supported on ${receiver.describe} terms." -> receiver.toLoc :: Nil))
term(receiver)(k)

case Error => End("error")

// case _ =>
Expand Down Expand Up @@ -374,6 +385,14 @@ class Lowering(using TL, Raise, Elaborator.State):

def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str])(using Subst): (List[ParamList], Block) =
(paramLists, returnedTerm(bodyTerm))

def reportAnnotations(target: Statement, annotations: Ls[Term]): Unit = if annotations.nonEmpty then
raise(WarningReport(
(msg"This annotation has no effect." -> annotations.foldLeft[Opt[Loc]](N):
case (acc, term) => acc match
case N => term.toLoc
case S(loc) => S(loc ++ term.toLoc)) ::
Nil))


trait LoweringSelSanityChecks
Expand Down
4 changes: 2 additions & 2 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/BlockImpl.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package semantics

import mlscript.utils.*, shorthands.*
import syntax.Tree.*
import hkmc2.syntax.TypeOrTermDef
import hkmc2.syntax.{Annotations, TypeOrTermDef}


trait BlockImpl(using Elaborator.State):
Expand All @@ -14,7 +14,7 @@ trait BlockImpl(using Elaborator.State):
val definedSymbols: Array[Str -> BlockMemberSymbol] =
desugStmts
.flatMap:
case td: syntax.TypeOrTermDef =>
case Annotations(_, td: syntax.TypeOrTermDef) =>
td.name match
case L(_) => Nil
case R(id) =>
Expand Down
97 changes: 63 additions & 34 deletions hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -150,8 +150,8 @@ class Elaborator(val tl: TraceLogger, val wd: os.Path)
extends Importer:
import tl.*

def mkLetBinding(sym: LocalSymbol, rhs: Term): Ls[Statement] =
LetDecl(sym) :: DefineVar(sym, rhs) :: Nil
def mkLetBinding(sym: LocalSymbol, rhs: Term, annotations: Ls[Term]): Ls[Statement] =
LetDecl(sym, annotations) :: DefineVar(sym, rhs) :: Nil

def resolveField(srcTree: Tree, base: Opt[Symbol], nme: Ident): Opt[FieldSymbol] =
base match
Expand Down Expand Up @@ -205,7 +205,7 @@ extends Importer:
val lt = term(lhs)
val sym = TempSymbol(S(lt), "old")
Term.Blk(
LetDecl(sym) :: DefineVar(sym, lt) :: Nil, Term.Try(Term.Blk(
LetDecl(sym, Nil) :: DefineVar(sym, lt) :: Nil, Term.Try(Term.Blk(
Term.Assgn(lt, term(rhs)) :: Nil,
term(bod),
), Term.Assgn(lt, sym.ref(id))))
Expand Down Expand Up @@ -467,6 +467,13 @@ extends Importer:
case Under() =>
raise(ErrorReport(msg"Illegal position for '_' placeholder." -> tree.toLoc :: Nil))
Term.Error
case Annotated(lhs, rhs) =>
val annotation = lhs match
case App(_: (Ident | SynthSel | Sel), _) | _: (Ident | SynthSel | Sel) => term(lhs)
case _ =>
raise(ErrorReport(msg"Illegal annotation shape." -> lhs.toLoc :: Nil))
Term.Error
Term.Annotated(annotation, term(rhs))
// case _ =>
// ???

Expand Down Expand Up @@ -522,11 +529,25 @@ extends Importer:

// TODO extract this into a separate method
@tailrec
def go(sts: Ls[Tree], acc: Ls[Statement]): Ctxl[(Term.Blk, Ctx)] = sts match
def go(sts: Ls[Tree], annotations: Ls[Term], acc: Ls[Statement]): Ctxl[(Term.Blk, Ctx)] =
/** Call this function when the following term cannot be annotated. */
def reportUnusedAnnotations: Unit = if annotations.nonEmpty then raise:
WarningReport:
msg"This annotation has no effect" -> (annotations.foldLeft[Opt[Loc]](N):
case (acc, ann) => acc match
case N => ann.toLoc
case S(loc) => S(loc ++ ann.toLoc)
) :: (sts.headOption match
case N => msg"A target term is expected at the end of block" -> blk.toLoc.map(_.right)
case S(head) => msg"Annotations are not supported on ${head.describe} terms." -> head.toLoc
) :: Nil
sts match
case Nil =>
reportUnusedAnnotations
val res = unit
(Term.Blk(acc.reverse, res), ctx)
case Open(bod) :: sts =>
reportUnusedAnnotations
bod match
case Jux(bse, Block(sts)) =>
some(bse -> some(sts))
Expand All @@ -537,7 +558,7 @@ extends Importer:
raise(ErrorReport(msg"Illegal 'open' statement shape." -> bod.toLoc :: Nil))
N
match
case N => go(sts, acc)
case N => go(sts, annotations, acc)
case S((base, importedTrees)) =>
base match
case baseId: Ident =>
Expand All @@ -561,14 +582,15 @@ extends Importer:
raise(ErrorReport(msg"Illegal 'open' statement element." -> t.toLoc :: Nil))
Nil
(ctx elem_++ importedNames).givenIn:
go(sts, acc)
go(sts, Nil, acc)
case N =>
raise(ErrorReport(msg"Name not found: ${baseId.name}" -> baseId.toLoc :: Nil))
go(sts, acc)
go(sts, Nil, acc)
case _ =>
raise(ErrorReport(msg"Illegal 'open' statement base." -> base.toLoc :: Nil))
go(sts, acc)
go(sts, Nil, acc)
case (m @ Modified(Keyword.`import`, absLoc, arg)) :: sts =>
reportUnusedAnnotations
val (newCtx, newAcc) = arg match
case Tree.StrLit(path) =>
val stmt = importPath(path)
Expand All @@ -580,7 +602,7 @@ extends Importer:
arg.toLoc :: Nil))
(ctx, acc)
newCtx.givenIn:
go(sts, newAcc)
go(sts, Nil, newAcc)

case (hd @ LetLike(`let`, Apps(id: Ident, tups), rhso, N)) :: sts if id.name.headOption.exists(_.isLower) =>
val sym =
Expand All @@ -590,17 +612,18 @@ extends Importer:
case S(rhs) =>
val rrhs = tups.foldRight(rhs):
Tree.InfixApp(_, Keyword.`=>`, _)
mkLetBinding(sym, term(rrhs)) reverse_::: acc
mkLetBinding(sym, term(rrhs), annotations) reverse_::: acc
case N =>
if tups.nonEmpty then
raise(ErrorReport(msg"Expected a right-hand side for let bindings with parameters" -> hd.toLoc :: Nil))
LetDecl(sym) :: acc
LetDecl(sym, annotations) :: acc
(ctx + (id.name -> sym)) givenIn:
go(sts, newAcc)
go(sts, Nil, newAcc)
case (tree @ LetLike(`let`, lhs, S(rhs), N)) :: sts =>
raise(ErrorReport(msg"Unsupported let binding shape" -> tree.toLoc :: Nil))
go(sts, Term.Error :: acc)
go(sts, Nil, Term.Error :: acc)
case (hd @ Handle(id: Ident, cls: Ident, Block(sts_), N)) :: sts =>
reportUnusedAnnotations
val sym = fieldOrVarSym(HandlerBind, id)
log(s"Processing `handle` statement $id (${sym}) ${ctx.outer}")

Expand All @@ -611,10 +634,10 @@ extends Importer:
case trm => raise(WarningReport(msg"Terms in handler block do nothing" -> trm.toLoc :: Nil))

val tds = elabed.stats.map {
case td @ TermDefinition(owner, Fun, sym, params, sign, body, resSym, flags) =>
case td @ TermDefinition(owner, Fun, sym, params, sign, body, resSym, flags, annotations) =>
params.reverse match
case ParamList(_, value :: Nil, _) :: newParams =>
val newTd = TermDefinition(owner, Fun, sym, newParams.reverse, sign, body, resSym, flags)
val newTd = TermDefinition(owner, Fun, sym, newParams.reverse, sign, body, resSym, flags, annotations)
S(HandlerTermDefinition(value.sym, newTd))
case _ =>
raise(ErrorReport(msg"Handler function is missing resumption parameter" -> td.toLoc :: Nil))
Expand All @@ -627,29 +650,30 @@ extends Importer:

val newAcc = Term.Handle(sym, term(cls), tds) :: acc
ctx + (id.name -> sym) givenIn:
go(sts, newAcc)
go(sts, Nil, newAcc)
case (tree @ Handle(_, _, _, N)) :: sts =>
raise(ErrorReport(msg"Unsupported handle binding shape" -> tree.toLoc :: Nil))
go(sts, Term.Error :: acc)
go(sts, Nil, Term.Error :: acc)

case Def(lhs, rhs) :: sts =>
reportUnusedAnnotations
lhs match
case id: Ident =>
val r = term(rhs)
ctx.get(id.name) match
case S(elem) =>
elem.symbol match
case S(sym: LocalSymbol) => go(sts, DefineVar(sym, r) :: acc)
case S(sym: LocalSymbol) => go(sts, Nil, DefineVar(sym, r) :: acc)
case N =>
// TODO lookup in members? inherited/refined stuff?
raise(ErrorReport(msg"Name not found: ${id.name}" -> id.toLoc :: Nil))
go(sts, Term.Error :: acc)
go(sts, Nil, Term.Error :: acc)
case App(base, args) =>
go(Def(base, InfixApp(args, Keyword.`=>`, rhs)) :: sts, acc)
go(Def(base, InfixApp(args, Keyword.`=>`, rhs)) :: sts, Nil, acc)
case _ =>
raise(ErrorReport(msg"Unrecognized definitional assignment left-hand side: ${lhs.describe}"
-> lhs.toLoc :: Nil)) // TODO BE
go(sts, Term.Error :: acc)
go(sts, Nil, Term.Error :: acc)
case (td @ TermDef(k, nme, rhs)) :: sts =>
log(s"Processing term definition $nme")
td.name match
Expand All @@ -674,7 +698,7 @@ extends Importer:
val b = rhs.map(term(_)(using newCtx))
val r = FlowSymbol(s"‹result of ${sym}")
val tdf = TermDefinition(owner, k, sym, pss, s, b, r,
TermDefFlags.empty.copy(isModMember = isModMember))
TermDefFlags.empty.copy(isModMember = isModMember), annotations)
sym.defn = S(tdf)

// indicates if the function really returns a module
Expand Down Expand Up @@ -705,10 +729,11 @@ extends Importer:
case _ => ()

tdf
go(sts, tdf :: acc)
go(sts, Nil, tdf :: acc)
case L(d) =>
reportUnusedAnnotations
raise(d)
go(sts, acc)
go(sts, Nil, acc)
case (td @ TypeDef(k, head, extension, body)) :: sts =>
assert((k is Als) || (k is Cls) || (k is Mod) || (k is Obj) || (k is Pat), k)
val nme = td.name match
Expand Down Expand Up @@ -759,7 +784,7 @@ extends Importer:
assert(body.isEmpty)
val d =
given Ctx = newCtx
semantics.TypeDef(alsSym, tps, extension.map(term(_)), N)
semantics.TypeDef(alsSym, tps, extension.map(term(_)), N, annotations)
alsSym.defn = S(d)
d
case Pat =>
Expand All @@ -770,7 +795,7 @@ extends Importer:
log(s"pattern body is ${td.extension}")
val translate = new ucs.Translator(this)
val bod = translate(ps.map(_.params).getOrElse(Nil), td.extension.getOrElse(die))
val pd = PatternDef(owner, patSym, tps, ps, ObjBody(Term.Blk(bod, Term.Lit(UnitLit(true)))))
val pd = PatternDef(owner, patSym, tps, ps, ObjBody(Term.Blk(bod, Term.Lit(UnitLit(true)))), annotations)
patSym.defn = S(pd)
pd
case k: (Mod.type | Obj.type) =>
Expand All @@ -784,7 +809,7 @@ extends Importer:
// case S(t) => block(t :: Nil)
case S(t) => ???
case N => (new Term.Blk(Nil, Term.Lit(UnitLit(true))), ctx)
ModuleDef(owner, clsSym, tps, ps, k, ObjBody(bod))
ModuleDef(owner, clsSym, tps, ps, k, ObjBody(bod), annotations)
clsSym.defn = S(cd)
cd
case Cls =>
Expand All @@ -798,28 +823,32 @@ extends Importer:
// case S(t) => block(t :: Nil)
case S(t) => ???
case N => (new Term.Blk(Nil, Term.Lit(UnitLit(true))), ctx)
ClassDef(owner, Cls, clsSym, tps, ps, ObjBody(bod))
ClassDef(owner, Cls, clsSym, tps, ps, ObjBody(bod), annotations)
clsSym.defn = S(cd)
cd
go(sts, defn :: acc)
go(sts, Nil, defn :: acc)

case Modified(Keyword.`abstract`, absLoc, body) :: sts =>
???
// TODO: pass abstract to `go`
go(body :: sts, acc)
go(body :: sts, annotations, acc)
case Modified(Keyword.`declare`, absLoc, body) :: sts =>
// TODO: pass declare to `go`
go(body :: sts, acc)
go(body :: sts, annotations, acc)
case Annotated(annotation, target) :: sts =>
go(target :: sts, annotations :+ term(annotation), acc)
case (result: Tree) :: Nil =>
reportUnusedAnnotations
val res = term(result)
(Term.Blk(acc.reverse, res), ctx)
case (st: Tree) :: sts =>
reportUnusedAnnotations
val res = term(st) // TODO reject plain term statements? Currently, `(1, 2)` is allowed to elaborate (tho it should be rejected in type checking later)
go(sts, res :: acc)
go(sts, Nil, res :: acc)
end go

c.withMembers(members, c.outer).givenIn:
go(blk.desugStmts, Nil)
go(blk.desugStmts, Nil, Nil)


def fieldOrVarSym(k: TermDefKind, id: Ident)(using Ctx): LocalSymbol & NamedSymbol =
Expand Down Expand Up @@ -902,7 +931,7 @@ extends Importer:
def computeVariances(s: Statement): Unit =
val trav = VarianceTraverser()
def go(s: Statement): Unit = s match
case TermDefinition(_, k, sym, pss, sign, body, r, _) =>
case TermDefinition(_, k, sym, pss, sign, body, r, _, _) =>
pss.foreach(ps => ps.params.foreach(trav.traverseType(S(false))))
sign.foreach(trav.traverseType(S(true)))
body match
Expand Down
Loading

0 comments on commit 1b5d920

Please sign in to comment.