feat: borrow inference: preserve mutual tail calls

Fixes #603
This commit is contained in:
Sebastian Ullrich 2021-08-05 09:12:59 +02:00 committed by Leonardo de Moura
parent 4cdfbde93b
commit 07d1735ea2
3 changed files with 45 additions and 5 deletions

View file

@ -137,6 +137,7 @@ def applyParamMap (decls : Array Decl) (map : ParamMap) : Array Decl :=
structure BorrowInfCtx where
env : Environment
decls : Array Decl -- block of mutually recursive functions
currFn : FunId := arbitrary -- Function being analyzed.
paramSet : IndexSet := {} -- Set of all function parameters in scope. This is used to implement the heuristic at `ownArgsUsingParams`
@ -259,8 +260,7 @@ def preserveTailCall (x : VarId) (v : Expr) (b : FnBody) : M Unit := do
let ctx ← read
match v, b with
| (Expr.fap g ys), (FnBody.ret (Arg.var z)) =>
if ctx.currFn == g && x == z then
-- dbgTrace ("preserveTailCall " ++ toString b) $ fun _ => do
if ctx.decls.any (·.name == g) && x == z then
let ps ← getParamInfo (ParamMap.Key.decl g)
ownParamsUsingArgs ys ps
| _, _ => pure ()
@ -300,13 +300,13 @@ partial def collectDecl : Decl → M Unit
else
pure ()
def collectDecls (decls : Array Decl) : M ParamMap := do
whileModifing (decls.forM collectDecl)
def collectDecls : M ParamMap := do
whileModifing ((← read).decls.forM collectDecl)
let s ← get
pure s.paramMap
def infer (env : Environment) (decls : Array Decl) : ParamMap :=
collectDecls decls { env := env } |>.run' { paramMap := mkInitParamMap env decls }
collectDecls { env, decls } |>.run' { paramMap := mkInitParamMap env decls }
end Borrow

10
tests/lean/603.lean Normal file
View file

@ -0,0 +1,10 @@
set_option trace.compiler.ir.result true
-- should be tail calls
mutual
partial def even (a : Nat) : Nat := if a == 0 then 1 else odd (a - 1)
partial def odd (a : Nat) : Nat := if a == 0 then 0 else even (a - 1)
end

View file

@ -0,0 +1,30 @@
[result]
def even (x_1 : obj) : obj :=
let x_2 : obj := 0;
let x_3 : u8 := Nat.decEq x_1 x_2;
case x_3 : u8 of
Bool.false →
let x_4 : obj := 1;
let x_5 : obj := Nat.sub x_1 x_4;
dec x_1;
let x_6 : obj := odd x_5;
ret x_6
Bool.true →
dec x_1;
let x_7 : obj := 1;
ret x_7
def odd (x_1 : obj) : obj :=
let x_2 : obj := 0;
let x_3 : u8 := Nat.decEq x_1 x_2;
case x_3 : u8 of
Bool.false →
let x_4 : obj := 1;
let x_5 : obj := Nat.sub x_1 x_4;
dec x_1;
let x_6 : obj := even x_5;
ret x_6
Bool.true →
dec x_1;
let x_7 : obj := 0;
ret x_7