fix: correctly handle join points with no params in the IR elim_dead_branches pass (#8015)
This PR fixes the IR elim_dead_branches pass to correctly handle join points with no params, which currently get considered unreachable. I was not able to find an easy repro of this with the old compiler, but it occurs when bootstrapping Lean with the new compiler.
This commit is contained in:
parent
32fe2391b9
commit
f163758bcf
1 changed files with 15 additions and 5 deletions
|
|
@ -165,6 +165,7 @@ structure InterpContext where
|
|||
structure InterpState where
|
||||
assignments : Array Assignment
|
||||
funVals : PArray Value -- we take snapshots during fixpoint computations
|
||||
visitedJps : Array (Std.HashSet JoinPointId)
|
||||
|
||||
abbrev M := ReaderT InterpContext (StateM InterpState)
|
||||
|
||||
|
|
@ -223,11 +224,18 @@ def updateCurrFnSummary (v : Value) : M Unit := do
|
|||
let currFnIdx := ctx.currFnIdx
|
||||
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
|
||||
|
||||
def markJPVisited (j : JoinPointId) : M Bool := do
|
||||
let currFnIdx := (← read).currFnIdx
|
||||
modifyGet fun s =>
|
||||
⟨!(s.visitedJps[currFnIdx]!.contains j),
|
||||
{ s with visitedJps := s.visitedJps.modify currFnIdx fun a => a.insert j }⟩
|
||||
|
||||
/-- Return true if the assignment of at least one parameter has been updated. -/
|
||||
def updateJPParamsAssignment (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
def updateJPParamsAssignment (j : JoinPointId) (ys : Array Param) (xs : Array Arg) : M Bool := do
|
||||
let ctx ← read
|
||||
let currFnIdx := ctx.currFnIdx
|
||||
ys.size.foldM (init := false) fun i _ r => do
|
||||
let isFirstVisit ← markJPVisited j
|
||||
ys.size.foldM (init := isFirstVisit) fun i _ r => do
|
||||
let y := ys[i]
|
||||
let x := xs[i]!
|
||||
let yVal ← findVarValue y.x
|
||||
|
|
@ -272,7 +280,7 @@ partial def interpFnBody : FnBody → M Unit
|
|||
let ctx ← read
|
||||
let ys := (ctx.lctx.getJPParams j).get!
|
||||
let b := (ctx.lctx.getJPBody j).get!
|
||||
let updated ← updateJPParamsAssignment ys xs
|
||||
let updated ← updateJPParamsAssignment j ys xs
|
||||
if updated then
|
||||
-- We must reset the value of nested join-point parameters since they depend on `ys` values
|
||||
resetNestedJPParams b
|
||||
|
|
@ -283,7 +291,8 @@ partial def interpFnBody : FnBody → M Unit
|
|||
|
||||
def inferStep : M Bool := do
|
||||
let ctx ← read
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {} }
|
||||
modify fun s => { s with assignments := ctx.decls.map fun _ => {},
|
||||
visitedJps := ctx.decls.map fun _ => {} }
|
||||
ctx.decls.size.foldM (init := false) fun idx _ modified => do
|
||||
match ctx.decls[idx] with
|
||||
| .fdecl (xs := ys) (body := b) .. => do
|
||||
|
|
@ -332,8 +341,9 @@ def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
|
|||
let env := s.env
|
||||
let assignments : Array Assignment := decls.map fun _ => {}
|
||||
let funVals := mkPArray decls.size Value.bot
|
||||
let visitedJps := decls.map fun _ => {}
|
||||
let ctx : InterpContext := { decls := decls, env := env }
|
||||
let s : InterpState := { assignments := assignments, funVals := funVals }
|
||||
let s : InterpState := { assignments, funVals, visitedJps }
|
||||
let (_, s) := (inferMain ctx).run s
|
||||
let funVals := s.funVals
|
||||
let assignments := s.assignments
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue