feat: Compiler pass for reducing common jp args
This commit is contained in:
parent
a608532fd4
commit
dac6127810
3 changed files with 184 additions and 0 deletions
|
|
@ -167,5 +167,20 @@ instance : TraverseFVar Alt where
|
|||
Code.forFVarM f c
|
||||
| .default c => Code.forFVarM f c
|
||||
|
||||
def anyFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do
|
||||
let (_, res) ← TraverseFVar.forFVarM go x |>.run false
|
||||
return res
|
||||
where
|
||||
-- TODO: StateRefT, early return?
|
||||
go (fvar : FVarId) : StateT Bool m Unit := do
|
||||
if (← f fvar) then set true
|
||||
|
||||
def allFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do
|
||||
let (_, res) ← TraverseFVar.forFVarM go x |>.run true
|
||||
return res
|
||||
where
|
||||
-- TODO: StateRefT, early return?
|
||||
go (fvar : FVarId) : StateT Bool m Unit := do
|
||||
if !(← f fvar) then set false
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ import Lean.Compiler.LCNF.PassManager
|
|||
import Lean.Compiler.LCNF.PullFunDecls
|
||||
import Lean.Compiler.LCNF.FVarUtil
|
||||
import Lean.Compiler.LCNF.ScopeM
|
||||
import Lean.Compiler.LCNF.InferType
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
|
|
@ -439,6 +440,164 @@ where
|
|||
|
||||
end JoinPointContextExtender
|
||||
|
||||
namespace JoinPointCommonArgs
|
||||
|
||||
/--
|
||||
Context for `ReduceAnalysisM`.
|
||||
-/
|
||||
structure AnalysisCtx where
|
||||
/--
|
||||
The variables that are in scope at the time of the definition of
|
||||
the join point.
|
||||
-/
|
||||
jpScopes : FVarIdMap FVarIdSet := {}
|
||||
|
||||
/--
|
||||
State for `ReduceAnalysisM`.
|
||||
-/
|
||||
structure AnalysisState where
|
||||
/--
|
||||
Lists of names of arguments of jmps to join points to find duplicates.
|
||||
-/
|
||||
jpJmpArgs : FVarIdMap FVarSubst := {}
|
||||
|
||||
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
|
||||
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
|
||||
|
||||
def isInJpScope (jp : FVarId) (var : FVarId) : ReduceAnalysisM Bool := do
|
||||
return (← read).jpScopes.find! jp |>.contains var
|
||||
|
||||
open ScopeM
|
||||
|
||||
/--
|
||||
Take a look at each join point and each of their call sites. If all
|
||||
call sites of a join point have one or more arguments in common, for example:
|
||||
```
|
||||
jp _jp.1 a b c => ...
|
||||
...
|
||||
cases foo
|
||||
| n1 => jmp _jp.1 d e f
|
||||
| n2 => jmp _jp.1 g e h
|
||||
```
|
||||
We can get rid of the common argument in favour of inlining it directly
|
||||
into the join point (in this case the `e`). This reduces the amount of
|
||||
arguments we have to pass around drastically for example in `ReaderT` based
|
||||
monad stacks.
|
||||
|
||||
Note 1: This transformation can in certain niche cases obtain better results.
|
||||
For example:
|
||||
```
|
||||
jp foo a b => ..
|
||||
let x := ...
|
||||
cases discr
|
||||
| n1 => jmp foo x y
|
||||
| n2 => jmp foo x z
|
||||
```
|
||||
Here we will not collapse the `x` since it is defined after the join point `foo`
|
||||
and thus not accessible for substitution yet. We could however reorder the code in
|
||||
such a way that this is possible, this is currently not done since we observe
|
||||
than in praxis most of the applications of this transformation can occur naturally
|
||||
without reordering.
|
||||
|
||||
Note 2: This transformation is kind of the opposite of `JoinPointContextExtender`.
|
||||
However we still benefit from the extender because in the `simp` run after it
|
||||
we might be able to pull join point declarations further up in the hierarchy
|
||||
of nested functions/join points which in turn might enable additional optimizations.
|
||||
After we have performed all of these optimizations we can take away the
|
||||
(remaining) common arguments and end up with nicely floated and optimized
|
||||
code that has as little arguments as possible in the join points.
|
||||
-/
|
||||
partial def reduce (decl : Decl) : CompilerM Decl := do
|
||||
let (_, analysis) ← goAnalyze decl.value |>.run {} |>.run {} |>.run' {}
|
||||
let newValue ← goReduce decl.value |>.run analysis
|
||||
return { decl with value := newValue }
|
||||
where
|
||||
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
|
||||
withNewScope do
|
||||
fn.params.forM (addToScope ·.fvarId)
|
||||
goAnalyze fn.value
|
||||
|
||||
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
addToScope decl.fvarId
|
||||
goAnalyze k
|
||||
| .jp decl k =>
|
||||
goAnalyzeFunDecl decl
|
||||
let scope ← getScope
|
||||
withReader (fun ctx => { ctx with jpScopes := ctx.jpScopes.insert decl.fvarId scope }) do
|
||||
addToScope decl.fvarId
|
||||
goAnalyze k
|
||||
| .fun decl k =>
|
||||
goAnalyzeFunDecl decl
|
||||
addToScope decl.fvarId
|
||||
goAnalyze k
|
||||
| .cases cs =>
|
||||
let visitor alt := do
|
||||
withNewScope do
|
||||
alt.getParams.forM (addToScope ·.fvarId)
|
||||
goAnalyze alt.getCode
|
||||
cs.alts.forM visitor
|
||||
| .jmp fn args =>
|
||||
let decl ← getFunDecl fn
|
||||
if let some knownArgs := (← get).jpJmpArgs.find? fn then
|
||||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs.find? param.fvarId then
|
||||
if arg != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
else
|
||||
let folder := fun acc (param, arg) => do
|
||||
if (← allFVarM (isInJpScope fn) arg) then
|
||||
return acc.insert param.fvarId arg
|
||||
else
|
||||
return acc
|
||||
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
|
||||
| .return .. | .unreach .. => return ()
|
||||
|
||||
goReduce (code : Code) : ReduceActionM Code := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
if let some reducibleArgs := (← read).jpJmpArgs.find? decl.fvarId then
|
||||
let filter param := do
|
||||
let erasable := reducibleArgs.contains param.fvarId
|
||||
if erasable then
|
||||
eraseParam param
|
||||
return !erasable
|
||||
let newParams ← decl.params.filterM filter
|
||||
let mut newValue ← goReduce decl.value
|
||||
newValue ← replaceFVars newValue reducibleArgs false
|
||||
let newType ←
|
||||
if newParams.size != decl.params.size then
|
||||
mkForallParams newParams (← newValue.inferType)
|
||||
else
|
||||
pure decl.type
|
||||
let k ← goReduce k
|
||||
let decl ← decl.update newType newParams newValue
|
||||
return Code.updateFun! code decl k
|
||||
else
|
||||
return Code.updateFun! code decl (← goReduce k)
|
||||
| .jmp fn args =>
|
||||
let reducibleArgs := (← read).jpJmpArgs.find! fn
|
||||
let decl ← getFunDecl fn
|
||||
let newParams := decl.params.zip args
|
||||
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|
||||
|>.map Prod.snd
|
||||
return Code.updateJmp! code fn newParams
|
||||
| .let decl k =>
|
||||
return Code.updateLet! code decl (← goReduce k)
|
||||
| .fun decl k =>
|
||||
let decl ← decl.updateValue (← goReduce decl.value)
|
||||
return Code.updateFun! code decl (← goReduce k)
|
||||
| .cases cs =>
|
||||
let alts ← cs.alts.mapM (·.mapCodeM goReduce)
|
||||
return Code.updateCases! code cs.resultType cs.discr alts
|
||||
| .return .. | .unreach .. => return code
|
||||
|
||||
end JoinPointCommonArgs
|
||||
|
||||
/--
|
||||
Find all `fun` declarations in `decl` that qualify as join points then replace
|
||||
their definitions and call sites with `jp`/`jmp`.
|
||||
|
|
@ -463,4 +622,13 @@ def extendJoinPointContext : Pass :=
|
|||
builtin_initialize
|
||||
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
|
||||
|
||||
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
|
||||
JoinPointCommonArgs.reduce decl
|
||||
|
||||
def commonJoinPointArgs : Pass :=
|
||||
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
|
|||
|
|
@ -62,6 +62,7 @@ def builtinPassManager : PassManager := {
|
|||
extendJoinPointContext,
|
||||
floatLetIn (phase := .mono) (occurrence := 1),
|
||||
reduceArity,
|
||||
commonJoinPointArgs,
|
||||
simp (occurrence := 4) (phase := .mono),
|
||||
floatLetIn (phase := .mono) (occurrence := 2),
|
||||
lambdaLifting,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue