feat: avoid generation of auxliary join points when inlining functions that have only one exit point
This commit is contained in:
parent
23be59b747
commit
bf2c0bf5b7
2 changed files with 104 additions and 70 deletions
|
|
@ -181,23 +181,41 @@ def simpAppApp? (e : Expr) : OptionT SimpM Expr := do
|
|||
def shouldInline (localDecl : LocalDecl) : SimpM Bool :=
|
||||
return (← read).localInline && (← read).stats.shouldInline localDecl.userName
|
||||
|
||||
def inlineCandidate? (e : Expr) : SimpM (Option Nat) := do
|
||||
structure InlineCandidateInfo where
|
||||
isLocal : Bool
|
||||
arity : Nat
|
||||
/-- Value (lambda expression) of the function to be inlined. -/
|
||||
value : Expr
|
||||
|
||||
def inlineCandidate? (e : Expr) : SimpM (Option InlineCandidateInfo) := do
|
||||
let f := e.getAppFn
|
||||
let arity ← match f with
|
||||
| .const declName _ =>
|
||||
unless hasInlineAttribute (← getEnv) declName do return none
|
||||
-- TODO: check whether function is recursive or not.
|
||||
-- We can skip the test and store function inline so far.
|
||||
let some decl ← getStage1Decl? declName | return none
|
||||
pure decl.getArity
|
||||
| _ =>
|
||||
match (← findLambda? f) with
|
||||
| none => return none
|
||||
| some localDecl =>
|
||||
unless (← shouldInline localDecl) do return none
|
||||
pure (getLambdaArity localDecl.value)
|
||||
if e.getAppNumArgs < arity then return none
|
||||
return e.getAppNumArgs - arity
|
||||
match f with
|
||||
| .const declName us =>
|
||||
unless hasInlineAttribute (← getEnv) declName do return none
|
||||
-- TODO: check whether function is recursive or not.
|
||||
-- We can skip the test and store function inline so far.
|
||||
let some decl ← getStage1Decl? declName | return none
|
||||
let numArgs := e.getAppNumArgs
|
||||
let arity := decl.getArity
|
||||
if numArgs < arity then return none
|
||||
return some {
|
||||
arity
|
||||
isLocal := false
|
||||
value := decl.value.instantiateLevelParams decl.levelParams us
|
||||
}
|
||||
| _ =>
|
||||
match (← findLambda? f) with
|
||||
| none => return none
|
||||
| some localDecl =>
|
||||
unless (← shouldInline localDecl) do return none
|
||||
let numArgs := e.getAppNumArgs
|
||||
let arity := getLambdaArity localDecl.value
|
||||
if numArgs < arity then return none
|
||||
return some {
|
||||
arity
|
||||
isLocal := true
|
||||
value := localDecl.value
|
||||
}
|
||||
|
||||
/--
|
||||
If `e` if a free variable that expands to a valid LCNF terminal `let`-block expression `e'`,
|
||||
|
|
@ -255,23 +273,6 @@ partial def inlineProjInst? (e : Expr) : OptionT SimpM Expr := do
|
|||
saved.restore
|
||||
return none
|
||||
|
||||
partial def inlineApp (e : Expr) (jp? : Option Expr := none) : SimpM Expr := do
|
||||
let f := e.getAppFn
|
||||
trace[Compiler.simp.inline] "inlining {e}"
|
||||
let value ← match f with
|
||||
| .const declName us =>
|
||||
let some decl ← getStage1Decl? declName | unreachable!
|
||||
pure <| decl.value.instantiateLevelParams decl.levelParams us
|
||||
| _ =>
|
||||
let some localDecl ← findLambda? f | unreachable!
|
||||
pure localDecl.value
|
||||
let args := e.getAppArgs
|
||||
let value := value.beta args
|
||||
let value ← attachOptJp value jp?
|
||||
assert! !value.isLambda
|
||||
markSimplified
|
||||
withLocalInline (!f.isConst) do visitLet value
|
||||
|
||||
/--
|
||||
If `e` is an application that can be inlined, inline it.
|
||||
|
||||
|
|
@ -280,45 +281,62 @@ that need to instantiated with `xs`. That is, if `k? = some k`, then `k.instanti
|
|||
is an expression without loose bound variables.
|
||||
-/
|
||||
partial def inlineApp? (e : Expr) (xs : Array Expr) (k? : Option Expr) : SimpM (Option Expr) := do
|
||||
let some numExtraArgs ← inlineCandidate? e | return none
|
||||
let some info ← inlineCandidate? e | return none
|
||||
let args := e.getAppArgs
|
||||
if k?.isNone && numExtraArgs == 0 then
|
||||
-- Easy case, there is not continuation and `e` is not over applied
|
||||
inlineApp e
|
||||
let numArgs := args.size
|
||||
trace[Compiler.simp.inline] "inlining {e}"
|
||||
markSimplified
|
||||
withLocalInline info.isLocal do
|
||||
if !(← manyExitPoints info.value) then
|
||||
-- If `info.value` has only one exit point, we don't need to create a new join point
|
||||
let value := info.value.beta args[:info.arity]
|
||||
let value ← visitLet value #[]
|
||||
trace[Meta.debug] "value: {value}"
|
||||
match numArgs == info.arity, k? with
|
||||
| true, none => return value
|
||||
| false, none => return mkAppN (← mkAuxLetDecl value) args[info.arity:]
|
||||
| true, some k => let x ← mkAuxLetDecl value; visitLet k (xs.push x)
|
||||
| false, some k =>
|
||||
let x ← mkAuxLetDecl value
|
||||
let x ← mkAuxLetDecl (mkAppN x args[info.arity:])
|
||||
visitLet k (xs.push x)
|
||||
else
|
||||
/-
|
||||
There is a continuation `k` or `e` is over applied.
|
||||
If `e` is over applied, the extra arguments act as continuation.
|
||||
-/
|
||||
let toInline := mkAppN e.getAppFn args[:args.size - numExtraArgs]
|
||||
/-
|
||||
`toInline` is the application that is going to be inline
|
||||
We create a new join point
|
||||
```
|
||||
let jp := fun y =>
|
||||
let x := y <extra-arguments> -- if `e` is over applied
|
||||
k
|
||||
```
|
||||
Recall that `visitLet` incorporates the current continuation
|
||||
to the new join point `jp`.
|
||||
-/
|
||||
let jpDomain ← inferType toInline
|
||||
let binderName ← mkFreshUserName `_y
|
||||
let jp ← withNewScope do
|
||||
let y ← mkLocalDecl binderName jpDomain
|
||||
let body ← if numExtraArgs == 0 then
|
||||
visitLet k?.get! (xs.push y)
|
||||
else
|
||||
let x ← mkAuxLetDecl (mkAppN y args[args.size - numExtraArgs:])
|
||||
if let some k := k? then
|
||||
visitLet k (xs.push x)
|
||||
let args := e.getAppArgs
|
||||
if k?.isNone && numArgs == info.arity then
|
||||
-- Easy case, there is no continuation and `e` is not overapplied
|
||||
return info.value.beta args
|
||||
else
|
||||
/-
|
||||
There is a continuation `k` or `e` is over applied.
|
||||
If `e` is over applied, the extra arguments act as a continuation.
|
||||
|
||||
We create a new join point
|
||||
```
|
||||
let jp := fun y =>
|
||||
let x := y <extra-arguments> -- if `e` is over applied
|
||||
k
|
||||
```
|
||||
Recall that `visitLet` incorporates the current continuation
|
||||
to the new join point `jp`.
|
||||
-/
|
||||
let jpDomain ← inferType (mkAppN e.getAppFn args[:info.arity])
|
||||
let binderName ← mkFreshUserName `_y
|
||||
let jp ← withNewScope do
|
||||
let y ← mkLocalDecl binderName jpDomain
|
||||
let body ← if numArgs == info.arity then
|
||||
visitLet k?.get! (xs.push y)
|
||||
else
|
||||
visitLet x (xs.push x)
|
||||
let body ← mkLetUsingScope body
|
||||
mkLambda #[y] body
|
||||
let jp ← mkJpDeclIfNotSimple jp
|
||||
/- Inline `toInline` and "go-to" `jp` with the result. -/
|
||||
inlineApp toInline jp
|
||||
let x ← mkAuxLetDecl (mkAppN y args[info.arity:])
|
||||
if let some k := k? then
|
||||
visitLet k (xs.push x)
|
||||
else
|
||||
visitLet x (xs.push x)
|
||||
let body ← mkLetUsingScope body
|
||||
mkLambda #[y] body
|
||||
let jp ← mkJpDeclIfNotSimple jp
|
||||
let value := info.value.beta args[:info.arity]
|
||||
let value ← attachJp value jp
|
||||
visitLet value
|
||||
|
||||
/-- Try to apply simple simplifications. -/
|
||||
partial def simpValue? (e : Expr) : SimpM (Option Expr) :=
|
||||
|
|
@ -365,6 +383,7 @@ def Decl.simp? (decl : Decl) : CoreM (Option Decl) := do
|
|||
trace[Compiler.simp.inline.stats] "{decl.name}:{Format.nest 2 (format stats)}"
|
||||
let (value, s) ← Simp.visitLambda decl.value |>.run { stats } |>.run { simplified := false } |>.run' {}
|
||||
trace[Compiler.simp.step] "{decl.name} :=\n{decl.value}"
|
||||
trace[Compiler.simp.stat] "{decl.name}: {← getLCNFSize decl.value}"
|
||||
if s.simplified then
|
||||
return some { decl with value }
|
||||
else
|
||||
|
|
@ -379,7 +398,8 @@ partial def Decl.simp (decl : Decl) : CoreM Decl := do
|
|||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.simp.inline
|
||||
registerTraceClass `Compiler.simp.stat
|
||||
registerTraceClass `Compiler.simp.step
|
||||
registerTraceClass `Compiler.simp.inline.stats
|
||||
|
||||
end Lean.Compiler
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
|
|
@ -175,4 +175,18 @@ def isJump? [Monad m] [MonadLCtx m] (e : Expr) : m (Option FVarId) := do
|
|||
else
|
||||
return none
|
||||
|
||||
/--
|
||||
Return if the LCNF expression has many exit points.
|
||||
It assumes `cases` expressions only occur at the end of `let`-blocks.
|
||||
That is, `terminalCases` has already been applied.
|
||||
It also assumes that if contains a join point, then it has multiple
|
||||
exit points. This is a reasonable assumption because the simplifier
|
||||
inlines any join point that was used only once.
|
||||
-/
|
||||
def manyExitPoints (e : Expr) : CoreM Bool := do
|
||||
match e with
|
||||
| .lam _ _ b _ => manyExitPoints b
|
||||
| .letE n _ _ b _ => pure (isJpBinderName n) <||> manyExitPoints b
|
||||
| e => return (← isCasesApp? e).isSome
|
||||
|
||||
end Lean.Compiler
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue