fix: if inlined code returns a function and has more than one exit point, create an auxiliary function instead of a join point that takes a closure as argument

This commit is contained in:
Leonardo de Moura 2022-10-15 11:44:36 -07:00
parent 53b995386d
commit 4d9483d1fa

View file

@ -136,35 +136,41 @@ partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := d
simp code
else
let code ← simp code
let simpK (result : FVarId) : SimpM Code := do
/- `result` contains the result of the inlined code -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar result) info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
simp (.let decl k)
else
addFVarSubst fvarId result
simp k
if oneExitPointQuick code then
-- TODO: if `k` is small, we should also inline it here
markSimplified
code.bind fun fvarId' => do
markUsedFVar fvarId'
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
simp (.let decl k)
else
addFVarSubst fvarId fvarId'
simp k
simpK fvarId'
-- else if info.ifReduce then
-- eraseCode code
-- return none
else
markSimplified
let jpParam ← mkAuxParam expectedType
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
simp (.let decl k)
if expectedType.headBeta.isForall then
/-
If `code` returns a function, we create an auxiliary local function declaration (and eta-expand it)
instead of creating a joinpoint that takes a closure as an argument.
-/
let auxFunDecl ← mkAuxFunDecl #[] code
let auxFunDecl ← auxFunDecl.etaExpand
let k ← simpK auxFunDecl.fvarId
attachCodeDecls #[.fun auxFunDecl] k
else
addFVarSubst fvarId jpParam.fvarId
simp k
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
let jpParam ← mkAuxParam expectedType
let jpValue ← simpK jpParam.fvarId
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
/--
Simplify the given local function declaration.