From 07d1735ea2ddd7cabd179e6f7fb1865d8d77c022 Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Thu, 5 Aug 2021 09:12:59 +0200 Subject: [PATCH] feat: borrow inference: preserve mutual tail calls Fixes #603 --- src/Lean/Compiler/IR/Borrow.lean | 10 +++++----- tests/lean/603.lean | 10 ++++++++++ tests/lean/603.lean.expected.out | 30 ++++++++++++++++++++++++++++++ 3 files changed, 45 insertions(+), 5 deletions(-) create mode 100644 tests/lean/603.lean create mode 100644 tests/lean/603.lean.expected.out diff --git a/src/Lean/Compiler/IR/Borrow.lean b/src/Lean/Compiler/IR/Borrow.lean index 71c49ebc39..fb41bd53fe 100644 --- a/src/Lean/Compiler/IR/Borrow.lean +++ b/src/Lean/Compiler/IR/Borrow.lean @@ -137,6 +137,7 @@ def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl := structure BorrowInfCtx where env : Environment + decls : Array Decl -- block of mutually recursive functions currFn : FunId := arbitrary -- Function being analyzed. paramSet : IndexSet := {} -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams` @@ -259,8 +260,7 @@ def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do let ctx ← read match v, b with | (Expr.fap g ys), (FnBody.ret (Arg.var z)) => - if ctx.currFn == g && x == z then - -- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do + if ctx.decls.any (·.name == g) && x == z then let ps ← getParamInfo (ParamMap.Key.decl g) ownParamsUsingArgs ys ps | _, _ => pure () @@ -300,13 +300,13 @@ partial def collectDecl : Decl → M Unit else pure () -def collectDecls (decls : Array Decl) : M ParamMap := do - whileModifing (decls.forM collectDecl) +def collectDecls : M ParamMap := do + whileModifing ((← read).decls.forM collectDecl) let s ← get pure s.paramMap def infer (env : Environment) (decls : Array Decl) : ParamMap := - collectDecls decls { env := env } |>.run' { paramMap := mkInitParamMap env decls } + collectDecls { env, decls } |>.run' { paramMap := mkInitParamMap env decls } end Borrow diff --git a/tests/lean/603.lean b/tests/lean/603.lean new file mode 100644 index 0000000000..75b1f4a75e --- /dev/null +++ b/tests/lean/603.lean @@ -0,0 +1,10 @@ +set_option trace.compiler.ir.result true + +-- should be tail calls +mutual + + partial def even (a : Nat) : Nat := if a == 0 then 1 else odd (a - 1) + + partial def odd (a : Nat) : Nat := if a == 0 then 0 else even (a - 1) + +end diff --git a/tests/lean/603.lean.expected.out b/tests/lean/603.lean.expected.out new file mode 100644 index 0000000000..3ab1f82b4a --- /dev/null +++ b/tests/lean/603.lean.expected.out @@ -0,0 +1,30 @@ + +[result] +def even (x_1 : obj) : obj := + let x_2 : obj := 0; + let x_3 : u8 := Nat.decEq x_1 x_2; + case x_3 : u8 of + Bool.false → + let x_4 : obj := 1; + let x_5 : obj := Nat.sub x_1 x_4; + dec x_1; + let x_6 : obj := odd x_5; + ret x_6 + Bool.true → + dec x_1; + let x_7 : obj := 1; + ret x_7 +def odd (x_1 : obj) : obj := + let x_2 : obj := 0; + let x_3 : u8 := Nat.decEq x_1 x_2; + case x_3 : u8 of + Bool.false → + let x_4 : obj := 1; + let x_5 : obj := Nat.sub x_1 x_4; + dec x_1; + let x_6 : obj := even x_5; + ret x_6 + Bool.true → + dec x_1; + let x_7 : obj := 0; + ret x_7 \ No newline at end of file