diff --git a/src/Lean/Compiler/IR/ElimDeadBranches.lean b/src/Lean/Compiler/IR/ElimDeadBranches.lean index e74c3eb05c..4445eec7ae 100644 --- a/src/Lean/Compiler/IR/ElimDeadBranches.lean +++ b/src/Lean/Compiler/IR/ElimDeadBranches.lean @@ -165,6 +165,7 @@ structure InterpContext where structure InterpState where assignments : Array Assignment funVals : PArray Value -- we take snapshots during fixpoint computations + visitedJps : Array (Std.HashSet JoinPointId) abbrev M := ReaderT InterpContext (StateM InterpState) @@ -223,11 +224,18 @@ def updateCurrFnSummary (v : Value) : M Unit := do let currFnIdx := ctx.currFnIdx modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') } +def markJPVisited (j : JoinPointId) : M Bool := do + let currFnIdx := (← read).currFnIdx + modifyGet fun s => + ⟨!(s.visitedJps[currFnIdx]!.contains j), + { s with visitedJps := s.visitedJps.modify currFnIdx fun a => a.insert j }⟩ + /-- Return true if the assignment of at least one parameter has been updated. -/ -def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do +def updateJPParamsAssignment (j : JoinPointId) (ys : Array Param) (xs : Array Arg) : M Bool := do let ctx ← read let currFnIdx := ctx.currFnIdx - ys.size.foldM (init := false) fun i _ r => do + let isFirstVisit ← markJPVisited j + ys.size.foldM (init := isFirstVisit) fun i _ r => do let y := ys[i] let x := xs[i]! let yVal ← findVarValue y.x @@ -272,7 +280,7 @@ partial def interpFnBody : FnBody → M Unit let ctx ← read let ys := (ctx.lctx.getJPParams j).get! let b := (ctx.lctx.getJPBody j).get! - let updated ← updateJPParamsAssignment ys xs + let updated ← updateJPParamsAssignment j ys xs if updated then -- We must reset the value of nested join-point parameters since they depend on `ys` values resetNestedJPParams b @@ -283,7 +291,8 @@ partial def interpFnBody : FnBody → M Unit def inferStep : M Bool := do let ctx ← read - modify fun s => { s with assignments := ctx.decls.map fun _ => {} } + modify fun s => { s with assignments := ctx.decls.map fun _ => {}, + visitedJps := ctx.decls.map fun _ => {} } ctx.decls.size.foldM (init := false) fun idx _ modified => do match ctx.decls[idx] with | .fdecl (xs := ys) (body := b) .. => do @@ -332,8 +341,9 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do let env := s.env let assignments : Array Assignment := decls.map fun _ => {} let funVals := mkPArray decls.size Value.bot + let visitedJps := decls.map fun _ => {} let ctx : InterpContext := { decls := decls, env := env } - let s : InterpState := { assignments := assignments, funVals := funVals } + let s : InterpState := { assignments, funVals, visitedJps } let (_, s) := (inferMain ctx).run s let funVals := s.funVals let assignments := s.assignments