diff --git a/.github/workflows/scala.yml b/.github/workflows/scala.yml index 5610a96e..4677da77 100644 --- a/.github/workflows/scala.yml +++ b/.github/workflows/scala.yml @@ -78,3 +78,4 @@ jobs: sbt 'testOnly gensym.wasm.TestScriptRun' sbt 'testOnly gensym.wasm.TestConcolic' sbt 'testOnly gensym.wasm.TestDriver' + sbt 'testOnly gensym.wasm.TestStagedEval' diff --git a/benchmarks/wasm/global.wat b/benchmarks/wasm/global.wat new file mode 100644 index 00000000..236467ef --- /dev/null +++ b/benchmarks/wasm/global.wat @@ -0,0 +1,19 @@ +(module + (type (;0;) (func (result i32))) + (type (;1;) (func)) + + (func (;0;) (type 0) (result i32) + i32.const 42 + global.set 0 + global.get 0 + ) + (func (;1;) (type 1) + call 0 + ;; should be 42 + ;; drop + ) + (start 1) + (memory (;0;) 2) + (export "main" (func 1)) + (global (;0;) (mut i32) (i32.const 0)) +) \ No newline at end of file diff --git a/benchmarks/wasm/staged/brtable.wat b/benchmarks/wasm/staged/brtable.wat new file mode 100644 index 00000000..91133d70 --- /dev/null +++ b/benchmarks/wasm/staged/brtable.wat @@ -0,0 +1,11 @@ +(module $push-drop + (global (;0;) (mut i32) (i32.const 1048576)) + (func (;0;) (type 1) (result i32) + i32.const 2 + (block + (block + br_table 0 1 0 + ) + ) + ) + (start 0)) diff --git a/src/main/scala/wasm/AST.scala b/src/main/scala/wasm/AST.scala index c59eefc9..274b2b50 100644 --- a/src/main/scala/wasm/AST.scala +++ b/src/main/scala/wasm/AST.scala @@ -270,7 +270,16 @@ case class RefType(kind: HeapType) extends ValueType case class GlobalType(ty: ValueType, mut: Boolean) extends WasmType -abstract class BlockType extends WIR +abstract class BlockType extends WIR { + def funcType: FuncType = + this match { + case VarBlockType(_, None) => + ??? // TODO: fill this branch until we handle type index correctly + case VarBlockType(_, Some(tipe)) => tipe + case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) + case ValBlockType(None) => FuncType(List(), List(), List()) + } +} case class VarBlockType(index: Int, tipe: Option[FuncType]) extends BlockType case class ValBlockType(tipe: Option[ValueType]) extends BlockType; diff --git a/src/main/scala/wasm/ConcolicMiniWasm.scala b/src/main/scala/wasm/ConcolicMiniWasm.scala index 849fd831..fec869fe 100644 --- a/src/main/scala/wasm/ConcolicMiniWasm.scala +++ b/src/main/scala/wasm/ConcolicMiniWasm.scala @@ -229,15 +229,6 @@ object Primitives { case NumType(F32Type) => F32V(rng.nextFloat()) case NumType(F64Type) => F64V(rng.nextDouble()) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(module: ModuleInstance, locals: ArrayBuffer[Value], symLocals: ArrayBuffer[SymVal]) @@ -383,7 +374,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, concStack, symStack, frame, kont, trail) case Unreachable => throw new RuntimeException("Unreachable") case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) @@ -391,7 +382,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, retStack.take(outputSize) ++ restStack, retSymStack.take(outputSize) ++ restSymStack, frame, kont, trail)(tree) eval(inner, inputs, symInputs, frame, restK, restK :: trail) case Loop(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputSize, outputSize) = (funcTy.inps.size, funcTy.out.size) val (inputs, restStack) = concStack.splitAt(inputSize) val (symInputs, restSymStack) = symStack.splitAt(inputSize) diff --git a/src/main/scala/wasm/MiniWasm.scala b/src/main/scala/wasm/MiniWasm.scala index 11eb301b..84a8bd88 100644 --- a/src/main/scala/wasm/MiniWasm.scala +++ b/src/main/scala/wasm/MiniWasm.scala @@ -229,15 +229,6 @@ object Primtives { case VecType(kind) => ??? case RefType(kind) => RefNullV(kind) } - - def getFuncType(ty: BlockType): FuncType = - ty match { - case VarBlockType(_, None) => - ??? // TODO: fill this branch until we handle type index correctly - case VarBlockType(_, Some(tipe)) => tipe - case ValBlockType(Some(tipe)) => FuncType(List(), List(), List(tipe)) - case ValBlockType(None) => FuncType(List(), List(), List()) - } } case class Frame(locals: ArrayBuffer[Value]) @@ -264,8 +255,8 @@ case class Evaluator(module: ModuleInstance) { val frameLocals = args ++ locals.map(zero(_)) val newFrame = Frame(ArrayBuffer(frameLocals: _*)) if (isTail) - // when tail call, share the continuation for returning with the callee - eval(body, List(), newFrame, kont, List(kont)) + // when tail call, return to the caller's return continuation + eval(body, List(), newFrame, trail.last, List(trail.last)) else { val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(ty.out.size) ++ newStack, frame, kont, trail) @@ -380,7 +371,7 @@ case class Evaluator(module: ModuleInstance) { eval(rest, stack, frame, kont, trail) case Unreachable => throw Trap() case Block(ty, inner) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -389,7 +380,7 @@ case class Evaluator(module: ModuleInstance) { // We construct two continuations, one for the break (to the begining of the loop), // and one for fall-through to the next instruction following the syntactic structure // of the program. - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val (inputs, restStack) = stack.splitAt(funcTy.inps.size) val restK: Cont[Ans] = (retStack) => eval(rest, retStack.take(funcTy.out.size) ++ restStack, frame, kont, trail) @@ -397,7 +388,7 @@ case class Evaluator(module: ModuleInstance) { eval(inner, retStack.take(funcTy.inps.size), frame, restK, loop _ :: trail) loop(inputs) case If(ty, thn, els) => - val funcTy = getFuncType(ty) + val funcTy = ty.funcType val I32V(cond) :: newStack = stack val inner = if (cond != 0) thn else els val (inputs, restStack) = newStack.splitAt(funcTy.inps.size) diff --git a/src/main/scala/wasm/Parser.scala b/src/main/scala/wasm/Parser.scala index 40b497e0..0ce9fa94 100644 --- a/src/main/scala/wasm/Parser.scala +++ b/src/main/scala/wasm/Parser.scala @@ -314,7 +314,7 @@ class GSWasmVisitor extends WatParserBaseVisitor[WIR] { else if (ctx.LOCAL_GET() != null) LocalGet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_SET() != null) LocalSet(getVar(ctx.idx(0)).toInt) else if (ctx.LOCAL_TEE() != null) LocalTee(getVar(ctx.idx(0)).toInt) - else if (ctx.GLOBAL_SET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) + else if (ctx.GLOBAL_SET() != null) GlobalSet(getVar(ctx.idx(0)).toInt) else if (ctx.GLOBAL_GET() != null) GlobalGet(getVar(ctx.idx(0)).toInt) else if (ctx.load() != null) { val ty = visitNumType(ctx.load.numType) diff --git a/src/main/scala/wasm/StagedMiniWasm.scala b/src/main/scala/wasm/StagedMiniWasm.scala new file mode 100644 index 00000000..c1358894 --- /dev/null +++ b/src/main/scala/wasm/StagedMiniWasm.scala @@ -0,0 +1,805 @@ +package gensym.wasm.miniwasm + +import scala.collection.mutable.{ArrayBuffer, HashMap} + +import lms.core.stub.Adapter +import lms.core.virtualize +import lms.macros.SourceContext +import lms.core.stub.{Base, ScalaGenBase, CGenBase} +import lms.core.Backend._ +import lms.core.Backend.{Block => LMSBlock} + +import gensym.wasm.ast._ +import gensym.wasm.ast.{Const => WasmConst, Block => WasmBlock} +import gensym.lmsx.{SAIDriver, StringOps, SAIOps, SAICodeGenBase, CppSAIDriver, CppSAICodeGenBase} + +@virtualize +trait StagedWasmEvaluator extends SAIOps { + def module: ModuleInstance + // NOTE: we don't need the following statements anymore, but where are they initialized? + // reset and initialize the internal state of Adapter + // Adapter.resetState + // Adapter.g = Adapter.mkGraphBuilder + + trait Slice + + trait Frame + + type Cont[A] = Unit => A + type Trail[A] = List[Rep[Cont[A]]] + + // a cache storing the compiled code for each function, to reduce re-compilation + val compileCache = new HashMap[Int, Rep[(Cont[Unit]) => Unit]] + + // NOTE: We don't support Ans type polymorphism yet + def eval(insts: List[Instr], + kont: Rep[Cont[Unit]], + trail: Trail[Unit]): Rep[Unit] = { + if (insts.isEmpty) return kont() + val (inst, rest) = (insts.head, insts.tail) + inst match { + case Drop => + Stack.pop() + eval(rest, kont, trail) + case WasmConst(num) => + Stack.push(num) + eval(rest, kont, trail) + case LocalGet(i) => + Stack.push(Frames.get(i)) + eval(rest, kont, trail) + case LocalSet(i) => + Frames.set(i, Stack.pop()) + eval(rest, kont, trail) + case LocalTee(i) => + Frames.set(i, Stack.peek) + eval(rest, kont, trail) + case GlobalGet(i) => + Stack.push(Global.globalGet(i)) + eval(rest, kont, trail) + case GlobalSet(i) => + val value = Stack.pop() + module.globals(i).ty match { + case GlobalType(tipe, true) => Global.globalSet(i, value) + case _ => throw new Exception("Cannot set immutable global") + } + eval(rest, kont, trail) + case MemorySize => ??? + case MemoryGrow => ??? + case MemoryFill => ??? + case Nop => + eval(rest, kont, trail) + case Unreachable => unreachable() + case Test(op) => + val v = Stack.pop() + Stack.push(evalTestOp(op, v)) + eval(rest, kont, trail) + case Unary(op) => + val v = Stack.pop() + Stack.push(evalUnaryOp(op, v)) + eval(rest, kont, trail) + case Binary(op) => + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalBinOp(op, v1, v2)) + eval(rest, kont, trail) + case Compare(op) => + val v2 = Stack.pop() + val v1 = Stack.pop() + Stack.push(evalRelOp(op, v1, v2)) + eval(rest, kont, trail) + case WasmBlock(ty, inner) => + // no need to modify the stack when entering a block + // the type system guarantees that we will never take more than the input size from the stack + val funcTy = ty.funcType + // TODO: somehow the type of exitSize in residual program is nothing + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + eval(inner, restK, restK :: trail) + case Loop(ty, inner) => + val funcTy = ty.funcType + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + def loop(_u: Rep[Unit]): Rep[Unit] = + eval(inner, restK, fun(loop _) :: trail) + loop(()) + case If(ty, thn, els) => + val funcTy = ty.funcType + val exitSize = Stack.size - funcTy.inps.size + funcTy.out.size + val cond = Stack.pop() + // TODO: can we avoid code duplication here? + val restK = fun((_: Rep[Unit]) => { + Stack.reset(exitSize) + eval(rest, kont, trail) + }) + if (cond != Values.I32(0)) { + eval(thn, restK, restK :: trail) + } else { + eval(els, restK, restK :: trail) + } + case Br(label) => + info(s"Jump to $label") + trail(label)(()) + case BrIf(label) => + val cond = Stack.pop() + info(s"The br_if(${label})'s condition is ", cond) + if (cond != Values.I32(0)) { + info(s"Jump to $label") + trail(label)(()) + } else { + info(s"Continue") + eval(rest, kont, trail) + } + case BrTable(labels, default) => + val cond = Stack.pop() + def aux(choices: List[Int], idx: Int): Rep[Unit] = { + if (choices.isEmpty) trail(default)(()) + else { + if (cond.toInt == idx) trail(choices.head)(()) + else aux(choices.tail, idx + 1) + } + } + aux(labels, 0) + case Return => trail.last(()) + case Call(f) => evalCall(rest, kont, trail, f, false) + case ReturnCall(f) => evalCall(rest, kont, trail, f, true) + case _ => + val todo = "todo-op".reflectCtrlWith[Unit]() + eval(rest, kont, trail) + } + } + + def evalCall(rest: List[Instr], + kont: Rep[Cont[Unit]], + trail: Trail[Unit], + funcIndex: Int, + isTail: Boolean): Rep[Unit] = { + module.funcs(funcIndex) match { + case FuncDef(_, FuncBodyDef(ty, _, locals, body)) => + val returnSize = Stack.size - ty.inps.size + ty.out.size + val args = Stack.take(ty.inps.size) + info("New frame:", Frames.top) + val callee = + if (compileCache.contains(funcIndex)) { + compileCache(funcIndex) + } else { + val callee = topFun( + (kont: Rep[Cont[Unit]]) => { + info(s"Entered the function at $funcIndex, stackSize =", Stack.size, ", frame =", Frames.top) + eval(body, kont, kont::Nil): Rep[Unit] + } + ) + compileCache(funcIndex) = callee + callee + } + val frameSize = ty.inps.size + locals.size + if (isTail) { + // when tail call, return to the caller's return continuation + Frames.popFrame() + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(trail.last) + } else { + val restK: Rep[Cont[Unit]] = fun((_: Rep[Unit]) => { + Stack.reset(returnSize) + Frames.popFrame() + eval(rest, kont, trail) + }) + // We make a new trail by `restK`, since function creates a new block to escape + // (more or less like `return`) + Frames.pushFrame(frameSize) + Frames.putAll(args) + callee(restK) + } + case Import("console", "log", _) + | Import("spectest", "print_i32", _) => + //println(s"[DEBUG] current stack: $stack") + val v = Stack.pop() + println(v) + eval(rest, kont, trail) + case Import(_, _, _) => throw new Exception(s"Unknown import at $funcIndex") + case _ => throw new Exception(s"Definition at $funcIndex is not callable") + } + } + + def evalTestOp(op: TestOp, value: Rep[Num]): Rep[Num] = op match { + case Eqz(_) => if (value.toInt == 0) Values.I32(1) else Values.I32(0) + } + + def evalUnaryOp(op: UnaryOp, value: Rep[Num]): Rep[Num] = op match { + case Clz(_) => value.clz() + case Ctz(_) => value.ctz() + case Popcnt(_) => value.popcnt() + case _ => ??? + } + + def evalBinOp(op: BinOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Add(_) => v1 + v2 + case Mul(_) => v1 * v2 + case Sub(_) => v1 - v2 + case Shl(_) => v1 << v2 + // case ShrS(_) => v1 >> v2 // TODO: signed shift right + case ShrU(_) => v1 >> v2 + case And(_) => v1 & v2 + case _ => ??? + } + + def evalRelOp(op: RelOp, v1: Rep[Num], v2: Rep[Num]): Rep[Num] = op match { + case Eq(_) => v1 numEq v2 + case Ne(_) => v1 numNe v2 + case LtS(_) => v1 < v2 + case LtU(_) => v1 ltu v2 + case GtS(_) => v1 > v2 + case GtU(_) => v1 gtu v2 + case LeS(_) => v1 <= v2 + case LeU(_) => v1 leu v2 + case GeS(_) => v1 >= v2 + case GeU(_) => v1 geu v2 + case _ => ??? + } + + def evalTop(kont: Rep[Cont[Unit]], main: Option[String]): Rep[Unit] = { + val funBody: FuncBodyDef = main match { + case Some(func_name) => + module.defs.flatMap({ + case Export(`func_name`, ExportFunc(fid)) => + Predef.println(s"Now compiling start with function $main") + module.funcs(fid) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => Some(body) + case _ => throw new Exception("Entry function has no concrete body") + } + case _ => None + }).head + case None => + val startIds = module.defs.flatMap { + case Start(id) => Some(id) + case _ => None + } + val startId = startIds.headOption.getOrElse { throw new Exception("No start function") } + module.funcs(startId) match { + case FuncDef(_, body@FuncBodyDef(_,_,_,_)) => body + case _ => + throw new Exception("Entry function has no concrete body") + } + } + val (instrs, localSize) = (funBody.body, funBody.locals.size) + Stack.initialize() + Frames.pushFrame(localSize) + eval(instrs, kont, kont::Nil) + Frames.popFrame() + } + + def evalTop(main: Option[String], printRes: Boolean = false): Rep[Unit] = { + val haltK: Rep[Unit] => Rep[Unit] = (_) => { + if (printRes) { + Stack.print() + } + "no-op".reflectCtrlWith[Unit]() + } + val temp: Rep[Cont[Unit]] = fun(haltK) + evalTop(temp, main) + } + + // stack creation and operations + object Stack { + def initialize(): Rep[Unit] = { + "stack-init".reflectCtrlWith[Unit]() + } + + def pop(): Rep[Num] = { + "stack-pop".reflectCtrlWith[Num]() + } + + def peek: Rep[Num] = { + "stack-peek".reflectCtrlWith[Num]() + } + + def push(v: Rep[Num]): Rep[Unit] = { + "stack-push".reflectCtrlWith[Unit](v) + } + + def drop(n: Int): Rep[Unit] = { + "stack-drop".reflectCtrlWith[Unit](n) + } + + def print(): Rep[Unit] = { + "stack-print".reflectCtrlWith[Unit]() + } + + def size: Rep[Int] = { + "stack-size".reflectCtrlWith[Int]() + } + + def reset(x: Rep[Int]): Rep[Unit] = { + "stack-reset".reflectCtrlWith[Unit](x) + } + + def take(n: Int): Rep[Slice] = { + "stack-take".reflectCtrlWith[Slice](n) + } + } + + object Frames { + def get(i: Int): Rep[Num] = { + "frame-get".reflectCtrlWith[Num](i) + } + + def set(i: Int, v: Rep[Num]): Rep[Unit] = { + "frame-set".reflectCtrlWith(i, v) + } + + def pushFrame(i: Int): Rep[Unit] = { + "frame-push".reflectCtrlWith[Unit](i) + } + + def popFrame(): Rep[Unit] = { + "frame-pop".reflectCtrlWith[Unit]() + } + + def putAll(args: Rep[Slice]): Rep[Unit] = { + "frame-putAll".reflectCtrlWith[Unit](args) + } + + def top: Rep[Frame] = { + "frame-top".reflectCtrlWith[Frame]() + } + } + + + // call unreachable + def unreachable(): Rep[Unit] = { + "unreachable".reflectCtrlWith[Unit]() + } + + def info(xs: Rep[_]*): Rep[Unit] = { + "info".reflectCtrlWith[Unit](xs: _*) + } + + // runtime values + object Values { + def lift(num: Num): Rep[Num] = { + num match { + case I32V(i) => I32(i) + case I64V(i) => I64(i) + } + } + + def I32(i: Rep[Int]): Rep[Num] = { + "I32V".reflectWith[Num](i) + } + + def I64(i: Rep[Long]): Rep[Num] = { + "I64V".reflectWith[Num](i) + } + } + + // global read/write + object Global{ + def globalGet(i: Int): Rep[Num] = { + "global-get".reflectWith[Num](i) + } + + def globalSet(i: Int, value: Rep[Num]): Rep[Unit] = { + "global-set".reflectCtrlWith[Unit](i, value) + } + } + + // runtime Num type + implicit class NumOps(num: Rep[Num]) { + + def toInt: Rep[Int] = "num-to-int".reflectWith[Int](num) + + def clz(): Rep[Num] = "unary-clz".reflectWith[Num](num) + + def ctz(): Rep[Num] = "unary-ctz".reflectWith[Num](num) + + def popcnt(): Rep[Num] = "unary-popcnt".reflectWith[Num](num) + + def +(rhs: Rep[Num]): Rep[Num] = "binary-add".reflectWith[Num](num, rhs) + + def -(rhs: Rep[Num]): Rep[Num] = "binary-sub".reflectWith[Num](num, rhs) + + def *(rhs: Rep[Num]): Rep[Num] = "binary-mul".reflectWith[Num](num, rhs) + + def /(rhs: Rep[Num]): Rep[Num] = "binary-div".reflectWith[Num](num, rhs) + + def <<(rhs: Rep[Num]): Rep[Num] = "binary-shl".reflectWith[Num](num, rhs) + + def >>(rhs: Rep[Num]): Rep[Num] = "binary-shr".reflectWith[Num](num, rhs) + + def &(rhs: Rep[Num]): Rep[Num] = "binary-and".reflectWith[Num](num, rhs) + + def numEq(rhs: Rep[Num]): Rep[Num] = "relation-eq".reflectWith[Num](num, rhs) + + def numNe(rhs: Rep[Num]): Rep[Num] = "relation-ne".reflectWith[Num](num, rhs) + + def <(rhs: Rep[Num]): Rep[Num] = "relation-lt".reflectWith[Num](num, rhs) + + def ltu(rhs: Rep[Num]): Rep[Num] = "relation-ltu".reflectWith[Num](num, rhs) + + def >(rhs: Rep[Num]): Rep[Num] = "relation-gt".reflectWith[Num](num, rhs) + + def gtu(rhs: Rep[Num]): Rep[Num] = "relation-gtu".reflectWith[Num](num, rhs) + + def <=(rhs: Rep[Num]): Rep[Num] = "relation-le".reflectWith[Num](num, rhs) + + def leu(rhs: Rep[Num]): Rep[Num] = "relation-leu".reflectWith[Num](num, rhs) + + def >=(rhs: Rep[Num]): Rep[Num] = "relation-ge".reflectWith[Num](num, rhs) + + def geu(rhs: Rep[Num]): Rep[Num] = "relation-geu".reflectWith[Num](num, rhs) + } + implicit class SliceOps(slice: Rep[Slice]) { + def reverse: Rep[Slice] = "slice-reverse".reflectWith[Slice](slice) + } +} + +trait StagedWasmScalaGen extends ScalaGenBase with SAICodeGenBase { + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(")\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(")\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(")\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize()\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print()\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(")\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(")\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(")\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(")\n") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") + case _ => super.shallow(n) + } +} + +trait WasmToScalaCompilerDriver[A, B] + extends SAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmScalaGen { + val IR: q.type = q + import IR._ + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Stack")) "Stack" + else if(m.toString.endsWith("Frame")) "Frame" + else if(m.toString.endsWith("Slice")) "Slice" + else super.remap(m) + } + } + + override val prelude = + """ +object Prelude { + sealed abstract class Num { + def +(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x + y) + case (I64V(x), I64V(y)) => I64V(x + y) + case _ => throw new RuntimeException("Invalid addition") + } + + def -(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(x - y) + case (I64V(x), I64V(y)) => I64V(x - y) + case _ => throw new RuntimeException("Invalid subtraction") + } + + def !=(that: Num): Num = (this, that) match { + case (I32V(x), I32V(y)) => I32V(if (x != y) 1 else 0) + case (I64V(x), I64V(y)) => I32V(if (x != y) 1 else 0) + case _ => throw new RuntimeException("Invalid inequality") + } + + def toInt: Int = this match { + case I32V(i) => i + case I64V(i) => i.toInt + } + } + case class I32V(i: Int) extends Num + case class I64V(i: Long) extends Num + +object Stack { + private val buffer = new scala.collection.mutable.ArrayBuffer[Num]() + def push(v: Num): Unit = buffer.append(v) + def pop(): Num = { + buffer.remove(buffer.size - 1) + } + def peek: Num = { + buffer.last + } + def size: Int = buffer.size + def drop(n: Int): Unit = { + buffer.remove(buffer.size - n, n) + } + def take(n: Int): List[Num] = { + val xs = buffer.takeRight(n).toList + drop(n) + xs + } + def reset(size: Int): Unit = { + info(s"Reset stack to size $size") + while (buffer.size > size) { + buffer.remove(buffer.size - 1) + } + } + def initialize(): Unit = buffer.clear() + def print(): Unit = { + println("Stack: " + buffer.mkString(", ")) + } +} + + type Slice = List[Num] + + class Frame(val size: Int) { + private val data = new Array[Num](size) + def apply(i: Int): Num = { + info(s"frame(${i}) is ${data(i)}") + data(i) + } + def update(i: Int, v: Num): Unit = { + info(s"set frame(${i}) to ${v}") + data(i) = v + } + def putAll(xs: List[Num]): Unit = { + for (i <- 0 until xs.size) { + data(i) = xs(i) + } + } + override def toString: String = { + "Frame(" + data.mkString(", ") + ")" + } + } + + object Frames { + private var frames = List[Frame]() + def pushFrame(size: Int): Unit = { + frames = new Frame(size) :: frames + } + def popFrame(): Unit = { + frames = frames.tail + } + def top: Frame = frames.head + def set(i: Int, v: Num): Unit = { + top(i) = v + } + def get(i: Int): Num = { + top(i) + } + def putAll(xs: Slice) = { + for (i <- 0 until xs.size) { + top(i) = xs(i) + } + } + } + + object Global { + // TODO: create global with specific size + private val globals = new Array[Num](10) + def globalGet(i: Int): Num = globals(i) + def globalSet(i: Int, v: Num): Unit = globals(i) = v + } + + def info(xs: Any*): Unit = { + if (System.getenv("DEBUG") != null) { + println("[INFO] " + xs.mkString(" ")) + } + } +} +import Prelude._ + + +object Main { + def main(args: Array[String]): Unit = { + val snippet = new Snippet() + snippet(()) + } +} +""" +} + + + +object WasmToScalaCompiler { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmToScalaCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + code.code + } +} + +trait StagedWasmCppGen extends CGenBase with CppSAICodeGenBase { + override def remap(m: Manifest[_]): String = { + if (m.toString.endsWith("Num")) "Num" + else if (m.toString.endsWith("Slice")) "Slice" + else if (m.toString.endsWith("Frame")) "Frame" + else if (m.toString.endsWith("Stack")) "Stack" + else if (m.toString.endsWith("Global")) "Global" + else if (m.toString.endsWith("I32V")) "I32V" + else if (m.toString.endsWith("I64V")) "I64V" + else super.remap(m) + } + + // for now, the traverse/shallow is same as the scala backend's + override def traverse(n: Node): Unit = n match { + case Node(_, "stack-push", List(value), _) => + emit("Stack.push("); shallow(value); emit(");\n") + case Node(_, "stack-drop", List(n), _) => + emit("Stack.drop("); shallow(n); emit(");\n") + case Node(_, "stack-reset", List(n), _) => + emit("Stack.reset("); shallow(n); emit(");\n") + case Node(_, "stack-init", _, _) => + emit("Stack.initialize();\n") + case Node(_, "stack-print", _, _) => + emit("Stack.print();\n") + case Node(_, "frame-push", List(i), _) => + emit("Frames.pushFrame("); shallow(i); emit(");\n") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame();\n") + case Node(_, "frame-putAll", List(args), _) => + emit("Frames.putAll("); shallow(args); emit(");\n") + case Node(_, "frame-set", List(i, value), _) => + emit("Frames.set("); shallow(i); emit(", "); shallow(value); emit(");\n") + case Node(_, "global-set", List(i, value), _) => + emit("Global.globalSet("); shallow(i); emit(", "); shallow(value); emit(");\n") + case _ => super.traverse(n) + } + + // code generation for pure nodes + override def shallow(n: Node): Unit = n match { + case Node(_, "frame-get", List(i), _) => + emit("Frames.get("); shallow(i); emit(")") + case Node(_, "stack-pop", _, _) => + emit("Stack.pop()") + case Node(_, "frame-pop", _, _) => + emit("Frames.popFrame()") + case Node(_, "stack-peek", _, _) => + emit("Stack.peek") + case Node(_, "stack-take", List(n), _) => + emit("Stack.take("); shallow(n); emit(")") + case Node(_, "slice-reverse", List(slice), _) => + shallow(slice); emit(".reverse") + case Node(_, "stack-size", _, _) => + emit("Stack.size()") + case Node(_, "global-get", List(i), _) => + emit("Global.globalGet("); shallow(i); emit(")") + case Node(_, "frame-top", _, _) => + emit("Frames.top()") + case Node(_, "binary-add", List(lhs, rhs), _) => + shallow(lhs); emit(" + "); shallow(rhs) + case Node(_, "binary-sub", List(lhs, rhs), _) => + shallow(lhs); emit(" - "); shallow(rhs) + case Node(_, "binary-mul", List(lhs, rhs), _) => + shallow(lhs); emit(" * "); shallow(rhs) + case Node(_, "binary-div", List(lhs, rhs), _) => + shallow(lhs); emit(" / "); shallow(rhs) + case Node(_, "binary-shl", List(lhs, rhs), _) => + shallow(lhs); emit(" << "); shallow(rhs) + case Node(_, "binary-shr", List(lhs, rhs), _) => + shallow(lhs); emit(" >> "); shallow(rhs) + case Node(_, "binary-and", List(lhs, rhs), _) => + shallow(lhs); emit(" & "); shallow(rhs) + case Node(_, "relation-eq", List(lhs, rhs), _) => + shallow(lhs); emit(" == "); shallow(rhs) + case Node(_, "relation-ne", List(lhs, rhs), _) => + shallow(lhs); emit(" != "); shallow(rhs) + case Node(_, "relation-lt", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-ltu", List(lhs, rhs), _) => + shallow(lhs); emit(" < "); shallow(rhs) + case Node(_, "relation-gt", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-gtu", List(lhs, rhs), _) => + shallow(lhs); emit(" > "); shallow(rhs) + case Node(_, "relation-le", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-leu", List(lhs, rhs), _) => + shallow(lhs); emit(" <= "); shallow(rhs) + case Node(_, "relation-ge", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "relation-geu", List(lhs, rhs), _) => + shallow(lhs); emit(" >= "); shallow(rhs) + case Node(_, "num-to-int", List(num), _) => + shallow(num); emit(".toInt") + case Node(_, "no-op", _, _) => + emit("()") + case _ => super.shallow(n) + } +} + + +trait WasmToCppCompilerDriver[A, B] extends CppSAIDriver[A, B] with StagedWasmEvaluator { q => + override val codegen = new StagedWasmCppGen { + val IR: q.type = q + import IR._ + } +} + +object WasmToCppCompiler { + def apply(moduleInst: ModuleInstance, main: Option[String], printRes: Boolean = false): String = { + println(s"Now compiling wasm module with entry function $main") + val code = new WasmToCppCompilerDriver[Unit, Unit] { + def module: ModuleInstance = moduleInst + def snippet(x: Rep[Unit]): Rep[Unit] = { + evalTop(main, printRes) + } + } + code.code + } +} + diff --git a/src/test/scala/genwasym/TestEval.scala b/src/test/scala/genwasym/TestEval.scala index 38453996..2e358375 100644 --- a/src/test/scala/genwasym/TestEval.scala +++ b/src/test/scala/genwasym/TestEval.scala @@ -81,6 +81,9 @@ class TestEval extends FunSuite { test("loop block - poly br") { testFile("./benchmarks/wasm/loop_poly.wat", None, ExpStack(List(2, 1))) } + test("global") { + testFile("./benchmarks/wasm/global.wat", None, ExpInt(42)) + } // just a test for .bin.wast utility // the complete tests can be seen at https://github.com/Generative-Program-Analysis/wasm-cps/ diff --git a/src/test/scala/genwasym/TestStagedEval.scala b/src/test/scala/genwasym/TestStagedEval.scala new file mode 100644 index 00000000..47afddce --- /dev/null +++ b/src/test/scala/genwasym/TestStagedEval.scala @@ -0,0 +1,35 @@ +package gensym.wasm + +import org.scalatest.FunSuite + +import lms.core.stub.Adapter + +import gensym.wasm.parser._ +import gensym.wasm.miniwasm._ + +class TestStagedEval extends FunSuite { + def testFileToScala(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = WasmToScalaCompiler(moduleInst, main, true) + println(code) + } + + test("ack-scala") { testFileToScala("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + + test("brtable-scala") { + testFileToScala("./benchmarks/wasm/staged/brtable.wat") + } + + def testFileToCpp(filename: String, main: Option[String] = None, printRes: Boolean = false) = { + val moduleInst = ModuleInstance(Parser.parseFile(filename)) + val code = WasmToCppCompiler(moduleInst, main, true) + println(code) + } + + test("ack-cpp") { testFileToCpp("./benchmarks/wasm/ack.wat", Some("real_main"), printRes = true) } + + test("brtable-cpp") { + testFileToCpp("./benchmarks/wasm/staged/brtable.wat") + } + +} diff --git a/third-party/lms-clean b/third-party/lms-clean index b6f019ae..f3338d3a 160000 --- a/third-party/lms-clean +++ b/third-party/lms-clean @@ -1 +1 @@ -Subproject commit b6f019aef1df5f1f12bcd0b43a9136d7f9ce7704 +Subproject commit f3338d3ab0ea74e90e44acfdbbda7912c249a7fc