Skip to content

Commit

Permalink
Clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilKleistGao committed Jan 13, 2025
1 parent a42a346 commit 7fda9e6
Show file tree
Hide file tree
Showing 8 changed files with 39 additions and 51 deletions.
8 changes: 4 additions & 4 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
import hkmc2.bbml.NormalForm.*

private def freshXVar(lvl: Int, sym: Symbol, hint: Str): InfVar =
InfVar(lvl, infVarState.nextUid, new VarState(), S(false))(InstSymbol(sym)(using elState), hint)
InfVar(lvl, infVarState.nextUid, new VarState(), false)(InstSymbol(sym)(using elState), hint)

def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache, bbctx: BbCtx, cctx: CCtx, tl: TL): Type =
trace[Type](s"Extruding[${printPol(pol)}] ${ty.showDbg}", r => s"~> ${r.showDbg}"):
Expand All @@ -50,7 +50,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
Wildcard(extrude(in)(using lvl, !pol), extrude(out))
case t: Type => Wildcard(extrude(t)(using lvl, !pol), extrude(t))
})
case v @ InfVar(_, uid, state, _) if v.isSkolem => // * skolem
case v @ InfVar(_, uid, state, true) => // * skolem
cache.getOrElse(uid -> pol, {
val nv = freshXVar(lvl, v.sym, v.hint)
cache += uid -> pol -> nv
Expand All @@ -60,7 +60,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
constrainImpl(nv, state.lowerBounds.foldLeft[Type](Bot)(_ | _))
nv
})
case v @ InfVar(_, uid, _, _) =>
case v @ InfVar(_, uid, _, false) =>
cache.getOrElse(uid -> pol, {
val nv = freshXVar(lvl, v.sym, v.hint)
cache += uid -> pol -> nv
Expand Down Expand Up @@ -129,7 +129,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State,
constrainImpl(lhs.posPart, rhs.posPart)

private def inlineSkolemBounds(ty: Type, pol: Bool)(using cache: Set[Uid[InfVar]]): Type = ty.toBasic match
case v @ InfVar(_, uid, state, _) if v.isSkolem && !cache(uid) =>
case v @ InfVar(_, uid, state, skolem) if skolem && !cache(uid) =>
given Set[Uid[InfVar]] = cache + uid
inlineSkolemBounds(if pol then state.upperBounds.foldLeft[Type](v)(_ & _) else state.lowerBounds.foldLeft[Type](v)(_ | _), pol)
case ComposedType(lhs, rhs, p) => ComposedType(inlineSkolemBounds(lhs, pol), inlineSkolemBounds(rhs, pol), p)
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ object Conj:
// * Conj objects cannot be created with `new` except in this file.
// * This is because we want to sort the vars in the apply function.
def apply(i: Inter, u: Union, vars: Ls[(InfVar, Bool)]) = new Conj(i, u, vars.sortWith {
case ((v1 @ InfVar(lv1, _, _, _), _), (v2 @ InfVar(lv2, _, _, _), _)) => !(v1.isSkolem || !v2.isSkolem && lv1 <= lv2)
case ((v1 @ InfVar(lv1, _, _, sk1), _), (v2 @ InfVar(lv2, _, _, sk2), _)) => !(sk1 || !sk2 && lv1 <= lv2)
}){}
lazy val empty: Conj = Conj(Inter.empty, Union.empty, Nil)
def mkVar(v: InfVar, pol: Bool) = Conj(Inter.empty, Union.empty, (v, pol) :: Nil)
Expand Down
5 changes: 4 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ class TypeSimplifier(tl: TraceLogger):
super.apply(pol)(ty)
// traversingTVs -= tv
curPath = oldPath
case pt @ PolyType(_, outer, _) => // Avoid simplify outer variables to Top unexpectedly
posVars += outer
negVars += outer
super.apply(pol)(pt)
case _ =>
val oldPath = curPath
pastPathsSet ++= oldPath
Expand Down Expand Up @@ -148,7 +152,6 @@ class TypeSimplifier(tl: TraceLogger):
def subst(ty: GeneralType): GeneralType = trace[GeneralType](s"subst(${ty.showDbg})", r => s"= ${r.showDbg}"):
ty match
case ty if ty.lvl <= lvl => ty // TODO NOPE
case InfVar(_, _, _, N) => ty // Ignore outer variables
case _tv: IV =>
val tv = Analysis.getRepr(_tv)
log(s"Repr: ${tv.showDbg}")
Expand Down
12 changes: 6 additions & 6 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,9 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
private val solver = new ConstraintSolver(infVarState, elState, tl)

private def freshSkolem(sym: Symbol, hint: Str = "")(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), S(true))(sym, hint)
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)(sym, hint)
private def freshVar(sym: Symbol, hint: Str = "")(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), S(false))(sym, hint)
InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)(sym, hint)
private def freshWildcard(sym: Symbol)(using ctx: BbCtx) =
val in = freshVar(sym, "-")
val out = freshVar(sym, "+")
Expand All @@ -91,14 +91,14 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
private def freshReg(sym: Symbol)(using ctx: BbCtx) =
val state = new VarState()
state.upperBounds = ctx.getRegEnv.! :: Nil
InfVar(ctx.lvl + 1, infVarState.nextUid, state, S(true))(sym, "")
InfVar(ctx.lvl + 1, infVarState.nextUid, state, true)(sym, "")
private def freshOuter(sym: Symbol)(using ctx: BbCtx): InfVar =
InfVar(ctx.lvl + 1, infVarState.nextUid, new VarState(), N)(sym, "env@")
InfVar(ctx.lvl + 1, infVarState.nextUid, new VarState(), true)(sym, "env@")
private def freshEnv(sym: Symbol)(using ctx: BbCtx): InfVar =
val state = new VarState()
state.upperBounds = ctx.getRegEnv :: Nil
state.lowerBounds = ctx.getRegEnv :: Nil
InfVar(ctx.lvl, infVarState.nextUid, state, S(false))(sym, "")
InfVar(ctx.lvl, infVarState.nextUid, state, false)(sym, "")

