From ef89945ea0fd0ac2165cdbaa46c8d420d51e23b5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 22 May 2019 07:57:24 -0700 Subject: [PATCH] fix(library/init/lean/compiler/ir/emitcpp): tail call Implement fix used at 4d2837430acd in the new IR compiler. --- library/init/data/nat/basic.lean | 13 ++++++ library/init/lean/compiler/ir/emitcpp.lean | 52 +++++++++++++++++++--- 2 files changed, 58 insertions(+), 7 deletions(-) diff --git a/library/init/data/nat/basic.lean b/library/init/data/nat/basic.lean index 19e2475f4c..df7c957fc0 100644 --- a/library/init/data/nat/basic.lean +++ b/library/init/data/nat/basic.lean @@ -700,3 +700,16 @@ protected def max (n m : Nat) : Nat := if n ≤ m then m else n end Nat + +namespace Prod + +@[inline] def foldI {α : Type u} (f : Nat → α → α) (i : Nat × Nat) (a : α) : α := +Nat.foldAux f i.2 (i.2 - i.1) a + +@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool := +Nat.anyAux f i.2 (i.2 - i.1) + +@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool := +!Nat.anyAux (λ a, !f a) i.2 (i.2 - i.1) + +end Prod diff --git a/library/init/lean/compiler/ir/emitcpp.lean b/library/init/lean/compiler/ir/emitcpp.lean index 64ea852dcd..1a6ec0fedf 100644 --- a/library/init/lean/compiler/ir/emitcpp.lean +++ b/library/init/lean/compiler/ir/emitcpp.lean @@ -530,18 +530,56 @@ match v, b with | Expr.fap f _, FnBody.ret (Arg.var y) := pure $ f == ctx.mainFn && x == y | _, _ := pure false +def paramEqArg (p : Param) (x : Arg) : Bool := +match x with +| Arg.var x := p.x == x +| _ := false + +/- +Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments +``` +p_0 := y_0, +... +p_{n-1} := y_{n-1} +``` +Return true iff we have `(i, j)` where `j > i`, and `y_j == p_i`. +That is, we have +``` + p_i := y_i, + ... + p_j := p_i, -- p_i was overwritten above +``` +-/ +def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool := +let n := ps.size in +n.any $ λ i, + let p := ps.get i in + (i+1, n).anyI $ λ j, paramEqArg p (ys.get j) + def emitTailCall (v : Expr) : M Unit := match v with | Expr.fap _ ys := do ctx ← read, let ps := ctx.mainParams, - unless (ys.size == ps.size) (throw "invalid tail call"), - ys.size.mfor $ λ i, do { - let p := ps.get i, - let y := ys.get i, - match y with - | Arg.irrelevant := pure () - | Arg.var y := unless (p.x == y) (do emit p.x, emit " = ", emit y, emitLn ";") + unless (ps.size == ys.size) (throw "invalid tail call"), + if overwriteParam ps ys then do { + emitLn "{", + ps.size.mfor $ λ i, do { + let p := ps.get i, let y := ys.get i, + unless (paramEqArg p y) $ do { + emit (toCppType p.ty), emit " _tmp_", emit i, emit " = ", emitArg y, emitLn ";" + } + }, + ps.size.mfor $ λ i, do { + let p := ps.get i, let y := ys.get i, + unless (paramEqArg p y) (do emit p.x, emit " = _tmp_", emit i, emitLn ";") + }, + emitLn "}" + } else do { + ys.size.mfor $ λ i, do { + let p := ps.get i, let y := ys.get i, + unless (paramEqArg p y) (do emit p.x, emit " = ", emitArg y, emitLn ";") + } }, emitLn "goto _start;" | _ := throw "bug at emitTailCall"