chore: port join point optimizations to LetExpr

This commit is contained in:
Henrik Böving 2022-10-31 20:43:08 +01:00 committed by Leonardo de Moura
parent 695f972ff2
commit c5a99bda2b

View file

@ -10,9 +10,6 @@ import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.ScopeM
import Lean.Compiler.LCNF.InferType
set_option warningAsError false
#exit
namespace Lean.Compiler.LCNF
namespace JoinPointFinder
@ -75,9 +72,15 @@ private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarI
modify (fun state => {state with candidates := f state.candidates })
/--
Remove all join point candidates contained in `e`.
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesContainedIn (e : Expr) : FindM Unit := do
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
forFVarM eraseCandidate a
/--
Remove all join point candidates contained in `a`.
-/
private partial def removeCandidatesInLetExpr (e : LetExpr) : FindM Unit := do
forFVarM eraseCandidate e
/--
@ -134,12 +137,12 @@ partial def find (decl : Decl) : CompilerM FindState := do
where
go : Code → FindM Unit
| .let decl k => do
match k, decl.value, decl.value.getAppFn with
| .return valId, .app .., .fvar fvarId =>
decl.value.getAppArgs.forM removeCandidatesContainedIn
match k, decl.value with
| .return valId, .fvar fvarId args =>
args.forM removeCandidatesInArg
if let some candidateInfo ← findCandidate? fvarId then
-- Erase candidate that are not fully applied or applied outside of tail position
if valId != decl.fvarId || decl.value.getAppNumArgs != candidateInfo.arity then
if valId != decl.fvarId || args.size != candidateInfo.arity then
eraseCandidate fvarId
-- Out of scope join point candidate handling
else if let some upperCandidate ← read then
@ -147,8 +150,8 @@ where
addDependency fvarId upperCandidate
else
eraseCandidate fvarId
| _, _, _ =>
removeCandidatesContainedIn decl.value
| _, _ =>
removeCandidatesInLetExpr decl.value
go k
| .fun decl k => do
withReader (fun _ => some decl.fvarId) do
@ -160,7 +163,7 @@ where
| .jp decl k => do
go decl.value
go k
| .jmp _ args => args.forM removeCandidatesContainedIn
| .jmp _ args => args.forM removeCandidatesInArg
| .return val => eraseCandidate val
| .cases c => do
eraseCandidate c.discr
@ -180,17 +183,17 @@ where
go (code : Code) : ReplaceM Code := do
match code with
| .let decl k =>
match k, decl.value, decl.value.getAppFn with
| .return valId, .app .., (.fvar fvarId) =>
match k, decl.value with
| .return valId, .fvar fvarId args =>
if valId == decl.fvarId then
if (← read).contains fvarId then
eraseLetDecl decl
return .jmp fvarId decl.value.getAppArgs
return .jmp fvarId args
else
return code
else
return code
| _, _, _ => return Code.updateLet! code decl (← go k)
| _, _ => return Code.updateLet! code decl (← go k)
| .fun decl k =>
if let some replacement := (← read).find? decl.fvarId then
let newDecl := { decl with
@ -392,12 +395,10 @@ where
goFVar (fvar : FVarId) : ExtendM FVarId := do
extendByIfNecessary fvar
replaceFVar fvar
goExpr (e : Expr) : ExtendM Expr :=
mapFVarM goFVar e
go (code : Code) : ExtendM Code := do
match code with
| .let decl k =>
let decl ← decl.updateValue (← goExpr decl.value)
let decl ← decl.updateValue (← mapFVarM goFVar decl.value)
withNewCandidate decl.fvarId do
return Code.updateLet! code decl (← go k)
| .jp decl k =>
@ -423,7 +424,7 @@ where
let alts ← cs.alts.mapM visitor
return Code.updateCases! code cs.resultType discr alts
| .jmp fn args =>
let mut newArgs ← args.mapM goExpr
let mut newArgs ← args.mapM (mapFVarM goFVar)
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
if let some _currentJp := (← read).currentJp? then
let f := fun arg => do
@ -456,7 +457,8 @@ State for `ReduceAnalysisM`.
-/
structure AnalysisState where
/--
Lists of names of arguments of jmps to join points to find duplicates.
A map, that for each join point id contains a map from all (so far)
duplicated argument ids to the respective duplicate value
-/
jpJmpArgs : FVarIdMap FVarSubst := {}
@ -543,13 +545,13 @@ where
let mut newArgs := knownArgs
for (param, arg) in decl.params.zip args do
if let some knownVal := newArgs.find? param.fvarId then
if arg != knownVal then
if arg.toExpr != knownVal then
newArgs := newArgs.erase param.fvarId
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
else
let folder := fun acc (param, arg) => do
if (← allFVarM (isInJpScope fn) arg) then
return acc.insert param.fvarId arg
return acc.insert param.fvarId arg.toExpr
else
return acc
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder