feat: Compiler pass for reducing common jp args

This commit is contained in:
Henrik Böving 2022-10-20 00:51:32 +02:00 committed by Leonardo de Moura
parent a608532fd4
commit dac6127810
3 changed files with 184 additions and 0 deletions

View file

@ -167,5 +167,20 @@ instance : TraverseFVar Alt where
Code.forFVarM f c
| .default c => Code.forFVarM f c
def anyFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do
let (_, res) ← TraverseFVar.forFVarM go x |>.run false
return res
where
-- TODO: StateRefT, early return?
go (fvar : FVarId) : StateT Bool m Unit := do
if (← f fvar) then set true
def allFVarM [Monad m] [TraverseFVar α] (f : FVarId → m Bool) (x : α) : m Bool := do
let (_, res) ← TraverseFVar.forFVarM go x |>.run true
return res
where
-- TODO: StateRefT, early return?
go (fvar : FVarId) : StateT Bool m Unit := do
if !(← f fvar) then set false
end Lean.Compiler.LCNF

View file

@ -8,6 +8,7 @@ import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PullFunDecls
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.ScopeM
import Lean.Compiler.LCNF.InferType
namespace Lean.Compiler.LCNF
@ -439,6 +440,164 @@ where
end JoinPointContextExtender
namespace JoinPointCommonArgs
/--
Context for `ReduceAnalysisM`.
-/
structure AnalysisCtx where
/--
The variables that are in scope at the time of the definition of
the join point.
-/
jpScopes : FVarIdMap FVarIdSet := {}
/--
State for `ReduceAnalysisM`.
-/
structure AnalysisState where
/--
Lists of names of arguments of jmps to join points to find duplicates.
-/
jpJmpArgs : FVarIdMap FVarSubst := {}
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
def isInJpScope (jp : FVarId) (var : FVarId) : ReduceAnalysisM Bool := do
return (← read).jpScopes.find! jp |>.contains var
open ScopeM
/--
Take a look at each join point and each of their call sites. If all
call sites of a join point have one or more arguments in common, for example:
```
jp _jp.1 a b c => ...
...
cases foo
| n1 => jmp _jp.1 d e f
| n2 => jmp _jp.1 g e h
```
We can get rid of the common argument in favour of inlining it directly
into the join point (in this case the `e`). This reduces the amount of
arguments we have to pass around drastically for example in `ReaderT` based
monad stacks.
Note 1: This transformation can in certain niche cases obtain better results.
For example:
```
jp foo a b => ..
let x := ...
cases discr
| n1 => jmp foo x y
| n2 => jmp foo x z
```
Here we will not collapse the `x` since it is defined after the join point `foo`
and thus not accessible for substitution yet. We could however reorder the code in
such a way that this is possible, this is currently not done since we observe
than in praxis most of the applications of this transformation can occur naturally
without reordering.
Note 2: This transformation is kind of the opposite of `JoinPointContextExtender`.
However we still benefit from the extender because in the `simp` run after it
we might be able to pull join point declarations further up in the hierarchy
of nested functions/join points which in turn might enable additional optimizations.
After we have performed all of these optimizations we can take away the
(remaining) common arguments and end up with nicely floated and optimized
code that has as little arguments as possible in the join points.
-/
partial def reduce (decl : Decl) : CompilerM Decl := do
let (_, analysis) ← goAnalyze decl.value |>.run {} |>.run {} |>.run' {}
let newValue ← goReduce decl.value |>.run analysis
return { decl with value := newValue }
where
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
withNewScope do
fn.params.forM (addToScope ·.fvarId)
goAnalyze fn.value
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
match code with
| .let decl k =>
addToScope decl.fvarId
goAnalyze k
| .jp decl k =>
goAnalyzeFunDecl decl
let scope ← getScope
withReader (fun ctx => { ctx with jpScopes := ctx.jpScopes.insert decl.fvarId scope }) do
addToScope decl.fvarId
goAnalyze k
| .fun decl k =>
goAnalyzeFunDecl decl
addToScope decl.fvarId
goAnalyze k
| .cases cs =>
let visitor alt := do
withNewScope do
alt.getParams.forM (addToScope ·.fvarId)
goAnalyze alt.getCode
cs.alts.forM visitor
| .jmp fn args =>
let decl ← getFunDecl fn
if let some knownArgs := (← get).jpJmpArgs.find? fn then
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs.find? param.fvarId then
if arg != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
else
let folder := fun acc (param, arg) => do
if (← allFVarM (isInJpScope fn) arg) then
return acc.insert param.fvarId arg
else
return acc
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
| .return .. | .unreach .. => return ()
goReduce (code : Code) : ReduceActionM Code := do
match code with
| .jp decl k =>
if let some reducibleArgs := (← read).jpJmpArgs.find? decl.fvarId then
let filter param := do
let erasable := reducibleArgs.contains param.fvarId
if erasable then
eraseParam param
return !erasable
let newParams ← decl.params.filterM filter
let mut newValue ← goReduce decl.value
newValue ← replaceFVars newValue reducibleArgs false
let newType ←
if newParams.size != decl.params.size then
mkForallParams newParams (← newValue.inferType)
else
pure decl.type
let k ← goReduce k
let decl ← decl.update newType newParams newValue
return Code.updateFun! code decl k
else
return Code.updateFun! code decl (← goReduce k)
| .jmp fn args =>
let reducibleArgs := (← read).jpJmpArgs.find! fn
let decl ← getFunDecl fn
let newParams := decl.params.zip args
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|>.map Prod.snd
return Code.updateJmp! code fn newParams
| .let decl k =>
return Code.updateLet! code decl (← goReduce k)
| .fun decl k =>
let decl ← decl.updateValue (← goReduce decl.value)
return Code.updateFun! code decl (← goReduce k)
| .cases cs =>
let alts ← cs.alts.mapM (·.mapCodeM goReduce)
return Code.updateCases! code cs.resultType cs.discr alts
| .return .. | .unreach .. => return code
end JoinPointCommonArgs
/--
Find all `fun` declarations in `decl` that qualify as join points then replace
their definitions and call sites with `jp`/`jmp`.
@ -463,4 +622,13 @@ def extendJoinPointContext : Pass :=
builtin_initialize
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
JoinPointCommonArgs.reduce decl
def commonJoinPointArgs : Pass :=
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
builtin_initialize
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)
end Lean.Compiler.LCNF

View file

@ -62,6 +62,7 @@ def builtinPassManager : PassManager := {
extendJoinPointContext,
floatLetIn (phase := .mono) (occurrence := 1),
reduceArity,
commonJoinPointArgs,
simp (occurrence := 4) (phase := .mono),
floatLetIn (phase := .mono) (occurrence := 2),
lambdaLifting,