From 5273b86fb044f39d2e90793cedeb80757a48ca23 Mon Sep 17 00:00:00 2001 From: Anson Yeung Date: Wed, 5 Nov 2025 19:23:26 +0800 Subject: [PATCH 1/5] Add test case --- .../src/test/mlscript/codegen/While.mls | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index 8ed5867acb..12faf9ee1b 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -309,3 +309,45 @@ while //│ ╙── ^^^^ +// Mix of local mutable variable, while loop and lambdas +:fixme +:expect 2 +fun f() = + let + cnt = 2 + savedLambda = null + while cnt > 0 do + let local = cnt + if cnt == 2 do + // This lambda should always return 2, + // since local is 2 and not reassigned in the while loop + savedLambda = () => local + set cnt -= 1 + // Calls the saved lambda + savedLambda() +f() +//│ ═══[RUNTIME ERROR] Expected: '2', got: '1' +//│ = 1 + +class Lazy[out A](f: () -> A) with + mut val cached: A | () = () + fun get = + if cached === () do + set cached = f() + cached + +:fixme +:expect [1, 2, 3] +let arr = [1, 2, 3] +let output = mut [] +let i = 0 +while i < arr.length do + let elem = arr.[i] + output.push(new Lazy(() => elem)) + set i += 1 +[output.[0].get, output.[1].get, output.[2].get] +//│ ═══[RUNTIME ERROR] Expected: '[1, 2, 3]', got: '[3, 3, 3]' +//│ = [3, 3, 3] +//│ arr = [1, 2, 3] +//│ i = 3 +//│ output = [Lazy(_), Lazy(_), Lazy(_)] From 8b414e92fd5bd1ab6d4f2c3279a2eeff748d2ba7 Mon Sep 17 00:00:00 2001 From: Anson Yeung Date: Mon, 10 Nov 2025 14:03:45 +0800 Subject: [PATCH 2/5] Lower while loop into tail recursive functions --- .../main/scala/hkmc2/codegen/Lowering.scala | 14 +- .../src/test/mlscript/codegen/While.mls | 128 ++++++++++-------- 2 files changed, 81 insertions(+), 61 deletions(-) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 1cce08105d..b707b7fa64 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -607,8 +607,8 @@ class Lowering()(using Config, TL, Raise, State, Ctx): usesResTmp = true new TempSymbol(S(t)) - lazy val lbl = - new TempSymbol(S(t)) + lazy val f = + new BlockMemberSymbol("while", Nil, false) def go(split: Split, topLevel: Bool)(using Subst): Block = split match case Split.Let(sym, trm, tl) => @@ -671,7 +671,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): else term_nonTail(els): r => Assign(l, r, - if isWhile && !topLevel then Continue(lbl) + if isWhile && !topLevel then Return(Call(Value.Ref(f), Nil)(true, true), false) else End() ) case Split.End => @@ -687,7 +687,13 @@ class Lowering()(using Config, TL, Raise, State, Ctx): if k.isInstanceOf[TailOp] && isIf then go(normalized, topLevel = true) else val body = if isWhile - then Label(lbl, go(normalized, topLevel = true), End()) + then blockBuilder + .assign(l, unit) + .define(FunDefn(N, f, PlainParamList(Nil) :: Nil, + Begin(go(normalized, topLevel = false), Return(unit, false)) + )) + .assign(new TempSymbol(N), Call(Value.Ref(f), Nil)(true, true)) + .end else go(normalized, topLevel = true) Begin( body, diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index 12faf9ee1b..50a40f244f 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -6,17 +6,20 @@ //│ JS (unsanitized): //│ let lambda; //│ lambda = (undefined, function () { -//│ let scrut, tmp; -//│ tmp1: while (true) { +//│ let tmp, while1, tmp1; +//│ tmp = runtime.Unit; +//│ while1 = (undefined, function () { +//│ let scrut; //│ scrut = true; //│ if (scrut === true) { //│ tmp = 0; -//│ continue tmp1 +//│ return while1() //│ } else { //│ throw globalThis.Object.freeze(new globalThis.Error("match error")) //│ } -//│ break; -//│ } +//│ return runtime.Unit +//│ }); +//│ tmp1 = while1(); //│ return tmp //│ }); //│ lambda @@ -53,17 +56,20 @@ while x set x = false else 42 //│ JS (unsanitized): -//│ let tmp4, tmp5; -//│ tmp6: while (true) { +//│ let tmp4, while3, tmp5; +//│ tmp4 = runtime.Unit; +//│ while3 = (undefined, function () { +//│ let tmp6; //│ if (x2 === true) { -//│ tmp4 = Predef.print("Hello World"); +//│ tmp6 = Predef.print("Hello World"); //│ x2 = false; -//│ tmp5 = runtime.Unit; -//│ continue tmp6 -//│ } else { tmp5 = 42; } -//│ break; -//│ } -//│ tmp5 +//│ tmp4 = runtime.Unit; +//│ return while3() +//│ } else { tmp4 = 42; } +//│ return runtime.Unit +//│ }); +//│ tmp5 = while3(); +//│ tmp4 //│ > Hello World //│ = 42 @@ -105,19 +111,24 @@ while //│ JS (unsanitized): //│ let lambda2; //│ lambda2 = (undefined, function () { -//│ let i2, scrut3, tmp19, tmp20; -//│ tmp21: while (true) { +//│ let tmp12, while7, tmp13; +//│ tmp12 = runtime.Unit; +//│ while7 = (undefined, function () { +//│ let i2, scrut, tmp14; //│ i2 = 0; -//│ scrut3 = i2 < 10; -//│ if (scrut3 === true) { -//│ tmp19 = i2 + 1; -//│ i2 = tmp19; -//│ tmp20 = runtime.Unit; -//│ continue tmp21 -//│ } else { tmp20 = runtime.Unit; } -//│ break; -//│ } -//│ return tmp20 +//│ scrut = i2 < 10; +//│ if (scrut === true) { +//│ tmp14 = i2 + 1; +//│ i2 = tmp14; +//│ tmp12 = runtime.Unit; +//│ return while7() +//│ } else { +//│ tmp12 = runtime.Unit; +//│ } +//│ return runtime.Unit +//│ }); +//│ tmp13 = while7(); +//│ return tmp12 //│ }); //│ lambda2 //│ = fun @@ -197,20 +208,25 @@ fun f(ls) = //│ JS (unsanitized): //│ let f; //│ f = function f(ls) { -//│ let param0, param1, h, tl, tmp28; -//│ tmp29: while (true) { +//│ let tmp18, while10, tmp19; +//│ tmp18 = runtime.Unit; +//│ while10 = (undefined, function () { +//│ let param0, param1, h, tl; //│ if (ls instanceof Cons1.class) { //│ param0 = ls.hd; //│ param1 = ls.tl; //│ h = param0; //│ tl = param1; //│ ls = tl; -//│ tmp28 = Predef.print(h); -//│ continue tmp29 -//│ } else { tmp28 = Predef.print("Done!"); } -//│ break; -//│ } -//│ return tmp28 +//│ tmp18 = Predef.print(h); +//│ return while10() +//│ } else { +//│ tmp18 = Predef.print("Done!"); +//│ } +//│ return runtime.Unit +//│ }); +//│ tmp19 = while10(); +//│ return tmp18 //│ }; f(0) @@ -248,15 +264,17 @@ let x = 1 :sjs while x is {} do() //│ JS (unsanitized): -//│ let tmp37; -//│ tmp38: while (true) { +//│ let tmp27, while11, tmp28; +//│ tmp27 = runtime.Unit; +//│ while11 = (undefined, function () { //│ if (x3 instanceof Object) { -//│ tmp37 = runtime.Unit; -//│ continue tmp38 -//│ } else { tmp37 = runtime.Unit; } -//│ break; -//│ } -//│ tmp37 +//│ tmp27 = runtime.Unit; +//│ return while11() +//│ } else { tmp27 = runtime.Unit; } +//│ return runtime.Unit +//│ }); +//│ tmp28 = while11(); +//│ tmp27 // ——— FIXME: ——— @@ -270,10 +288,10 @@ while print("Hello World"); false then 0(0) else 1 //│ ╔══[PARSE ERROR] Unexpected 'then' keyword here -//│ ║ l.270: then 0(0) +//│ ║ l.288: then 0(0) //│ ╙── ^^^^ //│ ╔══[ERROR] Unrecognized term split (false literal). -//│ ║ l.269: while print("Hello World"); false +//│ ║ l.287: while print("Hello World"); false //│ ╙── ^^^^^ //│ > Hello World //│ ═══[RUNTIME ERROR] Error: match error @@ -283,12 +301,12 @@ while { print("Hello World"), false } then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.282: while { print("Hello World"), false } +//│ ║ l.300: while { print("Hello World"), false } //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.283: then 0(0) +//│ ║ l.301: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.284: else 1 +//│ ║ l.302: else 1 //│ ╙── ^^^^ :fixme @@ -298,19 +316,18 @@ while then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.296: print("Hello World") +//│ ║ l.314: print("Hello World") //│ ║ ^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.297: false +//│ ║ l.315: false //│ ║ ^^^^^^^^^ -//│ ║ l.298: then 0(0) +//│ ║ l.316: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.299: else 1 +//│ ║ l.317: else 1 //│ ╙── ^^^^ // Mix of local mutable variable, while loop and lambdas -:fixme :expect 2 fun f() = let @@ -326,8 +343,7 @@ fun f() = // Calls the saved lambda savedLambda() f() -//│ ═══[RUNTIME ERROR] Expected: '2', got: '1' -//│ = 1 +//│ = 2 class Lazy[out A](f: () -> A) with mut val cached: A | () = () @@ -336,7 +352,6 @@ class Lazy[out A](f: () -> A) with set cached = f() cached -:fixme :expect [1, 2, 3] let arr = [1, 2, 3] let output = mut [] @@ -346,8 +361,7 @@ while i < arr.length do output.push(new Lazy(() => elem)) set i += 1 [output.[0].get, output.[1].get, output.[2].get] -//│ ═══[RUNTIME ERROR] Expected: '[1, 2, 3]', got: '[3, 3, 3]' -//│ = [3, 3, 3] +//│ = [1, 2, 3] //│ arr = [1, 2, 3] //│ i = 3 //│ output = [Lazy(_), Lazy(_), Lazy(_)] From cb5aef3f672042bce79be7e8a7d05237e1de662a Mon Sep 17 00:00:00 2001 From: Anson Yeung Date: Mon, 10 Nov 2025 15:42:47 +0800 Subject: [PATCH 3/5] Fix return statement inside while loop --- .../test/scala/hkmc2/CompileTestRunner.scala | 3 +- .../shared/src/main/scala/hkmc2/Config.scala | 4 +- .../main/scala/hkmc2/codegen/Lowering.scala | 134 +++++++++++------- .../scala/hkmc2/codegen/js/JSBuilder.scala | 2 - .../scala/hkmc2/semantics/Elaborator.scala | 1 + .../src/test/mlscript-compile/Runtime.mjs | 12 ++ .../src/test/mlscript-compile/Runtime.mls | 2 +- .../src/test/mlscript/codegen/While.mls | 67 +++++---- .../mlscript/handlers/UserThreadsSafe.mls | 10 +- .../mlscript/handlers/UserThreadsUnsafe.mls | 10 +- .../shared/src/test/mlscript/lifter/Loops.mls | 60 +++++--- .../ucs/general/LogicalConnectives.mls | 10 +- .../src/test/scala/hkmc2/MLsDiffMaker.scala | 1 + 13 files changed, 195 insertions(+), 121 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala index 6d04813ef2..da27841215 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala @@ -48,7 +48,8 @@ class CompileTestRunner val preludePath = mainTestDir/"mlscript"/"decls"/"Prelude.mls" - given Config = Config.default + // while loop is currently not rewritten so that stack safety works correctly as runtime relies on them. + given Config = Config.default.copy(rewriteWhileLoops = false) val compiler = MLsCompiler( preludePath, diff --git a/hkmc2/shared/src/main/scala/hkmc2/Config.scala b/hkmc2/shared/src/main/scala/hkmc2/Config.scala index 44dcbe6303..bec055fb8f 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/Config.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/Config.scala @@ -21,6 +21,7 @@ case class Config( effectHandlers: Opt[EffectHandlers], liftDefns: Opt[LiftDefns], target: CompilationTarget, + rewriteWhileLoops: Bool, ): def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety) @@ -35,7 +36,8 @@ object Config: // sanityChecks = S(SanityChecks(light = true)), effectHandlers = N, liftDefns = N, - target = CompilationTarget.JS + target = CompilationTarget.JS, + rewriteWhileLoops = true, ) case class SanityChecks(light: Bool) diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index b707b7fa64..9854127fae 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -34,7 +34,7 @@ object Thrw extends TailOp: // * No longer in meaningful use and could be removed if we don't find a use for it: -class Subst(initMap: Map[Local, Value]): +class LoweringCtx(initMap: Map[Local, Value], val mayRet: Bool): val map = initMap /* def +(kv: (Local, Value)): Subst = @@ -47,12 +47,13 @@ class Subst(initMap: Map[Local, Value]): def apply(v: Value): Value = v match case Value.Ref(l) => map.getOrElse(l, v) case _ => v -object Subst: - val empty = Subst(Map.empty) - def subst(using sub: Subst): Subst = sub -end Subst +object LoweringCtx: + val empty = LoweringCtx(Map.empty, false) + val func = LoweringCtx(Map.empty, true) + def subst(using sub: LoweringCtx): LoweringCtx = sub +end LoweringCtx -import Subst.subst +import LoweringCtx.subst class Lowering()(using Config, TL, Raise, State, Ctx): @@ -73,6 +74,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx): def unit: Path = Select(Value.Ref(State.runtimeSymbol), Tree.Ident("Unit"))(S(State.unitSymbol)) + def loopEnd: Path = + Select(Value.Ref(State.runtimeSymbol), Tree.Ident("LoopEnd"))(S(State.loopEndSymbol)) + def fail(err: ErrorReport): Block = raise(err) @@ -82,9 +86,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it type Rcd = (Bool, List[RcdArg]) - def returnedTerm(t: st)(using Subst): Block = term(t)(Ret) + def returnedTerm(t: st)(using LoweringCtx): Block = term(t)(Ret)(using LoweringCtx.func) - def parentConstructor(cls: Term, args: Ls[Term])(using Subst) = + def parentConstructor(cls: Term, args: Ls[Term])(using LoweringCtx) = if args.length > 1 then raise: ErrorReport( @@ -99,7 +103,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): )(c => Return(c, implct = true)) // * Used to work around Scala's @tailrec annotation for those few calls that are not in tail position. - final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block = + final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block = term(t: st, inStmtPos: Bool)(k) @@ -118,12 +122,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx): (imps.reverse, funs.reverse, rest.reverse) - def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block = + def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block = // TODO we should also isolate and reorder classes by inheritance topological sort val (imps, funs, rest) = splitBlock(stats, Nil, Nil, Nil) blockImpl(imps ::: funs ::: rest, res)(k) - def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block = + def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block = stats match case (t: sem.Term) :: stats => subTerm(t, inStmtPos = true): r => @@ -176,7 +180,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // Assign(td.sym, r, // term(st.Blk(stats, res))(k))) Define(ValDefn(td.tsym, td.sym, r), - blockImpl(stats, res)(k))) + blockImpl(stats, res)(k)))(using LoweringCtx.func) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) Define(FunDefn(td.owner, td.sym, paramLists, bodyBlock), @@ -287,17 +291,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx): blockImpl(stats, res)(k) - def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using Subst): Block = + def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block = arg match case S(a) => lowerCall(fr, isMlsFun, a, loc)(k) case N => // * No arguments means a nullary call, e.g., `f()` k(Call(fr, Nil)(isMlsFun, true).withLoc(loc)) - def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using Subst): Block = + def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block = lowerArg(arg)(as => k(Call(fr, as)(isMlsFun, true).withLoc(loc))) - def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using Subst): Block = + def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using LoweringCtx): Block = arg match case Tup(fs) => if fs.exists(e => e match @@ -314,7 +318,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): k(Arg(spread = S(true), ar) :: Nil) @tailrec - final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block = + final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block = tl.log(s"Lowering.term ${t.showDbg.truncate(100, "[...]")}${ if inStmtPos then " (in stmt)" else ""}${ t.resolvedSym.fold("")(" – symbol " + _)}") @@ -607,16 +611,18 @@ class Lowering()(using Config, TL, Raise, State, Ctx): usesResTmp = true new TempSymbol(S(t)) + lazy val lbl = + new TempSymbol(S(t)) lazy val f = new BlockMemberSymbol("while", Nil, false) - def go(split: Split, topLevel: Bool)(using Subst): Block = split match + def go(split: Split, topLevel: Bool)(using LoweringCtx): Block = split match case Split.Let(sym, trm, tl) => term_nonTail(trm): r => Assign(sym, r, go(tl, topLevel)) case Split.Cons(Branch(scrut, pat, tail), restSplit) => subTerm_nonTail(scrut): sr => - tl.log(s"Binding scrut $scrut to $sr (${summon[Subst].map})") + tl.log(s"Binding scrut $scrut to $sr (${summon[LoweringCtx].map})") // val cse = def mkMatch(cse: Case -> Block) = Match(sr, cse :: Nil, S(go(restSplit, topLevel = true)), @@ -631,7 +637,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // Normalization should reject cases where the user provides // more sub-patterns than there are actual class parameters. assert(argsOpt.isEmpty || args.length <= clsParams.length, (argsOpt, clsParams)) - def mkArgs(args: Ls[TermSymbol -> BlockLocalSymbol])(using Subst): Case -> Block = args match + def mkArgs(args: Ls[TermSymbol -> BlockLocalSymbol])(using LoweringCtx): Case -> Block = args match case Nil => Case.Cls(ctorSym, st) -> go(tail, topLevel = false) case (param, arg) :: args => @@ -671,7 +677,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx): else term_nonTail(els): r => Assign(l, r, - if isWhile && !topLevel then Return(Call(Value.Ref(f), Nil)(true, true), false) + if isWhile && !topLevel then + if config.rewriteWhileLoops then + Return(Call(Value.Ref(f), Nil)(true, true), false) + else Continue(lbl) else End() ) case Split.End => @@ -686,20 +695,39 @@ class Lowering()(using Config, TL, Raise, State, Ctx): if k.isInstanceOf[TailOp] && isIf then go(normalized, topLevel = true) else - val body = if isWhile - then blockBuilder - .assign(l, unit) + + val body = go(normalized, topLevel = true) + val rst = if usesResTmp then + k(Value.Ref(l)) + else + k(unit) + + if isWhile && config.rewriteWhileLoops + then + val loopResult = TempSymbol(N) + val isReturned = TempSymbol(N) + val blk = blockBuilder + .assign(l, Value.Lit(Tree.UnitLit(false))) .define(FunDefn(N, f, PlainParamList(Nil) :: Nil, - Begin(go(normalized, topLevel = false), Return(unit, false)) + Begin(go(normalized, topLevel = false), Return(loopEnd, false)) )) - .assign(new TempSymbol(N), Call(Value.Ref(f), Nil)(true, true)) - .end - else go(normalized, topLevel = true) - Begin( - body, - if usesResTmp then k(Value.Ref(l)) - else k(unit) // * it seems this currently never happens - ) + .assign(loopResult, Call(Value.Ref(f), Nil)(true, true)) + if summon[LoweringCtx].mayRet then + blk + .assign(isReturned, Call(Value.Ref(State.builtinOpsMap("!==")), + loopResult.asPath.asArg :: loopEnd.asArg :: Nil)(true, false)) + .ifthen(Value.Ref(isReturned), Case.Lit(Tree.BoolLit(true)), + Return(Value.Ref(loopResult), false), + S(rst) + ) + .end + else + blk.rest(rst) + else if isWhile + then + Label(lbl, body, rst) + else + Begin(body, rst) case sel @ Sel(prefix, nme) => setupSelection(prefix, nme, sel.sym)(k) @@ -806,17 +834,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // case _ => // subTerm(t)(k) - def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using Subst): Block = + def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN(name), args.map(_.asArg))) def setupQuotedKeyword(kw: Str): Path = Value.Ref(State.termSymbol).selSN("Keyword").selSN(kw) - def setupSymbol(symbol: Local)(k: Result => Block)(using Subst): Block = + def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block = k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN("Symbol"), Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil)) - def quotePattern(p: FlatPattern)(k: Result => Block)(using Subst): Block = p match + def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit) :: Nil)(k) case _ => // TODO fail: @@ -826,7 +854,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): source = Diagnostic.Source.Compilation ) - def quoteSplit(split: Split)(k: Result => Block)(using Subst): Block = split match + def quoteSplit(split: Split)(k: Result => Block)(using LoweringCtx): Block = split match case Split.Cons(Branch(scrutinee, pattern, continuation), tail) => quote(scrutinee): r1 => val l1, l2, l3, l4, l5 = new TempSymbol(N) blockBuilder.assign(l1, r1) @@ -851,7 +879,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val state = summon[State] Value.Ref(state.importSymbol).selSN("meta").selSN("url") - def quote(t: st)(k: Result => Block)(using Subst): Block = t match + def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match case Lit(lit) => setupTerm("Lit", Value.Lit(lit) :: Nil)(k) case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols @@ -882,7 +910,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): source = Diagnostic.Source.Compilation ) case Lam(params, body) => - def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using Subst): Block = ps match + def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match case Nil => quote(body): r => val l = new TempSymbol(N) val arr = new TempSymbol(N, "arr") @@ -944,7 +972,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): source = Diagnostic.Source.Compilation ) - def gatherMembers(clsBody: ObjBody)(using Subst) + def gatherMembers(clsBody: ObjBody)(using LoweringCtx) : (Ls[FunDefn], Ls[BlockMemberSymbol -> TermSymbol], Ls[TermSymbol], Block) = val mtds = clsBody.methods .flatMap: td => @@ -967,7 +995,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): /** Compile the pattern definition into `unapply` and `unapplyStringPrefix` * methods using the `NaiveCompiler`, which transliterate the pattern into * UCS splits that backtrack without any optimizations. */ - def compilePatternMethods(defn: PatternDef)(using Subst): + def compilePatternMethods(defn: PatternDef)(using LoweringCtx): // The return type is intended to be consistent with `gatherMembers` (Ls[FunDefn], Ls[BlockMemberSymbol -> TermSymbol], Ls[TermSymbol], Block) = val compiler = new ups.NaiveCompiler @@ -979,7 +1007,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): FunDefn(td.owner, td.sym, paramLists, bodyBlock) (mtds, Nil, Nil, End()) - def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using Subst): Block = + def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using LoweringCtx): Block = val as = elems.map: case sem.Fld(sem.FldFlags.benign(), value, N) => R(N -> value) case sem.Fld(sem.FldFlags.benign(), idx, S(rhs)) => L(idx -> rhs) @@ -1021,10 +1049,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx): k((Arg(N, Value.Ref(rcdSym)) :: asr).reverse))) - inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using Subst): Block = + inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using LoweringCtx): Block = subTerms(ts)(asr => k(asr.map(Arg(N, _)))) - inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using Subst): Block = + inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using LoweringCtx): Block = // @tailrec // TODO def rec(as: Ls[st], asr: Ls[Path]): Block = as match case Nil => k(asr.reverse) @@ -1033,10 +1061,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx): rec(as, ar :: asr) rec(ts, Nil) - def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block = + def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block = subTerm(t: st, inStmtPos: Bool)(k) - inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block = + inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block = term(t, inStmtPos = inStmtPos): case v: Value => k(v) case p: Path => k(p) @@ -1053,7 +1081,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): val (imps, funs, rest) = splitBlock(main.stats, Nil, Nil, Nil) - val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using Subst.empty) + val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using LoweringCtx.empty) val desug = LambdaRewriter.desugar(blk) @@ -1082,20 +1110,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx): ) - def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block = + def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block = subTerm(prefix): p => val selRes = TempSymbol(N, "selRes") k(Select(p, nme)(sym)) final def setupFunctionOrByNameDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) - (using Subst): (List[ParamList], Block) = + (using LoweringCtx): (List[ParamList], Block) = val physicalParams = paramLists match case Nil => ParamList(ParamListFlags.empty, Nil, N) :: Nil case ps => ps setupFunctionDef(physicalParams, bodyTerm, name) def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str]) - (using Subst): (List[ParamList], Block) = + (using LoweringCtx): (List[ParamList], Block) = (paramLists, returnedTerm(bodyTerm)) def reportAnnotations(target: Statement, annotations: Ls[Annot]): Unit = @@ -1110,7 +1138,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State) private val instrument: Bool = config.sanityChecks.isDefined - override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block = + override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block = if !instrument then return super.setupSelection(prefix, nme, sym)(k) subTerm(prefix): p => val selRes = TempSymbol(N, "selRes") @@ -1157,7 +1185,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) override def setupFunctionDef(paramLists: List[ParamList], bodyTerm: st, name: Option[Str]) - (using Subst): (List[ParamList], Block) = + (using LoweringCtx): (List[ParamList], Block) = if instrument then val (ps, bod) = handleMultipleParamLists(paramLists, bodyTerm) val instrumentedBody = setupFunctionBody(ps, bod, name) @@ -1173,7 +1201,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) case h :: t => go(t, Term.Lam(h, bod)) go(paramLists.reverse, bod) - def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using Subst): Block = + def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using LoweringCtx): Block = val enterMsgSym = TempSymbol(N, dbgNme = "traceLogEnterMsg") val prevIndentLvlSym = TempSymbol(N, dbgNme = "traceLogPrevIndent") val resSym = TempSymbol(N, dbgNme = "traceLogRes") @@ -1209,7 +1237,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) TempSymbol(N) -> pureCall(traceLogFn, Arg(N, Value.Ref(retMsgSym)) :: Nil) ) |>: Ret(Value.Ref(resSym)) - ) + )(using LoweringCtx.func) object TrivialStatementsAndMatch: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index cf02ddcddf..663593e975 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -490,8 +490,6 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder: case Label(lbl, bod, rst) => scope.allocateName(lbl) - // [fixme:0] TODO check scope and allocate local variables here (see: https://github.com/hkust-taco/mlscript/pull/293#issuecomment-2792229849) - doc" # ${getVar(lbl, lbl.toLoc)}: while (true) " :: braced { returningTerm(bod, endSemi = true) :/: doc"break;" } :: returningTerm(rst, endSemi) diff --git a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala index 2982c96c8c..ba70ab1335 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/semantics/Elaborator.scala @@ -227,6 +227,7 @@ object Elaborator: given State = this val globalThisSymbol = TopLevelSymbol("globalThis") val unitSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("Unit")) + val loopEndSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("LoopEnd")) // In JavaScript, `import` can be used for getting current file path, as `import.meta` val importSymbol = new VarSymbol(Ident("import")) val runtimeSymbol = TempSymbol(N, "runtime") diff --git a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs index 4f92f5f27a..6139f8e021 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs +++ b/hkmc2/shared/src/test/mlscript-compile/Runtime.mjs @@ -42,6 +42,18 @@ globalThis.Object.freeze(class Runtime { [prettyPrint]() { return this.toString(); } static [definitionMetadata] = ["object", "Unit"]; }); + globalThis.Object.freeze(class LoopEnd { + static { + Runtime.LoopEnd = globalThis.Object.freeze(new this) + } + constructor() { + Object.defineProperty(this, "class", { + value: LoopEnd + }) + } + toString() { return runtime.render(this); } + static [definitionMetadata] = ["object", "LoopEnd"]; + }); this.short_and = RuntimeJS.short_and; this.short_or = RuntimeJS.short_or; this.bitand = RuntimeJS.bitand; diff --git a/hkmc2/shared/src/test/mlscript-compile/Runtime.mls b/hkmc2/shared/src/test/mlscript-compile/Runtime.mls index b91e3d303e..9c7d9a856a 100644 --- a/hkmc2/shared/src/test/mlscript-compile/Runtime.mls +++ b/hkmc2/shared/src/test/mlscript-compile/Runtime.mls @@ -10,7 +10,7 @@ module Runtime with ... object Unit with fun toString() = "()" - +object LoopEnd val short_and = RuntimeJS.short_and val short_or = RuntimeJS.short_or diff --git a/hkmc2/shared/src/test/mlscript/codegen/While.mls b/hkmc2/shared/src/test/mlscript/codegen/While.mls index 50a40f244f..f97cf69849 100644 --- a/hkmc2/shared/src/test/mlscript/codegen/While.mls +++ b/hkmc2/shared/src/test/mlscript/codegen/While.mls @@ -6,8 +6,8 @@ //│ JS (unsanitized): //│ let lambda; //│ lambda = (undefined, function () { -//│ let tmp, while1, tmp1; -//│ tmp = runtime.Unit; +//│ let tmp, while1, tmp1, tmp2; +//│ tmp = undefined; //│ while1 = (undefined, function () { //│ let scrut; //│ scrut = true; @@ -17,10 +17,11 @@ //│ } else { //│ throw globalThis.Object.freeze(new globalThis.Error("match error")) //│ } -//│ return runtime.Unit +//│ return runtime.LoopEnd //│ }); //│ tmp1 = while1(); -//│ return tmp +//│ tmp2 = tmp1 !== runtime.LoopEnd; +//│ if (tmp2 === true) { return tmp1 } else { return tmp } //│ }); //│ lambda //│ = fun @@ -57,7 +58,7 @@ while x else 42 //│ JS (unsanitized): //│ let tmp4, while3, tmp5; -//│ tmp4 = runtime.Unit; +//│ tmp4 = undefined; //│ while3 = (undefined, function () { //│ let tmp6; //│ if (x2 === true) { @@ -66,7 +67,7 @@ while x //│ tmp4 = runtime.Unit; //│ return while3() //│ } else { tmp4 = 42; } -//│ return runtime.Unit +//│ return runtime.LoopEnd //│ }); //│ tmp5 = while3(); //│ tmp4 @@ -111,24 +112,25 @@ while //│ JS (unsanitized): //│ let lambda2; //│ lambda2 = (undefined, function () { -//│ let tmp12, while7, tmp13; -//│ tmp12 = runtime.Unit; +//│ let tmp12, while7, tmp13, tmp14; +//│ tmp12 = undefined; //│ while7 = (undefined, function () { -//│ let i2, scrut, tmp14; +//│ let i2, scrut, tmp15; //│ i2 = 0; //│ scrut = i2 < 10; //│ if (scrut === true) { -//│ tmp14 = i2 + 1; -//│ i2 = tmp14; +//│ tmp15 = i2 + 1; +//│ i2 = tmp15; //│ tmp12 = runtime.Unit; //│ return while7() //│ } else { //│ tmp12 = runtime.Unit; //│ } -//│ return runtime.Unit +//│ return runtime.LoopEnd //│ }); //│ tmp13 = while7(); -//│ return tmp12 +//│ tmp14 = tmp13 !== runtime.LoopEnd; +//│ if (tmp14 === true) { return tmp13 } else { return tmp12 } //│ }); //│ lambda2 //│ = fun @@ -208,8 +210,8 @@ fun f(ls) = //│ JS (unsanitized): //│ let f; //│ f = function f(ls) { -//│ let tmp18, while10, tmp19; -//│ tmp18 = runtime.Unit; +//│ let tmp18, while10, tmp19, tmp20; +//│ tmp18 = undefined; //│ while10 = (undefined, function () { //│ let param0, param1, h, tl; //│ if (ls instanceof Cons1.class) { @@ -223,10 +225,11 @@ fun f(ls) = //│ } else { //│ tmp18 = Predef.print("Done!"); //│ } -//│ return runtime.Unit +//│ return runtime.LoopEnd //│ }); //│ tmp19 = while10(); -//│ return tmp18 +//│ tmp20 = tmp19 !== runtime.LoopEnd; +//│ if (tmp20 === true) { return tmp19 } else { return tmp18 } //│ }; f(0) @@ -265,13 +268,13 @@ let x = 1 while x is {} do() //│ JS (unsanitized): //│ let tmp27, while11, tmp28; -//│ tmp27 = runtime.Unit; +//│ tmp27 = undefined; //│ while11 = (undefined, function () { //│ if (x3 instanceof Object) { //│ tmp27 = runtime.Unit; //│ return while11() //│ } else { tmp27 = runtime.Unit; } -//│ return runtime.Unit +//│ return runtime.LoopEnd //│ }); //│ tmp28 = while11(); //│ tmp27 @@ -288,10 +291,10 @@ while print("Hello World"); false then 0(0) else 1 //│ ╔══[PARSE ERROR] Unexpected 'then' keyword here -//│ ║ l.288: then 0(0) +//│ ║ l.291: then 0(0) //│ ╙── ^^^^ //│ ╔══[ERROR] Unrecognized term split (false literal). -//│ ║ l.287: while print("Hello World"); false +//│ ║ l.290: while print("Hello World"); false //│ ╙── ^^^^^ //│ > Hello World //│ ═══[RUNTIME ERROR] Error: match error @@ -301,12 +304,12 @@ while { print("Hello World"), false } then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.300: while { print("Hello World"), false } +//│ ║ l.303: while { print("Hello World"), false } //│ ║ ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.301: then 0(0) +//│ ║ l.304: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.302: else 1 +//│ ║ l.305: else 1 //│ ╙── ^^^^ :fixme @@ -316,14 +319,14 @@ while then 0(0) else 1 //│ ╔══[ERROR] Unexpected infix use of keyword 'then' here -//│ ║ l.314: print("Hello World") +//│ ║ l.317: print("Hello World") //│ ║ ^^^^^^^^^^^^^^^^^^^^ -//│ ║ l.315: false +//│ ║ l.318: false //│ ║ ^^^^^^^^^ -//│ ║ l.316: then 0(0) +//│ ║ l.319: then 0(0) //│ ╙── ^^^^^^^^^^^ //│ ╔══[ERROR] Illegal position for prefix keyword 'else'. -//│ ║ l.317: else 1 +//│ ║ l.320: else 1 //│ ╙── ^^^^ @@ -365,3 +368,11 @@ while i < arr.length do //│ arr = [1, 2, 3] //│ i = 3 //│ output = [Lazy(_), Lazy(_), Lazy(_)] + +// Returning inside while loop +:expect 42 +fun f() = + while true do + return 42 +f() +//│ = 42 diff --git a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls index c240949e57..26059f2a32 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsSafe.mls @@ -44,7 +44,8 @@ in //│ > f 2 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ //│ at f (UserThreadsSafe.mls:17:3) -//│ at drain (UserThreadsSafe.mls:10:7) +//│ at while (UserThreadsSafe.mls:10:7) +//│ at drain (pc=6) //│ at fork (UserThreadsSafe.mls:35:5) //│ at fork (UserThreadsSafe.mls:35:5) @@ -66,8 +67,9 @@ in //│ > f 0 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$1 //│ at f (UserThreadsSafe.mls:17:3) -//│ at drain (UserThreadsSafe.mls:10:7) -//│ at fork (UserThreadsSafe.mls:57:5) -//│ at fork (UserThreadsSafe.mls:57:5) +//│ at while (UserThreadsSafe.mls:10:7) +//│ at drain (pc=6) +//│ at fork (UserThreadsSafe.mls:58:5) +//│ at fork (UserThreadsSafe.mls:58:5) diff --git a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls index 402bfe41c3..c22ef1eadf 100644 --- a/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls +++ b/hkmc2/shared/src/test/mlscript/handlers/UserThreadsUnsafe.mls @@ -47,7 +47,8 @@ in //│ > f 2 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ //│ at f (UserThreadsUnsafe.mls:12:3) -//│ at drain (UserThreadsUnsafe.mls:29:5) +//│ at while (UserThreadsUnsafe.mls:29:5) +//│ at drain (pc=6) //│ at fork (UserThreadsUnsafe.mls:38:5) //│ at fork (UserThreadsUnsafe.mls:38:5) @@ -69,8 +70,9 @@ in //│ > f 1 //│ ═══[RUNTIME ERROR] Error: Unhandled effect Handler$h$ //│ at f (UserThreadsUnsafe.mls:12:3) -//│ at drain (UserThreadsUnsafe.mls:29:5) -//│ at fork (UserThreadsUnsafe.mls:60:5) -//│ at fork (UserThreadsUnsafe.mls:60:5) +//│ at while (UserThreadsUnsafe.mls:29:5) +//│ at drain (pc=6) +//│ at fork (UserThreadsUnsafe.mls:61:5) +//│ at fork (UserThreadsUnsafe.mls:61:5) diff --git a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls index a49b86a560..46f9fd6c64 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls @@ -27,13 +27,11 @@ fun foo() = set i += 1 fs.push of () => x -// * Note that this also fails without lifting, as we need to fix [fixme:0] -:fixme +// Usage of while loop local varable with capturing :expect 1 foo() fs.0() -//│ ═══[RUNTIME ERROR] Expected: '1', got: '4' -//│ = 4 +//│ = 1 :sjs @@ -43,29 +41,47 @@ fun foo() = set x += 1 return () => x //│ JS (unsanitized): -//│ let foo2, lambda2, lambda$2; -//│ lambda$2 = function lambda$(x) { -//│ return x +//│ let foo2, lambda2, while$2, lambda$2, foo$capture5; +//│ lambda$2 = function lambda$(foo$capture6) { +//│ return foo$capture6.x$capture$0 //│ }; -//│ lambda2 = (undefined, function (x) { +//│ lambda2 = (undefined, function (foo$capture6) { //│ return () => { -//│ return lambda$2(x) +//│ return lambda$2(foo$capture6) //│ } //│ }); -//│ foo2 = function foo() { -//│ let x, scrut, tmp2, tmp3, lambda$here; -//│ x = 1; -//│ tmp4: while (true) { -//│ scrut = true; -//│ if (scrut === true) { -//│ tmp2 = x + 1; -//│ x = tmp2; -//│ lambda$here = runtime.safeCall(lambda2(x)); -//│ return lambda$here -//│ } else { tmp3 = runtime.Unit; } -//│ break; +//│ while$2 = function while$(foo$capture6) { +//│ let scrut, tmp2, lambda$here; +//│ scrut = true; +//│ if (scrut === true) { +//│ tmp2 = foo$capture6.x$capture$0 + 1; +//│ foo$capture6.x$capture$0 = tmp2; +//│ lambda$here = runtime.safeCall(lambda2(foo$capture6)); +//│ return lambda$here +//│ } else { +//│ foo$capture6.tmp$capture$1 = runtime.Unit; +//│ } +//│ return runtime.LoopEnd +//│ }; +//│ globalThis.Object.freeze(class foo$capture4 { +//│ static { +//│ foo$capture5 = this //│ } -//│ return tmp3 +//│ constructor(x$capture$0, tmp$capture$1) { +//│ this.tmp$capture$1 = tmp$capture$1; +//│ this.x$capture$0 = x$capture$0; +//│ } +//│ toString() { return runtime.render(this); } +//│ static [definitionMetadata] = ["class", "foo$capture"]; +//│ }); +//│ foo2 = function foo() { +//│ let tmp2, tmp3, capture; +//│ capture = new foo$capture5(null, null); +//│ capture.x$capture$0 = 1; +//│ capture.tmp$capture$1 = undefined; +//│ tmp2 = while$2(capture); +//│ tmp3 = tmp2 !== runtime.LoopEnd; +//│ if (tmp3 === true) { return tmp2 } else { return capture.tmp$capture$1 } //│ }; :expect 2 diff --git a/hkmc2/shared/src/test/mlscript/ucs/general/LogicalConnectives.mls b/hkmc2/shared/src/test/mlscript/ucs/general/LogicalConnectives.mls index b770c8d3af..c4ee348716 100644 --- a/hkmc2/shared/src/test/mlscript/ucs/general/LogicalConnectives.mls +++ b/hkmc2/shared/src/test/mlscript/ucs/general/LogicalConnectives.mls @@ -25,11 +25,11 @@ fun test(x) = :sjs true and test(42) //│ JS (unsanitized): -//│ let scrut6, scrut7; -//│ scrut6 = true; -//│ if (scrut6 === true) { -//│ scrut7 = test(42); -//│ if (scrut7 === true) { true } else { false } +//│ let scrut4, scrut5; +//│ scrut4 = true; +//│ if (scrut4 === true) { +//│ scrut5 = test(42); +//│ if (scrut5 === true) { true } else { false } //│ } else { false } //│ > 42 //│ = false diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index 5e8323dc74..9bc75bdcdf 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -94,6 +94,7 @@ abstract class MLsDiffMaker extends DiffMaker: )), liftDefns = Opt.when(liftDefns.isSet)(LiftDefns()), target = if wasm.isSet then CompilationTarget.Wasm else CompilationTarget.JS, + rewriteWhileLoops = true, ) From 366575dc611799a6cce70db95b634faad2e5356f Mon Sep 17 00:00:00 2001 From: Anson Yeung Date: Tue, 11 Nov 2025 01:48:15 +0800 Subject: [PATCH 4/5] Address PR comments --- .../jvm/src/test/scala/hkmc2/CompileTestRunner.scala | 3 ++- .../src/main/scala/hkmc2/codegen/Lowering.scala | 12 ++++++------ .../src/main/scala/hkmc2/codegen/js/JSBuilder.scala | 2 ++ 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala b/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala index da27841215..6587504b5b 100644 --- a/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala +++ b/hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala @@ -48,7 +48,8 @@ class CompileTestRunner val preludePath = mainTestDir/"mlscript"/"decls"/"Prelude.mls" - // while loop is currently not rewritten so that stack safety works correctly as runtime relies on them. + // Stack safety relies on the fact that runtime uses while loops for resumption + // and does not create extra stack depth. Hence we disable while loop rewriting here. given Config = Config.default.copy(rewriteWhileLoops = false) val compiler = MLsCompiler( diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala index 15fb6b2474..ea6ec6d2c9 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala @@ -51,8 +51,8 @@ class LoweringCtx(initMap: Map[Local, Value], val mayRet: Bool): case _ => v object LoweringCtx: val empty = LoweringCtx(Map.empty, false) - val func = LoweringCtx(Map.empty, true) def subst(using sub: LoweringCtx): LoweringCtx = sub + def nestFunc(using sub: LoweringCtx): LoweringCtx = LoweringCtx(sub.map, true) end LoweringCtx import LoweringCtx.subst @@ -84,7 +84,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it type Rcd = (Bool, List[RcdArg]) - def returnedTerm(t: st)(using LoweringCtx): Block = term(t)(Ret)(using LoweringCtx.func) + def returnedTerm(t: st)(using LoweringCtx): Block = term(t)(Ret)(using LoweringCtx.nestFunc) def parentConstructor(cls: Term, args: Ls[Term])(using LoweringCtx) = if args.length > 1 then @@ -178,7 +178,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx): // Assign(td.sym, r, // term(st.Blk(stats, res))(k))) Define(ValDefn(td.tsym, td.sym, r), - blockImpl(stats, res)(k)))(using LoweringCtx.func) + blockImpl(stats, res)(k)))(using LoweringCtx.nestFunc) case syntax.Fun => val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme)) Define(FunDefn(td.owner, td.sym, paramLists, bodyBlock), @@ -572,9 +572,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx): case iftrm: st.IfLike => ucs.Normalization(this)(iftrm)(k) - - case iftrm: st.SynthIf => ucs.Normalization(this)(iftrm)(k) + case iftrm: st.SynthIf => ucs.Normalization(this)(iftrm)(k) + case sel @ Sel(prefix, nme) => setupSelection(prefix, nme, sel.sym)(k) @@ -1073,7 +1073,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State) TempSymbol(N) -> pureCall(traceLogFn, Arg(N, Value.Ref(retMsgSym)) :: Nil) ) |>: Ret(Value.Ref(resSym)) - )(using LoweringCtx.func) + )(using LoweringCtx.nestFunc) object TrivialStatementsAndMatch: diff --git a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala index c8a5a56ea5..f9a6858107 100644 --- a/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala +++ b/hkmc2/shared/src/main/scala/hkmc2/codegen/js/JSBuilder.scala @@ -490,6 +490,8 @@ class JSBuilder(using TL, State, Ctx) extends CodeBuilder: case Label(lbl, loop, bod, rst) => scope.allocateName(lbl) + // [fixme:0] TODO check scope and allocate local variables here (see: https://github.com/hkust-taco/mlscript/pull/293#issuecomment-2792229849) + doc" # ${getVar(lbl, lbl.toLoc)}:${if loop then doc" while (true)" else ""} " :: braced { returningTerm(bod, endSemi = true) :: (if loop then doc" # break;" else doc"") } :: returningTerm(rst, endSemi) From 9dc68404a9ac264d1f66de745742d0ac5f8f2ebe Mon Sep 17 00:00:00 2001 From: Anson Yeung Date: Tue, 11 Nov 2025 01:57:16 +0800 Subject: [PATCH 5/5] Implement dontRewriteWhile for testing lifter --- hkmc2/shared/src/test/mlscript/lifter/Loops.mls | 14 +++++++++----- .../src/test/scala/hkmc2/MLsDiffMaker.scala | 3 ++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls index 46f9fd6c64..f3202dd07c 100644 --- a/hkmc2/shared/src/test/mlscript/lifter/Loops.mls +++ b/hkmc2/shared/src/test/mlscript/lifter/Loops.mls @@ -20,6 +20,7 @@ fs.0() let fs = mut [] //│ fs = [] +:dontRewriteWhile fun foo() = let i = 1 while i < 5 do @@ -27,11 +28,14 @@ fun foo() = set i += 1 fs.push of () => x -// Usage of while loop local varable with capturing +// * Note that this works with while loop rewriting +// * See [fixme:0] for cause of the issue +:fixme :expect 1 foo() fs.0() -//│ = 1 +//│ ═══[RUNTIME ERROR] Expected: '1', got: '4' +//│ = 4 :sjs @@ -41,7 +45,7 @@ fun foo() = set x += 1 return () => x //│ JS (unsanitized): -//│ let foo2, lambda2, while$2, lambda$2, foo$capture5; +//│ let foo2, lambda2, while$1, lambda$2, foo$capture5; //│ lambda$2 = function lambda$(foo$capture6) { //│ return foo$capture6.x$capture$0 //│ }; @@ -50,7 +54,7 @@ fun foo() = //│ return lambda$2(foo$capture6) //│ } //│ }); -//│ while$2 = function while$(foo$capture6) { +//│ while$1 = function while$(foo$capture6) { //│ let scrut, tmp2, lambda$here; //│ scrut = true; //│ if (scrut === true) { @@ -79,7 +83,7 @@ fun foo() = //│ capture = new foo$capture5(null, null); //│ capture.x$capture$0 = 1; //│ capture.tmp$capture$1 = undefined; -//│ tmp2 = while$2(capture); +//│ tmp2 = while$1(capture); //│ tmp3 = tmp2 !== runtime.LoopEnd; //│ if (tmp3 === true) { return tmp2 } else { return capture.tmp$capture$1 } //│ }; diff --git a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala index 6ca53da8d1..f527cb10d8 100644 --- a/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala +++ b/hkmc2DiffTests/src/test/scala/hkmc2/MLsDiffMaker.scala @@ -71,6 +71,7 @@ abstract class MLsDiffMaker extends DiffMaker: val liftDefns = NullaryCommand("lift") val importQQ = NullaryCommand("qq") val stageCode = NullaryCommand("staging") + val dontRewriteWhile = NullaryCommand("dontRewriteWhile") def mkConfig: Config = import Config.* @@ -98,7 +99,7 @@ abstract class MLsDiffMaker extends DiffMaker: liftDefns = Opt.when(liftDefns.isSet)(LiftDefns()), stageCode = stageCode.isSet, target = if wasm.isSet then CompilationTarget.Wasm else CompilationTarget.JS, - rewriteWhileLoops = true, + rewriteWhileLoops = !dontRewriteWhile.isSet, )