diff --git a/src/Lean/Compiler/LCNF/JoinPoints.lean b/src/Lean/Compiler/LCNF/JoinPoints.lean index d2288555af..76b31f6fd7 100644 --- a/src/Lean/Compiler/LCNF/JoinPoints.lean +++ b/src/Lean/Compiler/LCNF/JoinPoints.lean @@ -10,9 +10,6 @@ import Lean.Compiler.LCNF.FVarUtil import Lean.Compiler.LCNF.ScopeM import Lean.Compiler.LCNF.InferType -set_option warningAsError false -#exit - namespace Lean.Compiler.LCNF namespace JoinPointFinder @@ -75,9 +72,15 @@ private def modifyCandidates (f : HashMap FVarId CandidateInfo → HashMap FVarI modify (fun state => {state with candidates := f state.candidates }) /-- -Remove all join point candidates contained in `e`. +Remove all join point candidates contained in `a`. -/ -private partial def removeCandidatesContainedIn (e : Expr) : FindM Unit := do +private partial def removeCandidatesInArg (a : Arg) : FindM Unit := do + forFVarM eraseCandidate a + +/-- +Remove all join point candidates contained in `a`. +-/ +private partial def removeCandidatesInLetExpr (e : LetExpr) : FindM Unit := do forFVarM eraseCandidate e /-- @@ -134,12 +137,12 @@ partial def find (decl : Decl) : CompilerM FindState := do where go : Code → FindM Unit | .let decl k => do - match k, decl.value, decl.value.getAppFn with - | .return valId, .app .., .fvar fvarId => - decl.value.getAppArgs.forM removeCandidatesContainedIn + match k, decl.value with + | .return valId, .fvar fvarId args => + args.forM removeCandidatesInArg if let some candidateInfo ← findCandidate? fvarId then -- Erase candidate that are not fully applied or applied outside of tail position - if valId != decl.fvarId || decl.value.getAppNumArgs != candidateInfo.arity then + if valId != decl.fvarId || args.size != candidateInfo.arity then eraseCandidate fvarId -- Out of scope join point candidate handling else if let some upperCandidate ← read then @@ -147,8 +150,8 @@ where addDependency fvarId upperCandidate else eraseCandidate fvarId - | _, _, _ => - removeCandidatesContainedIn decl.value + | _, _ => + removeCandidatesInLetExpr decl.value go k | .fun decl k => do withReader (fun _ => some decl.fvarId) do @@ -160,7 +163,7 @@ where | .jp decl k => do go decl.value go k - | .jmp _ args => args.forM removeCandidatesContainedIn + | .jmp _ args => args.forM removeCandidatesInArg | .return val => eraseCandidate val | .cases c => do eraseCandidate c.discr @@ -180,17 +183,17 @@ where go (code : Code) : ReplaceM Code := do match code with | .let decl k => - match k, decl.value, decl.value.getAppFn with - | .return valId, .app .., (.fvar fvarId) => + match k, decl.value with + | .return valId, .fvar fvarId args => if valId == decl.fvarId then if (← read).contains fvarId then eraseLetDecl decl - return .jmp fvarId decl.value.getAppArgs + return .jmp fvarId args else return code else return code - | _, _, _ => return Code.updateLet! code decl (← go k) + | _, _ => return Code.updateLet! code decl (← go k) | .fun decl k => if let some replacement := (← read).find? decl.fvarId then let newDecl := { decl with @@ -392,12 +395,10 @@ where goFVar (fvar : FVarId) : ExtendM FVarId := do extendByIfNecessary fvar replaceFVar fvar - goExpr (e : Expr) : ExtendM Expr := - mapFVarM goFVar e go (code : Code) : ExtendM Code := do match code with | .let decl k => - let decl ← decl.updateValue (← goExpr decl.value) + let decl ← decl.updateValue (← mapFVarM goFVar decl.value) withNewCandidate decl.fvarId do return Code.updateLet! code decl (← go k) | .jp decl k => @@ -423,7 +424,7 @@ where let alts ← cs.alts.mapM visitor return Code.updateCases! code cs.resultType discr alts | .jmp fn args => - let mut newArgs ← args.mapM goExpr + let mut newArgs ← args.mapM (mapFVarM goFVar) let additionalArgs := (← get).fvarMap.find! fn |>.toArray |>.map Prod.fst if let some _currentJp := (← read).currentJp? then let f := fun arg => do @@ -456,7 +457,8 @@ State for `ReduceAnalysisM`. -/ structure AnalysisState where /-- - Lists of names of arguments of jmps to join points to find duplicates. + A map, that for each join point id contains a map from all (so far) + duplicated argument ids to the respective duplicate value -/ jpJmpArgs : FVarIdMap FVarSubst := {} @@ -543,13 +545,13 @@ where let mut newArgs := knownArgs for (param, arg) in decl.params.zip args do if let some knownVal := newArgs.find? param.fvarId then - if arg != knownVal then + if arg.toExpr != knownVal then newArgs := newArgs.erase param.fvarId modify fun s => { s with jpJmpArgs := s.jpJmpArgs.insert fn newArgs } else let folder := fun acc (param, arg) => do if (← allFVarM (isInJpScope fn) arg) then - return acc.insert param.fvarId arg + return acc.insert param.fvarId arg.toExpr else return acc let interestingArgs ← decl.params.zip args |>.foldlM (init := {}) folder