Skip to content

Commit

Permalink
Add new IR forms for effect handlers (#253)
Browse files Browse the repository at this point in the history
Co-authored-by: CAG2Mark <git@markng.com>
  • Loading branch information
AnsonYeung and CAG2Mark authored Dec 21, 2024
1 parent 60191fe commit e37d8f8
Show file tree
Hide file tree
Showing 31 changed files with 286 additions and 185 deletions.
9 changes: 5 additions & 4 deletions compiler/shared/test/diff-ir/cpp/Makefile
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
CXX := g++
CFLAGS := $(CFLAGS) -O3 -Wall -Wextra -std=c++20 -I. -Wno-inconsistent-missing-override -I/opt/homebrew/include
LDFLAGS := $(LDFLAGS) -lmimalloc -lgmp -L/opt/homebrew/lib
CFLAGS += -O3 -Wall -Wextra -std=c++20 -I. -Wno-inconsistent-missing-override -I/opt/homebrew/include
LDFLAGS += -L/opt/homebrew/lib
LDLIBS := -lmimalloc -lgmp
SRC :=
INCLUDES = mlsprelude.h
INCLUDES := mlsprelude.h
DST :=
DEFAULT_TARGET := mls
TARGET := $(or $(DST),$(DEFAULT_TARGET))
Expand All @@ -23,4 +24,4 @@ clean:
auto: $(TARGET)

$(TARGET): $(SRC) $(INCLUDES)
$(CXX) $(CFLAGS) $(LDFLAGS) $(SRC) -o $(TARGET)
$(CXX) $(CFLAGS) $(LDFLAGS) $(SRC) $(LDLIBS) -o $(TARGET)
2 changes: 2 additions & 0 deletions hkmc2/jvm/src/test/scala/hkmc2/JSBackendDiffMaker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
val silent = NullaryCommand("silent")
val noSanityCheck = NullaryCommand("noSanityCheck")
val traceJS = NullaryCommand("traceJS")
val handler = NullaryCommand("handler")
val expect = Command("expect"): ln =>
ln.trim

Expand Down Expand Up @@ -78,6 +79,7 @@ abstract class JSBackendDiffMaker extends MLsDiffMaker:
new codegen.Lowering
with codegen.LoweringSelSanityChecks(noSanityCheck.isUnset)
with codegen.LoweringTraceLog(traceJS.isSet)
with codegen.LoweringHandler(handler.isSet)
given Elaborator.Ctx = curCtx
val jsb = new JSBuilder
with JSBuilderArgNumSanityChecks(noSanityCheck.isUnset)
Expand Down
18 changes: 17 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Block.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ sealed abstract class Block extends Product with AutoLocated:
case Break(_) => Set.empty
case Continue(_) => Set.empty
case Define(defn, rst) => rst.definedVars
case HandleBlock(lhs, res, cls, hdr, bod, rst) => bod.definedVars ++ rst.definedVars + lhs
case HandleBlockReturn(_) => Set.empty
case TryBlock(sub, fin, rst) => sub.definedVars ++ fin.definedVars ++ rst.definedVars
case Label(lbl, bod, rst) => bod.definedVars ++ rst.definedVars

Expand All @@ -56,6 +58,8 @@ sealed abstract class Block extends Product with AutoLocated:
case Begin(sub, rst) => Begin(sub, rst.mapTail(f))
case Assign(lhs, rhs, rst) => Assign(lhs, rhs, rst.mapTail(f))
case Define(defn, rst) => Define(defn, rst.mapTail(f))
case HandleBlock(lhs, res, cls, handlers, body, rest) =>
HandleBlock(lhs, res, cls, handlers.map(h => Handler(h.sym, h.resumeSym, h.params, h.body.mapTail(f))), body.mapTail(f), rest.mapTail(f))
case Match(scrut, arms, dflt, rst) =>
Match(scrut, arms.map(_ -> _.mapTail(f)), dflt.map(_.mapTail(f)), rst.mapTail(f))

Expand Down Expand Up @@ -89,10 +93,13 @@ case class TryBlock(sub: Block, finallyDo: Block, rest: Block) extends Block wit
case class Assign(lhs: Local, rhs: Result, rest: Block) extends Block with ProductWithTail
// case class Assign(lhs: Path, rhs: Result, rest: Block) extends Block with ProductWithTail

case class AssignField(lhs: Path, nme: Tree.Ident, rhs: Result, rest: Block)(symbol: Opt[FieldSymbol]) extends Block with ProductWithTail
case class AssignField(lhs: Path, nme: Tree.Ident, rhs: Result, rest: Block)(val symbol: Opt[FieldSymbol]) extends Block with ProductWithTail

case class Define(defn: Defn, rest: Block) extends Block with ProductWithTail

case class HandleBlock(lhs: Local, res: Local, cls: Path, handlers: Ls[Handler], body: Block, rest: Block) extends Block with ProductWithTail
case class HandleBlockReturn(res: Result) extends BlockTail

sealed abstract class Defn:
val sym: MemberSymbol[?]

Expand All @@ -112,12 +119,21 @@ final case class ValDefn(
final case class ClsLikeDefn(
sym: MemberSymbol[? <: ClassLikeDef],
k: syntax.ClsLikeKind,
parentSym: Opt[Path],
methods: Ls[FunDefn],
privateFields: Ls[TermSymbol],
publicFields: Ls[TermDefinition],
preCtor: Block,
ctor: Block,
) extends Defn

final case class Handler(
sym: BlockMemberSymbol,
resumeSym: LocalSymbol & NamedSymbol,
params: Ls[ParamList],
body: Block,
)

/* Represents either unreachable code (for functions that must return a result)
* or the end of a non-returning function or a REPL block */
case class End(msg: Str = "") extends BlockTail with ProductWithTail
Expand Down
32 changes: 31 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -129,14 +129,15 @@ class Lowering(using TL, Raise, Elaborator.State):
case s => R(s)
val publicFlds = rest2.collect:
case td @ TermDefinition(k = (_: syntax.Val)) => td
Define(ClsLikeDefn(cls.sym, syntax.Cls,
Define(ClsLikeDefn(cls.sym, syntax.Cls, N,
mtds.flatMap: td =>
td.body.map: bod =>
val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme))
FunDefn(td.sym, paramLists, bodyBlock)
,
privateFlds,
publicFlds,
End(),
term(Blk(rest2, bodBlk.res))(ImplctRet).mapTail:
case Return(Value.Lit(syntax.Tree.UnitLit(true)), true) => End()
case t => t
Expand Down Expand Up @@ -308,6 +309,13 @@ class Lowering(using TL, Raise, Elaborator.State):
term(finallyDo)(_ => End()),
k(Value.Ref(l))
)

case Handle(lhs, rhs, defs) =>
raise(ErrorReport(
msg"Effect handlers are not enabled" ->
t.toLoc :: Nil,
source = Diagnostic.Source.Compilation))
End("error")

// * BbML-specific cases: t.Cls#field and mutable operations
case SelProj(prefix, _, proj) =>
Expand Down Expand Up @@ -470,3 +478,25 @@ trait LoweringTraceLog
) |>:
Ret(Value.Ref(resSym))
)


trait LoweringHandler
(instrument: Bool)(using TL, Raise, Elaborator.State)
extends Lowering:
override def term(t: st)(k: Result => Block)(using Subst): Block =
if !instrument then return super.term(t)(k)
t match
case st.Blk(Handle(lhs, rhs, defs) :: stmts, res) =>
val handlers = defs.map {
case HandlerTermDefinition(resumeSym, td) => td.body match
case None =>
raise(ErrorReport(msg"Handler function definitions cannot be empty" -> td.toLoc :: Nil))
N
case Some(bod) =>
val (paramLists, bodyBlock) = setupFunctionDef(td.params, bod, S(td.sym.nme))
S(Handler(td.sym, resumeSym, paramLists, bodyBlock))
}.collect{ case Some(v) => v }
val resSym = TempSymbol(S(t))
subTerm(rhs): cls =>
HandleBlock(lhs, resSym, cls, handlers, term(st.Blk(stmts, res))(HandleBlockReturn(_)), k(Value.Ref(resSym)))
case _ => super.term(t)(k)
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/codegen/Printer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ object Printer:
doc"fun ${sym.nme}${docParams} { #{ # ${docBody} #} # }"
case ValDefn(owner, k, sym, rhs) =>
doc"val ${sym.nme} = ${mkDocument(rhs)}"
case ClsLikeDefn(sym, k, methods, privateFields, publicFields, ctor) =>
case ClsLikeDefn(sym, k, parentSym, methods, privateFields, publicFields, preCtor, ctor) =>
doc"class ${sym.nme} #{ #} "

def mkDocument(arg: Arg)(using Raise, Scope): Document =
Expand Down
Loading

0 comments on commit e37d8f8

Please sign in to comment.