fix(library/init/lean/compiler/ir/emitcpp): tail call
Implement fix used at 4d2837430a in the new IR compiler.
This commit is contained in:
parent
4d2837430a
commit
ef89945ea0
2 changed files with 58 additions and 7 deletions
|
|
@ -700,3 +700,16 @@ protected def max (n m : Nat) : Nat :=
|
|||
if n ≤ m then m else n
|
||||
|
||||
end Nat
|
||||
|
||||
namespace Prod
|
||||
|
||||
@[inline] def foldI {α : Type u} (f : Nat → α → α) (i : Nat × Nat) (a : α) : α :=
|
||||
Nat.foldAux f i.2 (i.2 - i.1) a
|
||||
|
||||
@[inline] def anyI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
Nat.anyAux f i.2 (i.2 - i.1)
|
||||
|
||||
@[inline] def allI (f : Nat → Bool) (i : Nat × Nat) : Bool :=
|
||||
!Nat.anyAux (λ a, !f a) i.2 (i.2 - i.1)
|
||||
|
||||
end Prod
|
||||
|
|
|
|||
|
|
@ -530,18 +530,56 @@ match v, b with
|
|||
| Expr.fap f _, FnBody.ret (Arg.var y) := pure $ f == ctx.mainFn && x == y
|
||||
| _, _ := pure false
|
||||
|
||||
def paramEqArg (p : Param) (x : Arg) : Bool :=
|
||||
match x with
|
||||
| Arg.var x := p.x == x
|
||||
| _ := false
|
||||
|
||||
/-
|
||||
Given `[p_0, ..., p_{n-1}]`, `[y_0, ..., y_{n-1}]`, representing the assignments
|
||||
```
|
||||
p_0 := y_0,
|
||||
...
|
||||
p_{n-1} := y_{n-1}
|
||||
```
|
||||
Return true iff we have `(i, j)` where `j > i`, and `y_j == p_i`.
|
||||
That is, we have
|
||||
```
|
||||
p_i := y_i,
|
||||
...
|
||||
p_j := p_i, -- p_i was overwritten above
|
||||
```
|
||||
-/
|
||||
def overwriteParam (ps : Array Param) (ys : Array Arg) : Bool :=
|
||||
let n := ps.size in
|
||||
n.any $ λ i,
|
||||
let p := ps.get i in
|
||||
(i+1, n).anyI $ λ j, paramEqArg p (ys.get j)
|
||||
|
||||
def emitTailCall (v : Expr) : M Unit :=
|
||||
match v with
|
||||
| Expr.fap _ ys := do
|
||||
ctx ← read,
|
||||
let ps := ctx.mainParams,
|
||||
unless (ys.size == ps.size) (throw "invalid tail call"),
|
||||
ys.size.mfor $ λ i, do {
|
||||
let p := ps.get i,
|
||||
let y := ys.get i,
|
||||
match y with
|
||||
| Arg.irrelevant := pure ()
|
||||
| Arg.var y := unless (p.x == y) (do emit p.x, emit " = ", emit y, emitLn ";")
|
||||
unless (ps.size == ys.size) (throw "invalid tail call"),
|
||||
if overwriteParam ps ys then do {
|
||||
emitLn "{",
|
||||
ps.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
unless (paramEqArg p y) $ do {
|
||||
emit (toCppType p.ty), emit " _tmp_", emit i, emit " = ", emitArg y, emitLn ";"
|
||||
}
|
||||
},
|
||||
ps.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
unless (paramEqArg p y) (do emit p.x, emit " = _tmp_", emit i, emitLn ";")
|
||||
},
|
||||
emitLn "}"
|
||||
} else do {
|
||||
ys.size.mfor $ λ i, do {
|
||||
let p := ps.get i, let y := ys.get i,
|
||||
unless (paramEqArg p y) (do emit p.x, emit " = ", emitArg y, emitLn ";")
|
||||
}
|
||||
},
|
||||
emitLn "goto _start;"
|
||||
| _ := throw "bug at emitTailCall"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue