743 lines
24 KiB
Text
743 lines
24 KiB
Text
/-
|
||
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Leonardo de Moura
|
||
-/
|
||
prelude
|
||
import init.control.conditional
|
||
import init.lean.runtime
|
||
import init.lean.compiler.namemangling
|
||
import init.lean.compiler.exportattr
|
||
import init.lean.compiler.initattr
|
||
import init.lean.compiler.ir.compilerm
|
||
import init.lean.compiler.ir.emitutil
|
||
import init.lean.compiler.ir.normids
|
||
import init.lean.compiler.ir.simpcase
|
||
import init.lean.compiler.ir.boxing
|
||
|
||
namespace Lean
|
||
namespace IR
|
||
open ExplicitBoxing (requiresBoxedVersion mkBoxedName isBoxedName)
|
||
namespace EmitCpp
|
||
|
||
def leanMainFn := "_lean_main"
|
||
|
||
structure Context :=
|
||
(env : Environment)
|
||
(modName : Name)
|
||
(varMap : VarTypeMap := {})
|
||
(jpMap : JPParamsMap := {})
|
||
(mainFn : FunId := default _)
|
||
(mainParams : Array Param := Array.empty)
|
||
|
||
abbrev M := ReaderT Context (EState String String)
|
||
|
||
def getEnv : M Environment := Context.env <$> read
|
||
def getModName : M Name := Context.modName <$> read
|
||
def getDecl (n : Name) : M Decl :=
|
||
do env ← getEnv;
|
||
match findEnvDecl env n with
|
||
| some d => pure d
|
||
| none => throw ("unknown declaration '" ++ toString n ++ "'")
|
||
|
||
@[inline] def emit {α : Type} [HasToString α] (a : α) : M Unit :=
|
||
modify (fun out => out ++ toString a)
|
||
|
||
@[inline] def emitLn {α : Type} [HasToString α] (a : α) : M Unit :=
|
||
emit a *> emit "\n"
|
||
|
||
def emitLns {α : Type} [HasToString α] (as : List α) : M Unit :=
|
||
as.mfor $ fun a => emitLn a
|
||
|
||
def argToCppString (x : Arg) : String :=
|
||
match x with
|
||
| Arg.var x => toString x
|
||
| _ => "lean::box(0)"
|
||
|
||
def emitArg (x : Arg) : M Unit :=
|
||
emit (argToCppString x)
|
||
|
||
def toCppType : IRType → String
|
||
| IRType.float => "double"
|
||
| IRType.uint8 => "uint8"
|
||
| IRType.uint16 => "uint16"
|
||
| IRType.uint32 => "uint32"
|
||
| IRType.uint64 => "uint64"
|
||
| IRType.usize => "usize"
|
||
| IRType.object => "obj*"
|
||
| IRType.tobject => "obj*"
|
||
| IRType.irrelevant => "obj*"
|
||
|
||
def openNamespacesAux : Name → M Unit
|
||
| Name.anonymous => pure ()
|
||
| Name.mkString p s => openNamespacesAux p *> emitLn ("namespace " ++ s ++ " {")
|
||
| n => throw ("invalid namespace '" ++ toString n ++ "'")
|
||
|
||
def openNamespaces (n : Name) : M Unit :=
|
||
openNamespacesAux n.getPrefix
|
||
|
||
def openNamespacesFor (n : Name) : M Unit :=
|
||
do env ← getEnv;
|
||
match getExportNameFor env n with
|
||
| none => pure ()
|
||
| some n => openNamespaces n
|
||
|
||
def closeNamespacesAux : Name → M Unit
|
||
| Name.anonymous => pure ()
|
||
| Name.mkString p _ => emitLn "}" *> closeNamespacesAux p
|
||
| n => throw ("invalid namespace '" ++ toString n ++ "'")
|
||
|
||
def closeNamespaces (n : Name) : M Unit :=
|
||
closeNamespacesAux n.getPrefix
|
||
|
||
def closeNamespacesFor (n : Name) : M Unit :=
|
||
do env ← getEnv;
|
||
match getExportNameFor env n with
|
||
| none => pure ()
|
||
| some n => closeNamespaces n
|
||
|
||
def throwInvalidExportName {α : Type} (n : Name) : M α :=
|
||
throw ("invalid export name '" ++ toString n ++ "'")
|
||
|
||
def toBaseCppName (n : Name) : M String :=
|
||
do env ← getEnv;
|
||
match getExportNameFor env n with
|
||
| some (Name.mkString _ s) => pure s
|
||
| some _ => throwInvalidExportName n
|
||
| none => if n == `main then pure leanMainFn else pure n.mangle
|
||
|
||
def toCppName (n : Name) : M String :=
|
||
do env ← getEnv;
|
||
match getExportNameFor env n with
|
||
| some s => pure (s.toStringWithSep "::")
|
||
| none => if n == `main then pure leanMainFn else pure n.mangle
|
||
|
||
def emitCppName (n : Name) : M Unit :=
|
||
toCppName n >>= emit
|
||
|
||
def toCppInitName (n : Name) : M String :=
|
||
do env ← getEnv;
|
||
match getExportNameFor env n with
|
||
| some (Name.mkString p s) => pure $ (Name.mkString p ("_init_" ++ s)).toStringWithSep "::"
|
||
| some _ => throwInvalidExportName n
|
||
| none => pure ("_init_" ++ n.mangle)
|
||
|
||
def emitCppInitName (n : Name) : M Unit :=
|
||
toCppInitName n >>= emit
|
||
|
||
def emitFnDeclAux (decl : Decl) (cppBaseName : String) (addExternForConsts : Bool) : M Unit :=
|
||
do
|
||
let ps := decl.params;
|
||
when (ps.isEmpty && addExternForConsts) (emit "extern ");
|
||
emit (toCppType decl.resultType ++ " " ++ cppBaseName);
|
||
unless (ps.isEmpty) $ do {
|
||
emit "(";
|
||
if ps.size > closureMaxArgs && isBoxedName decl.name then
|
||
emit "obj**"
|
||
else
|
||
ps.size.mfor $ fun i => do {
|
||
when (i > 0) (emit ", ");
|
||
emit (toCppType (ps.get i).ty)
|
||
};
|
||
emit ")"
|
||
};
|
||
emitLn ";"
|
||
|
||
def emitFnDecl (decl : Decl) (addExternForConsts : Bool) : M Unit :=
|
||
do
|
||
openNamespacesFor decl.name;
|
||
cppBaseName ← toBaseCppName decl.name;
|
||
emitFnDeclAux decl cppBaseName addExternForConsts;
|
||
closeNamespacesFor decl.name
|
||
|
||
def cppQualifiedNameToName (s : String) : Name :=
|
||
(s.split "::").foldl Name.mkString Name.anonymous
|
||
|
||
def emitExternDeclAux (decl : Decl) (cppName : String) : M Unit :=
|
||
do
|
||
let qCppName := cppQualifiedNameToName cppName;
|
||
openNamespaces qCppName;
|
||
env ← getEnv;
|
||
let extC := isExternC env decl.name;
|
||
when extC (emit "extern \"C\" ");
|
||
(Name.mkString _ qCppBaseName) ← pure qCppName | throw "invalid name";
|
||
emitFnDeclAux decl qCppBaseName (!extC);
|
||
closeNamespaces qCppName
|
||
|
||
def emitFnDecls : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
let decls := getDecls env;
|
||
let modDecls : NameSet := decls.foldl (fun s d => s.insert d.name) {};
|
||
let usedDecls : NameSet := decls.foldl (fun s d => collectUsedDecls env d (s.insert d.name)) {};
|
||
let usedDecls := usedDecls.toList;
|
||
usedDecls.mfor $ fun n => do
|
||
decl ← getDecl n;
|
||
match getExternNameFor env `cpp decl.name with
|
||
| some cppName => emitExternDeclAux decl cppName
|
||
| none => emitFnDecl decl (!modDecls.contains n)
|
||
|
||
def emitMainFn : M Unit :=
|
||
do
|
||
d ← getDecl `main;
|
||
match d with
|
||
| Decl.fdecl f xs t b => do
|
||
unless (xs.size == 2 || xs.size == 1) (throw "invalid main function, incorrect arity when generating code");
|
||
env ← getEnv;
|
||
let usesLeanAPI := usesLeanNamespace env d;
|
||
when usesLeanAPI (emitLn "namespace lean { void initialize(); }");
|
||
emitLn "int main(int argc, char ** argv) {";
|
||
if usesLeanAPI then
|
||
emitLn "lean::initialize();"
|
||
else
|
||
emitLn "lean::initialize_runtime_module();";
|
||
emitLn "obj * w = lean::io_mk_world();";
|
||
modName ← getModName;
|
||
emitLn ("w = initialize_" ++ (modName.mangle "") ++ "(w);");
|
||
emitLns ["lean::io_mark_end_initialization();",
|
||
"if (io_result_is_ok(w)) {",
|
||
"lean::scoped_task_manager tmanager(lean::hardware_concurrency());"];
|
||
if xs.size == 2 then do {
|
||
emitLns ["obj* in = lean::box(0);",
|
||
"int i = argc;",
|
||
"while (i > 1) {",
|
||
" i--;",
|
||
" obj* n = lean::alloc_cnstr(1,2,0); lean::cnstr_set(n, 0, lean::mk_string(argv[i])); lean::cnstr_set(n, 1, in);",
|
||
" in = n;",
|
||
"}"];
|
||
emitLn ("w = " ++ leanMainFn ++ "(in, w);")
|
||
} else do {
|
||
emitLn ("w = " ++ leanMainFn ++ "(w);")
|
||
};
|
||
emitLn "}";
|
||
emitLns ["if (io_result_is_ok(w)) {",
|
||
" int ret = lean::unbox(io_result_get_value(w));",
|
||
" lean::dec_ref(w);",
|
||
" return ret;",
|
||
"} else {",
|
||
" lean::io_result_show_error(w);",
|
||
" lean::dec_ref(w);",
|
||
" return 1;",
|
||
"}"];
|
||
emitLn "}"
|
||
| other => throw "function declaration expected"
|
||
|
||
def hasMainFn : M Bool :=
|
||
do env ← getEnv;
|
||
let decls := getDecls env;
|
||
pure $ decls.any (fun d => d.name == `main)
|
||
|
||
def emitMainFnIfNeeded : M Unit :=
|
||
mwhen hasMainFn emitMainFn
|
||
|
||
def emitFileHeader : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
modName ← getModName;
|
||
emitLn "// Lean compiler output";
|
||
emitLn ("// Module: " ++ toString modName);
|
||
emit "// Imports:";
|
||
env.imports.mfor $ fun m => emit (" " ++ toString m);
|
||
emitLn "";
|
||
emitLn "#include \"runtime/object.h\"";
|
||
emitLn "#include \"runtime/apply.h\"";
|
||
mwhen hasMainFn $ emitLn "#include \"runtime/init_module.h\"";
|
||
emitLns [
|
||
"typedef lean::object obj; typedef lean::usize usize;",
|
||
"typedef lean::uint8 uint8; typedef lean::uint16 uint16;",
|
||
"typedef lean::uint32 uint32; typedef lean::uint64 uint64;",
|
||
"#if defined(__clang__)",
|
||
"#pragma clang diagnostic ignored \"-Wunused-parameter\"",
|
||
"#pragma clang diagnostic ignored \"-Wunused-label\"",
|
||
"#elif defined(__GNUC__) && !defined(__CLANG__)",
|
||
"#pragma GCC diagnostic ignored \"-Wunused-parameter\"",
|
||
"#pragma GCC diagnostic ignored \"-Wunused-label\"",
|
||
"#pragma GCC diagnostic ignored \"-Wunused-but-set-variable\"",
|
||
"#endif"]
|
||
|
||
def throwUnknownVar {α : Type} (x : VarId) : M α :=
|
||
throw ("unknown variable '" ++ toString x ++ "'")
|
||
|
||
def isObj (x : VarId) : M Bool :=
|
||
do ctx ← read;
|
||
match ctx.varMap.find x with
|
||
| some t => pure t.isObj
|
||
| none => throwUnknownVar x
|
||
|
||
def getJPParams (j : JoinPointId) : M (Array Param) :=
|
||
do ctx ← read;
|
||
match ctx.jpMap.find j with
|
||
| some ps => pure ps
|
||
| none => throw "unknown join point"
|
||
|
||
def declareVar (x : VarId) (t : IRType) : M Unit :=
|
||
do emit (toCppType t); emit " "; emit x; emit "; "
|
||
|
||
def declareParams (ps : Array Param) : M Unit :=
|
||
ps.mfor $ fun p => declareVar p.x p.ty
|
||
|
||
partial def declareVars : FnBody → Bool → M Bool
|
||
| e@(FnBody.vdecl x t _ b), d => do
|
||
ctx ← read;
|
||
if isTailCallTo ctx.mainFn e then
|
||
pure d
|
||
else
|
||
declareVar x t *> declareVars b true
|
||
| FnBody.jdecl j xs _ b, d => declareParams xs *> declareVars b (d || xs.size > 0)
|
||
| e, d => if e.isTerminal then pure d else declareVars e.body d
|
||
|
||
def emitTag (x : VarId) : M Unit :=
|
||
do
|
||
xIsObj ← isObj x;
|
||
if xIsObj then do
|
||
emit "lean::obj_tag("; emit x; emit ")"
|
||
else
|
||
emit x
|
||
|
||
def isIf (alts : Array Alt) : Option (Nat × FnBody × FnBody) :=
|
||
if alts.size != 2 then none
|
||
else match alts.get 0 with
|
||
| Alt.ctor c b => some (c.cidx, b, (alts.get 1).body)
|
||
| _ => none
|
||
|
||
def emitIf (emitBody : FnBody → M Unit) (x : VarId) (tag : Nat) (t : FnBody) (e : FnBody) : M Unit :=
|
||
do
|
||
emit "if ("; emitTag x; emit " == "; emit tag; emitLn ")";
|
||
emitBody t;
|
||
emitLn "else";
|
||
emitBody e
|
||
|
||
def emitCase (emitBody : FnBody → M Unit) (x : VarId) (alts : Array Alt) : M Unit :=
|
||
match isIf alts with
|
||
| some (tag, t, e) => emitIf emitBody x tag t e
|
||
| _ => do
|
||
emit "switch ("; emitTag x; emitLn ") {";
|
||
let alts := ensureHasDefault alts;
|
||
alts.mfor $ fun alt => match alt with
|
||
| Alt.ctor c b => emit "case " *> emit c.cidx *> emitLn ":" *> emitBody b
|
||
| Alt.default b => emitLn "default: " *> emitBody b;
|
||
emitLn "}"
|
||
|
||
def emitInc (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
|
||
do
|
||
emit (if checkRef then "lean::inc" else "lean::inc_ref");
|
||
emit "(" *> emit x;
|
||
when (n != 1) (emit ", " *> emit n);
|
||
emitLn ");"
|
||
|
||
def emitDec (x : VarId) (n : Nat) (checkRef : Bool) : M Unit :=
|
||
do
|
||
emit (if checkRef then "lean::dec" else "lean::dec_ref");
|
||
emit "("; emit x;
|
||
when (n != 1) (do emit ", "; emit n);
|
||
emitLn ");"
|
||
|
||
def emitDel (x : VarId) : M Unit :=
|
||
do emit "lean::free_heap_obj("; emit x; emitLn ");"
|
||
|
||
def emitSetTag (x : VarId) (i : Nat) : M Unit :=
|
||
do emit "lean::cnstr_set_tag("; emit x; emit ", "; emit i; emitLn ");"
|
||
|
||
def emitSet (x : VarId) (i : Nat) (y : Arg) : M Unit :=
|
||
do emit "lean::cnstr_set("; emit x; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||
|
||
def emitOffset (n : Nat) (offset : Nat) : M Unit :=
|
||
if n > 0 then do
|
||
emit "sizeof(void*)*"; emit n;
|
||
when (offset > 0) (emit " + " *> emit offset)
|
||
else
|
||
emit offset
|
||
|
||
def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit :=
|
||
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n 0; emit ", "; emit y; emitLn ");"
|
||
|
||
def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) : M Unit :=
|
||
do emit "lean::cnstr_set_scalar("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");"
|
||
|
||
def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit :=
|
||
do
|
||
ps ← getJPParams j;
|
||
unless (xs.size == ps.size) (throw "invalid goto");
|
||
xs.size.mfor $ fun i => do {
|
||
let p := ps.get i;
|
||
let x := xs.get i;
|
||
emit p.x; emit " = "; emitArg x; emitLn ";"
|
||
};
|
||
emit "goto "; emit j; emitLn ";"
|
||
|
||
def emitLhs (z : VarId) : M Unit :=
|
||
do emit z; emit " = "
|
||
|
||
def emitArgs (ys : Array Arg) : M Unit :=
|
||
ys.size.mfor $ fun i => do
|
||
when (i > 0) (emit ", ");
|
||
emitArg (ys.get i)
|
||
|
||
def emitCtorScalarSize (usize : Nat) (ssize : Nat) : M Unit :=
|
||
if usize == 0 then emit ssize
|
||
else if ssize == 0 then emit "sizeof(size_t)*" *> emit usize
|
||
else emit "sizeof(size_t)*" *> emit usize *> emit " + " *> emit ssize
|
||
|
||
def emitAllocCtor (c : CtorInfo) : M Unit :=
|
||
do
|
||
emit "lean::alloc_cnstr("; emit c.cidx; emit ", "; emit c.size; emit ", ";
|
||
emitCtorScalarSize c.usize c.ssize; emitLn ");"
|
||
|
||
def emitCtorSetArgs (z : VarId) (ys : Array Arg) : M Unit :=
|
||
ys.size.mfor $ fun i => do
|
||
emit "lean::cnstr_set("; emit z; emit ", "; emit i; emit ", "; emitArg (ys.get i); emitLn ");"
|
||
|
||
def emitCtor (z : VarId) (c : CtorInfo) (ys : Array Arg) : M Unit :=
|
||
do
|
||
emitLhs z;
|
||
if c.size == 0 && c.usize == 0 && c.ssize == 0 then do
|
||
emit "lean::box("; emit c.cidx; emitLn ");"
|
||
else do
|
||
emitAllocCtor c; emitCtorSetArgs z ys
|
||
|
||
def emitReset (z : VarId) (n : Nat) (x : VarId) : M Unit :=
|
||
do
|
||
emit "if (lean::is_exclusive("; emit x; emitLn ")) {";
|
||
n.mfor $ fun i => do {
|
||
emit " lean::cnstr_release("; emit x; emit ", "; emit i; emitLn ");"
|
||
};
|
||
emit " "; emitLhs z; emit x; emitLn ";";
|
||
emitLn "} else {";
|
||
emit " lean::dec_ref("; emit x; emitLn ");";
|
||
emit " "; emitLhs z; emitLn "lean::box(0);";
|
||
emitLn "}"
|
||
|
||
def emitReuse (z : VarId) (x : VarId) (c : CtorInfo) (updtHeader : Bool) (ys : Array Arg) : M Unit :=
|
||
do
|
||
emit "if (lean::is_scalar("; emit x; emitLn ")) {";
|
||
emit " "; emitLhs z; emitAllocCtor c;
|
||
emitLn "} else {";
|
||
emit " "; emitLhs z; emit x; emitLn ";";
|
||
when updtHeader (do emit " lean::cnstr_set_tag("; emit z; emit ", "; emit c.cidx; emitLn ");");
|
||
emitLn "}";
|
||
emitCtorSetArgs z ys
|
||
|
||
def emitProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
|
||
do emitLhs z; emit "lean::cnstr_get("; emit x; emit ", "; emit i; emitLn ");"
|
||
|
||
def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit :=
|
||
do emitLhs z; emit "lean::cnstr_get_scalar<usize>("; emit x; emit ", sizeof(void*)*"; emit i; emitLn ");"
|
||
|
||
def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit :=
|
||
do emitLhs z; emit "lean::cnstr_get_scalar<"; emit (toCppType t); emit ">("; emit x; emit ", "; emitOffset n offset; emitLn ");"
|
||
|
||
def toStringArgs (ys : Array Arg) : List String :=
|
||
ys.toList.map argToCppString
|
||
|
||
def emitFullApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
|
||
do
|
||
emitLhs z;
|
||
decl ← getDecl f;
|
||
match decl with
|
||
| Decl.extern _ _ _ extData =>
|
||
match mkExternCall extData `cpp (toStringArgs ys) with
|
||
| some c => emit c *> emitLn ";"
|
||
| none => throw "failed to emit extern application"
|
||
| _ => do emitCppName f; when (ys.size > 0) (do emit "("; emitArgs ys; emit ")"); emitLn ";"
|
||
|
||
def emitPartialApp (z : VarId) (f : FunId) (ys : Array Arg) : M Unit :=
|
||
do
|
||
decl ← getDecl f;
|
||
let arity := decl.params.size;
|
||
emitLhs z; emit "lean::alloc_closure(reinterpret_cast<void*>("; emitCppName f; emit "), "; emit arity; emit ", "; emit ys.size; emitLn ");";
|
||
ys.size.mfor $ fun i => do {
|
||
let y := ys.get i;
|
||
emit "lean::closure_set("; emit z; emit ", "; emit i; emit ", "; emitArg y; emitLn ");"
|
||
}
|
||
|
||
def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit :=
|
||
if ys.size > closureMaxArgs then do
|
||
emit "{ obj* _aargs[] = {"; emitArgs ys; emitLn "};";
|
||
emitLhs z; emit "lean::apply_m("; emit f; emit ", "; emit ys.size; emitLn ", _aargs); }"
|
||
else do
|
||
emitLhs z; emit "lean::apply_"; emit ys.size; emit "("; emit f; emit ", "; emitArgs ys; emitLn ");"
|
||
|
||
def emitBoxFn (xType : IRType) : M Unit :=
|
||
match xType with
|
||
| IRType.usize => emit "lean::box_size_t"
|
||
| IRType.uint32 => emit "lean::box_uint32"
|
||
| IRType.uint64 => emit "lean::box_uint64"
|
||
| IRType.float => throw "floats are not supported yet"
|
||
| other => emit "lean::box"
|
||
|
||
def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit :=
|
||
do emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");"
|
||
|
||
def emitUnbox (z : VarId) (t : IRType) (x : VarId) : M Unit :=
|
||
do
|
||
emitLhs z;
|
||
match t with
|
||
| IRType.usize => emit "lean::unbox_size_t"
|
||
| IRType.uint32 => emit "lean::unbox_uint32"
|
||
| IRType.uint64 => emit "lean::unbox_uint64"
|
||
| IRType.float => throw "floats are not supported yet"
|
||
| other => emit "lean::unbox";
|
||
emit "("; emit x; emitLn ");"
|
||
|
||
def emitIsShared (z : VarId) (x : VarId) : M Unit :=
|
||
do emitLhs z; emit "!lean::is_exclusive("; emit x; emitLn ");"
|
||
|
||
def emitIsTaggedPtr (z : VarId) (x : VarId) : M Unit :=
|
||
do emitLhs z; emit "!lean::is_scalar("; emit x; emitLn ");"
|
||
|
||
def toHexDigit (c : Nat) : String :=
|
||
String.singleton c.digitChar
|
||
|
||
def quoteString (s : String) : String :=
|
||
let q := "\"";
|
||
let q := s.foldl
|
||
(fun q c => q ++
|
||
if c == '\n' then "\\n"
|
||
else if c == '\n' then "\\t"
|
||
else if c == '\\' then "\\\\"
|
||
else if c == '\"' then "\\\""
|
||
else if c.toNat <= 31 then
|
||
"\\x" ++ toHexDigit (c.toNat / 16) ++ toHexDigit (c.toNat % 16)
|
||
-- TODO(Leo): we should use `\unnnn` for escaping unicode characters.
|
||
else String.singleton c)
|
||
q;
|
||
q ++ "\""
|
||
|
||
def emitNumLit (t : IRType) (v : Nat) : M Unit :=
|
||
if t.isObj then do
|
||
emit "lean::mk_nat_obj(";
|
||
if v < uint32Sz then emit v *> emit "u"
|
||
else emit "lean::mpz(\"" *> emit v *> emit "\")";
|
||
emit ")"
|
||
else
|
||
emit v
|
||
|
||
def emitLit (z : VarId) (t : IRType) (v : LitVal) : M Unit :=
|
||
emitLhs z *>
|
||
match v with
|
||
| LitVal.num v => emitNumLit t v *> emitLn ";"
|
||
| LitVal.str v => do emit "lean::mk_string("; emit (quoteString v); emitLn ");"
|
||
|
||
def emitVDecl (z : VarId) (t : IRType) (v : Expr) : M Unit :=
|
||
match v with
|
||
| Expr.ctor c ys => emitCtor z c ys
|
||
| Expr.reset n x => emitReset z n x
|
||
| Expr.reuse x c u ys => emitReuse z x c u ys
|
||
| Expr.proj i x => emitProj z i x
|
||
| Expr.uproj i x => emitUProj z i x
|
||
| Expr.sproj n o x => emitSProj z t n o x
|
||
| Expr.fap c ys => emitFullApp z c ys
|
||
| Expr.pap c ys => emitPartialApp z c ys
|
||
| Expr.ap x ys => emitApp z x ys
|
||
| Expr.box t x => emitBox z x t
|
||
| Expr.unbox x => emitUnbox z t x
|
||
| Expr.isShared x => emitIsShared z x
|
||
| Expr.isTaggedPtr x => emitIsTaggedPtr z x
|
||
| Expr.lit v => emitLit z t v
|
||
|
||
def isTailCall (x : VarId) (v : Expr) (b : FnBody) : M Bool :=
|
||
do
|
||
ctx ← read;
|
||
match v, b with
|
||
| Expr.fap f _, FnBody.ret (Arg.var y) => pure $ f == ctx.mainFn && x == y
|
||
| _, _ => pure false
|
||
|
||
def paramEqArg (p : Param) (x : Arg) : Bool :=
|
||
match x with
|
||
| Arg.var x => p.x == x
|
||
| _ => false
|
||
|
||
/-
|
||
Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments
|
||
```
|
||
p_0 := y_0,
|
||
...
|
||
p_{n-1} := y_{n-1}
|
||
```
|
||
Return true iff we have `(i, j)` where `j > i`, and `y_j == p_i`.
|
||
That is, we have
|
||
```
|
||
p_i := y_i,
|
||
...
|
||
p_j := p_i, -- p_i was overwritten above
|
||
```
|
||
-/
|
||
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
|
||
let n := ps.size;
|
||
n.any $ fun i =>
|
||
let p := ps.get i;
|
||
(i+1, n).anyI $ fun j => paramEqArg p (ys.get j)
|
||
|
||
def emitTailCall (v : Expr) : M Unit :=
|
||
match v with
|
||
| Expr.fap _ ys => do
|
||
ctx ← read;
|
||
let ps := ctx.mainParams;
|
||
unless (ps.size == ys.size) (throw "invalid tail call");
|
||
if overwriteParam ps ys then do {
|
||
emitLn "{";
|
||
ps.size.mfor $ fun i => do {
|
||
let p := ps.get i;
|
||
let y := ys.get i;
|
||
unless (paramEqArg p y) $ do {
|
||
emit (toCppType p.ty); emit " _tmp_"; emit i; emit " = "; emitArg y; emitLn ";"
|
||
}
|
||
};
|
||
ps.size.mfor $ fun i => do {
|
||
let p := ps.get i;
|
||
let y := ys.get i;
|
||
unless (paramEqArg p y) (do emit p.x; emit " = _tmp_"; emit i; emitLn ";")
|
||
};
|
||
emitLn "}"
|
||
} else do {
|
||
ys.size.mfor $ fun i => do {
|
||
let p := ps.get i;
|
||
let y := ys.get i;
|
||
unless (paramEqArg p y) (do emit p.x; emit " = "; emitArg y; emitLn ";")
|
||
}
|
||
};
|
||
emitLn "goto _start;"
|
||
| _ => throw "bug at emitTailCall"
|
||
|
||
partial def emitBlock (emitBody : FnBody → M Unit) : FnBody → M Unit
|
||
| FnBody.jdecl j xs v b => emitBlock b
|
||
| d@(FnBody.vdecl x t v b) =>
|
||
do ctx ← read; if isTailCallTo ctx.mainFn d then emitTailCall v else emitVDecl x t v *> emitBlock b
|
||
| FnBody.inc x n c b => emitInc x n c *> emitBlock b
|
||
| FnBody.dec x n c b => emitDec x n c *> emitBlock b
|
||
| FnBody.del x b => emitDel x *> emitBlock b
|
||
| FnBody.setTag x i b => emitSetTag x i *> emitBlock b
|
||
| FnBody.set x i y b => emitSet x i y *> emitBlock b
|
||
| FnBody.uset x i y b => emitUSet x i y *> emitBlock b
|
||
| FnBody.sset x i o y _ b => emitSSet x i o y *> emitBlock b
|
||
| FnBody.mdata _ b => emitBlock b
|
||
| FnBody.ret x => emit "return " *> emitArg x *> emitLn ";"
|
||
| FnBody.case _ x alts => emitCase emitBody x alts
|
||
| FnBody.jmp j xs => emitJmp j xs
|
||
| FnBody.unreachable => emitLn "lean_unreachable();"
|
||
|
||
partial def emitJPs (emitBody : FnBody → M Unit) : FnBody → M Unit
|
||
| FnBody.jdecl j xs v b => do emit j; emitLn ":"; emitBody v; emitJPs b
|
||
| e => unless e.isTerminal (emitJPs e.body)
|
||
|
||
partial def emitFnBody : FnBody → M Unit
|
||
| b => do
|
||
emitLn "{";
|
||
declared ← declareVars b false;
|
||
when declared (emitLn "");
|
||
emitBlock emitFnBody b;
|
||
emitJPs emitFnBody b;
|
||
emitLn "}"
|
||
|
||
def emitDeclAux (d : Decl) : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
let (vMap, jpMap) := mkVarJPMaps d;
|
||
adaptReader (fun (ctx : Context) => { varMap := vMap, jpMap := jpMap, .. ctx }) $ do
|
||
unless (hasInitAttr env d.name) $
|
||
match d with
|
||
| Decl.fdecl f xs t b => do
|
||
openNamespacesFor f;
|
||
baseName ← toBaseCppName f;
|
||
emit (toCppType t); emit " ";
|
||
if xs.size > 0 then do {
|
||
emit baseName;
|
||
emit "(";
|
||
if xs.size > closureMaxArgs && isBoxedName d.name then
|
||
emit "obj** _args"
|
||
else
|
||
xs.size.mfor $ fun i => do {
|
||
when (i > 0) (emit ", ");
|
||
let x := xs.get i;
|
||
emit (toCppType x.ty); emit " "; emit x.x
|
||
};
|
||
emit ")"
|
||
} else do {
|
||
emit ("_init_" ++ baseName ++ "()")
|
||
};
|
||
emitLn " {";
|
||
when (xs.size > closureMaxArgs && isBoxedName d.name) $
|
||
xs.size.mfor $ fun i => do {
|
||
let x := xs.get i;
|
||
emit "obj * "; emit x.x; emit " = _args["; emit i; emitLn "];"
|
||
};
|
||
emitLn "_start:";
|
||
adaptReader (fun (ctx : Context) => { mainFn := f, mainParams := xs, .. ctx }) (emitFnBody b);
|
||
emitLn "}";
|
||
closeNamespacesFor f
|
||
| _ => pure ()
|
||
|
||
def emitDecl (d : Decl) : M Unit :=
|
||
let d := d.normalizeIds;
|
||
catch
|
||
(emitDeclAux d)
|
||
(fun err => throw (err ++ "\ncompiling:\n" ++ toString d))
|
||
|
||
def emitFns : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
let decls := getDecls env;
|
||
decls.reverse.mfor emitDecl
|
||
|
||
def emitDeclInit (d : Decl) : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
let n := d.name;
|
||
if isIOUnitInitFn env n then do {
|
||
emit "w = "; emitCppName n; emitLn "(w);";
|
||
emitLn "if (io_result_is_error(w)) return w;"
|
||
} else when (d.params.size == 0) $ do {
|
||
match getInitFnNameFor env d.name with
|
||
| some initFn => do {
|
||
emit "w = "; emitCppName initFn; emitLn "(w);";
|
||
emitLn "if (io_result_is_error(w)) return w;";
|
||
emitCppName n; emitLn " = io_result_get_value(w);"
|
||
}
|
||
| _ => do {
|
||
emitCppName n; emit " = "; emitCppInitName n; emitLn "();"
|
||
};
|
||
when d.resultType.isObj $ do {
|
||
emit "lean::mark_persistent("; emitCppName n; emitLn ");"
|
||
}
|
||
}
|
||
|
||
def emitInitFn : M Unit :=
|
||
do
|
||
env ← getEnv;
|
||
modName ← getModName;
|
||
env.imports.mfor $ fun m => emitLn ("obj* initialize_" ++ m.mangle "" ++ "(obj*);");
|
||
emitLns [
|
||
"static bool _G_initialized = false;",
|
||
"obj* initialize_" ++ modName.mangle "" ++ "(obj* w) {",
|
||
"if (_G_initialized) return w;",
|
||
"_G_initialized = true;",
|
||
"if (io_result_is_error(w)) return w;"
|
||
];
|
||
env.imports.mfor $ fun m => emitLns [
|
||
"w = initialize_" ++ m.mangle "" ++ "(w);",
|
||
"if (io_result_is_error(w)) return w;"
|
||
];
|
||
let decls := getDecls env;
|
||
decls.reverse.mfor emitDeclInit;
|
||
emitLns [
|
||
"return w;",
|
||
"}"]
|
||
|
||
def main : M Unit :=
|
||
do
|
||
emitFileHeader;
|
||
emitFnDecls;
|
||
emitFns;
|
||
emitInitFn;
|
||
emitMainFnIfNeeded
|
||
|
||
end EmitCpp
|
||
|
||
@[export lean.ir.emit_cpp_core]
|
||
def emitCpp (env : Environment) (modName : Name) : Except String String :=
|
||
match (EmitCpp.main { env := env, modName := modName }).run "" with
|
||
| EState.Result.ok _ s => Except.ok s
|
||
| EState.Result.error err _ => Except.error err
|
||
|
||
end IR
|
||
end Lean
|