fix: if inlined code returns a function and has more than one exit point, create an auxiliary function instead of a join point that takes a closure as argument
This commit is contained in:
parent
53b995386d
commit
4d9483d1fa
1 changed files with 24 additions and 18 deletions
|
|
@ -136,35 +136,41 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
|
|||
simp code
|
||||
else
|
||||
let code ← simp code
|
||||
let simpK (result : FVarId) : SimpM Code := do
|
||||
/- `result` contains the result of the inlined code -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar result) info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId result
|
||||
simp k
|
||||
if oneExitPointQuick code then
|
||||
-- TODO: if `k` is small, we should also inline it here
|
||||
markSimplified
|
||||
code.bind fun fvarId' => do
|
||||
markUsedFVar fvarId'
|
||||
/- fvarId' is the result of the computation -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId fvarId'
|
||||
simp k
|
||||
simpK fvarId'
|
||||
-- else if info.ifReduce then
|
||||
-- eraseCode code
|
||||
-- return none
|
||||
else
|
||||
markSimplified
|
||||
let jpParam ← mkAuxParam expectedType
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
if expectedType.headBeta.isForall then
|
||||
/-
|
||||
If `code` returns a function, we create an auxiliary local function declaration (and eta-expand it)
|
||||
instead of creating a joinpoint that takes a closure as an argument.
|
||||
-/
|
||||
let auxFunDecl ← mkAuxFunDecl #[] code
|
||||
let auxFunDecl ← auxFunDecl.etaExpand
|
||||
let k ← simpK auxFunDecl.fvarId
|
||||
attachCodeDecls #[.fun auxFunDecl] k
|
||||
else
|
||||
addFVarSubst fvarId jpParam.fvarId
|
||||
simp k
|
||||
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
|
||||
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
return Code.jp jpDecl code
|
||||
let jpParam ← mkAuxParam expectedType
|
||||
let jpValue ← simpK jpParam.fvarId
|
||||
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
|
||||
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
return Code.jp jpDecl code
|
||||
|
||||
/--
|
||||
Simplify the given local function declaration.
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue