chore: port join point optimizations to LetExpr
This commit is contained in:
parent
695f972ff2
commit
c5a99bda2b
1 changed files with 25 additions and 23 deletions
|
|
@ -10,9 +10,6 @@ import Lean.Compiler.LCNF.FVarUtil
|
|||
import Lean.Compiler.LCNF.ScopeM
|
||||
import Lean.Compiler.LCNF.InferType
|
||||
|
||||
set_option warningAsError false
|
||||
#exit
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
namespace JoinPointFinder
|
||||
|
|
@ -75,9 +72,15 @@ private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarI
|
|||
modify (fun state => {state with candidates := f state.candidates })
|
||||
|
||||
/--
|
||||
Remove all join point candidates contained in `e`.
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesContainedIn (e : Expr) : FindM Unit := do
|
||||
private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do
|
||||
forFVarM eraseCandidate a
|
||||
|
||||
/--
|
||||
Remove all join point candidates contained in `a`.
|
||||
-/
|
||||
private partial def removeCandidatesInLetExpr (e : LetExpr) : FindM Unit := do
|
||||
forFVarM eraseCandidate e
|
||||
|
||||
/--
|
||||
|
|
@ -134,12 +137,12 @@ partial def find (decl : Decl) : CompilerM FindState := do
|
|||
where
|
||||
go : Code → FindM Unit
|
||||
| .let decl k => do
|
||||
match k, decl.value, decl.value.getAppFn with
|
||||
| .return valId, .app .., .fvar fvarId =>
|
||||
decl.value.getAppArgs.forM removeCandidatesContainedIn
|
||||
match k, decl.value with
|
||||
| .return valId, .fvar fvarId args =>
|
||||
args.forM removeCandidatesInArg
|
||||
if let some candidateInfo ← findCandidate? fvarId then
|
||||
-- Erase candidate that are not fully applied or applied outside of tail position
|
||||
if valId != decl.fvarId || decl.value.getAppNumArgs != candidateInfo.arity then
|
||||
if valId != decl.fvarId || args.size != candidateInfo.arity then
|
||||
eraseCandidate fvarId
|
||||
-- Out of scope join point candidate handling
|
||||
else if let some upperCandidate ← read then
|
||||
|
|
@ -147,8 +150,8 @@ where
|
|||
addDependency fvarId upperCandidate
|
||||
else
|
||||
eraseCandidate fvarId
|
||||
| _, _, _ =>
|
||||
removeCandidatesContainedIn decl.value
|
||||
| _, _ =>
|
||||
removeCandidatesInLetExpr decl.value
|
||||
go k
|
||||
| .fun decl k => do
|
||||
withReader (fun _ => some decl.fvarId) do
|
||||
|
|
@ -160,7 +163,7 @@ where
|
|||
| .jp decl k => do
|
||||
go decl.value
|
||||
go k
|
||||
| .jmp _ args => args.forM removeCandidatesContainedIn
|
||||
| .jmp _ args => args.forM removeCandidatesInArg
|
||||
| .return val => eraseCandidate val
|
||||
| .cases c => do
|
||||
eraseCandidate c.discr
|
||||
|
|
@ -180,17 +183,17 @@ where
|
|||
go (code : Code) : ReplaceM Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
match k, decl.value, decl.value.getAppFn with
|
||||
| .return valId, .app .., (.fvar fvarId) =>
|
||||
match k, decl.value with
|
||||
| .return valId, .fvar fvarId args =>
|
||||
if valId == decl.fvarId then
|
||||
if (← read).contains fvarId then
|
||||
eraseLetDecl decl
|
||||
return .jmp fvarId decl.value.getAppArgs
|
||||
return .jmp fvarId args
|
||||
else
|
||||
return code
|
||||
else
|
||||
return code
|
||||
| _, _, _ => return Code.updateLet! code decl (← go k)
|
||||
| _, _ => return Code.updateLet! code decl (← go k)
|
||||
| .fun decl k =>
|
||||
if let some replacement := (← read).find? decl.fvarId then
|
||||
let newDecl := { decl with
|
||||
|
|
@ -392,12 +395,10 @@ where
|
|||
goFVar (fvar : FVarId) : ExtendM FVarId := do
|
||||
extendByIfNecessary fvar
|
||||
replaceFVar fvar
|
||||
goExpr (e : Expr) : ExtendM Expr :=
|
||||
mapFVarM goFVar e
|
||||
go (code : Code) : ExtendM Code := do
|
||||
match code with
|
||||
| .let decl k =>
|
||||
let decl ← decl.updateValue (← goExpr decl.value)
|
||||
let decl ← decl.updateValue (← mapFVarM goFVar decl.value)
|
||||
withNewCandidate decl.fvarId do
|
||||
return Code.updateLet! code decl (← go k)
|
||||
| .jp decl k =>
|
||||
|
|
@ -423,7 +424,7 @@ where
|
|||
let alts ← cs.alts.mapM visitor
|
||||
return Code.updateCases! code cs.resultType discr alts
|
||||
| .jmp fn args =>
|
||||
let mut newArgs ← args.mapM goExpr
|
||||
let mut newArgs ← args.mapM (mapFVarM goFVar)
|
||||
let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst
|
||||
if let some _currentJp := (← read).currentJp? then
|
||||
let f := fun arg => do
|
||||
|
|
@ -456,7 +457,8 @@ State for `ReduceAnalysisM`.
|
|||
-/
|
||||
structure AnalysisState where
|
||||
/--
|
||||
Lists of names of arguments of jmps to join points to find duplicates.
|
||||
A map, that for each join point id contains a map from all (so far)
|
||||
duplicated argument ids to the respective duplicate value
|
||||
-/
|
||||
jpJmpArgs : FVarIdMap FVarSubst := {}
|
||||
|
||||
|
|
@ -543,13 +545,13 @@ where
|
|||
let mut newArgs := knownArgs
|
||||
for (param, arg) in decl.params.zip args do
|
||||
if let some knownVal := newArgs.find? param.fvarId then
|
||||
if arg != knownVal then
|
||||
if arg.toExpr != knownVal then
|
||||
newArgs := newArgs.erase param.fvarId
|
||||
modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs }
|
||||
else
|
||||
let folder := fun acc (param, arg) => do
|
||||
if (← allFVarM (isInJpScope fn) arg) then
|
||||
return acc.insert param.fvarId arg
|
||||
return acc.insert param.fvarId arg.toExpr
|
||||
else
|
||||
return acc
|
||||
let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue