fix(library/compiler/emit_cpp): tail call
Add temporary hack to fix `emit_tail_call`. TODO: find a cleaner solution for the new IR compiler.
This commit is contained in:
parent
64525116f5
commit
4d2837430a
3 changed files with 153 additions and 4 deletions
|
|
@ -752,17 +752,45 @@ struct emit_fn_fn {
|
|||
return none_expr();
|
||||
}
|
||||
|
||||
bool overwrite_param(buffer<expr> const & args) {
|
||||
lean_assert(args.size() == m_fn_args.size());
|
||||
for (unsigned i = 0; i < m_fn_args.size(); i++) {
|
||||
expr p = m_fn_args[i];
|
||||
for (unsigned j = i+1; j < args.size(); j++) {
|
||||
if (args[j] == p)
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void emit_tail_call(expr const & e) {
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(e, args);
|
||||
lean_assert(is_constant(fn) && const_name(fn) == m_fn_name);
|
||||
lean_assert(args.size() == m_fn_args.size());
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (args[i] != m_fn_args[i]) {
|
||||
emit_fvar(m_fn_args[i]); m_out << " = "; emit_arg(args[i]); m_out << ";\n";
|
||||
if (overwrite_param(args)) {
|
||||
m_out << "{\n";
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (args[i] != m_fn_args[i]) {
|
||||
m_out << "auto y_" << i << " = "; emit_arg(args[i]); m_out << ";\n";
|
||||
}
|
||||
}
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (args[i] != m_fn_args[i]) {
|
||||
emit_fvar(m_fn_args[i]); m_out << " = y_" << i << ";\n";
|
||||
}
|
||||
}
|
||||
m_out << "}\n";
|
||||
m_out << "goto _start;\n";
|
||||
} else {
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (args[i] != m_fn_args[i]) {
|
||||
emit_fvar(m_fn_args[i]); m_out << " = "; emit_arg(args[i]); m_out << ";\n";
|
||||
}
|
||||
}
|
||||
m_out << "goto _start;\n";
|
||||
}
|
||||
m_out << "goto _start;\n";
|
||||
}
|
||||
|
||||
void emit_terminal(expr const & e, bool tail_call) {
|
||||
|
|
|
|||
114
tests/compiler/t4.lean
Normal file
114
tests/compiler/t4.lean
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
/- Benchmark for new code generator -/
|
||||
inductive Expr
|
||||
| Val : Int → Expr
|
||||
| Var : String → Expr
|
||||
| Add : Expr → Expr → Expr
|
||||
| Mul : Expr → Expr → Expr
|
||||
| Pow : Expr → Expr → Expr
|
||||
| Ln : Expr → Expr
|
||||
|
||||
open Expr
|
||||
|
||||
def Expr.toString : Expr → String
|
||||
| (Val n) := toString n
|
||||
| (Var x) := x
|
||||
| (Add f g) := "(" ++ Expr.toString f ++ " + " ++ Expr.toString g ++ ")"
|
||||
| (Mul f g) := "(" ++ Expr.toString f ++ " * " ++ Expr.toString g ++ ")"
|
||||
| (Pow f g) := "(" ++ Expr.toString f ++ " ^ " ++ Expr.toString g ++ ")"
|
||||
| (Ln f) := "ln(" ++ Expr.toString f ++ ")"
|
||||
|
||||
instance : HasToString Expr :=
|
||||
⟨Expr.toString⟩
|
||||
|
||||
partial def pown : Int → Int → Int
|
||||
| a 0 := 1
|
||||
| a 1 := a
|
||||
| a n :=
|
||||
let b := pown a (n / 2) in
|
||||
b * b * (if n % 2 = 0 then 1 else a)
|
||||
|
||||
partial def addAux : Expr → Expr → Expr
|
||||
| (Val n) (Val m) := Val (n + m)
|
||||
| (Val 0) f := f
|
||||
| f (Val 0) := f
|
||||
| f (Val n) := addAux (Val n) f
|
||||
| (Val n) (Add (Val m) f) := addAux (Val (n+m)) f
|
||||
| f (Add (Val n) g) := addAux (Val n) (addAux f g)
|
||||
| (Add f g) h := addAux f (addAux g h)
|
||||
| f g := Add f g
|
||||
|
||||
def add (a b : Expr) : Expr :=
|
||||
-- dbgTrace (">> add (" ++ toString a ++ ", " ++ toString b ++ ")") $ λ _,
|
||||
addAux a b
|
||||
|
||||
-- set_option trace.compiler.borrowed_inference true
|
||||
|
||||
partial def mulAux : Expr → Expr → Expr
|
||||
| (Val n) (Val m) := Val (n*m)
|
||||
| (Val 0) _ := Val 0
|
||||
| _ (Val 0) := Val 0
|
||||
| (Val 1) f := f
|
||||
| f (Val 1) := f
|
||||
| f (Val n) := mulAux (Val n) f
|
||||
| (Val n) (Mul (Val m) f) := mulAux (Val (n*m)) f
|
||||
| f (Mul (Val n) g) := mulAux (Val n) (mulAux f g)
|
||||
| (Mul f g) h := mulAux f (mulAux g h)
|
||||
| f g := Mul f g
|
||||
|
||||
def mul (a b : Expr) : Expr :=
|
||||
-- dbgTrace (">> mul (" ++ toString a ++ ", " ++ toString b ++ ")") $ λ _,
|
||||
mulAux a b
|
||||
|
||||
def pow : Expr → Expr → Expr
|
||||
| (Val m) (Val n) := Val (pown m n)
|
||||
| _ (Val 0) := Val 1
|
||||
| f (Val 1) := f
|
||||
| (Val 0) _ := Val 0
|
||||
| f g := Pow f g
|
||||
|
||||
def ln : Expr → Expr
|
||||
| (Val 1) := Val 0
|
||||
| f := Ln f
|
||||
|
||||
def d (x : String) : Expr → Expr
|
||||
| (Val _) := Val 0
|
||||
| (Var y) := if x = y then Val 1 else Val 0
|
||||
| (Add f g) := add (d f) (d g)
|
||||
| (Mul f g) :=
|
||||
-- dbgTrace (">> d (" ++ toString f ++ ", " ++ toString g ++ ")") $ λ _,
|
||||
add (mul f (d g)) (mul g (d f))
|
||||
| (Pow f g) := mul (pow f g) (add (mul (mul g (d f)) (pow f (Val (-1)))) (mul (ln f) (d g)))
|
||||
| (Ln f) := mul (d f) (pow f (Val (-1)))
|
||||
|
||||
def count : Expr → Nat
|
||||
| (Val _) := 1
|
||||
| (Var _) := 1
|
||||
| (Add f g) := count f + count g
|
||||
| (Mul f g) := count f + count g
|
||||
| (Pow f g) := count f + count g
|
||||
| (Ln f) := count f
|
||||
|
||||
|
||||
def nestAux (s : Nat) (f : Nat → Expr → IO Expr) : Nat → Expr → IO Expr
|
||||
| 0 x := pure x
|
||||
| m@(n+1) x := f (s - m) x >>= nestAux n
|
||||
|
||||
def nest (f : Nat → Expr → IO Expr) (n : Nat) (e : Expr) : IO Expr :=
|
||||
nestAux n f n e
|
||||
|
||||
def deriv (i : Nat) (f : Expr) : IO Expr :=
|
||||
do
|
||||
let d := d "x" f,
|
||||
IO.println (toString (i+1) ++ " count: " ++ (toString $ count d)),
|
||||
IO.println (toString d),
|
||||
pure d
|
||||
|
||||
def main (xs : List String) : IO UInt32 :=
|
||||
do let x := Var "x",
|
||||
let f := add x (mul x (mul x (add x x))),
|
||||
IO.println f,
|
||||
nest deriv 3 f,
|
||||
pure 0
|
||||
|
||||
-- setOption profiler True
|
||||
-- #eval main []
|
||||
7
tests/compiler/t4.lean.expected.out
Normal file
7
tests/compiler/t4.lean.expected.out
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
(x + (x * (x * (x + x))))
|
||||
1 count: 9
|
||||
(1 + ((x * ((2 * x) + (x + x))) + (x * (x + x))))
|
||||
2 count: 10
|
||||
((4 * x) + ((2 * x) + (x + (x + ((2 * x) + (x + x))))))
|
||||
3 count: 1
|
||||
12
|
||||
Loading…
Add table
Reference in a new issue