466 lines
16 KiB
Text
466 lines
16 KiB
Text
/-
|
||
Copyright (c) 2022 Henrik Böving. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Henrik Böving
|
||
-/
|
||
import Lean.Compiler.LCNF.CompilerM
|
||
import Lean.Compiler.LCNF.PassManager
|
||
import Lean.Compiler.LCNF.PullFunDecls
|
||
import Lean.Compiler.LCNF.FVarUtil
|
||
import Lean.Compiler.LCNF.ScopeM
|
||
|
||
namespace Lean.Compiler.LCNF
|
||
|
||
namespace JoinPointFinder
|
||
|
||
open ScopeM
|
||
|
||
/--
|
||
Info about a join point candidate (a `fun` declaration) during the find phase.
|
||
-/
|
||
structure CandidateInfo where
|
||
/--
|
||
The arity of the candidate
|
||
-/
|
||
arity : Nat
|
||
/--
|
||
The set of candidates that rely on this candidate to be a join point.
|
||
For a more detailed explanation see the documentation of `find`
|
||
-/
|
||
associated : HashSet FVarId
|
||
deriving Inhabited
|
||
|
||
/--
|
||
The state for the join point candidate finder.
|
||
-/
|
||
structure FindState where
|
||
/--
|
||
All current join point candidates accessible by their `FVarId`.
|
||
-/
|
||
candidates : HashMap FVarId CandidateInfo := .empty
|
||
/--
|
||
The `FVarId`s of all `fun` declarations that were declared within the
|
||
current `fun`.
|
||
-/
|
||
scope : HashSet FVarId := .empty
|
||
|
||
abbrev ReplaceCtx := HashMap FVarId Name
|
||
|
||
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
|
||
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
|
||
|
||
/--
|
||
Attempt to find a join point candidate by its `FVarId`.
|
||
-/
|
||
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
|
||
return (← get).candidates.find? fvarId
|
||
|
||
/--
|
||
Erase a join point candidate as well as all the ones that depend on it
|
||
by its `FVarId`, no error is thrown is the candidate does not exist.
|
||
-/
|
||
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
|
||
if let some info ← findCandidate? fvarId then
|
||
modify (fun state => { state with candidates := state.candidates.erase fvarId })
|
||
info.associated.forM eraseCandidate
|
||
|
||
/--
|
||
Combinator for modifying the candidates in `FindM`.
|
||
-/
|
||
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
|
||
modify (fun state => {state with candidates := f state.candidates })
|
||
|
||
/--
|
||
Remove all join point candidates contained in `e`.
|
||
-/
|
||
private partial def removeCandidatesContainedIn (e : Expr) : FindM Unit := do
|
||
forFVarM eraseCandidate e
|
||
|
||
/--
|
||
Add a new join point candidate to the state.
|
||
-/
|
||
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
|
||
let cinfo := { arity, associated := .empty }
|
||
modifyCandidates (fun cs => cs.insert fvarId cinfo )
|
||
|
||
/--
|
||
Add a new join point dependency from `src` to `dst`.
|
||
-/
|
||
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
|
||
if let some targetInfo ← findCandidate? target then
|
||
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
|
||
else
|
||
eraseCandidate src
|
||
|
||
/--
|
||
Find all `fun` declarations that qualify as a join point, that is:
|
||
- are always fully applied
|
||
- are always called in tail position
|
||
|
||
Where a `fun` `f` is in tail position iff it is called as follows:
|
||
```
|
||
let res := f arg
|
||
res
|
||
```
|
||
The majority (if not all) tail calls will be brought into this form
|
||
by the simplifier pass.
|
||
|
||
Furthermore a `fun` disqualifies as a join point if turning it into a join
|
||
point would turn a call to it into an out of scope join point.
|
||
This can happen if we have something like:
|
||
```
|
||
def test (b : Bool) (x y : Nat) : Nat :=
|
||
fun myjp x => Nat.add x (Nat.add x x)
|
||
fun f y =>
|
||
let x := Nat.add y y
|
||
myjp x
|
||
fun f y =>
|
||
let x := Nat.mul y y
|
||
myjp x
|
||
cases b (f x) (g y)
|
||
```
|
||
`f` and `g` can be detected as a join point right away, however
|
||
`myjp` can only ever be detected as a join point after we have established
|
||
this. This is because otherwise the calls to `myjp` in `f` and `g` would
|
||
produce out of scope join point jumps.
|
||
-/
|
||
partial def find (decl : Decl) : CompilerM FindState := do
|
||
let (_, candidates) ← go decl.value |>.run none |>.run {} |>.run' {}
|
||
return candidates
|
||
where
|
||
go : Code → FindM Unit
|
||
| .let decl k => do
|
||
match k, decl.value, decl.value.getAppFn with
|
||
| .return valId, .app .., .fvar fvarId =>
|
||
decl.value.getAppArgs.forM removeCandidatesContainedIn
|
||
if let some candidateInfo ← findCandidate? fvarId then
|
||
-- Erase candidate that are not fully applied or applied outside of tail position
|
||
if valId != decl.fvarId || decl.value.getAppNumArgs != candidateInfo.arity then
|
||
eraseCandidate fvarId
|
||
-- Out of scope join point candidate handling
|
||
else if let some upperCandidate ← read then
|
||
if !(← isInScope fvarId) then
|
||
addDependency fvarId upperCandidate
|
||
else
|
||
eraseCandidate fvarId
|
||
| _, _, _ =>
|
||
removeCandidatesContainedIn decl.value
|
||
go k
|
||
| .fun decl k => do
|
||
withReader (fun _ => some decl.fvarId) do
|
||
withNewScope do
|
||
go decl.value
|
||
addCandidate decl.fvarId decl.getArity
|
||
addToScope decl.fvarId
|
||
go k
|
||
| .jp decl k => do
|
||
go decl.value
|
||
go k
|
||
| .jmp _ args => args.forM removeCandidatesContainedIn
|
||
| .return val => eraseCandidate val
|
||
| .cases c => do
|
||
eraseCandidate c.discr
|
||
c.alts.forM (·.forCodeM go)
|
||
| .unreach .. => return ()
|
||
|
||
/--
|
||
Replace all join point candidate `fun` declarations with `jp` ones
|
||
and all calls to them with `jmp`s.
|
||
-/
|
||
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
|
||
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
|
||
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper
|
||
let newValue ← go decl.value |>.run replaceCtx
|
||
return { decl with value := newValue }
|
||
where
|
||
go (code : Code) : ReplaceM Code := do
|
||
match code with
|
||
| .let decl k =>
|
||
match k, decl.value, decl.value.getAppFn with
|
||
| .return valId, .app .., (.fvar fvarId) =>
|
||
if valId == decl.fvarId then
|
||
if (← read).contains fvarId then
|
||
eraseLetDecl decl
|
||
return .jmp fvarId decl.value.getAppArgs
|
||
else
|
||
return code
|
||
else
|
||
return code
|
||
| _, _, _ => return Code.updateLet! code decl (← go k)
|
||
| .fun decl k =>
|
||
if let some replacement := (← read).find? decl.fvarId then
|
||
let newDecl := { decl with
|
||
binderName := replacement,
|
||
value := (← go decl.value)
|
||
}
|
||
modifyLCtx fun lctx => lctx.addFunDecl newDecl
|
||
return .jp newDecl (← go k)
|
||
else
|
||
let newDecl ← decl.updateValue (← go decl.value)
|
||
return Code.updateFun! code newDecl (← go k)
|
||
| .jp decl k =>
|
||
let newDecl ← decl.updateValue (← go decl.value)
|
||
return Code.updateFun! code newDecl (← go k)
|
||
| .cases cs =>
|
||
return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go))
|
||
| .jmp .. | .return .. | .unreach .. =>
|
||
return code
|
||
|
||
end JoinPointFinder
|
||
|
||
namespace JoinPointContextExtender
|
||
|
||
open ScopeM
|
||
|
||
/--
|
||
The context managed by `ExtendM`.
|
||
-/
|
||
structure ExtendContext where
|
||
/--
|
||
The `FVarId` of the current join point if we are currently inside one.
|
||
-/
|
||
currentJp? : Option FVarId := none
|
||
/--
|
||
The list of valid candidates for extending the context. This will be
|
||
all `let` and `fun` declarations as well as all `jp` parameters up
|
||
until the last `fun` declaration in the tree.
|
||
-/
|
||
candidates : FVarIdSet := {}
|
||
|
||
/--
|
||
The state managed by `ExtendM`.
|
||
-/
|
||
structure ExtendState where
|
||
/--
|
||
A map from join point `FVarId`s to a respective map from free variables
|
||
to `Param`s. The free variables in this map are the once that the context
|
||
of said join point will be extended by by passing in the respective parameter.
|
||
-/
|
||
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
|
||
|
||
/--
|
||
The monad for the `extendJoinPointContext` pass.
|
||
-/
|
||
abbrev ExtendM := ReaderT ExtendContext StateRefT ExtendState ScopeM
|
||
|
||
/--
|
||
Replace a free variable if necessary, that is:
|
||
- It is in the list of candidates
|
||
- We are currently within a join point (if we are within a function there
|
||
cannot be a need to replace them since we dont extend their context)
|
||
- Said join point actually has a replacement parameter registered.
|
||
otherwise just return `fvar`.
|
||
-/
|
||
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
|
||
if (← read).candidates.contains fvar then
|
||
if let some currentJp := (← read).currentJp? then
|
||
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
|
||
return replacement.fvarId
|
||
return fvar
|
||
|
||
/--
|
||
Add a new candidate to the current scope + to the list of candidates
|
||
if we are currently within a join point. Then execute `x`.
|
||
-/
|
||
def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
|
||
addToScope fvar
|
||
if (← read).currentJp?.isSome then
|
||
withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do
|
||
x
|
||
else
|
||
x
|
||
|
||
/--
|
||
Same as `withNewCandidate` but with multiple `FVarId`s.
|
||
-/
|
||
def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do
|
||
if (← read).currentJp?.isSome then
|
||
let candidates := (← read).candidates
|
||
let folder (acc : FVarIdSet) (val : FVarId) := do
|
||
addToScope val
|
||
return acc.insert val
|
||
let newCandidates ← fvars.foldlM (init := candidates) folder
|
||
withReader (fun ctx => { ctx with candidates := newCandidates }) do
|
||
x
|
||
else
|
||
x
|
||
|
||
/--
|
||
Extend the context of the current join point (if we are within one)
|
||
by `fvar` if necessary.
|
||
This is necessary if:
|
||
- `fvar` is not in scope (that is, was declared outside of the current jp)
|
||
- we have not already extended the context by `fvar`
|
||
- the list of candidates contains `fvar`. This is because if we have something
|
||
like:
|
||
```
|
||
let x := ..
|
||
fun f a =>
|
||
jp j b =>
|
||
let y := x
|
||
y
|
||
```
|
||
There is no point in extending the context of `j` by `x` because we
|
||
cannot lift a join point outside of a local function declaration.
|
||
-/
|
||
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
|
||
if let some currentJp := (← read).currentJp? then
|
||
let mut translator := (← get).fvarMap.find! currentJp
|
||
let candidates := (← read).candidates
|
||
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
|
||
let typ ← getType fvar
|
||
let newParam ← mkAuxParam typ
|
||
translator := translator.insert fvar newParam
|
||
modify fun s => { s with fvarMap := s.fvarMap.insert currentJp translator }
|
||
|
||
/--
|
||
Merge the extended context of two join points if necessary. That is
|
||
if we have a structure such as:
|
||
```
|
||
jp j.1 ... =>
|
||
jp j.2 .. =>
|
||
...
|
||
...
|
||
```
|
||
And we are just done visiting `j.2` we want to extend the context of
|
||
`j.1` by all free variables that the context of `j.2` was extended by
|
||
as well because we need to drag these variables through at the call sites
|
||
of `j.2` in `j.1`.
|
||
-/
|
||
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
|
||
if (← read).currentJp?.isSome then
|
||
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
|
||
for (fvar, _) in additionalArgs do
|
||
extendByIfNecessary fvar
|
||
|
||
/--
|
||
We call this whenever we enter a new local function. It clears both the
|
||
current join point and the list of candidates since we cant lift join
|
||
points outside of functions as explained in `mergeJpContextIfNecessary`.
|
||
-/
|
||
def withNewFunScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
|
||
withReader (fun ctx => { ctx with currentJp? := none, candidates := {} }) do
|
||
withNewScope do
|
||
x
|
||
|
||
/--
|
||
We call this whenever we enter a new join point. It will set the current
|
||
join point and extend the list of candidates by all of the parameters of
|
||
the join point. This is so in the case of nested join points that refer
|
||
to parameters of the current one we extend the context of the nested
|
||
join points by said parameters.
|
||
-/
|
||
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
|
||
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
|
||
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
|
||
withNewScope do
|
||
withNewCandidates (decl.params.map (·.fvarId)) do
|
||
x
|
||
|
||
/--
|
||
We call this whenever we visit a new arm of a cases statement.
|
||
It will back up the current scope (since we are doing a case split
|
||
and want to continue with other arms afterwards) and add all of the
|
||
parameters of the match arm to the list of candidates.
|
||
-/
|
||
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
|
||
withBackTrackingScope do
|
||
withNewCandidates (alt.getParams.map (·.fvarId)) do
|
||
x
|
||
|
||
/--
|
||
Use all of the above functions to find free variables declared outside
|
||
of join points that said join points can be reasonaly extended by. Reasonable
|
||
meaning that in case the current join point is nested within a function
|
||
declaration we will not extend it by free variables declared before the
|
||
function declaration because we cannot lift join points outside of function
|
||
declarations.
|
||
|
||
All of this is done to eliminate dependencies of join points onto their
|
||
position within the code so we can pull them out as far as possible, hopefully
|
||
enabling new inlining possibilities in the next simplifier run.
|
||
-/
|
||
partial def extend (decl : Decl) : CompilerM Decl := do
|
||
let newValue ← go decl.value |>.run {} |>.run' {} |>.run' {}
|
||
let decl := { decl with value := newValue }
|
||
decl.pullFunDecls
|
||
where
|
||
goExpr (e : Expr) : ExtendM Expr :=
|
||
let visitor := fun fvar => do
|
||
extendByIfNecessary fvar
|
||
replaceFVar fvar
|
||
mapFVarM visitor e
|
||
go (code : Code) : ExtendM Code := do
|
||
match code with
|
||
| .let decl k =>
|
||
let decl ← decl.updateValue (← goExpr decl.value)
|
||
withNewCandidate decl.fvarId do
|
||
return Code.updateLet! code decl (← go k)
|
||
| .jp decl k =>
|
||
let decl ← withNewJpScope decl do
|
||
let value ← go decl.value
|
||
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
|
||
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
|
||
decl.update newType (additionalParams ++ decl.params) value
|
||
mergeJpContextIfNecessary decl.fvarId
|
||
withNewCandidate decl.fvarId do
|
||
return Code.updateFun! code decl (← go k)
|
||
| .fun decl k =>
|
||
let decl ← withNewFunScope decl do
|
||
decl.updateValue (← go decl.value)
|
||
withNewCandidate decl.fvarId do
|
||
return Code.updateFun! code decl (← go k)
|
||
| .cases cs =>
|
||
extendByIfNecessary cs.discr
|
||
let discr ← replaceFVar cs.discr
|
||
let visitor := fun alt => do
|
||
withNewAltScope alt do
|
||
alt.mapCodeM go
|
||
let alts ← cs.alts.mapM visitor
|
||
return Code.updateCases! code cs.resultType discr alts
|
||
| .jmp fn args =>
|
||
let mut newArgs ← args.mapM goExpr
|
||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||
if let some currentJp := (← read).currentJp? then
|
||
let translator := (← get).fvarMap.find! currentJp
|
||
let f := fun arg =>
|
||
if let some translated := translator.find? arg then
|
||
.fvar translated.fvarId
|
||
else
|
||
.fvar arg
|
||
newArgs := (additionalArgs.map f) ++ newArgs
|
||
else
|
||
newArgs := (additionalArgs.map .fvar) ++ newArgs
|
||
return Code.updateJmp! code fn newArgs
|
||
| .return var =>
|
||
extendByIfNecessary var
|
||
return Code.updateReturn! code (← replaceFVar var)
|
||
| .unreach .. => return code
|
||
|
||
end JoinPointContextExtender
|
||
|
||
/--
|
||
Find all `fun` declarations in `decl` that qualify as join points then replace
|
||
their definitions and call sites with `jp`/`jmp`.
|
||
-/
|
||
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
|
||
let findResult ← JoinPointFinder.find decl
|
||
trace[Compiler.findJoinPoints] "Found: {findResult.candidates.size} jp candidates"
|
||
JoinPointFinder.replace decl findResult
|
||
|
||
def findJoinPoints : Pass :=
|
||
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base
|
||
|
||
builtin_initialize
|
||
registerTraceClass `Compiler.findJoinPoints (inherited := true)
|
||
|
||
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
|
||
JoinPointContextExtender.extend decl
|
||
|
||
def extendJoinPointContext : Pass :=
|
||
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext .mono
|
||
|
||
builtin_initialize
|
||
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
|
||
|
||
end Lean.Compiler.LCNF
|