chore: remove IR elim dead branches (#11576)

This PR removes the old ElimDeadBranches pass and shifts the new one
past lambda lifting.

The reason for dropping the old one is its general unsoundness and the
fact that we want to do refactorings on the IR part. The reason for
shifting the current pass past lambda lifting, is that its analysis is
imprecise in the presence of local function symbols. I experimented with
the exact placement for a while and it seems like it is optimal here.
Overall we observe a slight regression in the amount of C code
generated, likely because we don't propagate information into lambdas
before lifting them anymore. But generally measure a slight performance
improvement in general.
This commit is contained in:
Henrik Böving 2025-12-11 11:39:02 +01:00 committed by GitHub
parent e7f4fc9baf
commit b8c53b1d29
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 18 additions and 352 deletions

View file

@ -21,7 +21,6 @@ public import Lean.Compiler.IR.Boxing
public import Lean.Compiler.IR.RC
public import Lean.Compiler.IR.ExpandResetReuse
public import Lean.Compiler.IR.UnboxResult
public import Lean.Compiler.IR.ElimDeadBranches
public import Lean.Compiler.IR.EmitC
public import Lean.Compiler.IR.Sorry
public import Lean.Compiler.IR.ToIR
@ -46,8 +45,7 @@ register_builtin_option compiler.reuse : Bool := {
def compile (decls : Array Decl) : CompilerM (Array Decl) := do
logDecls `init decls
checkDecls decls
let mut decls ← elimDeadBranches decls
logDecls `elim_dead_branches decls
let mut decls := decls
decls := decls.map Decl.pushProj
logDecls `push_proj decls
if compiler.reuse.get (← getOptions) then

View file

@ -1,344 +0,0 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Leonardo de Moura
-/
module
prelude
public import Lean.Compiler.IR.CompilerM
public section
namespace Lean.IR.UnreachableBranches
/-- Value used in the abstract interpreter -/
inductive Value where
| bot -- undefined
| top -- any value
| ctor (i : CtorInfo) (vs : Array Value)
| choice (vs : List Value)
deriving Inhabited, BEq, Repr
protected partial def Value.toFormat : Value → Format
| Value.bot => "⊥"
| Value.top => ""
| Value.ctor info vs =>
if vs.isEmpty then
format info.name
else
Format.paren <| format info.name ++ Std.Format.join (vs.toList.map fun v => " " ++ Value.toFormat v)
| Value.choice vs =>
Format.paren <| Std.Format.joinSep (vs.map Value.toFormat) " | "
instance : ToFormat Value where
format := Value.toFormat
instance : ToString Value where
toString v := toString (format v)
namespace Value
partial def addChoice (merge : Value → Value → Value) : List Value → Value → List Value
| [], v => [v]
| v₁@(ctor i₁ _) :: cs, v₂@(ctor i₂ _) =>
if i₁ == i₂ then merge v₁ v₂ :: cs
else v₁ :: addChoice merge cs v₂
| _, _ => panic! "invalid addChoice"
partial def merge (v₁ v₂ : Value) : Value :=
match v₁, v₂ with
| bot, v => v
| v, bot => v
| top, _ => top
| _, top => top
| v₁@(ctor i₁ vs₁), v₂@(ctor i₂ vs₂) =>
if i₁ == i₂ then ctor i₁ <| vs₁.size.fold (init := #[]) fun i _ r => r.push (merge vs₁[i] vs₂[i]!)
else choice [v₁, v₂]
| choice vs₁, choice vs₂ => choice <| vs₁.foldl (addChoice merge) vs₂
| choice vs, v => choice <| addChoice merge vs v
| v, choice vs => choice <| addChoice merge vs v
/--
In `truncate`, we approximate a value as `top` if depth > `truncateMaxDepth`.
TODO: add option to control this parameter.
-/
def truncateMaxDepth := 8
/--
Make sure constructors of recursive inductive datatypes can only occur once in each path.
Values at depth > truncateMaxDepth are also approximated at `top`.
We use this function this function to implement a simple widening operation for our abstract
interpreter.
Recall the widening functions is used to ensure termination in abstract interpreters.
-/
partial def truncate (env : Environment) (v : Value) (s : NameSet) : Value :=
go v s truncateMaxDepth
where
go (v : Value) (s : NameSet) (depth : Nat) : Value :=
match depth with
| 0 => top
| depth+1 =>
match v, s with
| ctor i vs, found =>
let I := i.name.getPrefix
if found.contains I then
top
else
let cont (found' : NameSet) : Value :=
ctor i (vs.map fun v => go v found' depth)
match env.find? I with
| some (ConstantInfo.inductInfo d) =>
if d.isRec then cont (found.insert I)
else cont found
| _ => cont found
| choice vs, found =>
let newVs := vs.map fun v => go v found depth
if newVs.elem top then top
else choice newVs
| v, _ => v
/-- Widening operator that guarantees termination in our abstract interpreter. -/
def widening (env : Environment) (v₁ v₂ : Value) : Value :=
truncate env (merge v₁ v₂) {}
end Value
abbrev FunctionSummaries := PHashMap FunId Value
private abbrev declLt (a b : FunId × Value) :=
Name.quickLt a.1 b.1
private abbrev sortEntries (entries : Array (FunId × Value)) : Array (FunId × Value) :=
entries.qsort declLt
private abbrev findAtSorted? (entries : Array (FunId × Value)) (fid : FunId) : Option Value :=
if let some (_, value) := entries.binSearch (fid, default) declLt then
some value
else
none
builtin_initialize functionSummariesExt : SimplePersistentEnvExtension (FunId × Value) FunctionSummaries ←
registerSimplePersistentEnvExtension {
addImportedFn := fun _ => {}
addEntryFn := fun s ⟨e, n⟩ => s.insert e n
exportEntriesFnEx? := some fun _ s _ => fun
-- preserved for non-modules, make non-persistent at some point?
| .private => sortEntries s.toArray
| _ => #[]
asyncMode := .sync -- compilation is non-parallel anyway
replay? := some <| SimplePersistentEnvExtension.replayOfFilter (!·.contains ·.1) (fun s ⟨e, n⟩ => s.insert e n)
}
def addFunctionSummary (env : Environment) (fid : FunId) (v : Value) : Environment :=
functionSummariesExt.addEntry env (fid, v)
def getFunctionSummary? (env : Environment) (fid : FunId) : Option Value :=
match env.getModuleIdxFor? fid with
| some modIdx => findAtSorted? (functionSummariesExt.getModuleEntries env modIdx) fid
| none => functionSummariesExt.getState env |>.find? fid
abbrev Assignment := Std.HashMap VarId Value
structure InterpContext where
currFnIdx : Nat := 0
decls : Array Decl
env : Environment
lctx : LocalContext := {}
structure InterpState where
assignments : Array Assignment
funVals : PArray Value -- we take snapshots during fixpoint computations
visitedJps : Array (Std.HashSet JoinPointId)
abbrev M := ReaderT InterpContext (StateM InterpState)
open Value
def findVarValue (x : VarId) : M Value := do
let ctx ← read
let s ← get
let assignment := s.assignments[ctx.currFnIdx]!
return assignment.getD x bot
def findArgValue (arg : Arg) : M Value :=
match arg with
| .var x => findVarValue x
| .erased => pure top
def updateVarAssignment (x : VarId) (v : Value) : M Unit := do
let v' ← findVarValue x
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x (merge v v') }
def resetVarAssignment (x : VarId) : M Unit := do
let ctx ← read
modify fun s => { s with assignments := s.assignments.modify ctx.currFnIdx fun a => a.insert x Value.bot }
def resetParamAssignment (y : Param) : M Unit :=
resetVarAssignment y.x
partial def projValue : Value → Nat → Value
| ctor _ vs, i => vs.getD i bot
| choice vs, i => vs.foldl (fun r v => merge r (projValue v i)) bot
| v, _ => v
def interpExpr : Expr → M Value
| Expr.ctor i ys => return ctor i (← ys.mapM fun y => findArgValue y)
| Expr.proj i x => return projValue (← findVarValue x) i
| Expr.fap fid _ => do
let ctx ← read
match getFunctionSummary? ctx.env fid with
| some v => pure v
| none => do
let s ← get
match ctx.decls.findIdx? (fun decl => decl.name == fid) with
| some idx => pure s.funVals[idx]!
| none => pure top
| _ => pure top
partial def containsCtor : Value → CtorInfo → Bool
| top, _ => true
| ctor i _, j => i == j
| choice vs, j => vs.any fun v => containsCtor v j
| _, _ => false
def updateCurrFnSummary (v : Value) : M Unit := do
let ctx ← read
let currFnIdx := ctx.currFnIdx
modify fun s => { s with funVals := s.funVals.modify currFnIdx (fun v' => widening ctx.env v v') }
def markJPVisited (j : JoinPointId) : M Bool := do
let currFnIdx := (← read).currFnIdx
modifyGet fun s =>
⟨!(s.visitedJps[currFnIdx]!.contains j),
{ s with visitedJps := s.visitedJps.modify currFnIdx fun a => a.insert j }⟩
/-- Return true if the assignment of at least one parameter has been updated. -/
def updateJPParamsAssignment (j : JoinPointId) (ys : Array Param) (xs : Array Arg) : M Bool := do
let ctx ← read
let currFnIdx := ctx.currFnIdx
let isFirstVisit ← markJPVisited j
ys.size.foldM (init := isFirstVisit) fun i _ r => do
let y := ys[i]
let x := xs[i]!
let yVal ← findVarValue y.x
let xVal ← findArgValue x
let newVal := merge yVal xVal
if newVal == yVal then
pure r
else
modify fun s => { s with assignments := s.assignments.modify currFnIdx fun a => a.insert y.x newVal }
pure true
private partial def resetNestedJPParams : FnBody → M Unit
| FnBody.jdecl _ ys _ k => do
ys.forM resetParamAssignment
/- Remark we don't need to reset the parameters of join points
nested in `b` since they will be reset if this JP is used. -/
resetNestedJPParams k
| FnBody.case _ _ _ alts =>
alts.forM fun alt => match alt with
| Alt.ctor _ b => resetNestedJPParams b
| Alt.default b => resetNestedJPParams b
| e => do unless e.isTerminal do resetNestedJPParams e.body
partial def interpFnBody : FnBody → M Unit
| FnBody.vdecl x _ e b => do
let v ← interpExpr e
updateVarAssignment x v
interpFnBody b
| FnBody.jdecl j ys v b =>
withReader (fun ctx => { ctx with lctx := ctx.lctx.addJP j ys v }) do
interpFnBody b
| FnBody.case _ x _ alts => do
let v ← findVarValue x
alts.forM fun alt => do
match alt with
| Alt.ctor i b => if containsCtor v i then interpFnBody b
| Alt.default b => interpFnBody b
| FnBody.ret x => do
let v ← findArgValue x
updateCurrFnSummary v
| FnBody.jmp j xs => do
let ctx ← read
let ys := (ctx.lctx.getJPParams j).get!
let b := (ctx.lctx.getJPBody j).get!
let updated ← updateJPParamsAssignment j ys xs
if updated then
-- We must reset the value of nested join-point parameters since they depend on `ys` values
resetNestedJPParams b
interpFnBody b
| e => do
unless e.isTerminal do
interpFnBody e.body
def inferStep : M Bool := do
let ctx ← read
modify fun s => { s with assignments := ctx.decls.map fun _ => {},
visitedJps := ctx.decls.map fun _ => {} }
ctx.decls.size.foldM (init := false) fun idx _ modified => do
match ctx.decls[idx] with
| .fdecl (xs := ys) (body := b) .. => do
let s ← get
let currVals := s.funVals[idx]!
withReader (fun ctx => { ctx with currFnIdx := idx }) do
ys.forM fun y => updateVarAssignment y.x top
interpFnBody b
let s ← get
let newVals := s.funVals[idx]!
pure (modified || currVals != newVals)
| .extern .. => do
let currVals := (← get).funVals[idx]!
updateCurrFnSummary .top
let newVals := (← get).funVals[idx]!
pure (modified || currVals != newVals)
partial def inferMain : M Unit := do
let modified ← inferStep
if modified then inferMain else pure ()
partial def elimDeadAux (assignment : Assignment) : FnBody → FnBody
| FnBody.vdecl x t e b => FnBody.vdecl x t e (elimDeadAux assignment b)
| FnBody.jdecl j ys v b => FnBody.jdecl j ys (elimDeadAux assignment v) (elimDeadAux assignment b)
| FnBody.case tid x xType alts =>
let v := assignment.getD x bot
let alts := alts.map fun alt =>
match alt with
| Alt.ctor i b => Alt.ctor i <| if containsCtor v i then elimDeadAux assignment b else FnBody.unreachable
| Alt.default b => Alt.default (elimDeadAux assignment b)
FnBody.case tid x xType alts
| e =>
if e.isTerminal then e
else
let (instr, b) := e.split
let b := elimDeadAux assignment b
instr.setBody b
partial def elimDead (assignment : Assignment) (d : Decl) : Decl :=
match d with
| .fdecl (body := b) .. => d.updateBody! <| elimDeadAux assignment b
| other => other
end UnreachableBranches
open UnreachableBranches
def elimDeadBranches (decls : Array Decl) : CompilerM (Array Decl) := do
let env ← getEnv
let assignments : Array Assignment := decls.map fun _ => {}
let funVals := mkPArray decls.size Value.bot
let visitedJps := decls.map fun _ => {}
let ctx : InterpContext := { decls := decls, env := env }
let s : InterpState := { assignments, funVals, visitedJps }
let (_, s) := (inferMain ctx).run s
let funVals := s.funVals
let assignments := s.assignments
modifyEnv fun env =>
decls.size.fold (init := env) fun i _ env =>
addFunctionSummary env decls[i].name funVals[i]!
return decls.mapIdx fun i decl => elimDead assignments[i]! decl
builtin_initialize registerTraceClass `compiler.ir.elim_dead_branches (inherited := true)
end Lean.IR

View file

@ -113,10 +113,10 @@ def builtinPassManager : PassManager := {
commonJoinPointArgs,
simp (occurrence := 4) (phase := .mono),
floatLetIn (phase := .mono) (occurrence := 2),
elimDeadBranches,
lambdaLifting,
extendJoinPointContext (phase := .mono) (occurrence := 1),
simp (occurrence := 5) (phase := .mono),
elimDeadBranches,
cse (occurrence := 2) (phase := .mono),
saveMono, -- End of mono phase
inferVisibility (phase := .mono),

View file

@ -85,6 +85,18 @@ def getDeclInfo? (declName : Name) : CoreM (Option ConstantInfo) := do
let env ← getEnv
return env.find? (mkUnsafeRecName declName) <|> env.find? declName
def declIsNotUnsafe (declName : Name) : CoreM Bool := do
let env ← getEnv
let some info := env.find? declName | return true
if info.isUnsafe then
return false
else
if info matches .opaqueInfo .. then
-- check if its a partial def
return env.find? (Compiler.mkUnsafeRecName declName) |>.isNone
else
return true
/--
Convert the given declaration from the Lean environment into `Decl`.
The steps for this are roughly:
@ -97,7 +109,7 @@ The steps for this are roughly:
def toDecl (declName : Name) : CompilerM Decl := do
let declName := if let some name := isUnsafeRecName? declName then name else declName
let some info ← getDeclInfo? declName | throwError "declaration `{.ofConstName declName}` not found"
let safe := !info.isPartial && !info.isUnsafe
let safe ← declIsNotUnsafe declName
let env ← getEnv
let inlineAttr? := getInlineAttribute? env declName
let paramsFromTypeBinders (expr : Expr) : CompilerM (Array Param) := do

View file

@ -6,13 +6,13 @@ class Semiring (α : Type u) where
/--
trace: [Compiler.IR] [result]
def instSemiringUInt8 : obj :=
let x_1 : obj := pap instSemiringUInt8._lam_0._boxed;
ret x_1
def instSemiringUInt8._lam_0 (x_1 : @& tobj) (x_2 : u8) : u8 :=
let x_3 : u8 := UInt8.ofNat x_1;
let x_4 : u8 := UInt8.mul x_3 x_2;
ret x_4
def instSemiringUInt8 : obj :=
let x_1 : obj := pap instSemiringUInt8._lam_0._boxed;
ret x_1
def instSemiringUInt8._lam_0._boxed (x_1 : tobj) (x_2 : tagged) : tagged :=
let x_3 : u8 := unbox x_2;
let x_4 : u8 := instSemiringUInt8._lam_0 x_1 x_3;