feat: avoid generation of auxliary join points when inlining functions that have only one exit point

This commit is contained in:
Leonardo de Moura 2022-08-18 00:12:24 -07:00
parent 23be59b747
commit bf2c0bf5b7
2 changed files with 104 additions and 70 deletions

View file

@ -181,23 +181,41 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
def shouldInline (localDecl : LocalDecl) : SimpM Bool :=
return (← read).localInline && (← read).stats.shouldInline localDecl.userName
def inlineCandidate? (e : Expr) : SimpM (Option Nat) := do
structure InlineCandidateInfo where
isLocal : Bool
arity : Nat
/-- Value (lambda expression) of the function to be inlined. -/
value : Expr
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
let f := e.getAppFn
let arity ← match f with
| .const declName _ =>
unless hasInlineAttribute (← getEnv) declName do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
pure decl.getArity
| _ =>
match (← findLambda? f) with
| none => return none
| some localDecl =>
unless (← shouldInline localDecl) do return none
pure (getLambdaArity localDecl.value)
if e.getAppNumArgs < arity then return none
return e.getAppNumArgs - arity
match f with
| .const declName us =>
unless hasInlineAttribute (← getEnv) declName do return none
-- TODO: check whether function is recursive or not.
-- We can skip the test and store function inline so far.
let some decl ← getStage1Decl? declName | return none
let numArgs := e.getAppNumArgs
let arity := decl.getArity
if numArgs < arity then return none
return some {
arity
isLocal := false
value := decl.value.instantiateLevelParams decl.levelParams us
}
| _ =>
match (← findLambda? f) with
| none => return none
| some localDecl =>
unless (← shouldInline localDecl) do return none
let numArgs := e.getAppNumArgs
let arity := getLambdaArity localDecl.value
if numArgs < arity then return none
return some {
arity
isLocal := true
value := localDecl.value
}
/--
If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`,
@ -255,23 +273,6 @@ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do
saved.restore
return none
partial def inlineApp (e : Expr) (jp? : Option Expr := none) : SimpM Expr := do
let f := e.getAppFn
trace[Compiler.simp.inline] "inlining {e}"
let value ← match f with
| .const declName us =>
let some decl ← getStage1Decl? declName | unreachable!
pure <| decl.value.instantiateLevelParams decl.levelParams us
| _ =>
let some localDecl ← findLambda? f | unreachable!
pure localDecl.value
let args := e.getAppArgs
let value := value.beta args
let value ← attachOptJp value jp?
assert! !value.isLambda
markSimplified
withLocalInline (!f.isConst) do visitLet value
/--
If `e` is an application that can be inlined, inline it.
@ -280,45 +281,62 @@ that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instanti
is an expression without loose bound variables.
-/
partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do
let some numExtraArgs ← inlineCandidate? e | return none
let some info ← inlineCandidate? e | return none
let args := e.getAppArgs
if k?.isNone && numExtraArgs == 0 then
-- Easy case, there is not continuation and `e` is not over applied
inlineApp e
let numArgs := args.size
trace[Compiler.simp.inline] "inlining {e}"
markSimplified
withLocalInline info.isLocal do
if !(← manyExitPoints info.value) then
-- If `info.value` has only one exit point, we don't need to create a new join point
let value := info.value.beta args[:info.arity]
let value ← visitLet value #[]
trace[Meta.debug] "value: {value}"
match numArgs == info.arity, k? with
| true, none => return value
| false, none => return mkAppN (← mkAuxLetDecl value) args[info.arity:]
| true, some k => let x ← mkAuxLetDecl value; visitLet k (xs.push x)
| false, some k =>
let x ← mkAuxLetDecl value
let x ← mkAuxLetDecl (mkAppN x args[info.arity:])
visitLet k (xs.push x)
else
/-
There is a continuation `k` or `e` is over applied.
If `e` is over applied, the extra arguments act as continuation.
-/
let toInline := mkAppN e.getAppFn args[:args.size - numExtraArgs]
/-
`toInline` is the application that is going to be inline
We create a new join point
```
let jp := fun y =>
let x := y <extra-arguments> -- if `e` is over applied
k
```
Recall that `visitLet` incorporates the current continuation
to the new join point `jp`.
-/
let jpDomain ← inferType toInline
let binderName ← mkFreshUserName `_y
let jp ← withNewScope do
let y ← mkLocalDecl binderName jpDomain
let body ← if numExtraArgs == 0 then
visitLet k?.get! (xs.push y)
else
let x ← mkAuxLetDecl (mkAppN y args[args.size - numExtraArgs:])
if let some k := k? then
visitLet k (xs.push x)
let args := e.getAppArgs
if k?.isNone && numArgs == info.arity then
-- Easy case, there is no continuation and `e` is not overapplied
return info.value.beta args
else
/-
There is a continuation `k` or `e` is over applied.
If `e` is over applied, the extra arguments act as a continuation.
We create a new join point
```
let jp := fun y =>
let x := y <extra-arguments> -- if `e` is over applied
k
```
Recall that `visitLet` incorporates the current continuation
to the new join point `jp`.
-/
let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity])
let binderName ← mkFreshUserName `_y
let jp ← withNewScope do
let y ← mkLocalDecl binderName jpDomain
let body ← if numArgs == info.arity then
visitLet k?.get! (xs.push y)
else
visitLet x (xs.push x)
let body ← mkLetUsingScope body
mkLambda #[y] body
let jp ← mkJpDeclIfNotSimple jp
/- Inline `toInline` and "go-to" `jp` with the result. -/
inlineApp toInline jp
let x ← mkAuxLetDecl (mkAppN y args[info.arity:])
if let some k := k? then
visitLet k (xs.push x)
else
visitLet x (xs.push x)
let body ← mkLetUsingScope body
mkLambda #[y] body
let jp ← mkJpDeclIfNotSimple jp
let value := info.value.beta args[:info.arity]
let value ← attachJp value jp
visitLet value
/-- Try to apply simple simplifications. -/
partial def simpValue? (e : Expr) : SimpM (Option Expr) :=
@ -365,6 +383,7 @@ def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do
trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}"
let (value, s) ← Simp.visitLambda decl.value |>.run { stats } |>.run { simplified := false } |>.run' {}
trace[Compiler.simp.step] "{decl.name} :=\n{decl.value}"
trace[Compiler.simp.stat] "{decl.name}: {← getLCNFSize decl.value}"
if s.simplified then
return some { decl with value }
else
@ -379,7 +398,8 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do
builtin_initialize
registerTraceClass `Compiler.simp.inline
registerTraceClass `Compiler.simp.stat
registerTraceClass `Compiler.simp.step
registerTraceClass `Compiler.simp.inline.stats
end Lean.Compiler
end Lean.Compiler

View file

@ -175,4 +175,18 @@ def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do
else
return none
/--
Return if the LCNF expression has many exit points.
It assumes `cases` expressions only occur at the end of `let`-blocks.
That is, `terminalCases` has already been applied.
It also assumes that if contains a join point, then it has multiple
exit points. This is a reasonable assumption because the simplifier
inlines any join point that was used only once.
-/
def manyExitPoints (e : Expr) : CoreM Bool := do
match e with
| .lam _ _ b _ => manyExitPoints b
| .letE n _ _ b _ => pure (isJpBinderName n) <||> manyExitPoints b
| e => return (← isCasesApp? e).isSome
end Lean.Compiler