chore(library/init/lean/ir/extract_cpp): cleanup

This commit is contained in:
Leonardo de Moura 2018-05-10 10:27:04 -07:00
parent e882d9d7cf
commit 9a261975f7

View file

@ -24,11 +24,19 @@ modify (++ (to_string a))
def emit_line : extract_m unit :=
emit "\n"
def paren {α} (a : extract_m α) : extract_m α :=
emit "(" >> a <* emit ")"
def comma (a b : extract_m unit) : extract_m unit :=
a >> emit ", " >> b
local infix `<+>`:65 := comma
def emit_var (x : var) : extract_m unit :=
emit $ name.mangle x "_x"
def emit_blockid (b : blockid) : extract_m unit :=
emit $ name.mangle b "_label"
emit $ name.mangle b "_lbl"
def emit_fnid (fid : fnid) : extract_m unit :=
do env ← read,
@ -41,7 +49,7 @@ def type2cpp : type → string
| type.uint16 := "unsigned short" | type.uint32 := "unsigned" | type.uint64 := "unsigned long long" | type.usize := "size_t"
| type.int16 := "short" | type.int32 := "int" | type.int64 := "long long"
| type.float := "float" | type.double := "double"
| type.object := "lean_obj*"
| type.object := "lean::lean_obj*"
def emit_type (ty : type) : extract_m unit :=
emit (type2cpp ty)
@ -54,10 +62,19 @@ def emit_sep_aux {α} (f : α → extract_m unit) (sep : string) : list α → e
def emit_sep {α} (l : list α) (f : α → extract_m unit) (sep := ",") : extract_m unit :=
emit_sep_aux f sep l
def emit_var_list (xs : list var) : extract_m unit :=
emit_sep xs emit_var
def emit_template_params (ts : list type) : extract_m unit :=
emit "<" >> emit_sep ts emit_type >> emit ">"
def emit_template_param (t : type) : extract_m unit :=
emit_template_params [t]
def emit_return : list result → extract_m unit
| [] := emit "void"
| [r] := emit_type r.ty
| rs := emit "std::tuple<" >> emit_sep rs (λ r, emit_type r.ty) >> emit ">"
| rs := emit "std::tuple" >> emit_template_params (rs.map result.ty)
def emit_arg_list (args : list arg) : extract_m unit :=
emit_sep args $ λ a, emit_type a.ty >> emit " : " >> emit_var a.n
@ -67,12 +84,12 @@ def emit_eos : extract_m unit :=
emit ";" >> emit_line
def emit_tag (x : var) : extract_m unit :=
emit "lean::cnstr_tag(" >> emit_var x >> emit ")"
emit "lean::cnstr_tag" >> paren(emit_var x)
def emit_return_vars : list var → extract_m unit
| [] := return ()
| [x] := emit_var x
| xs := emit "std::make_tuple(" >> emit_sep xs emit_var >> emit ")"
| xs := emit "std::make_tuple" >> paren(emit_var_list xs)
def emit_cases : list blockid → nat → extract_m unit
| [] n := throw "ill-formed case terminator"
@ -98,10 +115,10 @@ match term with
def emit_call_lhs : list var → extract_m unit
| [] := return ()
| [x] := emit_var x >> emit " := "
| xs := emit "std::tie(" >> emit_sep xs emit_var >> emit ") := "
| xs := emit "std::tie" >> paren(emit_var_list xs) >> emit " := "
def emit_type_size (ty : type) : extract_m unit :=
emit "sizeof(" >> emit_type ty >> emit ")"
emit "sizeof" >> paren(emit_type ty)
def emit_sizet : list (nat × type) → extract_m unit
| [] := emit 0
@ -109,7 +126,7 @@ def emit_sizet : list (nat × type) → extract_m unit
/-- Emit `op(x)` -/
def emit_op_x (op : string) (x : var) : extract_m unit :=
emit op >> emit "(" >> emit_var x >> emit ")"
emit op >> paren (emit_var x)
/-- Emit `x := y op z` -/
def emit_infix (x y z : var) (op : string) : extract_m unit :=
@ -117,7 +134,7 @@ emit_var x >> emit " := " >> emit_var y >> emit op >> emit_var z
/- Emit `x := big_op(y, z)` -/
def emit_big_binop (x y z : var) (big_op : string) : extract_m unit :=
emit_var x >> emit " := " >> emit big_op >> emit "(" >> emit_var y >> emit ", " >> emit_var z >> emit ")"
emit_var x >> emit " := " >> emit big_op >> paren (emit_var y <+> emit_var z)
def emit_arith (t : type) (x y z : var) (op : string) (big_op : string) : extract_m unit :=
match t with
@ -153,12 +170,12 @@ match op with
| binop.ne := emit_arith t x y z "!=" "lean::big_nq"
| binop.array_read :=
(match t with
| type.object := emit_var x >> emit " := lean::array_obj(" >> emit_var y >> emit ", " >> emit_var z >> emit ")"
| _ := emit_var x >> emit " := lean::sarray_data<" >> emit_type t >> emit ">(" >> emit_var y >> emit ", " >> emit_var z >> emit ")")
| type.object := emit_var x >> emit " := lean::array_obj" >> paren (emit_var y <+> emit_var z)
| _ := emit_var x >> emit " := lean::sarray_data" >> emit_template_param t >> paren (emit_var y <+> emit_var z))
/-- Emit `x := op(y)` -/
def emit_x_op_y (x : var) (op : string) (y : var) : extract_m unit :=
emit_var x >> emit " := " >> emit op >> emit "(" >> emit_var y >> emit ")"
emit_var x >> emit " := " >> emit op >> paren(emit_var y)
def unop2cpp (t : type) : unop → string
| unop.not := if t = type.bool then "!" else "~"
@ -175,7 +192,7 @@ def unop2cpp (t : type) : unop → string
| unop.string_len := "lean::string_len"
def emit_unop (x : var) (t : type) (op : unop) (y : var) : extract_m unit :=
emit_var x >> emit " := " >> emit (unop2cpp t op) >> emit "(" >> emit_var y >> emit ")"
emit_var x >> emit " := " >> emit (unop2cpp t op) >> paren(emit_var y)
def emit_num_suffix : type → extract_m unit
| type.uint32 := emit "u"
@ -187,7 +204,7 @@ def emit_lit (x : var) (t : type) (l : literal) : extract_m unit :=
match l with
| literal.bool tt := emit_var x >> emit " := true"
| literal.bool ff := emit_var x >> emit " := false"
| literal.str s := emit_var x >> emit " := lean::mk_string(" >> emit (repr s) >> emit ")"
| literal.str s := emit_var x >> emit " := lean::mk_string" >> paren(emit (repr s))
| literal.float v := emit_var x >> emit " := " >> emit v
| literal.num v := emit_var x >> emit " := " >> emit v >> emit_num_suffix t
@ -201,7 +218,7 @@ def unins2cpp : unins → string
| unins.sarray_pop := "lean::sarray_pop"
def emit_unary (op : unins) (x : var) : extract_m unit :=
emit (unins2cpp op) >> emit "(" >> emit_var x >> emit ")"
emit (unins2cpp op) >> paren(emit_var x)
def emit_instr (ins : instr) : extract_m unit :=
ins.decorate_error $
@ -209,26 +226,26 @@ ins.decorate_error $
| (instr.lit x t l) := emit_lit x t l
| (instr.unop x t op y) := emit_unop x t op y
| (instr.binop x t op y z) := emit_binop x t op y z
| (instr.call xs f ys) := emit_call_lhs xs >> emit_fnid f >> emit "(" >> emit_sep ys emit_var >> emit ")"
| (instr.cnstr o t n sz) := emit_var o >> emit " := lean::alloc_cnstr(" >> emit t >> emit ", " >> emit n >> emit ", " >> emit_sizet sz >> emit ")"
| (instr.set o i x) := emit "lean::set_cnstr_obj(" >> emit_var o >> emit ", " >> emit i >> emit ", " >> emit_var x >> emit ")"
| (instr.get x o i) := emit_var x >> emit " := lean::cnstr_obj(" >> emit_var o >> emit ", " >> emit i >> emit ")"
| (instr.sset o d x) := emit "lean::set_cnstr_scalar(" >> emit_var o >> emit ", " >> emit_sizet d >> emit ", " >> emit_var x >> emit ")"
| (instr.sget x t o d) := emit_var x >> emit " := lean::cnstr_scalar<" >> emit_type t >> emit ">(" >> emit_var o >> emit ", " >> emit_sizet d >> emit ")"
| (instr.call xs f ys) := emit_call_lhs xs >> emit_fnid f >> paren(emit_var_list ys)
| (instr.cnstr o t n sz) := emit_var o >> emit " := lean::alloc_cnstr" >> paren(emit t <+> emit n <+> emit_sizet sz)
| (instr.set o i x) := emit "lean::set_cnstr_obj" >> paren (emit_var o <+> emit i <+> emit_var x)
| (instr.get x o i) := emit_var x >> emit " := lean::cnstr_obj" >> paren(emit_var o <+> emit i)
| (instr.sset o d x) := emit "lean::set_cnstr_scalar" >> paren(emit_var o <+> emit_sizet d <+> emit_var x)
| (instr.sget x t o d) := emit_var x >> emit " := lean::cnstr_scalar" >> emit_template_param t >> paren(emit_var o <+> emit_sizet d)
| (instr.closure x f ys) := return () -- TODO
| (instr.apply x ys) := return () -- TODO
| (instr.array a sz c) := emit_var a >> emit " := lean::alloc_array(" >> emit_var sz >> emit ", " >> emit_var c >> emit ")"
| (instr.sarray a t sz c) := emit_var a >> emit " := lean::alloc_sarray(" >> emit_type_size t >> emit ", " >> emit_var sz >> emit ", " >> emit_var c >> emit ")"
| (instr.array a sz c) := emit_var a >> emit " := lean::alloc_array" >> paren(emit_var sz <+> emit_var c)
| (instr.sarray a t sz c) := emit_var a >> emit " := lean::alloc_sarray" >> paren(emit_type_size t <+> emit_var sz <+> emit_var c)
| (instr.array_write a i v) :=
do env ← read,
if env.ctx.find v = some type.object
then emit "lean::set_array_obj(" >> emit_var a >> emit ", " >> emit_var i >> emit ", " >> emit_var v >> emit ")"
else emit "lean::set_sarray_data(" >> emit_var a >> emit ", " >> emit_var i >> emit ", " >> emit_var v >> emit ")"
then emit "lean::set_array_obj" >> paren(emit_var a <+> emit_var i <+> emit_var v)
else emit "lean::set_sarray_data" >> paren(emit_var a <+> emit_var i <+> emit_var v)
| (instr.array_push a v) :=
do env ← read,
if env.ctx.find v = some type.object
then emit "lean::array_push(" >> emit_var a >> emit ", " >> emit_var v >> emit ")"
else emit "lean::sarray_push(" >> emit_var a >> emit ", " >> emit_var v >> emit ")"
then emit "lean::array_push" >> paren(emit_var a <+> emit_var v)
else emit "lean::sarray_push" >> paren(emit_var a <+> emit_var v)
| (instr.unary op x) := emit_unary op x)
>> emit_eos
@ -239,7 +256,7 @@ def emit_block (b : block) : extract_m unit :=
>> emit_terminator b.term
def emit_header (h : header) : extract_m unit :=
emit_return h.return >> emit " " >> emit_fnid h.n >> emit "(" >> emit_arg_list h.args >> emit ")"
emit_return h.return >> emit " " >> emit_fnid h.n >> paren(emit_arg_list h.args)
def decl_local (x : var) (ty : type) : extract_m unit :=
emit_var x >> emit_type ty >> emit_eos
@ -258,11 +275,11 @@ match d with
def emit_decl (env : environment) (external_names : fnid → option string) (d : decl) : except_t format (state_t string id) unit :=
do ctx ← monad_lift $ infer_types d env,
(emit_decl_core d).run { external_names := external_names, ctx := ctx }
def emit_decls (env : environment) (cpp_names : fnid → option string) (ds : list decl) : except format string :=
let out := file_header in
run (ds.mfor (emit_decl env cpp_names) >> get) out
end cpp
def extract_cpp (env : environment) (cpp_names : fnid → option string) (ds : list decl) : except format string :=
let out := cpp.file_header in
run (ds.mfor (cpp.emit_decl env cpp_names) >> get) out
end ir
end lean