lean4-htt/src/Lean/Compiler/LCNF/JoinPoints.lean
2023-10-08 10:46:05 +02:00

635 lines
22 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PullFunDecls
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.ScopeM
import Lean.Compiler.LCNF.InferType
namespace Lean.Compiler.LCNF
namespace JoinPointFinder
open ScopeM
/--
Info about a join point candidate (a `fun` declaration) during the find phase.
-/
structure CandidateInfo where
/--
The arity of the candidate
-/
arity : Nat
/--
The set of candidates that rely on this candidate to be a join point.
For a more detailed explanation see the documentation of `find`
-/
associated : HashSet FVarId
deriving Inhabited
/--
The state for the join point candidate finder.
-/
structure FindState where
/--
All current join point candidates accessible by their `FVarId`.
-/
candidates : HashMap FVarId CandidateInfo := .empty
/--
The `FVarId`s of all `fun` declarations that were declared within the
current `fun`.
-/
scope : HashSet FVarId := .empty
abbrev ReplaceCtx := HashMap FVarId Name
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
/--
Attempt to find a join point candidate by its `FVarId`.
-/
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
return (← get).candidates.find? fvarId
/--
Erase a join point candidate as well as all the ones that depend on it
by its `FVarId`, no error is thrown is the candidate does not exist.
-/
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
if let some info ← findCandidate? fvarId then
modify (fun state => { state with candidates := state.candidates.erase fvarId })
info.associated.forM eraseCandidate
/--
Combinator for modifying the candidates in `FindM`.
-/
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
modify (fun state => {state with candidates := f state.candidates })
/--
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
forFVarM eraseCandidate a
/--
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesInLetValue (e : LetValue) : FindM Unit := do
forFVarM eraseCandidate e
/--
Add a new join point candidate to the state.
-/
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
let cinfo := { arity, associated := .empty }
modifyCandidates (fun cs => cs.insert fvarId cinfo )
/--
Add a new join point dependency from `src` to `dst`.
-/
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
if let some targetInfo ← findCandidate? target then
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
else
eraseCandidate src
/--
Find all `fun` declarations that qualify as a join point, that is:
- are always fully applied
- are always called in tail position
Where a `fun` `f` is in tail position iff it is called as follows:
```
let res := f arg
res
```
The majority (if not all) tail calls will be brought into this form
by the simplifier pass.
Furthermore a `fun` disqualifies as a join point if turning it into a join
point would turn a call to it into an out of scope join point.
This can happen if we have something like:
```
def test (b : Bool) (x y : Nat) : Nat :=
fun myjp x => Nat.add x (Nat.add x x)
fun f y =>
let x := Nat.add y y
myjp x
fun f y =>
let x := Nat.mul y y
myjp x
cases b (f x) (g y)
```
`f` and `g` can be detected as a join point right away, however
`myjp` can only ever be detected as a join point after we have established
this. This is because otherwise the calls to `myjp` in `f` and `g` would
produce out of scope join point jumps.
-/
partial def find (decl : Decl) : CompilerM FindState := do
let (_, candidates) ← go decl.value |>.run none |>.run {} |>.run' {}
return candidates
where
go : Code → FindM Unit
| .let decl k => do
match k, decl.value with
| .return valId, .fvar fvarId args =>
args.forM removeCandidatesInArg
if let some candidateInfo ← findCandidate? fvarId then
-- Erase candidate that are not fully applied or applied outside of tail position
if valId != decl.fvarId || args.size != candidateInfo.arity then
eraseCandidate fvarId
-- Out of scope join point candidate handling
else if let some upperCandidate ← read then
if !(← isInScope fvarId) then
addDependency fvarId upperCandidate
else
eraseCandidate fvarId
| _, _ =>
removeCandidatesInLetValue decl.value
go k
| .fun decl k => do
withReader (fun _ => some decl.fvarId) do
withNewScope do
go decl.value
addCandidate decl.fvarId decl.getArity
addToScope decl.fvarId
go k
| .jp decl k => do
go decl.value
go k
| .jmp _ args => args.forM removeCandidatesInArg
| .return val => eraseCandidate val
| .cases c => do
eraseCandidate c.discr
c.alts.forM (·.forCodeM go)
| .unreach .. => return ()
/--
Replace all join point candidate `fun` declarations with `jp` ones
and all calls to them with `jmp`s.
-/
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper
let newValue ← go decl.value |>.run replaceCtx
return { decl with value := newValue }
where
go (code : Code) : ReplaceM Code := do
match code with
| .let decl k =>
match k, decl.value with
| .return valId, .fvar fvarId args =>
if valId == decl.fvarId then
if (← read).contains fvarId then
eraseLetDecl decl
return .jmp fvarId args
else
return code
else
return code
| _, _ => return Code.updateLet! code decl (← go k)
| .fun decl k =>
if let some replacement := (← read).find? decl.fvarId then
let newDecl := { decl with
binderName := replacement,
value := (← go decl.value)
}
modifyLCtx fun lctx => lctx.addFunDecl newDecl
return .jp newDecl (← go k)
else
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .jp decl k =>
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .cases cs =>
return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go))
| .jmp .. | .return .. | .unreach .. =>
return code
end JoinPointFinder
namespace JoinPointContextExtender
open ScopeM
/--
The context managed by `ExtendM`.
-/
structure ExtendContext where
/--
The `FVarId` of the current join point if we are currently inside one.
-/
currentJp? : Option FVarId := none
/--
The list of valid candidates for extending the context. This will be
all `let` and `fun` declarations as well as all `jp` parameters up
until the last `fun` declaration in the tree.
-/
candidates : FVarIdSet := {}
/--
The state managed by `ExtendM`.
-/
structure ExtendState where
/--
A map from join point `FVarId`s to a respective map from free variables
to `Param`s. The free variables in this map are the once that the context
of said join point will be extended by by passing in the respective parameter.
-/
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
/--
The monad for the `extendJoinPointContext` pass.
-/
abbrev ExtendM := ReaderT ExtendContext StateRefT ExtendState ScopeM
/--
Replace a free variable if necessary, that is:
- It is in the list of candidates
- We are currently within a join point (if we are within a function there
cannot be a need to replace them since we dont extend their context)
- Said join point actually has a replacement parameter registered.
otherwise just return `fvar`.
-/
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
if (← read).candidates.contains fvar then
if let some currentJp := (← read).currentJp? then
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
return replacement.fvarId
return fvar
/--
Add a new candidate to the current scope + to the list of candidates
if we are currently within a join point. Then execute `x`.
-/
def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
addToScope fvar
if (← read).currentJp?.isSome then
withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do
x
else
x
/--
Same as `withNewCandidate` but with multiple `FVarId`s.
-/
def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do
if (← read).currentJp?.isSome then
let candidates := (← read).candidates
let folder (acc : FVarIdSet) (val : FVarId) := do
addToScope val
return acc.insert val
let newCandidates ← fvars.foldlM (init := candidates) folder
withReader (fun ctx => { ctx with candidates := newCandidates }) do
x
else
x
/--
Extend the context of the current join point (if we are within one)
by `fvar` if necessary.
This is necessary if:
- `fvar` is not in scope (that is, was declared outside of the current jp)
- we have not already extended the context by `fvar`
- the list of candidates contains `fvar`. This is because if we have something
like:
```
let x := ..
fun f a =>
jp j b =>
let y := x
y
```
There is no point in extending the context of `j` by `x` because we
cannot lift a join point outside of a local function declaration.
-/
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
if let some currentJp := (← read).currentJp? then
let mut translator := (← get).fvarMap.find! currentJp
let candidates := (← read).candidates
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
let typ ← getType fvar
let newParam ← mkAuxParam typ
translator := translator.insert fvar newParam
modify fun s => { s with fvarMap := s.fvarMap.insert currentJp translator }
/--
Merge the extended context of two join points if necessary. That is
if we have a structure such as:
```
jp j.1 ... =>
jp j.2 .. =>
...
...
```
And we are just done visiting `j.2` we want to extend the context of
`j.1` by all free variables that the context of `j.2` was extended by
as well because we need to drag these variables through at the call sites
of `j.2` in `j.1`.
-/
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
if (← read).currentJp?.isSome then
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
for (fvar, _) in additionalArgs do
extendByIfNecessary fvar
/--
We call this whenever we enter a new local function. It clears both the
current join point and the list of candidates since we can't lift join
points outside of functions as explained in `mergeJpContextIfNecessary`.
-/
def withNewFunScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
withReader (fun ctx => { ctx with currentJp? := none, candidates := {} }) do
withNewScope do
x
/--
We call this whenever we enter a new join point. It will set the current
join point and extend the list of candidates by all of the parameters of
the join point. This is so in the case of nested join points that refer
to parameters of the current one we extend the context of the nested
join points by said parameters.
-/
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
withNewScope do
withNewCandidates (decl.params.map (·.fvarId)) do
x
/--
We call this whenever we visit a new arm of a cases statement.
It will back up the current scope (since we are doing a case split
and want to continue with other arms afterwards) and add all of the
parameters of the match arm to the list of candidates.
-/
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
withBackTrackingScope do
withNewCandidates (alt.getParams.map (·.fvarId)) do
x
/--
Use all of the above functions to find free variables declared outside
of join points that said join points can be reasonaly extended by. Reasonable
meaning that in case the current join point is nested within a function
declaration we will not extend it by free variables declared before the
function declaration because we cannot lift join points outside of function
declarations.
All of this is done to eliminate dependencies of join points onto their
position within the code so we can pull them out as far as possible, hopefully
enabling new inlining possibilities in the next simplifier run.
-/
partial def extend (decl : Decl) : CompilerM Decl := do
let newValue ← go decl.value |>.run {} |>.run' {} |>.run' {}
let decl := { decl with value := newValue }
decl.pullFunDecls
where
goFVar (fvar : FVarId) : ExtendM FVarId := do
extendByIfNecessary fvar
replaceFVar fvar
go (code : Code) : ExtendM Code := do
match code with
| .let decl k =>
let decl ← decl.updateValue (← mapFVarM goFVar decl.value)
withNewCandidate decl.fvarId do
return Code.updateLet! code decl (← go k)
| .jp decl k =>
let decl ← withNewJpScope decl do
let value ← go decl.value
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
decl.update newType (additionalParams ++ decl.params) value
mergeJpContextIfNecessary decl.fvarId
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (← go k)
| .fun decl k =>
let decl ← withNewFunScope decl do
decl.updateValue (← go decl.value)
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (← go k)
| .cases cs =>
extendByIfNecessary cs.discr
let discr ← replaceFVar cs.discr
let visitor := fun alt => do
withNewAltScope alt do
alt.mapCodeM go
let alts ← cs.alts.mapM visitor
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs ← args.mapM (mapFVarM goFVar)
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some _currentJp := (← read).currentJp? then
let f := fun arg => do
return .fvar (← goFVar arg)
newArgs := (←additionalArgs.mapM f) ++ newArgs
else
newArgs := (additionalArgs.map .fvar) ++ newArgs
return Code.updateJmp! code fn newArgs
| .return var =>
extendByIfNecessary var
return Code.updateReturn! code (← replaceFVar var)
| .unreach .. => return code
end JoinPointContextExtender
namespace JoinPointCommonArgs
/--
Context for `ReduceAnalysisM`.
-/
structure AnalysisCtx where
/--
The variables that are in scope at the time of the definition of
the join point.
-/
jpScopes : FVarIdMap FVarIdSet := {}
/--
State for `ReduceAnalysisM`.
-/
structure AnalysisState where
/--
A map, that for each join point id contains a map from all (so far)
duplicated argument ids to the respective duplicate value
-/
jpJmpArgs : FVarIdMap FVarSubst := {}
abbrev ReduceAnalysisM := ReaderT AnalysisCtx StateRefT AnalysisState ScopeM
abbrev ReduceActionM := ReaderT AnalysisState CompilerM
def isInJpScope (jp : FVarId) (var : FVarId) : ReduceAnalysisM Bool := do
return (← read).jpScopes.find! jp |>.contains var
open ScopeM
/--
Take a look at each join point and each of their call sites. If all
call sites of a join point have one or more arguments in common, for example:
```
jp _jp.1 a b c => ...
...
cases foo
| n1 => jmp _jp.1 d e f
| n2 => jmp _jp.1 g e h
```
We can get rid of the common argument in favour of inlining it directly
into the join point (in this case the `e`). This reduces the amount of
arguments we have to pass around drastically for example in `ReaderT` based
monad stacks.
Note 1: This transformation can in certain niche cases obtain better results.
For example:
```
jp foo a b => ..
let x := ...
cases discr
| n1 => jmp foo x y
| n2 => jmp foo x z
```
Here we will not collapse the `x` since it is defined after the join point `foo`
and thus not accessible for substitution yet. We could however reorder the code in
such a way that this is possible, this is currently not done since we observe
than in praxis most of the applications of this transformation can occur naturally
without reordering.
Note 2: This transformation is kind of the opposite of `JoinPointContextExtender`.
However we still benefit from the extender because in the `simp` run after it
we might be able to pull join point declarations further up in the hierarchy
of nested functions/join points which in turn might enable additional optimizations.
After we have performed all of these optimizations we can take away the
(remaining) common arguments and end up with nicely floated and optimized
code that has as little arguments as possible in the join points.
-/
partial def reduce (decl : Decl) : CompilerM Decl := do
let (_, analysis) ← goAnalyze decl.value |>.run {} |>.run {} |>.run' {}
let newValue ← goReduce decl.value |>.run analysis
return { decl with value := newValue }
where
goAnalyzeFunDecl (fn : FunDecl) : ReduceAnalysisM Unit := do
withNewScope do
fn.params.forM (addToScope ·.fvarId)
goAnalyze fn.value
goAnalyze (code : Code) : ReduceAnalysisM Unit := do
match code with
| .let decl k =>
addToScope decl.fvarId
goAnalyze k
| .jp decl k =>
goAnalyzeFunDecl decl
let scope ← getScope
withReader (fun ctx => { ctx with jpScopes := ctx.jpScopes.insert decl.fvarId scope }) do
addToScope decl.fvarId
goAnalyze k
| .fun decl k =>
goAnalyzeFunDecl decl
addToScope decl.fvarId
goAnalyze k
| .cases cs =>
let visitor alt := do
withNewScope do
alt.getParams.forM (addToScope ·.fvarId)
goAnalyze alt.getCode
cs.alts.forM visitor
| .jmp fn args =>
let decl ← getFunDecl fn
if let some knownArgs := (← get).jpJmpArgs.find? fn then
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs.find? param.fvarId then
if arg.toExpr != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
else
let folder := fun acc (param, arg) => do
if (← allFVarM (isInJpScope fn) arg) then
return acc.insert param.fvarId arg.toExpr
else
return acc
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn interestingArgs }
| .return .. | .unreach .. => return ()
goReduce (code : Code) : ReduceActionM Code := do
match code with
| .jp decl k =>
if let some reducibleArgs := (← read).jpJmpArgs.find? decl.fvarId then
let filter param := do
let erasable := reducibleArgs.contains param.fvarId
if erasable then
eraseParam param
return !erasable
let newParams ← decl.params.filterM filter
let mut newValue ← goReduce decl.value
newValue ← replaceFVars newValue reducibleArgs false
let newType ←
if newParams.size != decl.params.size then
mkForallParams newParams (← newValue.inferType)
else
pure decl.type
let k ← goReduce k
let decl ← decl.update newType newParams newValue
return Code.updateFun! code decl k
else
return Code.updateFun! code decl (← goReduce k)
| .jmp fn args =>
let reducibleArgs := (← read).jpJmpArgs.find! fn
let decl ← getFunDecl fn
let newParams := decl.params.zip args
|>.filter (!reducibleArgs.contains ·.fst.fvarId)
|>.map Prod.snd
return Code.updateJmp! code fn newParams
| .let decl k =>
return Code.updateLet! code decl (← goReduce k)
| .fun decl k =>
let decl ← decl.updateValue (← goReduce decl.value)
return Code.updateFun! code decl (← goReduce k)
| .cases cs =>
let alts ← cs.alts.mapM (·.mapCodeM goReduce)
return Code.updateCases! code cs.resultType cs.discr alts
| .return .. | .unreach .. => return code
end JoinPointCommonArgs
/--
Find all `fun` declarations in `decl` that qualify as join points then replace
their definitions and call sites with `jp`/`jmp`.
-/
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
let findResult ← JoinPointFinder.find decl
trace[Compiler.findJoinPoints] "Found: {findResult.candidates.size} jp candidates"
JoinPointFinder.replace decl findResult
def findJoinPoints : Pass :=
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base
builtin_initialize
registerTraceClass `Compiler.findJoinPoints (inherited := true)
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
JoinPointContextExtender.extend decl
def extendJoinPointContext (occurrence : Nat := 0) (phase := Phase.mono) (_h : phase ≠ .base := by simp): Pass :=
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext phase (occurrence := occurrence)
builtin_initialize
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
def Decl.commonJoinPointArgs (decl : Decl) : CompilerM Decl := do
JoinPointCommonArgs.reduce decl
def commonJoinPointArgs : Pass :=
.mkPerDeclaration `commonJoinPointArgs Decl.commonJoinPointArgs .mono
builtin_initialize
registerTraceClass `Compiler.commonJoinPointArgs (inherited := true)
end Lean.Compiler.LCNF