Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion hkmc2/jvm/src/test/scala/hkmc2/CompileTestRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ class CompileTestRunner

val preludePath = mainTestDir/"mlscript"/"decls"/"Prelude.mls"

given Config = Config.default
// Stack safety relies on the fact that runtime uses while loops for resumption
// and does not create extra stack depth. Hence we disable while loop rewriting here.
given Config = Config.default.copy(rewriteWhileLoops = false)

val compiler = MLsCompiler(
preludePath,
Expand Down
4 changes: 3 additions & 1 deletion hkmc2/shared/src/main/scala/hkmc2/Config.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ case class Config(
liftDefns: Opt[LiftDefns],
stageCode: Bool,
target: CompilationTarget,
rewriteWhileLoops: Bool,
):

def stackSafety: Opt[StackSafety] = effectHandlers.flatMap(_.stackSafety)
Expand All @@ -36,8 +37,9 @@ object Config:
// sanityChecks = S(SanityChecks(light = true)),
effectHandlers = N,
liftDefns = N,
target = CompilationTarget.JS,
rewriteWhileLoops = true,
stageCode = false,
target = CompilationTarget.JS
)

case class SanityChecks(light: Bool)
Expand Down
74 changes: 37 additions & 37 deletions hkmc2/shared/src/main/scala/hkmc2/codegen/Lowering.scala
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ object Thrw extends TailOp:


// * No longer in meaningful use and could be removed if we don't find a use for it:
class Subst(initMap: Map[Local, Value]):
class LoweringCtx(initMap: Map[Local, Value], val mayRet: Bool):
val map = initMap
/*
def +(kv: (Local, Value)): Subst =
Expand All @@ -49,12 +49,13 @@ class Subst(initMap: Map[Local, Value]):
def apply(v: Value): Value = v match
case Value.Ref(l) => map.getOrElse(l, v)
case _ => v
object Subst:
val empty = Subst(Map.empty)
def subst(using sub: Subst): Subst = sub
end Subst
object LoweringCtx:
val empty = LoweringCtx(Map.empty, false)
def subst(using sub: LoweringCtx): LoweringCtx = sub
def nestFunc(using sub: LoweringCtx): LoweringCtx = LoweringCtx(sub.map, true)
end LoweringCtx

import Subst.subst
import LoweringCtx.subst


class Lowering()(using Config, TL, Raise, State, Ctx):
Expand All @@ -75,7 +76,6 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
def unit: Path =
Select(Value.Ref(State.runtimeSymbol), Tree.Ident("Unit"))(S(State.unitSymbol))


def fail(err: ErrorReport): Block =
raise(err)
End("error")
Expand All @@ -84,9 +84,9 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
// type Rcd = (mut: Bool, args: List[RcdArg]) // * Better, but Scala's patmat exhaustiveness chokes on it
type Rcd = (Bool, List[RcdArg])

def returnedTerm(t: st)(using Subst): Block = term(t)(Ret)
def returnedTerm(t: st)(using LoweringCtx): Block = term(t)(Ret)(using LoweringCtx.nestFunc)

def parentConstructor(cls: Term, args: Ls[Term])(using Subst) =
def parentConstructor(cls: Term, args: Ls[Term])(using LoweringCtx) =
if args.length > 1 then
raise:
ErrorReport(
Expand All @@ -101,7 +101,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
)(c => Return(c, implct = true))

// * Used to work around Scala's @tailrec annotation for those few calls that are not in tail position.
final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block =
final def term_nonTail(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block =
term(t: st, inStmtPos: Bool)(k)


Expand All @@ -120,12 +120,12 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
(imps.reverse, funs.reverse, rest.reverse)


def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block =
def block(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block =
// TODO we should also isolate and reorder classes by inheritance topological sort
val (imps, funs, rest) = splitBlock(stats, Nil, Nil, Nil)
blockImpl(imps ::: funs ::: rest, res)(k)

def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using Subst): Block =
def blockImpl(stats: Ls[Statement], res: Rcd \/ Term)(k: Result => Block)(using LoweringCtx): Block =
stats match
case (t: sem.Term) :: stats =>
subTerm(t, inStmtPos = true): r =>
Expand Down Expand Up @@ -178,7 +178,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
// Assign(td.sym, r,
// term(st.Blk(stats, res))(k)))
Define(ValDefn(td.tsym, td.sym, r),
blockImpl(stats, res)(k)))
blockImpl(stats, res)(k)))(using LoweringCtx.nestFunc)
case syntax.Fun =>
val (paramLists, bodyBlock) = setupFunctionOrByNameDef(td.params, bod, S(td.sym.nme))
Define(FunDefn(td.owner, td.sym, paramLists, bodyBlock),
Expand Down Expand Up @@ -302,17 +302,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
blockImpl(stats, res)(k)


def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using Subst): Block =
def lowerCall(fr: Path, isMlsFun: Bool, arg: Opt[Term], loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block =
arg match
case S(a) =>
lowerCall(fr, isMlsFun, a, loc)(k)
case N =>
// * No arguments means a nullary call, e.g., `f()`
k(Call(fr, Nil)(isMlsFun, true).withLoc(loc))
def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using Subst): Block =
def lowerCall(fr: Path, isMlsFun: Bool, arg: Term, loc: Opt[Loc])(k: Result => Block)(using LoweringCtx): Block =
lowerArg(arg)(as => k(Call(fr, as)(isMlsFun, true).withLoc(loc)))

def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using Subst): Block =
def lowerArg(arg: Term)(k: Ls[Arg] => Block)(using LoweringCtx): Block =
arg match
case Tup(fs) =>
if fs.exists(e => e match
Expand All @@ -329,7 +329,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
k(Arg(spread = S(true), ar) :: Nil)

@tailrec
final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using Subst): Block =
final def term(t: st, inStmtPos: Bool = false)(k: Result => Block)(using LoweringCtx): Block =
tl.log(s"Lowering.term ${t.showDbg.truncate(100, "[...]")}${
if inStmtPos then " (in stmt)" else ""}${
t.resolvedSym.fold("")(" – symbol " + _)}")
Expand Down Expand Up @@ -680,17 +680,17 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
// case _ =>
// subTerm(t)(k)

def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using Subst): Block =
def setupTerm(name: Str, args: Ls[Path])(k: Result => Block)(using LoweringCtx): Block =
k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN(name), args.map(_.asArg)))

def setupQuotedKeyword(kw: Str): Path =
Value.Ref(State.termSymbol).selSN("Keyword").selSN(kw)

def setupSymbol(symbol: Local)(k: Result => Block)(using Subst): Block =
def setupSymbol(symbol: Local)(k: Result => Block)(using LoweringCtx): Block =
k(Instantiate(mut = false, Value.Ref(State.termSymbol).selSN("Symbol"),
Value.Lit(Tree.StrLit(symbol.nme)).asArg :: Nil))

def quotePattern(p: FlatPattern)(k: Result => Block)(using Subst): Block = p match
def quotePattern(p: FlatPattern)(k: Result => Block)(using LoweringCtx): Block = p match
case FlatPattern.Lit(lit) => setupTerm("LitPattern", Value.Lit(lit) :: Nil)(k)
case _ => // TODO
fail:
Expand All @@ -700,7 +700,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
source = Diagnostic.Source.Compilation
)

def quoteSplit(split: Split)(k: Result => Block)(using Subst): Block = split match
def quoteSplit(split: Split)(k: Result => Block)(using LoweringCtx): Block = split match
case Split.Cons(Branch(scrutinee, pattern, continuation), tail) => quote(scrutinee): r1 =>
val l1, l2, l3, l4, l5 = new TempSymbol(N)
blockBuilder.assign(l1, r1)
Expand All @@ -725,7 +725,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
val state = summon[State]
Value.Ref(state.importSymbol).selSN("meta").selSN("url")

def quote(t: st)(k: Result => Block)(using Subst): Block = t match
def quote(t: st)(k: Result => Block)(using LoweringCtx): Block = t match
case Lit(lit) =>
setupTerm("Lit", Value.Lit(lit) :: Nil)(k)
case Ref(sym) if Elaborator.binaryOps.contains(sym.nme) => // builtin symbols
Expand Down Expand Up @@ -756,7 +756,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
source = Diagnostic.Source.Compilation
)
case Lam(params, body) =>
def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using Subst): Block = ps match
def rec(ps: Ls[LocalSymbol & NamedSymbol], ds: Ls[Path])(k: Result => Block)(using LoweringCtx): Block = ps match
case Nil => quote(body): r =>
val l = new TempSymbol(N)
val arr = new TempSymbol(N, "arr")
Expand Down Expand Up @@ -818,7 +818,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
source = Diagnostic.Source.Compilation
)

def gatherMembers(clsBody: ObjBody)(using Subst)
def gatherMembers(clsBody: ObjBody)(using LoweringCtx)
: (Ls[FunDefn], Ls[BlockMemberSymbol -> TermSymbol], Ls[TermSymbol], Block) =
val mtds = clsBody.methods
.flatMap: td =>
Expand All @@ -838,7 +838,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
case t => t
(mtds, publicFlds, privateFlds, ctor)

def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using Subst): Block =
def args(elems: Ls[Elem])(k: Ls[Arg] => Block)(using LoweringCtx): Block =
val as = elems.map:
case sem.Fld(sem.FldFlags.benign(), value, N) => R(N -> value)
case sem.Fld(sem.FldFlags.benign(), idx, S(rhs)) => L(idx -> rhs)
Expand Down Expand Up @@ -880,10 +880,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
k((Arg(N, Value.Ref(rcdSym)) :: asr).reverse)))


inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using Subst): Block =
inline def plainArgs(ts: Ls[st])(k: Ls[Arg] => Block)(using LoweringCtx): Block =
subTerms(ts)(asr => k(asr.map(Arg(N, _))))

inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using Subst): Block =
inline def subTerms(ts: Ls[st])(k: Ls[Path] => Block)(using LoweringCtx): Block =
// @tailrec // TODO
def rec(as: Ls[st], asr: Ls[Path]): Block = as match
case Nil => k(asr.reverse)
Expand All @@ -892,10 +892,10 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
rec(as, ar :: asr)
rec(ts, Nil)

def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block =
def subTerm_nonTail(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block =
subTerm(t: st, inStmtPos: Bool)(k)

inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using Subst): Block =
inline def subTerm(t: st, inStmtPos: Bool = false)(k: Path => Block)(using LoweringCtx): Block =
term(t, inStmtPos = inStmtPos):
case v: Value => k(v)
case p: Path => k(p)
Expand All @@ -912,7 +912,7 @@ class Lowering()(using Config, TL, Raise, State, Ctx):

val (imps, funs, rest) = splitBlock(main.stats, Nil, Nil, Nil)

val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using Subst.empty)
val blk = block(funs ::: rest, R(main.res))(ImplctRet)(using LoweringCtx.empty)

val desug = LambdaRewriter.desugar(blk)

Expand Down Expand Up @@ -945,20 +945,20 @@ class Lowering()(using Config, TL, Raise, State, Ctx):
)


def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block =
def setupSelection(prefix: Term, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block =
subTerm(prefix): p =>
val selRes = TempSymbol(N, "selRes")
k(Select(p, nme)(sym))

final def setupFunctionOrByNameDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str])
(using Subst): (List[ParamList], Block) =
(using LoweringCtx): (List[ParamList], Block) =
val physicalParams = paramLists match
case Nil => ParamList(ParamListFlags.empty, Nil, N) :: Nil
case ps => ps
setupFunctionDef(physicalParams, bodyTerm, name)

def setupFunctionDef(paramLists: List[ParamList], bodyTerm: Term, name: Option[Str])
(using Subst): (List[ParamList], Block) =
(using LoweringCtx): (List[ParamList], Block) =
(paramLists, returnedTerm(bodyTerm))

def reportAnnotations(target: Statement, annotations: Ls[Annot]): Unit =
Expand All @@ -974,7 +974,7 @@ trait LoweringSelSanityChecks(using Config, TL, Raise, State)

private val instrument: Bool = config.sanityChecks.isDefined

override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using Subst): Block =
override def setupSelection(prefix: st, nme: Tree.Ident, sym: Opt[FieldSymbol])(k: Result => Block)(using LoweringCtx): Block =
if !instrument then return super.setupSelection(prefix, nme, sym)(k)
subTerm(prefix): p =>
val selRes = TempSymbol(N, "selRes")
Expand Down Expand Up @@ -1021,7 +1021,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)


override def setupFunctionDef(paramLists: List[ParamList], bodyTerm: st, name: Option[Str])
(using Subst): (List[ParamList], Block) =
(using LoweringCtx): (List[ParamList], Block) =
if instrument then
val (ps, bod) = handleMultipleParamLists(paramLists, bodyTerm)
val instrumentedBody = setupFunctionBody(ps, bod, name)
Expand All @@ -1037,7 +1037,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
case h :: t => go(t, Term.Lam(h, bod))
go(paramLists.reverse, bod)

def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using Subst): Block =
def setupFunctionBody(params: ParamList, bod: Term, name: Option[Str])(using LoweringCtx): Block =
val enterMsgSym = TempSymbol(N, dbgNme = "traceLogEnterMsg")
val prevIndentLvlSym = TempSymbol(N, dbgNme = "traceLogPrevIndent")
val resSym = TempSymbol(N, dbgNme = "traceLogRes")
Expand Down Expand Up @@ -1073,7 +1073,7 @@ trait LoweringTraceLog(instrument: Bool)(using TL, Raise, State)
TempSymbol(N) -> pureCall(traceLogFn, Arg(N, Value.Ref(retMsgSym)) :: Nil)
) |>:
Ret(Value.Ref(resSym))
)
)(using LoweringCtx.nestFunc)


object TrivialStatementsAndMatch:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ object Elaborator:
given State = this
val globalThisSymbol = TopLevelSymbol("globalThis")
val unitSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("Unit"))
val loopEndSymbol = ModuleOrObjectSymbol(DummyTypeDef(syntax.Obj), Ident("LoopEnd"))
// In JavaScript, `import` can be used for getting current file path, as `import.meta`
val importSymbol = new VarSymbol(Ident("import"))
val runtimeSymbol = TempSymbol(N, "runtime")
Expand Down
Loading
Loading