diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala index be0fa308d..186992bc9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/ConstraintSolver.scala @@ -39,7 +39,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State, import hkmc2.bbml.NormalForm.* private def freshXVar(lvl: Int, sym: Symbol, hint: Str): InfVar = - InfVar(lvl, infVarState.nextUid, new VarState(), S(false))(InstSymbol(sym)(using elState), hint) + InfVar(lvl, infVarState.nextUid, new VarState(), false)(InstSymbol(sym)(using elState), hint) def extrude(ty: Type)(using lvl: Int, pol: Bool, cache: ExtrudeCache, bbctx: BbCtx, cctx: CCtx, tl: TL): Type = trace[Type](s"Extruding[${printPol(pol)}] ${ty.showDbg}", r => s"~> ${r.showDbg}"): @@ -50,7 +50,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State, Wildcard(extrude(in)(using lvl, !pol), extrude(out)) case t: Type => Wildcard(extrude(t)(using lvl, !pol), extrude(t)) }) - case v @ InfVar(_, uid, state, _) if v.isSkolem => // * skolem + case v @ InfVar(_, uid, state, true) => // * skolem cache.getOrElse(uid -> pol, { val nv = freshXVar(lvl, v.sym, v.hint) cache += uid -> pol -> nv @@ -60,7 +60,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State, constrainImpl(nv, state.lowerBounds.foldLeft[Type](Bot)(_ | _)) nv }) - case v @ InfVar(_, uid, _, _) => + case v @ InfVar(_, uid, _, false) => cache.getOrElse(uid -> pol, { val nv = freshXVar(lvl, v.sym, v.hint) cache += uid -> pol -> nv @@ -129,7 +129,7 @@ class ConstraintSolver(infVarState: InfVarUid.State, elState: Elaborator.State, constrainImpl(lhs.posPart, rhs.posPart) private def inlineSkolemBounds(ty: Type, pol: Bool)(using cache: Set[Uid[InfVar]]): Type = ty.toBasic match - case v @ InfVar(_, uid, state, _) if v.isSkolem && !cache(uid) => + case v @ InfVar(_, uid, state, skolem) if skolem && !cache(uid) => given Set[Uid[InfVar]] = cache + uid inlineSkolemBounds(if pol then state.upperBounds.foldLeft[Type](v)(_ & _) else state.lowerBounds.foldLeft[Type](v)(_ | _), pol) case ComposedType(lhs, rhs, p) => ComposedType(inlineSkolemBounds(lhs, pol), inlineSkolemBounds(rhs, pol), p) diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala index 48c160da5..78c0588dc 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/NormalForm.scala @@ -72,7 +72,7 @@ object Conj: // * Conj objects cannot be created with `new` except in this file. // * This is because we want to sort the vars in the apply function. def apply(i: Inter, u: Union, vars: Ls[(InfVar, Bool)]) = new Conj(i, u, vars.sortWith { - case ((v1 @ InfVar(lv1, _, _, _), _), (v2 @ InfVar(lv2, _, _, _), _)) => !(v1.isSkolem || !v2.isSkolem && lv1 <= lv2) + case ((v1 @ InfVar(lv1, _, _, sk1), _), (v2 @ InfVar(lv2, _, _, sk2), _)) => !(sk1 || !sk2 && lv1 <= lv2) }){} lazy val empty: Conj = Conj(Inter.empty, Union.empty, Nil) def mkVar(v: InfVar, pol: Bool) = Conj(Inter.empty, Union.empty, (v, pol) :: Nil) diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala index cdc2590a6..256f65365 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/TypeSimplifier.scala @@ -119,6 +119,10 @@ class TypeSimplifier(tl: TraceLogger): super.apply(pol)(ty) // traversingTVs -= tv curPath = oldPath + case pt @ PolyType(_, outer, _) => // Avoid simplify outer variables to Top unexpectedly + posVars += outer + negVars += outer + super.apply(pol)(pt) case _ => val oldPath = curPath pastPathsSet ++= oldPath @@ -148,7 +152,6 @@ class TypeSimplifier(tl: TraceLogger): def subst(ty: GeneralType): GeneralType = trace[GeneralType](s"subst(${ty.showDbg})", r => s"= ${r.showDbg}"): ty match case ty if ty.lvl <= lvl => ty // TODO NOPE - case InfVar(_, _, _, N) => ty // Ignore outer variables case _tv: IV => val tv = Analysis.getRepr(_tv) log(s"Repr: ${tv.showDbg}") diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala index a088125c4..2c9d4bbdf 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/bbML.scala @@ -80,9 +80,9 @@ class BBTyper(using elState: Elaborator.State, tl: TL): private val solver = new ConstraintSolver(infVarState, elState, tl) private def freshSkolem(sym: Symbol, hint: Str = "")(using ctx: BbCtx): InfVar = - InfVar(ctx.lvl, infVarState.nextUid, new VarState(), S(true))(sym, hint) + InfVar(ctx.lvl, infVarState.nextUid, new VarState(), true)(sym, hint) private def freshVar(sym: Symbol, hint: Str = "")(using ctx: BbCtx): InfVar = - InfVar(ctx.lvl, infVarState.nextUid, new VarState(), S(false))(sym, hint) + InfVar(ctx.lvl, infVarState.nextUid, new VarState(), false)(sym, hint) private def freshWildcard(sym: Symbol)(using ctx: BbCtx) = val in = freshVar(sym, "-") val out = freshVar(sym, "+") @@ -91,14 +91,14 @@ class BBTyper(using elState: Elaborator.State, tl: TL): private def freshReg(sym: Symbol)(using ctx: BbCtx) = val state = new VarState() state.upperBounds = ctx.getRegEnv.! :: Nil - InfVar(ctx.lvl + 1, infVarState.nextUid, state, S(true))(sym, "") + InfVar(ctx.lvl + 1, infVarState.nextUid, state, true)(sym, "") private def freshOuter(sym: Symbol)(using ctx: BbCtx): InfVar = - InfVar(ctx.lvl + 1, infVarState.nextUid, new VarState(), N)(sym, "env@") + InfVar(ctx.lvl + 1, infVarState.nextUid, new VarState(), true)(sym, "env@") private def freshEnv(sym: Symbol)(using ctx: BbCtx): InfVar = val state = new VarState() state.upperBounds = ctx.getRegEnv :: Nil state.lowerBounds = ctx.getRegEnv :: Nil - InfVar(ctx.lvl, infVarState.nextUid, state, S(false))(sym, "") + InfVar(ctx.lvl, infVarState.nextUid, state, false)(sym, "") private def error(msg: Ls[Message -> Opt[Loc]])(using BbCtx) = raise(ErrorReport(msg)) @@ -393,7 +393,7 @@ class BBTyper(using elState: Elaborator.State, tl: TL): val (ty, ef) = ascribe(f.term, t) resEff |= ef (ret, resEff) - case (ft @ FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t) + case (FunType(params, ret, eff), lhsEff) => app((PolyFunType(params, ret, eff), lhsEff), rhs, t) case (ty: PolyType, eff) => app((instantiate(ty), eff), rhs, t) case (funTy, lhsEff) => val (argTy, argEff) = rhs.flatMap: diff --git a/hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala b/hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala index 3feca21cf..fb36a2cfb 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/bbml/types.scala @@ -155,7 +155,7 @@ sealed abstract class BasicType extends Type: def mapBasic(f: Type => Type): Type = this match case ClassLikeType(name, targs) => ClassLikeType(name, targs.map(_.mapArg(f))) - case ft @ FunType(args, ret, eff) => FunType(args.map(f), f(ret), f(eff)) + case FunType(args, ret, eff) => FunType(args.map(f), f(ret), f(eff)) case ComposedType(lhs, rhs, pol) => Type.mkComposedType(f(lhs), f(rhs), pol) case NegType(ty) => Type.mkNegType(f(ty)) case Top | Bot | _: InfVar => this @@ -168,9 +168,9 @@ sealed abstract class BasicType extends Type: this match case ClassLikeType(name, targs) => if targs.isEmpty then s"${name.nme}" else s"${name.nme}[${targs.map(_.show).mkString(", ")}]" - case v @ InfVar(lvl, uid, _, _) => + case v @ InfVar(lvl, uid, _, isSkolem) => val name = scope.lookup(v.sym).getOrElse(scope.allocateName(v.sym, v.hint)) - if v.isSkolem then name else s"'${name}" + if isSkolem then name else s"'${name}" case FunType(arg :: Nil, ret, eff) => s"${arg.paren} ->${printEff(eff)} ${ret.paren}" case FunType(args, ret, eff) => s"(${args.map(_.show).mkString(", ")}) ->${printEff(eff)} ${ret.paren}" case ComposedType(lhs, rhs, pol) => s"${lhs.paren} ${if pol then "∨" else "∧"} ${rhs.paren}" @@ -181,9 +181,9 @@ sealed abstract class BasicType extends Type: override def showDbg: Str = this match case ClassLikeType(name, targs) => if targs.isEmpty then s"${name.nme}" else s"${name.nme}[${targs.map(_.showDbg).mkString(", ")}]" - case v @ InfVar(lvl, uid, _, _) => + case v @ InfVar(lvl, uid, _, isSkolem) => val name = if v.hint.isEmpty then s"${v.sym.nme}" else s"${v.sym.nme}(${v.hint})" - if v.isSkolem then s"${name}${uid}_${lvl}" else s"'${name}${uid}_${lvl}" + if isSkolem then s"${name}${uid}_${lvl}" else s"'${name}${uid}_${lvl}" case FunType(arg :: Nil, ret, eff) => s"${arg.parenDbg} ->{${eff.showDbg}} ${ret.parenDbg}" case FunType(args, ret, eff) => s"(${args.map(_.showDbg).mkString(", ")}) ->{${eff.showDbg}} ${ret.parenDbg}" case ComposedType(lhs, rhs, pol) => s"${lhs.parenDbg} ${if pol then "∨" else "∧"} ${rhs.parenDbg}" @@ -249,11 +249,8 @@ case class ClassLikeType(name: TypeSymbol | ModuleSymbol, targs: Ls[TypeArg]) ex case ty: Type => ty.subst }) -// * skolemFlag: S(true) -> skolem, S(false) -> normal tv, N -> outer, always skolem -final case class InfVar(vlvl: Int, uid: Uid[InfVar], state: VarState, skolemFlag: Opt[Bool])(val sym: Symbol, val hint: Str) extends BasicType: +final case class InfVar(vlvl: Int, uid: Uid[InfVar], state: VarState, isSkolem: Bool)(val sym: Symbol, val hint: Str) extends BasicType: override def subst(using map: Map[Uid[InfVar], InfVar]): ThisType = map.get(uid).getOrElse(this) - val isSkolem = skolemFlag.getOrElse(true) - val isOuter = skolemFlag.isEmpty given Ordering[InfVar] = Ordering.by(_.uid) @@ -322,7 +319,7 @@ case class PolyType(tvs: Ls[InfVar], outer: InfVar, body: GeneralType) extends G // * Note that by this point, the state is supposed to be frozen/treated as immutable // * `outer` is already skolemized when it is declared val map = tvs.map(v => - val sk = InfVar(lvl, nextUid, new VarState(), S(true))(v.sym, v.hint) + val sk = InfVar(lvl, nextUid, new VarState(), true)(v.sym, v.hint) tl.log(s"skolemize ${v.showDbg} ~> ${sk.showDbg}") v.uid -> sk ).toMap @@ -330,7 +327,7 @@ case class PolyType(tvs: Ls[InfVar], outer: InfVar, body: GeneralType) extends G def instantiate(nextUid: => Uid[InfVar], env: InfVar, lvl: Int)(tl: TL)(using State): GeneralType = val map = ((outer.uid -> env) :: tvs.map(v => - val nv = InfVar(lvl, nextUid, new VarState(), S(false))(new InstSymbol(v.sym), v.hint) + val nv = InfVar(lvl, nextUid, new VarState(), false)(new InstSymbol(v.sym), v.hint) tl.log(s"instantiate ${v.showDbg} ~> ${nv.showDbg}") v.uid -> nv )).toMap @@ -341,7 +338,6 @@ object PolyType: val tvs = MutSet[InfVar]() object CollectTVs extends TypeTraverser: override def apply(pol: Boolean)(ty: GeneralType): Unit = ty match - case InfVar(_, _, _, N) => () // Ignore outer variables here case v @ InfVar(vlvl, _, state, _) if vlvl > lvl => if tvs.add(v) then state.lowerBounds.foreach: bd => diff --git a/hkmc2/shared/src/test/mlscript/bbml/bbDisjoint.mls b/hkmc2/shared/src/test/mlscript/bbml/bbDisjoint.mls index e885c3104..c12aae2ae 100644 --- a/hkmc2/shared/src/test/mlscript/bbml/bbDisjoint.mls +++ b/hkmc2/shared/src/test/mlscript/bbml/bbDisjoint.mls @@ -102,7 +102,7 @@ fun helper(r1) = region r2 in fork((_ => r1.ref 1), (_ => r2.ref 2)) helper -//│ Type: forall 'reg, 'reg1: Region[in 'reg out 'reg1] ->{'reg1} Pair[out Ref[Int, out 'reg1], out Ref[Int, out ¬env@helper]] +//│ Type: forall env@helper, 'reg, 'reg1: Region[in 'reg out 'reg1] ->{'reg1} Pair[out Ref[Int, out 'reg1], out Ref[Int, out ¬env@helper]] //│ Where: //│ 'reg <: env@helper //│ 'reg <: 'reg1 @@ -110,7 +110,7 @@ helper region x in helper(x) -//│ Type: Pair[out Ref[Int, ?], out Ref[Int, ?]] +//│ Type: Pair[out Ref[Int, ?], out Ref[Int, ⊥]] region x in @@ -127,24 +127,13 @@ region x in :e region x in (region y in let t = helper(x) in 42): [A] -> Int -//│ ╔══[ERROR] Type error in reference with expected type 'r1 -//│ ║ l.129: (region y in let t = helper(x) in 42): [A] -> Int -//│ ║ ^ -//│ ╟── because: cannot constrain Region[x] <: 'r1 -//│ ╟── because: cannot constrain Region[in x out x] <: 'r1 -//│ ╟── because: cannot constrain Region[in x out x] <: Region[in 'reg out 'reg1] -//│ ╟── because: cannot constrain x <: 'reg1 -//│ ╟── because: cannot constrain x <: 'reg1 -//│ ╟── because: cannot constrain x <: 'env -//│ ╟── because: cannot constrain x <: 'env -//│ ╙── because: cannot constrain x <: env@outer ∨ y //│ ╔══[ERROR] Type error in region expression with expected type forall 'A: Int //│ ║ l.129: (region y in let t = helper(x) in 42): [A] -> Int //│ ║ ^^^^^^^^^^^^^^^ //│ ╟── because: cannot constrain 'eff <: ⊥ //│ ╟── because: cannot constrain 'eff <: ¬() -//│ ╟── because: cannot constrain (¬⊥ ∧ ¬'y1) ∧ x <: ¬() -//│ ╟── because: cannot constrain x <: 'y1 +//│ ╟── because: cannot constrain (¬⊥ ∧ ¬'y) ∧ x <: ¬() +//│ ╟── because: cannot constrain x <: 'y //│ ╙── because: cannot constrain x <: ¬() //│ Type: Int @@ -166,10 +155,10 @@ fun badanno: outer :e fun badanno2: [outer A, outer B] -> Int ->{A | B} Int //│ ╔══[ERROR] Only one outer variable can be bound. -//│ ║ l.167: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int +//│ ║ l.156: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int //│ ╙── ^^^^^^^^^^^^^^^^^^ //│ ╔══[ERROR] Illegal forall annotation. -//│ ║ l.167: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int +//│ ║ l.156: fun badanno2: [outer A, outer B] -> Int ->{A | B} Int //│ ╙── ^^^^^^^^^^^^^^^^^^ //│ ═══[ERROR] Invalid type //│ Type: ⊤ @@ -198,13 +187,13 @@ fun foo(r1) = fork((_ => r1.ref 1), (_ => r2.ref 2)) foo(r2) //│ ╔══[ERROR] Type error in function literal -//│ ║ l.196: fun foo(r1) = +//│ ║ l.185: fun foo(r1) = //│ ║ ^^^^^ -//│ ║ l.197: region r2 in +//│ ║ l.186: region r2 in //│ ║ ^^^^^^^^^^^^^^ -//│ ║ l.198: fork((_ => r1.ref 1), (_ => r2.ref 2)) +//│ ║ l.187: fork((_ => r1.ref 1), (_ => r2.ref 2)) //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.199: foo(r2) +//│ ║ l.188: foo(r2) //│ ║ ^^^^^^^^^^^ //│ ╟── because: cannot constrain 'r1 ->{'eff} 'app <: 'foo //│ ╟── because: cannot constrain ('r1) ->{'eff} ('app) <: 'foo @@ -220,13 +209,13 @@ fun foo(r1) = //│ ╟── because: cannot constrain 'r22 <: ¬(¬foo) //│ ╙── because: cannot constrain ¬foo <: ¬(¬foo) //│ ╔══[ERROR] Type error in function literal -//│ ║ l.196: fun foo(r1) = +//│ ║ l.185: fun foo(r1) = //│ ║ ^^^^^ -//│ ║ l.197: region r2 in +//│ ║ l.186: region r2 in //│ ║ ^^^^^^^^^^^^^^ -//│ ║ l.198: fork((_ => r1.ref 1), (_ => r2.ref 2)) +//│ ║ l.187: fork((_ => r1.ref 1), (_ => r2.ref 2)) //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.199: foo(r2) +//│ ║ l.188: foo(r2) //│ ║ ^^^^^^^^^^^ //│ ╟── because: cannot constrain 'r1 ->{'eff} 'app <: 'foo //│ ╟── because: cannot constrain ('r1) ->{'eff} ('app) <: 'foo diff --git a/hkmc2/shared/src/test/mlscript/bbml/bbRef.mls b/hkmc2/shared/src/test/mlscript/bbml/bbRef.mls index 24d9b490d..98da20a4b 100644 --- a/hkmc2/shared/src/test/mlscript/bbml/bbRef.mls +++ b/hkmc2/shared/src/test/mlscript/bbml/bbRef.mls @@ -36,7 +36,7 @@ let r = region x in x.ref 42 fun mkRef() = region x in x.ref 42 mkRef -//│ Type: forall : () -> Ref[Int, out ¬env@mkRef] +//│ Type: forall env@mkRef: () -> Ref[Int, out ¬env@mkRef] :e let t = region x in x in t.ref 42 diff --git a/hkmc2/shared/src/test/mlscript/bbml/bbSeq.mls b/hkmc2/shared/src/test/mlscript/bbml/bbSeq.mls index 633f3bada..1a063b077 100644 --- a/hkmc2/shared/src/test/mlscript/bbml/bbSeq.mls +++ b/hkmc2/shared/src/test/mlscript/bbml/bbSeq.mls @@ -96,7 +96,7 @@ fun mapi = s => f => i := !i + 1 f(!i, x) mapi -//│ Type: forall 'A, 'E, 'app, 'eff: Seq[out 'A, out 'E] -> (((Int, 'A) ->{'eff} 'app) -> Seq[out 'app, out (¬env@mapi ∨ 'eff) ∨ 'E]) +//│ Type: forall env@mapi, 'A, 'E, 'app, 'eff: Seq[out 'A, out 'E] -> (((Int, 'A) ->{'eff} 'app) -> Seq[out 'app, out (¬env@mapi ∨ 'eff) ∨ 'E]) // * This version is correct as it keeps the mutation encapsulated within the region fun mapi_force: [A, E] -> Seq[out A, out E] -> ((Int, A) ->{E} A) ->{E} Seq[out A, Nothing]