private def error(msg: Ls[Message -> Opt[Loc]])(using BbCtx) =
raise(ErrorReport(msg))
Expand Down Expand Up @@ -393,7 +393,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL):
val (ty, ef) = ascribe(f.term, t)
resEff |= ef
(ret, resEff)
case (ft @ FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t)
case (FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t)
case (ty: PolyType, eff) => app((instantiate(ty), eff), rhs, t)
case (funTy, lhsEff) =>
val (argTy, argEff) = rhs.flatMap:
Expand Down
20 changes: 8 additions & 12 deletions hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ sealed abstract class BasicType extends Type:

def mapBasic(f: Type => Type): Type = this match
case ClassLikeType(name, targs) => ClassLikeType(name, targs.map(_.mapArg(f)))
case ft @ FunType(args, ret, eff) => FunType(args.map(f), f(ret), f(eff))
case FunType(args, ret, eff) => FunType(args.map(f), f(ret), f(eff))
case ComposedType(lhs, rhs, pol) => Type.mkComposedType(f(lhs), f(rhs), pol)
case NegType(ty) => Type.mkNegType(f(ty))
case Top | Bot | _: InfVar => this
Expand All @@ -168,9 +168,9 @@ sealed abstract class BasicType extends Type:
this match
case ClassLikeType(name, targs) =>
if targs.isEmpty then s"${name.nme}" else s"${name.nme}[${targs.map(_.show).mkString(", ")}]"
case v @ InfVar(lvl, uid, _, _) =>
case v @ InfVar(lvl, uid, _, isSkolem) =>
val name = scope.lookup(v.sym).getOrElse(scope.allocateName(v.sym, v.hint))
if v.isSkolem then name else s"'${name}"
if isSkolem then name else s"'${name}"
case FunType(arg :: Nil, ret, eff) => s"${arg.paren} ->${printEff(eff)} ${ret.paren}"
case FunType(args, ret, eff) => s"(${args.map(_.show).mkString(", ")}) ->${printEff(eff)} ${ret.paren}"
case ComposedType(lhs, rhs, pol) => s"${lhs.paren} ${if pol then "" else ""} ${rhs.paren}"
Expand All @@ -181,9 +181,9 @@ sealed abstract class BasicType extends Type:
override def showDbg: Str = this match
case ClassLikeType(name, targs) =>
if targs.isEmpty then s"${name.nme}" else s"${name.nme}[${targs.map(_.showDbg).mkString(", ")}]"
case v @ InfVar(lvl, uid, _, _) =>
case v @ InfVar(lvl, uid, _, isSkolem) =>
val name = if v.hint.isEmpty then s"${v.sym.nme}" else s"${v.sym.nme}(${v.hint})"
if v.isSkolem then s"${name}${uid}_${lvl}" else s"'${name}${uid}_${lvl}"
if isSkolem then s"${name}${uid}_${lvl}" else s"'${name}${uid}_${lvl}"
case FunType(arg :: Nil, ret, eff) => s"${arg.parenDbg} ->{${eff.showDbg}} ${ret.parenDbg}"
case FunType(args, ret, eff) => s"(${args.map(_.showDbg).mkString(", ")}) ->{${eff.showDbg}} ${ret.parenDbg}"
case ComposedType(lhs, rhs, pol) => s"${lhs.parenDbg} ${if pol then "" else ""} ${rhs.parenDbg}"
Expand Down Expand Up @@ -249,11 +249,8 @@ case class ClassLikeType(name: TypeSymbol | ModuleSymbol, targs: Ls[TypeArg]) ex
case ty: Type => ty.subst
})

// * skolemFlag: S(true) -> skolem, S(false) -> normal tv, N -> outer, always skolem
final case class InfVar(vlvl: Int, uid: Uid[InfVar], state: VarState, skolemFlag: Opt[Bool])(val sym: Symbol, val hint: Str) extends BasicType:
final case class InfVar(vlvl: Int, uid: Uid[InfVar], state: VarState, isSkolem: Bool)(val sym: Symbol, val hint: Str) extends BasicType:
override def subst(using map: Map[Uid[InfVar], InfVar]): ThisType = map.get(uid).getOrElse(this)
val isSkolem = skolemFlag.getOrElse(true)
val isOuter = skolemFlag.isEmpty

given Ordering[InfVar] = Ordering.by(_.uid)

Expand Down Expand Up @@ -322,15 +319,15 @@ case class PolyType(tvs: Ls[InfVar], outer: InfVar, body: GeneralType) extends G
// * Note that by this point, the state is supposed to be frozen/treated as immutable
// * `outer` is already skolemized when it is declared
val map = tvs.map(v =>
val sk = InfVar(lvl, nextUid, new VarState(), S(true))(v.sym, v.hint)
val sk = InfVar(lvl, nextUid, new VarState(), true)(v.sym, v.hint)
tl.log(s"skolemize ${v.showDbg} ~> ${sk.showDbg}")
v.uid -> sk
).toMap
substAndGetBody(using map)

def instantiate(nextUid: => Uid[InfVar], env: InfVar, lvl: Int)(tl: TL)(using State): GeneralType =
val map = ((outer.uid -> env) :: tvs.map(v =>
val nv = InfVar(lvl, nextUid, new VarState(), S(false))(new InstSymbol(v.sym), v.hint)
val nv = InfVar(lvl, nextUid, new VarState(), false)(new InstSymbol(v.sym), v.hint)
tl.log(s"instantiate ${v.showDbg} ~> ${nv.showDbg}")
v.uid -> nv
)).toMap
Expand All @@ -341,7 +338,6 @@ object PolyType:
val tvs = MutSet[InfVar]()
object CollectTVs extends TypeTraverser:
override def apply(pol: Boolean)(ty: GeneralType): Unit = ty match
case InfVar(_, _, _, N) => () // Ignore outer variables here
case v @ InfVar(vlvl, _, state, _) if vlvl > lvl =>
if tvs.add(v) then
state.lowerBounds.foreach: bd =>
Expand Down
39 changes: 14 additions & 25 deletions hkmc2/shared/src/test/mlscript/bbml/bbDisjoint.mls
Original file line number Diff line number Diff line change
Expand Up @@ -102,15 +102,15 @@ fun helper(r1) =
region r2 in
fork((_ => r1.ref 1), (_ => r2.ref 2))
helper
//│ Type: forall 'reg, 'reg1: Region[in 'reg out 'reg1] ->{'reg1} Pair[out Ref[Int, out 'reg1], out Ref[Int, out ¬env@helper]]
//│ Type: forall env@helper, 'reg, 'reg1: Region[in 'reg out 'reg1] ->{'reg1} Pair[out Ref[Int, out 'reg1], out Ref[Int, out ¬env@helper]]
//│ Where:
//│ 'reg <: env@helper
//│ 'reg <: 'reg1


region x in
helper(x)
//│ Type: Pair[out Ref[Int, ?], out Ref[Int, ?]]
//│ Type: Pair[out Ref[Int, ?], out Ref[Int, ]]


region x in
Expand All @@ -127,24 +127,13 @@ region x in
:e
region x in
(region y in let t = helper(x) in 42): [A] -> Int
//│ ╔══[ERROR] Type error in reference with expected type 'r1
//│ ║ l.129: (region y in let t = helper(x) in 42): [A] -> Int
//│ ║ ^
//│ ╟── because: cannot constrain Region[x] <: 'r1
//│ ╟── because: cannot constrain Region[in x out x] <: 'r1
//│ ╟── because: cannot constrain Region[in x out x] <: Region[in 'reg out 'reg1]
//│ ╟── because: cannot constrain x <: 'reg1
//│ ╟── because: cannot constrain x <: 'reg1
//│ ╟── because: cannot constrain x <: 'env
//│ ╟── because: cannot constrain x <: 'env
//│ ╙── because: cannot constrain x <: env@outer ∨ y
//│ ╔══[ERROR] Type error in region expression with expected type forall 'A: Int
//│ ║ l.129: (region y in let t = helper(x) in 42): [A] -> Int
//│ ║ ^^^^^^^^^^^^^^^
//│ ╟── because: cannot constrain 'eff <: ⊥
//│ ╟── because: cannot constrain 'eff <: ¬()
//│ ╟── because: cannot constrain (¬⊥ ∧ ¬'y1) ∧ x <: ¬()
//│ ╟── because: cannot constrain x <: 'y1
//│ ╟── because: cannot constrain (¬⊥ ∧ ¬'y) ∧ x <: ¬()
//│ ╟── because: cannot constrain x <: 'y
//│ ╙── because: cannot constrain x <: ¬()
//│ Type: Int

