lean4-htt/src/Lean/Compiler/LCNF/JoinPoints.lean
2022-10-13 18:56:17 -07:00

466 lines
16 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2022 Henrik Böving. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
import Lean.Compiler.LCNF.CompilerM
import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PullFunDecls
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.ScopeM
namespace Lean.Compiler.LCNF
namespace JoinPointFinder
open ScopeM
/--
Info about a join point candidate (a `fun` declaration) during the find phase.
-/
structure CandidateInfo where
/--
The arity of the candidate
-/
arity : Nat
/--
The set of candidates that rely on this candidate to be a join point.
For a more detailed explanation see the documentation of `find`
-/
associated : HashSet FVarId
deriving Inhabited
/--
The state for the join point candidate finder.
-/
structure FindState where
/--
All current join point candidates accessible by their `FVarId`.
-/
candidates : HashMap FVarId CandidateInfo := .empty
/--
The `FVarId`s of all `fun` declarations that were declared within the
current `fun`.
-/
scope : HashSet FVarId := .empty
abbrev ReplaceCtx := HashMap FVarId Name
abbrev FindM := ReaderT (Option FVarId) StateRefT FindState ScopeM
abbrev ReplaceM := ReaderT ReplaceCtx CompilerM
/--
Attempt to find a join point candidate by its `FVarId`.
-/
private def findCandidate? (fvarId : FVarId) : FindM (Option CandidateInfo) := do
return (← get).candidates.find? fvarId
/--
Erase a join point candidate as well as all the ones that depend on it
by its `FVarId`, no error is thrown is the candidate does not exist.
-/
private partial def eraseCandidate (fvarId : FVarId) : FindM Unit := do
if let some info ← findCandidate? fvarId then
modify (fun state => { state with candidates := state.candidates.erase fvarId })
info.associated.forM eraseCandidate
/--
Combinator for modifying the candidates in `FindM`.
-/
private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarId CandidateInfo) : FindM Unit :=
modify (fun state => {state with candidates := f state.candidates })
/--
Remove all join point candidates contained in `e`.
-/
private partial def removeCandidatesContainedIn (e : Expr) : FindM Unit := do
forFVarM eraseCandidate e
/--
Add a new join point candidate to the state.
-/
private def addCandidate (fvarId : FVarId) (arity : Nat) : FindM Unit := do
let cinfo := { arity, associated := .empty }
modifyCandidates (fun cs => cs.insert fvarId cinfo )
/--
Add a new join point dependency from `src` to `dst`.
-/
private def addDependency (src : FVarId) (target : FVarId) : FindM Unit := do
if let some targetInfo ← findCandidate? target then
modifyCandidates (fun cs => cs.insert target { targetInfo with associated := targetInfo.associated.insert src })
else
eraseCandidate src
/--
Find all `fun` declarations that qualify as a join point, that is:
- are always fully applied
- are always called in tail position
Where a `fun` `f` is in tail position iff it is called as follows:
```
let res := f arg
res
```
The majority (if not all) tail calls will be brought into this form
by the simplifier pass.
Furthermore a `fun` disqualifies as a join point if turning it into a join
point would turn a call to it into an out of scope join point.
This can happen if we have something like:
```
def test (b : Bool) (x y : Nat) : Nat :=
fun myjp x => Nat.add x (Nat.add x x)
fun f y =>
let x := Nat.add y y
myjp x
fun f y =>
let x := Nat.mul y y
myjp x
cases b (f x) (g y)
```
`f` and `g` can be detected as a join point right away, however
`myjp` can only ever be detected as a join point after we have established
this. This is because otherwise the calls to `myjp` in `f` and `g` would
produce out of scope join point jumps.
-/
partial def find (decl : Decl) : CompilerM FindState := do
let (_, candidates) ← go decl.value |>.run none |>.run {} |>.run' {}
return candidates
where
go : Code → FindM Unit
| .let decl k => do
match k, decl.value, decl.value.getAppFn with
| .return valId, .app .., .fvar fvarId =>
decl.value.getAppArgs.forM removeCandidatesContainedIn
if let some candidateInfo ← findCandidate? fvarId then
-- Erase candidate that are not fully applied or applied outside of tail position
if valId != decl.fvarId || decl.value.getAppNumArgs != candidateInfo.arity then
eraseCandidate fvarId
-- Out of scope join point candidate handling
else if let some upperCandidate ← read then
if !(← isInScope fvarId) then
addDependency fvarId upperCandidate
else
eraseCandidate fvarId
| _, _, _ =>
removeCandidatesContainedIn decl.value
go k
| .fun decl k => do
withReader (fun _ => some decl.fvarId) do
withNewScope do
go decl.value
addCandidate decl.fvarId decl.getArity
addToScope decl.fvarId
go k
| .jp decl k => do
go decl.value
go k
| .jmp _ args => args.forM removeCandidatesContainedIn
| .return val => eraseCandidate val
| .cases c => do
eraseCandidate c.discr
c.alts.forM (·.forCodeM go)
| .unreach .. => return ()
/--
Replace all join point candidate `fun` declarations with `jp` ones
and all calls to them with `jmp`s.
-/
partial def replace (decl : Decl) (state : FindState) : CompilerM Decl := do
let mapper := fun acc cname _ => do return acc.insert cname (← mkFreshJpName)
let replaceCtx : ReplaceCtx ← state.candidates.foldM (init := .empty) mapper
let newValue ← go decl.value |>.run replaceCtx
return { decl with value := newValue }
where
go (code : Code) : ReplaceM Code := do
match code with
| .let decl k =>
match k, decl.value, decl.value.getAppFn with
| .return valId, .app .., (.fvar fvarId) =>
if valId == decl.fvarId then
if (← read).contains fvarId then
eraseLetDecl decl
return .jmp fvarId decl.value.getAppArgs
else
return code
else
return code
| _, _, _ => return Code.updateLet! code decl (← go k)
| .fun decl k =>
if let some replacement := (← read).find? decl.fvarId then
let newDecl := { decl with
binderName := replacement,
value := (← go decl.value)
}
modifyLCtx fun lctx => lctx.addFunDecl newDecl
return .jp newDecl (← go k)
else
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .jp decl k =>
let newDecl ← decl.updateValue (← go decl.value)
return Code.updateFun! code newDecl (← go k)
| .cases cs =>
return Code.updateCases! code cs.resultType cs.discr (← cs.alts.mapM (·.mapCodeM go))
| .jmp .. | .return .. | .unreach .. =>
return code
end JoinPointFinder
namespace JoinPointContextExtender
open ScopeM
/--
The context managed by `ExtendM`.
-/
structure ExtendContext where
/--
The `FVarId` of the current join point if we are currently inside one.
-/
currentJp? : Option FVarId := none
/--
The list of valid candidates for extending the context. This will be
all `let` and `fun` declarations as well as all `jp` parameters up
until the last `fun` declaration in the tree.
-/
candidates : FVarIdSet := {}
/--
The state managed by `ExtendM`.
-/
structure ExtendState where
/--
A map from join point `FVarId`s to a respective map from free variables
to `Param`s. The free variables in this map are the once that the context
of said join point will be extended by by passing in the respective parameter.
-/
fvarMap : HashMap FVarId (HashMap FVarId Param) := {}
/--
The monad for the `extendJoinPointContext` pass.
-/
abbrev ExtendM := ReaderT ExtendContext StateRefT ExtendState ScopeM
/--
Replace a free variable if necessary, that is:
- It is in the list of candidates
- We are currently within a join point (if we are within a function there
cannot be a need to replace them since we dont extend their context)
- Said join point actually has a replacement parameter registered.
otherwise just return `fvar`.
-/
def replaceFVar (fvar : FVarId) : ExtendM FVarId := do
if (← read).candidates.contains fvar then
if let some currentJp := (← read).currentJp? then
if let some replacement := (← get).fvarMap.find! currentJp |>.find? fvar then
return replacement.fvarId
return fvar
/--
Add a new candidate to the current scope + to the list of candidates
if we are currently within a join point. Then execute `x`.
-/
def withNewCandidate (fvar : FVarId) (x : ExtendM α) : ExtendM α := do
addToScope fvar
if (← read).currentJp?.isSome then
withReader (fun ctx => { ctx with candidates := ctx.candidates.insert fvar }) do
x
else
x
/--
Same as `withNewCandidate` but with multiple `FVarId`s.
-/
def withNewCandidates (fvars : Array FVarId) (x : ExtendM α) : ExtendM α := do
if (← read).currentJp?.isSome then
let candidates := (← read).candidates
let folder (acc : FVarIdSet) (val : FVarId) := do
addToScope val
return acc.insert val
let newCandidates ← fvars.foldlM (init := candidates) folder
withReader (fun ctx => { ctx with candidates := newCandidates }) do
x
else
x
/--
Extend the context of the current join point (if we are within one)
by `fvar` if necessary.
This is necessary if:
- `fvar` is not in scope (that is, was declared outside of the current jp)
- we have not already extended the context by `fvar`
- the list of candidates contains `fvar`. This is because if we have something
like:
```
let x := ..
fun f a =>
jp j b =>
let y := x
y
```
There is no point in extending the context of `j` by `x` because we
cannot lift a join point outside of a local function declaration.
-/
def extendByIfNecessary (fvar : FVarId) : ExtendM Unit := do
if let some currentJp := (← read).currentJp? then
let mut translator := (← get).fvarMap.find! currentJp
let candidates := (← read).candidates
if !(← isInScope fvar) && !translator.contains fvar && candidates.contains fvar then
let typ ← getType fvar
let newParam ← mkAuxParam typ
translator := translator.insert fvar newParam
modify fun s => { s with fvarMap := s.fvarMap.insert currentJp translator }
/--
Merge the extended context of two join points if necessary. That is
if we have a structure such as:
```
jp j.1 ... =>
jp j.2 .. =>
...
...
```
And we are just done visiting `j.2` we want to extend the context of
`j.1` by all free variables that the context of `j.2` was extended by
as well because we need to drag these variables through at the call sites
of `j.2` in `j.1`.
-/
def mergeJpContextIfNecessary (jp : FVarId) : ExtendM Unit := do
if (← read).currentJp?.isSome then
let additionalArgs := (← get).fvarMap.find! jp |>.toArray
for (fvar, _) in additionalArgs do
extendByIfNecessary fvar
/--
We call this whenever we enter a new local function. It clears both the
current join point and the list of candidates since we cant lift join
points outside of functions as explained in `mergeJpContextIfNecessary`.
-/
def withNewFunScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
withReader (fun ctx => { ctx with currentJp? := none, candidates := {} }) do
withNewScope do
x
/--
We call this whenever we enter a new join point. It will set the current
join point and extend the list of candidates by all of the parameters of
the join point. This is so in the case of nested join points that refer
to parameters of the current one we extend the context of the nested
join points by said parameters.
-/
def withNewJpScope (decl : FunDecl) (x : ExtendM α): ExtendM α := do
withReader (fun ctx => { ctx with currentJp? := some decl.fvarId }) do
modify fun s => { s with fvarMap := s.fvarMap.insert decl.fvarId {} }
withNewScope do
withNewCandidates (decl.params.map (·.fvarId)) do
x
/--
We call this whenever we visit a new arm of a cases statement.
It will back up the current scope (since we are doing a case split
and want to continue with other arms afterwards) and add all of the
parameters of the match arm to the list of candidates.
-/
def withNewAltScope (alt : Alt) (x : ExtendM α) : ExtendM α := do
withBackTrackingScope do
withNewCandidates (alt.getParams.map (·.fvarId)) do
x
/--
Use all of the above functions to find free variables declared outside
of join points that said join points can be reasonaly extended by. Reasonable
meaning that in case the current join point is nested within a function
declaration we will not extend it by free variables declared before the
function declaration because we cannot lift join points outside of function
declarations.
All of this is done to eliminate dependencies of join points onto their
position within the code so we can pull them out as far as possible, hopefully
enabling new inlining possibilities in the next simplifier run.
-/
partial def extend (decl : Decl) : CompilerM Decl := do
let newValue ← go decl.value |>.run {} |>.run' {} |>.run' {}
let decl := { decl with value := newValue }
decl.pullFunDecls
where
goExpr (e : Expr) : ExtendM Expr :=
let visitor := fun fvar => do
extendByIfNecessary fvar
replaceFVar fvar
mapFVarM visitor e
go (code : Code) : ExtendM Code := do
match code with
| .let decl k =>
let decl ← decl.updateValue (← goExpr decl.value)
withNewCandidate decl.fvarId do
return Code.updateLet! code decl (← go k)
| .jp decl k =>
let decl ← withNewJpScope decl do
let value ← go decl.value
let additionalParams := (← get).fvarMap.find! decl.fvarId |>.toArray |>.map Prod.snd
let newType := additionalParams.foldr (init := decl.type) (fun val acc => .forallE val.binderName val.type acc .default)
decl.update newType (additionalParams ++ decl.params) value
mergeJpContextIfNecessary decl.fvarId
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (← go k)
| .fun decl k =>
let decl ← withNewFunScope decl do
decl.updateValue (← go decl.value)
withNewCandidate decl.fvarId do
return Code.updateFun! code decl (← go k)
| .cases cs =>
extendByIfNecessary cs.discr
let discr ← replaceFVar cs.discr
let visitor := fun alt => do
withNewAltScope alt do
alt.mapCodeM go
let alts ← cs.alts.mapM visitor
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs ← args.mapM goExpr
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some currentJp := (← read).currentJp? then
let translator := (← get).fvarMap.find! currentJp
let f := fun arg =>
if let some translated := translator.find? arg then
.fvar translated.fvarId
else
.fvar arg
newArgs := (additionalArgs.map f) ++ newArgs
else
newArgs := (additionalArgs.map .fvar) ++ newArgs
return Code.updateJmp! code fn newArgs
| .return var =>
extendByIfNecessary var
return Code.updateReturn! code (← replaceFVar var)
| .unreach .. => return code
end JoinPointContextExtender
/--
Find all `fun` declarations in `decl` that qualify as join points then replace
their definitions and call sites with `jp`/`jmp`.
-/
def Decl.findJoinPoints (decl : Decl) : CompilerM Decl := do
let findResult ← JoinPointFinder.find decl
trace[Compiler.findJoinPoints] "Found: {findResult.candidates.size} jp candidates"
JoinPointFinder.replace decl findResult
def findJoinPoints : Pass :=
.mkPerDeclaration `findJoinPoints Decl.findJoinPoints .base
builtin_initialize
registerTraceClass `Compiler.findJoinPoints (inherited := true)
def Decl.extendJoinPointContext (decl : Decl) : CompilerM Decl := do
JoinPointContextExtender.extend decl
def extendJoinPointContext : Pass :=
.mkPerDeclaration `extendJoinPointContext Decl.extendJoinPointContext .mono
builtin_initialize
registerTraceClass `Compiler.extendJoinPointContext (inherited := true)
end Lean.Compiler.LCNF