fix: exponential blowup at LCNF simp

This commit is contained in:
Leonardo de Moura 2022-09-20 17:03:40 -07:00
parent a5ac950b54
commit 727ee79f05
2 changed files with 102 additions and 79 deletions

View file

@ -326,12 +326,14 @@ Result of `inlineCandidate?`.
It contains information for inlining local and global functions.
-/
structure InlineCandidateInfo where
isLocal : Bool
params : Array Param
isLocal : Bool
params : Array Param
/-- Value (lambda expression) of the function to be inlined. -/
value : Code
f : Expr
args : Array Expr
value : Code
f : Expr
args : Array Expr
/-- `ifReduce = true` if the declaration being inlined was tagged with `inlineIfReduce`. -/
ifReduce : Bool
/-- The arity (aka number of parameters) of the function to be inlined. -/
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
@ -383,15 +385,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
if inlineIfReduce then
let some paramIdx := isCasesOnParam? decl | return none
unless paramIdx < numArgs do return none
let arg ← findCtor (e.getArg! paramIdx)
let arg ← findExpr (e.getArg! paramIdx)
unless arg.isConstructorApp (← getEnv) do return none
let params := decl.instantiateParamsLevelParams us
let value := decl.instantiateValueLevelParams us
incInline
return some {
isLocal := false
f := e.getAppFn
args := e.getAppArgs
isLocal := false
f := e.getAppFn
args := e.getAppArgs
ifReduce := inlineIfReduce
params, value
}
else if let some decl ← findFunDecl? f then
@ -401,22 +404,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
incInlineLocal
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
return some {
isLocal := true
f := e.getAppFn
args := e.getAppArgs
params := decl.params
value := decl.value
isLocal := true
f := e.getAppFn
args := e.getAppArgs
params := decl.params
value := decl.value
ifReduce := false
}
else
return none
/--
Add substitution `fvarId ↦ val`. `val` is a free variable, or
it is a type, type former, or `lcErased`.
-/
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
modify fun s => { s with subst := s.subst.insert fvarId val }
/--
Return `true` if `c` has only one exit point.
This is a quick approximation. It does not check cases
@ -431,7 +428,7 @@ where
match c with
| .let _ k | .fun _ k => go k
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
| .cases c => c.alts.size == 1 && c.alts.any fun alt => go alt.getCode
| .cases c => c.alts.size == 1 && go c.alts[0]!.getCode
-- Approximation, we assume that any code containing join points have more than one exit point
| .jp .. | .jmp .. => false
| .return .. | .unreach .. => true
@ -468,56 +465,6 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
updateFunDeclInfo code
mkAuxFunDecl paramsNew code
/--
If the value of the given let-declaration is an application that can be inlined, inline it.
`k` is the "continuation" for the let declaration.
-/
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
if k matches .unreach .. then return some k
let some info ← inlineCandidate? letDecl.value | return none
markSimplified
let numArgs := info.args.size
trace[Compiler.simp.inline] "inlining {letDecl.value}"
let fvarId := letDecl.fvarId
if numArgs < info.arity then
let funDecl ← specializePartialApp info
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
return some (.fun funDecl k)
else
let code ← betaReduce info.params info.value info.args[:info.arity]
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
return code
else if oneExitPointQuick code then
/-
`code` has only one exit point, thus we can attach the continuation directly there,
and simplify the result.
-/
code.bind fun fvarId' => do
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
return .let decl k
else
replaceFVar k fvarId fvarId'
else
/-
`code` has multiple exit points, and the continuation is non-trivial
Thus, we create an auxiliary join point.
-/
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
let k ← replaceFVar k fvarId decl.fvarId
pure <| .let decl k
else
replaceFVar k fvarId jpParam.fvarId
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
/--
Try to inline a join point.
-/
@ -868,11 +815,83 @@ def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
let auxDecl ← mkAuxLetDecl value
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
addFVarSubst letDecl.fvarId funDecl.fvarId
eraseLetDecl letDecl
return funDecl
/--
Similar to `Code.isReturnOf`, but taking the current substitution into account.
-/
def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
match c with
| .return fvarId' => return (← normFVar fvarId') == fvarId
| _ => return false
mutual
/--
If the value of the given let-declaration is an application that can be inlined,
inline it and simplify the result.
`k` is the "continuation" for the let declaration, if the application is inlined,
it will also be simplified.
Note: `inlineApp?` did not use to be in this mutually recursive declaration.
It used to be invoked by `simp`, and would return `Option Code` that would be
then simplified by `simp`. However, this simpler architecture produced an
exponential blowup in when processing functions such as `Lean.Elab.Deriving.Ord.mkMatch.mkAlts`.
The key problem is that when inlining a declaration we often can reduce the number
of exit points by simplified the inlined code, and then connecting the result to the
continuation `k`. However, this optimization is only possible if we simplify the
inlined code **before** we attach it to the continuation.
-/
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
let some info ← inlineCandidate? letDecl.value | return none
let numArgs := info.args.size
trace[Compiler.simp.inline] "inlining {letDecl.value}"
let fvarId := letDecl.fvarId
if numArgs < info.arity then
let funDecl ← specializePartialApp info
addFVarSubst fvarId funDecl.fvarId
markSimplified
simp (.fun funDecl k)
else
let code ← betaReduce info.params info.value info.args[:info.arity]
if k.isReturnOf fvarId && numArgs == info.arity then
/- Easy case, the continuation `k` is just returning the result of the application. -/
markSimplified
simp code
else
let code ← simp code
if oneExitPointQuick code then
-- TODO: if `k` is small, we should also inline it here
markSimplified
code.bind fun fvarId' => do
markUsedFVar fvarId'
/- fvarId' is the result of the computation -/
if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
simp (.let decl k)
else
addFVarSubst fvarId fvarId'
simp k
-- else if info.ifReduce then
-- eraseCode code
-- return none
else
markSimplified
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
let jpValue ← if numArgs > info.arity then
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
addFVarSubst fvarId decl.fvarId
simp (.let decl k)
else
addFVarSubst fvarId jpParam.fvarId
simp k
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
return Code.jp jpDecl code
/--
Simplify the given local function declaration.
-/
@ -904,11 +923,11 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
To make the code robust, we add auxiliary declarations whenever the `field` is not a free variable.
-/
if field.isFVar then
addSubst param.fvarId field
addFVarSubst param.fvarId field.fvarId!
else
let auxDecl ← mkAuxLetDecl field
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
addSubst param.fvarId (.fvar auxDecl.fvarId)
addFVarSubst param.fvarId auxDecl.fvarId
let k ← simp k
eraseParams params
attachCodeDecls auxDecls k
@ -927,14 +946,14 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
simp (.fun funDecl k)
else if decl.value.isFVar then
/- Eliminate `let _x_i := _x_j;` -/
addSubst decl.fvarId decl.value
addFVarSubst decl.fvarId decl.value.fvarId!
eraseLetDecl decl
simp k
else if let some code ← inlineApp? decl k then
eraseLetDecl decl
simp code
return code
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
addSubst decl.fvarId (.fvar fvarId)
addFVarSubst decl.fvarId fvarId
eraseLetDecl decl
let k ← simp k
attachCodeDecls decls k
@ -959,7 +978,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
else
/-
Note that functions in `decl` will be marked as used even if `decl` is not actually used.
They will only be deleted in the next pass.
They will only be deleted in the next pass. TODO: investigate whether this is a problem.
-/
if code.isFun then
if decl.isEtaExpandCandidate then

View file

@ -0,0 +1,4 @@
import Lean
set_option trace.Compiler.result true
#eval Lean.Compiler.compile #[``Lean.Elab.Deriving.Ord.mkMatch.mkAlts]