Expand All @@ -166,10 +155,10 @@ fun badanno: outer
:e
fun badanno2: [outer A, outer B] -> Int ->{A | B} Int
//│ ╔══[ERROR] Only one outer variable can be bound.
//│ ║ l.167: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int
//│ ║ l.156: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int
//│ ╙── ^^^^^^^^^^^^^^^^^^
//│ ╔══[ERROR] Illegal forall annotation.
//│ ║ l.167: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int
//│ ║ l.156: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int
//│ ╙── ^^^^^^^^^^^^^^^^^^
//│ ═══[ERROR] Invalid type
//│ Type: ⊤
Expand Down Expand Up @@ -198,13 +187,13 @@ fun foo(r1) =
fork((_ => r1.ref 1), (_ => r2.ref 2))
foo(r2)
//│ ╔══[ERROR] Type error in function literal
//│ ║ l.196: fun foo(r1) =
//│ ║ l.185: fun foo(r1) =
//│ ║ ^^^^^
//│ ║ l.197: region r2 in
//│ ║ l.186: region r2 in
//│ ║ ^^^^^^^^^^^^^^
//│ ║ l.198: fork((_ => r1.ref 1), (_ => r2.ref 2))
//│ ║ l.187: fork((_ => r1.ref 1), (_ => r2.ref 2))
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ║ l.199: foo(r2)
//│ ║ l.188: foo(r2)
//│ ║ ^^^^^^^^^^^
//│ ╟── because: cannot constrain 'r1 ->{'eff} 'app <: 'foo
//│ ╟── because: cannot constrain ('r1) ->{'eff} ('app) <: 'foo
Expand All @@ -220,13 +209,13 @@ fun foo(r1) =
//│ ╟── because: cannot constrain 'r22 <: ¬(¬foo)
//│ ╙── because: cannot constrain ¬foo <: ¬(¬foo)
//│ ╔══[ERROR] Type error in function literal
//│ ║ l.196: fun foo(r1) =
//│ ║ l.185: fun foo(r1) =
//│ ║ ^^^^^
//│ ║ l.197: region r2 in
//│ ║ l.186: region r2 in
//│ ║ ^^^^^^^^^^^^^^
//│ ║ l.198: fork((_ => r1.ref 1), (_ => r2.ref 2))
//│ ║ l.187: fork((_ => r1.ref 1), (_ => r2.ref 2))
//│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
//│ ║ l.199: foo(r2)
//│ ║ l.188: foo(r2)
//│ ║ ^^^^^^^^^^^
//│ ╟── because: cannot constrain 'r1 ->{'eff} 'app <: 'foo
//│ ╟── because: cannot constrain ('r1) ->{'eff} ('app) <: 'foo
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/bbml/bbRef.mls
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ let r = region x in x.ref 42

fun mkRef() = region x in x.ref 42
mkRef
//│ Type: forall : () -> Ref[Int, out ¬env@mkRef]
//│ Type: forall env@mkRef: () -> Ref[Int, out ¬env@mkRef]

:e
let t = region x in x in t.ref 42
Expand Down
2 changes: 1 addition & 1 deletion hkmc2/shared/src/test/mlscript/bbml/bbSeq.mls
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ fun mapi = s => f =>
i := !i + 1
f(!i, x)
mapi
//│ Type: forall 'A, 'E, 'app, 'eff: Seq[out 'A, out 'E] -> (((Int, 'A) ->{'eff} 'app) -> Seq[out 'app, out (¬env@mapi ∨ 'eff) ∨ 'E])
//│ Type: forall env@mapi, 'A, 'E, 'app, 'eff: Seq[out 'A, out 'E] -> (((Int, 'A) ->{'eff} 'app) -> Seq[out 'app, out (¬env@mapi ∨ 'eff) ∨ 'E])

// * This version is correct as it keeps the mutation encapsulated within the region
fun mapi_force: [A, E] -> Seq[out A, out E] -> ((Int, A) ->{E} A) ->{E} Seq[out A, Nothing]
Expand Down

0 comments on commit 7fda9e6

Please sign in to comment.