fix(library/init/lean/compiler/ir/emitcpp): tail call

Implement fix used at 4d2837430a in the new IR compiler.
This commit is contained in:
Leonardo de Moura 2019-05-22 07:57:24 -07:00
parent 4d2837430a
commit ef89945ea0
2 changed files with 58 additions and 7 deletions

View file

@ -700,3 +700,16 @@ protected def max (n m : Nat) : Nat :=
if n ≤ m then m else n
end Nat
namespace Prod
@[inline] def foldI {α : Type u} (f : Nat → αα) (i : Nat × Nat) (a : α) : α :=
Nat.foldAux f i.2 (i.2 - i.1) a
@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
Nat.anyAux f i.2 (i.2 - i.1)
@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
!Nat.anyAux (λ a, !f a) i.2 (i.2 - i.1)
end Prod

View file

@ -530,18 +530,56 @@ match v, b with
| Expr.fap f _, FnBody.ret (Arg.var y) := pure $ f == ctx.mainFn && x == y
| _, _ := pure false
def paramEqArg (p : Param) (x : Arg) : Bool :=
match x with
| Arg.var x := p.x == x
| _ := false
/-
Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments
```
p_0 := y_0,
...
p_{n-1} := y_{n-1}
```
Return true iff we have `(i, j)` where `j > i`, and `y_j == p_i`.
That is, we have
```
p_i := y_i,
...
p_j := p_i, -- p_i was overwritten above
```
-/
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
let n := ps.size in
n.any $ λ i,
let p := ps.get i in
(i+1, n).anyI $ λ j, paramEqArg p (ys.get j)
def emitTailCall (v : Expr) : M Unit :=
match v with
| Expr.fap _ ys := do
ctx ← read,
let ps := ctx.mainParams,
unless (ys.size == ps.size) (throw "invalid tail call"),
ys.size.mfor $ λ i, do {
let p := ps.get i,
let y := ys.get i,
match y with
| Arg.irrelevant := pure ()
| Arg.var y := unless (p.x == y) (do emit p.x, emit " = ", emit y, emitLn ";")
unless (ps.size == ys.size) (throw "invalid tail call"),
if overwriteParam ps ys then do {
emitLn "{",
ps.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
unless (paramEqArg p y) $ do {
emit (toCppType p.ty), emit " _tmp_", emit i, emit " = ", emitArg y, emitLn ";"
}
},
ps.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
unless (paramEqArg p y) (do emit p.x, emit " = _tmp_", emit i, emitLn ";")
},
emitLn "}"
} else do {
ys.size.mfor $ λ i, do {
let p := ps.get i, let y := ys.get i,
unless (paramEqArg p y) (do emit p.x, emit " = ", emitArg y, emitLn ";")
}
},
emitLn "goto _start;"
| _ := throw "bug at emitTailCall"