refactor: port IR elim_dead_vars to LCNF (#12356)

This PR moves the IR elim_dead_vars pass to LCNF. It cannot delete the
pass yet as it is still used
in later IR passes.
This commit is contained in:
Henrik Böving 2026-02-06 18:01:59 +01:00 committed by GitHub
parent 85899ddd17
commit 32fb1ccf1c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 118 additions and 60 deletions

View file

@ -11,7 +11,6 @@ public import Lean.Compiler.IR.Basic
public import Lean.Compiler.IR.Format
public import Lean.Compiler.IR.CompilerM
public import Lean.Compiler.IR.PushProj
public import Lean.Compiler.IR.ElimDeadVars
public import Lean.Compiler.IR.SimpCase
public import Lean.Compiler.IR.NormIds
public import Lean.Compiler.IR.Checker
@ -41,8 +40,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
logDecls `init decls
checkDecls decls
let mut decls := decls
decls := decls.map Decl.elimDead
logDecls `elim_dead decls
decls := decls.map Decl.simpCase
logDecls `simp_case decls
decls := decls.map Decl.normalizeIds

View file

@ -6,67 +6,82 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
public section
/-!
This module implements a pass that does a syntactic use-def check for all let/fun/jp bindings and
removes them if they are unused. Note that in impure mode not all unused let bindings can be removed safely
so we opt for a safe subset.
-/
namespace Lean.Compiler.LCNF
abbrev UsedLocalDecls := FVarIdHashSet
public abbrev UsedLocalDecls := FVarIdHashSet
/--
Collect set of (let) free variables in a LCNF value.
This code exploits the LCNF property that local declarations do not occur in types.
-/
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg .pure) : UsedLocalDecls :=
def collectLocalDeclsArg (s : UsedLocalDecls) (arg : Arg pu) : UsedLocalDecls :=
match arg with
| .fvar fvarId => s.insert fvarId
-- Locally declared variables do not occur in types.
| .type _ | .erased => s
| .type _ _ | .erased => s
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg .pure)) : UsedLocalDecls :=
def collectLocalDeclsArgs (s : UsedLocalDecls) (args : Array (Arg pu)) : UsedLocalDecls :=
args.foldl (init := s) collectLocalDeclsArg
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue .pure) : UsedLocalDecls :=
def collectLocalDeclsLetValue (s : UsedLocalDecls) (e : LetValue pu) : UsedLocalDecls :=
match e with
| .erased | .lit .. => s
| .proj _ _ fvarId => s.insert fvarId
| .const _ _ args => collectLocalDeclsArgs s args
| .fvar fvarId args => collectLocalDeclsArgs (s.insert fvarId) args
namespace ElimDead
| .proj _ _ fvarId _ | .reset _ fvarId _ | .sproj _ _ fvarId _ | .uproj _ fvarId _
| .oproj _ fvarId _ => s.insert fvarId
| .const _ _ args _ => collectLocalDeclsArgs s args
| .fvar fvarId args | .reuse fvarId _ _ args _ => collectLocalDeclsArgs (s.insert fvarId) args
| .fap _ args _ | .pap _ args _ | .ctor _ args _ => collectLocalDeclsArgs s args
abbrev M := StateRefT UsedLocalDecls CompilerM
private abbrev collectArgM (arg : Arg .pure) : M Unit :=
abbrev collectArgM (arg : Arg pu) : M Unit :=
modify (collectLocalDeclsArg · arg)
private abbrev collectLetValueM (e : LetValue .pure) : M Unit :=
abbrev collectLetValueM (e : LetValue pu) : M Unit :=
modify (collectLocalDeclsLetValue · e)
private abbrev collectFVarM (fvarId : FVarId) : M Unit :=
abbrev collectFVarM (fvarId : FVarId) : M Unit :=
modify (·.insert fvarId)
def LetValue.safeToElim (val : LetValue pu) : Bool :=
match pu with
| .pure => true
| .impure =>
match val with
| .ctor .. | .reset .. | .reuse .. | .oproj .. | .uproj .. | .sproj .. | .lit .. | .pap ..
-- TODO | .box .. | .unbox .. | .isShared ..
| .erased .. => true
-- 0-ary full applications are considered constants
| .fap _ args => args.isEmpty
| .fvar .. => false
mutual
partial def visitFunDecl (funDecl : FunDecl .pure) : M (FunDecl .pure) := do
let value ← elimDead funDecl.value
partial def visitFunDecl (funDecl : FunDecl pu) : M (FunDecl pu) := do
let value ← funDecl.value.elimDead
funDecl.updateValue value
partial def elimDead (code : Code .pure) : M (Code .pure) := do
partial def Code.elimDead (code : Code pu) : M (Code pu) := do
match code with
| .let decl k =>
let k ← elimDead k
if (← get).contains decl.fvarId then
let k ← k.elimDead
if (← get).contains decl.fvarId || !decl.value.safeToElim then
/- Remark: we don't need to collect `decl.type` because LCNF local declarations do not occur in types. -/
collectLetValueM decl.value
return code.updateCont! k
else
eraseLetDecl decl
return k
| .fun decl k | .jp decl k =>
let k ← elimDead k
| .fun decl k _ | .jp decl k =>
let k ← k.elimDead
if (← get).contains decl.fvarId then
let decl ← visitFunDecl decl
return code.updateFun! decl k
@ -74,22 +89,26 @@ partial def elimDead (code : Code .pure) : M (Code .pure) := do
eraseFunDecl decl
return k
| .cases c =>
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← elimDead alt.getCode)
let alts ← c.alts.mapMonoM fun alt => return alt.updateCode (← alt.getCode.elimDead)
collectFVarM c.discr
return code.updateAlts! alts
| .return fvarId => collectFVarM fvarId; return code
| .jmp fvarId args => collectFVarM fvarId; args.forM collectArgM; return code
| .unreach .. => return code
| .uset var _ y k _ | .sset var _ _ y _ k _ =>
collectFVarM var
collectFVarM y
return code.updateCont! (← k.elimDead)
end
end ElimDead
def Decl.elimDead (decl : Decl pu) : CompilerM (Decl pu) := do
return { decl with value := (← decl.value.mapCodeM fun code => code.elimDead.run' {}) }
-- TODO: Generalize this to arbitrary phases, keep in mind that in impure elim dead is not as easy though
def Code.elimDead (code : Code .pure) : CompilerM (Code .pure) :=
ElimDead.elimDead code |>.run' {}
public def elimDeadVars (phase : Phase) (occurrence : Nat) : Pass :=
Pass.mkPerDeclaration `elimDeadVars phase Decl.elimDead occurrence
def Decl.elimDead (decl : Decl .pure) : CompilerM (Decl .pure) := do
return { decl with value := (← decl.value.mapCodeM Code.elimDead) }
builtin_initialize
registerTraceClass `Compiler.elimDeadVars (inherited := true)
end Lean.Compiler.LCNF

View file

@ -145,6 +145,7 @@ def builtinPassManager : PassManager := {
saveImpure, -- End of impure phase
pushProj (occurrence := 0),
insertResetReuse,
elimDeadVars (phase := .impure) (occurrence := 0),
inferVisibility (phase := .impure),
]
}

View file

@ -0,0 +1,44 @@
prelude
import Init.Data.Option
/-! This test asserts that the `elimDeadVars` pass is able to eliminate dead projections in the
impure phase correctly. -/
/--
trace: [Compiler.saveMono] size: 5
def isNone x : Bool :=
cases x : Bool
| Option.none =>
let _x.1 := true;
return _x.1
| Option.some val.2 =>
let _x.3 := false;
return _x.3
[Compiler.pushProj] size: 6
def isNone x : UInt8 :=
cases x : UInt8
| Option.none =>
let _x.1 := 1;
return _x.1
| Option.some =>
let val.2 := proj[0] x;
let _x.3 := 0;
return _x.3
[Compiler.elimDeadVars] size: 5
def isNone x : UInt8 :=
cases x : UInt8
| Option.none =>
let _x.1 := 1;
return _x.1
| Option.some =>
let _x.2 := 0;
return _x.2
-/
#guard_msgs in
set_option trace.Compiler.saveMono true in
set_option trace.Compiler.pushProj true in
set_option trace.Compiler.elimDeadVars true in
def isNone (x : Option Nat) : Bool :=
match x with
| some _ => false
| none => true

View file

@ -11,7 +11,7 @@ unsafe def tst1 : MetaM Unit := do
#eval tst1
set_option trace.compiler.ir.init true
set_option trace.Compiler.saveMono true
def sefFn (e : Expr) (f : Expr) : Expr :=
match e with
| .app _ a => e.updateApp! f a

View file

@ -1,26 +1,23 @@
[Compiler.IR] [init]
def sefFn (x_1 : obj) (x_2 : obj) : obj :=
case x_1 : obj of
Lean.Expr.app._impl →
let x_3 : u64 := sproj[2, 0] x_1;
let x_4 : obj := proj[0] x_1;
let x_5 : obj := proj[1] x_1;
block_6 (x_7 : u8) :=
case x_7 : u8 of
Bool.false →
let x_8 : obj := Lean.Expr.app._override x_2 x_5;
ret x_8
Bool.true →
ret x_1;
let x_9 : usize := ptrAddrUnsafe ◾ x_4;
let x_10 : usize := ptrAddrUnsafe ◾ x_2;
let x_11 : u8 := USize.decEq x_9 x_10;
case x_11 : u8 of
Bool.false →
jmp block_6 x_11
Bool.true →
let x_12 : usize := ptrAddrUnsafe ◾ x_5;
let x_13 : u8 := USize.decEq x_12 x_12;
jmp block_6 x_13
default →
ret x_1
[Compiler.saveMono] size: 17
def sefFn e f : Expr :=
cases e : Expr
| Lean.Expr.app._impl data fn.1 arg.2 =>
jp _jp.3 _y.4 : Expr :=
cases _y.4 : Expr
| Bool.false =>
let _x.5 := Expr.app._override f arg.2;
return _x.5
| Bool.true =>
return e;
let _x.6 := ptrAddrUnsafe ◾ fn.1;
let _x.7 := ptrAddrUnsafe ◾ f;
let _x.8 := USize.decEq _x.6 _x.7;
cases _x.8 : Expr
| Bool.false =>
goto _jp.3 _x.8
| Bool.true =>
let _x.9 := ptrAddrUnsafe ◾ arg.2;
let _x.10 := USize.decEq _x.9 _x.9;
goto _jp.3 _x.10
| _ =>
return e