fix: correctly handle join points with no params in the IR elim_dead_branches pass (#8015)

This PR fixes the IR elim_dead_branches pass to correctly handle join
points with no params, which currently get considered unreachable. I was
not able to find an easy repro of this with the old compiler, but it
occurs when bootstrapping Lean with the new compiler.
This commit is contained in:
Cameron Zwarich 2025-04-17 20:52:19 -07:00 committed by GitHub
parent 32fe2391b9
commit f163758bcf
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -165,6 +165,7 @@ structure InterpContext where
structure InterpState where
assignments : Array Assignment
funVals : PArray Value -- we take snapshots during fixpoint computations
visitedJps : Array (Std.HashSet JoinPointId)
abbrev M := ReaderT InterpContext (StateM InterpState)
@ -223,11 +224,18 @@ def updateCurrFnSummary (v : Value) : M Unit := do
let currFnIdx := ctx.currFnIdx
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
def markJPVisited (j : JoinPointId) : M Bool := do
let currFnIdx := (← read).currFnIdx
modifyGet fun s =>
⟨!(s.visitedJps[currFnIdx]!.contains j),
{ s with visitedJps := s.visitedJps.modify currFnIdx fun a => a.insert j }⟩
/-- Return true if the assignment of at least one parameter has been updated. -/
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
def updateJPParamsAssignment (j : JoinPointId) (ys : Array Param) (xs : Array Arg) : M Bool := do
let ctx ← read
let currFnIdx := ctx.currFnIdx
ys.size.foldM (init := false) fun i _ r => do
let isFirstVisit ← markJPVisited j
ys.size.foldM (init := isFirstVisit) fun i _ r => do
let y := ys[i]
let x := xs[i]!
let yVal ← findVarValue y.x
@ -272,7 +280,7 @@ partial def interpFnBody : FnBody → M Unit
let ctx ← read
let ys := (ctx.lctx.getJPParams j).get!
let b := (ctx.lctx.getJPBody j).get!
let updated ← updateJPParamsAssignment ys xs
let updated ← updateJPParamsAssignment j ys xs
if updated then
-- We must reset the value of nested join-point parameters since they depend on `ys` values
resetNestedJPParams b
@ -283,7 +291,8 @@ partial def interpFnBody : FnBody → M Unit
def inferStep : M Bool := do
let ctx ← read
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
modify fun s => { s with assignments := ctx.decls.map fun _ => {},
visitedJps := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx _ modified => do
match ctx.decls[idx] with
| .fdecl (xs := ys) (body := b) .. => do
@ -332,8 +341,9 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
let env := s.env
let assignments : Array Assignment := decls.map fun _ => {}
let funVals := mkPArray decls.size Value.bot
let visitedJps := decls.map fun _ => {}
let ctx : InterpContext := { decls := decls, env := env }
let s : InterpState := { assignments := assignments, funVals := funVals }
let s : InterpState := { assignments, funVals, visitedJps }
let (_, s) := (inferMain ctx).run s
let funVals := s.funVals
let assignments := s.assignments