fix: exponential blowup at LCNF simp
This commit is contained in:
parent
a5ac950b54
commit
727ee79f05
2 changed files with 102 additions and 79 deletions
|
|
@ -326,12 +326,14 @@ Result of `inlineCandidate?`.
|
|||
It contains information for inlining local and global functions.
|
||||
-/
|
||||
structure InlineCandidateInfo where
|
||||
isLocal : Bool
|
||||
params : Array Param
|
||||
isLocal : Bool
|
||||
params : Array Param
|
||||
/-- Value (lambda expression) of the function to be inlined. -/
|
||||
value : Code
|
||||
f : Expr
|
||||
args : Array Expr
|
||||
value : Code
|
||||
f : Expr
|
||||
args : Array Expr
|
||||
/-- `ifReduce = true` if the declaration being inlined was tagged with `inlineIfReduce`. -/
|
||||
ifReduce : Bool
|
||||
|
||||
/-- The arity (aka number of parameters) of the function to be inlined. -/
|
||||
def InlineCandidateInfo.arity : InlineCandidateInfo → Nat
|
||||
|
|
@ -383,15 +385,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
|
|||
if inlineIfReduce then
|
||||
let some paramIdx := isCasesOnParam? decl | return none
|
||||
unless paramIdx < numArgs do return none
|
||||
let arg ← findCtor (e.getArg! paramIdx)
|
||||
let arg ← findExpr (e.getArg! paramIdx)
|
||||
unless arg.isConstructorApp (← getEnv) do return none
|
||||
let params := decl.instantiateParamsLevelParams us
|
||||
let value := decl.instantiateValueLevelParams us
|
||||
incInline
|
||||
return some {
|
||||
isLocal := false
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
isLocal := false
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
ifReduce := inlineIfReduce
|
||||
params, value
|
||||
}
|
||||
else if let some decl ← findFunDecl? f then
|
||||
|
|
@ -401,22 +404,16 @@ def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
|
|||
incInlineLocal
|
||||
modify fun s => { s with inlineLocal := s.inlineLocal + 1 }
|
||||
return some {
|
||||
isLocal := true
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
params := decl.params
|
||||
value := decl.value
|
||||
isLocal := true
|
||||
f := e.getAppFn
|
||||
args := e.getAppArgs
|
||||
params := decl.params
|
||||
value := decl.value
|
||||
ifReduce := false
|
||||
}
|
||||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Add substitution `fvarId ↦ val`. `val` is a free variable, or
|
||||
it is a type, type former, or `lcErased`.
|
||||
-/
|
||||
def addSubst (fvarId : FVarId) (val : Expr) : SimpM Unit :=
|
||||
modify fun s => { s with subst := s.subst.insert fvarId val }
|
||||
|
||||
/--
|
||||
Return `true` if `c` has only one exit point.
|
||||
This is a quick approximation. It does not check cases
|
||||
|
|
@ -431,7 +428,7 @@ where
|
|||
match c with
|
||||
| .let _ k | .fun _ k => go k
|
||||
-- Approximation, the cases may have many unreachable alternatives, and only reachable.
|
||||
| .cases c => c.alts.size == 1 && c.alts.any fun alt => go alt.getCode
|
||||
| .cases c => c.alts.size == 1 && go c.alts[0]!.getCode
|
||||
-- Approximation, we assume that any code containing join points have more than one exit point
|
||||
| .jp .. | .jmp .. => false
|
||||
| .return .. | .unreach .. => true
|
||||
|
|
@ -468,56 +465,6 @@ def specializePartialApp (info : InlineCandidateInfo) : SimpM FunDecl := do
|
|||
updateFunDeclInfo code
|
||||
mkAuxFunDecl paramsNew code
|
||||
|
||||
/--
|
||||
If the value of the given let-declaration is an application that can be inlined, inline it.
|
||||
|
||||
`k` is the "continuation" for the let declaration.
|
||||
-/
|
||||
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
|
||||
if k matches .unreach .. then return some k
|
||||
let some info ← inlineCandidate? letDecl.value | return none
|
||||
markSimplified
|
||||
let numArgs := info.args.size
|
||||
trace[Compiler.simp.inline] "inlining {letDecl.value}"
|
||||
let fvarId := letDecl.fvarId
|
||||
if numArgs < info.arity then
|
||||
let funDecl ← specializePartialApp info
|
||||
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
|
||||
return some (.fun funDecl k)
|
||||
else
|
||||
let code ← betaReduce info.params info.value info.args[:info.arity]
|
||||
if k.isReturnOf fvarId && numArgs == info.arity then
|
||||
/- Easy case, the continuation `k` is just returning the result of the application. -/
|
||||
return code
|
||||
else if oneExitPointQuick code then
|
||||
/-
|
||||
`code` has only one exit point, thus we can attach the continuation directly there,
|
||||
and simplify the result.
|
||||
-/
|
||||
code.bind fun fvarId' => do
|
||||
/- fvarId' is the result of the computation -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
|
||||
let k ← replaceFVar k fvarId decl.fvarId
|
||||
return .let decl k
|
||||
else
|
||||
replaceFVar k fvarId fvarId'
|
||||
else
|
||||
/-
|
||||
`code` has multiple exit points, and the continuation is non-trivial
|
||||
Thus, we create an auxiliary join point.
|
||||
-/
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
let k ← replaceFVar k fvarId decl.fvarId
|
||||
pure <| .let decl k
|
||||
else
|
||||
replaceFVar k fvarId jpParam.fvarId
|
||||
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
|
||||
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
return Code.jp jpDecl code
|
||||
|
||||
/--
|
||||
Try to inline a join point.
|
||||
-/
|
||||
|
|
@ -868,11 +815,83 @@ def etaPolyApp? (letDecl : LetDecl) : OptionT SimpM FunDecl := do
|
|||
let value := mkAppN letDecl.value (params.map (.fvar ·.fvarId))
|
||||
let auxDecl ← mkAuxLetDecl value
|
||||
let funDecl ← mkAuxFunDecl params (.let auxDecl (.return auxDecl.fvarId))
|
||||
addSubst letDecl.fvarId (.fvar funDecl.fvarId)
|
||||
addFVarSubst letDecl.fvarId funDecl.fvarId
|
||||
eraseLetDecl letDecl
|
||||
return funDecl
|
||||
|
||||
/--
|
||||
Similar to `Code.isReturnOf`, but taking the current substitution into account.
|
||||
-/
|
||||
def isReturnOf (c : Code) (fvarId : FVarId) : SimpM Bool := do
|
||||
match c with
|
||||
| .return fvarId' => return (← normFVar fvarId') == fvarId
|
||||
| _ => return false
|
||||
|
||||
mutual
|
||||
/--
|
||||
If the value of the given let-declaration is an application that can be inlined,
|
||||
inline it and simplify the result.
|
||||
|
||||
`k` is the "continuation" for the let declaration, if the application is inlined,
|
||||
it will also be simplified.
|
||||
|
||||
Note: `inlineApp?` did not use to be in this mutually recursive declaration.
|
||||
It used to be invoked by `simp`, and would return `Option Code` that would be
|
||||
then simplified by `simp`. However, this simpler architecture produced an
|
||||
exponential blowup in when processing functions such as `Lean.Elab.Deriving.Ord.mkMatch.mkAlts`.
|
||||
The key problem is that when inlining a declaration we often can reduce the number
|
||||
of exit points by simplified the inlined code, and then connecting the result to the
|
||||
continuation `k`. However, this optimization is only possible if we simplify the
|
||||
inlined code **before** we attach it to the continuation.
|
||||
-/
|
||||
partial def inlineApp? (letDecl : LetDecl) (k : Code) : SimpM (Option Code) := do
|
||||
let some info ← inlineCandidate? letDecl.value | return none
|
||||
let numArgs := info.args.size
|
||||
trace[Compiler.simp.inline] "inlining {letDecl.value}"
|
||||
let fvarId := letDecl.fvarId
|
||||
if numArgs < info.arity then
|
||||
let funDecl ← specializePartialApp info
|
||||
addFVarSubst fvarId funDecl.fvarId
|
||||
markSimplified
|
||||
simp (.fun funDecl k)
|
||||
else
|
||||
let code ← betaReduce info.params info.value info.args[:info.arity]
|
||||
if k.isReturnOf fvarId && numArgs == info.arity then
|
||||
/- Easy case, the continuation `k` is just returning the result of the application. -/
|
||||
markSimplified
|
||||
simp code
|
||||
else
|
||||
let code ← simp code
|
||||
if oneExitPointQuick code then
|
||||
-- TODO: if `k` is small, we should also inline it here
|
||||
markSimplified
|
||||
code.bind fun fvarId' => do
|
||||
markUsedFVar fvarId'
|
||||
/- fvarId' is the result of the computation -/
|
||||
if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar fvarId') info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId fvarId'
|
||||
simp k
|
||||
-- else if info.ifReduce then
|
||||
-- eraseCode code
|
||||
-- return none
|
||||
else
|
||||
markSimplified
|
||||
let jpParam ← mkAuxParam (← inferType (mkAppN info.f info.args[:info.arity]))
|
||||
let jpValue ← if numArgs > info.arity then
|
||||
let decl ← mkAuxLetDecl (mkAppN (.fvar jpParam.fvarId) info.args[info.arity:])
|
||||
addFVarSubst fvarId decl.fvarId
|
||||
simp (.let decl k)
|
||||
else
|
||||
addFVarSubst fvarId jpParam.fvarId
|
||||
simp k
|
||||
let jpDecl ← mkAuxJpDecl #[jpParam] jpValue
|
||||
let code ← code.bind fun fvarId => return .jmp jpDecl.fvarId #[.fvar fvarId]
|
||||
return Code.jp jpDecl code
|
||||
|
||||
/--
|
||||
Simplify the given local function declaration.
|
||||
-/
|
||||
|
|
@ -904,11 +923,11 @@ partial def simpCasesOnCtor? (cases : Cases) : SimpM (Option Code) := do
|
|||
To make the code robust, we add auxiliary declarations whenever the `field` is not a free variable.
|
||||
-/
|
||||
if field.isFVar then
|
||||
addSubst param.fvarId field
|
||||
addFVarSubst param.fvarId field.fvarId!
|
||||
else
|
||||
let auxDecl ← mkAuxLetDecl field
|
||||
auxDecls := auxDecls.push (CodeDecl.let auxDecl)
|
||||
addSubst param.fvarId (.fvar auxDecl.fvarId)
|
||||
addFVarSubst param.fvarId auxDecl.fvarId
|
||||
let k ← simp k
|
||||
eraseParams params
|
||||
attachCodeDecls auxDecls k
|
||||
|
|
@ -927,14 +946,14 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
simp (.fun funDecl k)
|
||||
else if decl.value.isFVar then
|
||||
/- Eliminate `let _x_i := _x_j;` -/
|
||||
addSubst decl.fvarId decl.value
|
||||
addFVarSubst decl.fvarId decl.value.fvarId!
|
||||
eraseLetDecl decl
|
||||
simp k
|
||||
else if let some code ← inlineApp? decl k then
|
||||
eraseLetDecl decl
|
||||
simp code
|
||||
return code
|
||||
else if let some (decls, fvarId) ← inlineProjInst? decl.value then
|
||||
addSubst decl.fvarId (.fvar fvarId)
|
||||
addFVarSubst decl.fvarId fvarId
|
||||
eraseLetDecl decl
|
||||
let k ← simp k
|
||||
attachCodeDecls decls k
|
||||
|
|
@ -959,7 +978,7 @@ partial def simp (code : Code) : SimpM Code := withIncRecDepth do
|
|||
else
|
||||
/-
|
||||
Note that functions in `decl` will be marked as used even if `decl` is not actually used.
|
||||
They will only be deleted in the next pass.
|
||||
They will only be deleted in the next pass. TODO: investigate whether this is a problem.
|
||||
-/
|
||||
if code.isFun then
|
||||
if decl.isEtaExpandCandidate then
|
||||
|
|
|
|||
4
tests/lean/run/simpExpBlowup.lean
Normal file
4
tests/lean/run/simpExpBlowup.lean
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
import Lean
|
||||
|
||||
set_option trace.Compiler.result true
|
||||
#eval Lean.Compiler.compile #[``Lean.Elab.Deriving.Ord.mkMatch.mkAlts]
|
||||
Loading…
Add table
Reference in a new issue