From 4d2837430acde4e2e6fa3634da1581126e8eabad Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 21 May 2019 23:07:10 -0700 Subject: [PATCH] fix(library/compiler/emit_cpp): tail call Add temporary hack to fix `emit_tail_call`. TODO: find a cleaner solution for the new IR compiler. --- src/library/compiler/emit_cpp.cpp | 36 ++++++++- tests/compiler/t4.lean | 114 ++++++++++++++++++++++++++++ tests/compiler/t4.lean.expected.out | 7 ++ 3 files changed, 153 insertions(+), 4 deletions(-) create mode 100644 tests/compiler/t4.lean create mode 100644 tests/compiler/t4.lean.expected.out diff --git a/src/library/compiler/emit_cpp.cpp b/src/library/compiler/emit_cpp.cpp index e484363aab..b90ab372e2 100644 --- a/src/library/compiler/emit_cpp.cpp +++ b/src/library/compiler/emit_cpp.cpp @@ -752,17 +752,45 @@ struct emit_fn_fn { return none_expr(); } + bool overwrite_param(buffer const & args) { + lean_assert(args.size() == m_fn_args.size()); + for (unsigned i = 0; i < m_fn_args.size(); i++) { + expr p = m_fn_args[i]; + for (unsigned j = i+1; j < args.size(); j++) { + if (args[j] == p) + return true; + } + } + return false; + } + void emit_tail_call(expr const & e) { buffer args; expr fn = get_app_args(e, args); lean_assert(is_constant(fn) && const_name(fn) == m_fn_name); lean_assert(args.size() == m_fn_args.size()); - for (unsigned i = 0; i < args.size(); i++) { - if (args[i] != m_fn_args[i]) { - emit_fvar(m_fn_args[i]); m_out << " = "; emit_arg(args[i]); m_out << ";\n"; + if (overwrite_param(args)) { + m_out << "{\n"; + for (unsigned i = 0; i < args.size(); i++) { + if (args[i] != m_fn_args[i]) { + m_out << "auto y_" << i << " = "; emit_arg(args[i]); m_out << ";\n"; + } } + for (unsigned i = 0; i < args.size(); i++) { + if (args[i] != m_fn_args[i]) { + emit_fvar(m_fn_args[i]); m_out << " = y_" << i << ";\n"; + } + } + m_out << "}\n"; + m_out << "goto _start;\n"; + } else { + for (unsigned i = 0; i < args.size(); i++) { + if (args[i] != m_fn_args[i]) { + emit_fvar(m_fn_args[i]); m_out << " = "; emit_arg(args[i]); m_out << ";\n"; + } + } + m_out << "goto _start;\n"; } - m_out << "goto _start;\n"; } void emit_terminal(expr const & e, bool tail_call) { diff --git a/tests/compiler/t4.lean b/tests/compiler/t4.lean new file mode 100644 index 0000000000..e79fe5e8a2 --- /dev/null +++ b/tests/compiler/t4.lean @@ -0,0 +1,114 @@ +/- Benchmark for new code generator -/ +inductive Expr +| Val : Int → Expr +| Var : String → Expr +| Add : Expr → Expr → Expr +| Mul : Expr → Expr → Expr +| Pow : Expr → Expr → Expr +| Ln : Expr → Expr + +open Expr + +def Expr.toString : Expr → String +| (Val n) := toString n +| (Var x) := x +| (Add f g) := "(" ++ Expr.toString f ++ " + " ++ Expr.toString g ++ ")" +| (Mul f g) := "(" ++ Expr.toString f ++ " * " ++ Expr.toString g ++ ")" +| (Pow f g) := "(" ++ Expr.toString f ++ " ^ " ++ Expr.toString g ++ ")" +| (Ln f) := "ln(" ++ Expr.toString f ++ ")" + +instance : HasToString Expr := +⟨Expr.toString⟩ + +partial def pown : Int → Int → Int +| a 0 := 1 +| a 1 := a +| a n := + let b := pown a (n / 2) in + b * b * (if n % 2 = 0 then 1 else a) + +partial def addAux : Expr → Expr → Expr +| (Val n) (Val m) := Val (n + m) +| (Val 0) f := f +| f (Val 0) := f +| f (Val n) := addAux (Val n) f +| (Val n) (Add (Val m) f) := addAux (Val (n+m)) f +| f (Add (Val n) g) := addAux (Val n) (addAux f g) +| (Add f g) h := addAux f (addAux g h) +| f g := Add f g + +def add (a b : Expr) : Expr := +-- dbgTrace (">> add (" ++ toString a ++ ", " ++ toString b ++ ")") $ λ _, +addAux a b + +-- set_option trace.compiler.borrowed_inference true + +partial def mulAux : Expr → Expr → Expr +| (Val n) (Val m) := Val (n*m) +| (Val 0) _ := Val 0 +| _ (Val 0) := Val 0 +| (Val 1) f := f +| f (Val 1) := f +| f (Val n) := mulAux (Val n) f +| (Val n) (Mul (Val m) f) := mulAux (Val (n*m)) f +| f (Mul (Val n) g) := mulAux (Val n) (mulAux f g) +| (Mul f g) h := mulAux f (mulAux g h) +| f g := Mul f g + +def mul (a b : Expr) : Expr := +-- dbgTrace (">> mul (" ++ toString a ++ ", " ++ toString b ++ ")") $ λ _, +mulAux a b + +def pow : Expr → Expr → Expr +| (Val m) (Val n) := Val (pown m n) +| _ (Val 0) := Val 1 +| f (Val 1) := f +| (Val 0) _ := Val 0 +| f g := Pow f g + +def ln : Expr → Expr +| (Val 1) := Val 0 +| f := Ln f + +def d (x : String) : Expr → Expr +| (Val _) := Val 0 +| (Var y) := if x = y then Val 1 else Val 0 +| (Add f g) := add (d f) (d g) +| (Mul f g) := + -- dbgTrace (">> d (" ++ toString f ++ ", " ++ toString g ++ ")") $ λ _, + add (mul f (d g)) (mul g (d f)) +| (Pow f g) := mul (pow f g) (add (mul (mul g (d f)) (pow f (Val (-1)))) (mul (ln f) (d g))) +| (Ln f) := mul (d f) (pow f (Val (-1))) + +def count : Expr → Nat +| (Val _) := 1 +| (Var _) := 1 +| (Add f g) := count f + count g +| (Mul f g) := count f + count g +| (Pow f g) := count f + count g +| (Ln f) := count f + + +def nestAux (s : Nat) (f : Nat → Expr → IO Expr) : Nat → Expr → IO Expr +| 0 x := pure x +| m@(n+1) x := f (s - m) x >>= nestAux n + +def nest (f : Nat → Expr → IO Expr) (n : Nat) (e : Expr) : IO Expr := +nestAux n f n e + +def deriv (i : Nat) (f : Expr) : IO Expr := +do + let d := d "x" f, + IO.println (toString (i+1) ++ " count: " ++ (toString $ count d)), + IO.println (toString d), + pure d + +def main (xs : List String) : IO UInt32 := +do let x := Var "x", + let f := add x (mul x (mul x (add x x))), + IO.println f, + nest deriv 3 f, + pure 0 + +-- setOption profiler True +-- #eval main [] diff --git a/tests/compiler/t4.lean.expected.out b/tests/compiler/t4.lean.expected.out new file mode 100644 index 0000000000..4522ca225b --- /dev/null +++ b/tests/compiler/t4.lean.expected.out @@ -0,0 +1,7 @@ +(x + (x * (x * (x + x)))) +1 count: 9 +(1 + ((x * ((2 * x) + (x + x))) + (x * (x + x)))) +2 count: 10 +((4 * x) + ((2 * x) + (x + (x + ((2 * x) + (x + x)))))) +3 count: 1 +12