refactor: port push_proj to LCNF (#12294)

This PR ports the `push_proj` pass from IR to LCNF. Notably it cannot
delete it from IR yet as the pass is still used later on.
This commit is contained in:
Henrik Böving 2026-02-03 20:21:45 +01:00 committed by GitHub
parent 00c1f0d3a9
commit 7ba76bd33e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
10 changed files with 315 additions and 3 deletions

View file

@ -47,8 +47,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
logDecls `init decls
checkDecls decls
let mut decls := decls
decls := decls.map Decl.pushProj
logDecls `push_proj decls
if compiler.reuse.get (← getOptions) then
decls := decls.map (Decl.insertResetReuse (← getEnv))
logDecls `reset_reuse decls

View file

@ -421,7 +421,7 @@ inductive Code (pu : Purity) where
| return (fvarId : FVarId)
| unreach (type : Expr)
| uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
deriving Inhabited
end
@ -497,10 +497,13 @@ inductive CodeDecl (pu : Purity) where
| let (decl : LetDecl pu)
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
| jp (decl : FunDecl pu)
| uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
deriving Inhabited
def CodeDecl.fvarId : CodeDecl pu → FVarId
| .let decl | .fun decl _ | .jp decl => decl.fvarId
| .uset var .. | .sset var .. => var
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
go decls.size code
@ -511,6 +514,8 @@ where
| .let decl => go (i-1) (.let decl code)
| .fun decl _ => go (i-1) (.fun decl code)
| .jp decl => go (i-1) (.jp decl code)
| .uset var idx y _ => go (i-1) (.uset var idx y code)
| .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code)
else
code
@ -1072,6 +1077,13 @@ end
@[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet :=
collectType e s
def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : FVarIdHashSet :=
match codeDecl with
| .let decl => collectLetValue decl.value <| collectType decl.type s
| .jp decl | .fun decl _ => decl.collectUsed s
| .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty
| .uset var _ y _ => s.insert var |>.insert y
/--
Traverse the given block of potentially mutually recursive functions
and mark a declaration `f` as recursive if there is an application

View file

@ -149,6 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
match decl with
| .let decl => eraseLetDecl decl
| .jp decl | .fun decl _ => eraseFunDecl decl
| .sset .. | .uset .. => return ()
/--
Erase all free variables occurring in `decls` from the local context.

View file

@ -57,6 +57,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
match decl with
| .let decl => decl.dependsOn s
| .jp decl | .fun decl _ => decl.dependsOn s
| .uset var _ y _ => s.contains var || s.contains y
| .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s)
/--
Return `true` is `c` depends on a free variable in `s`.

View file

@ -190,11 +190,15 @@ instance : TraverseFVar (CodeDecl pu) where
| .fun decl _ => return .fun (← mapFVarM f decl)
| .jp decl => return .jp (← mapFVarM f decl)
| .let decl => return .let (← mapFVarM f decl)
| .uset var i y _ => return .uset (← f var) i (← f y)
| .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty)
forFVarM f decl :=
match decl with
| .fun decl _ => forFVarM f decl
| .jp decl => forFVarM f decl
| .let decl => forFVarM f decl
| .uset var i y _ => do f var; f y
| .sset var i offset y ty _ => do f var; f y; forFVarM f ty
instance : TraverseFVar (Alt pu) where
mapFVarM f alt := do

View file

@ -158,6 +158,17 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
| .let decl => return .let (← internalizeLetDecl decl)
| .fun decl _ => return .fun (← internalizeFunDecl decl)
| .jp decl => return .jp (← internalizeFunDecl decl)
| .uset var i y _ =>
-- Something weird should be happening if these become erased...
let .fvar var ← normFVar var | unreachable!
let .fvar y ← normFVar y | unreachable!
return .uset var i y
| .sset var i offset y ty _ =>
let .fvar var ← normFVar var | unreachable!
let .fvar y ← normFVar y | unreachable!
let ty ← normExpr ty
return .sset var i offset y ty
end Internalize

View file

@ -77,6 +77,7 @@ mutual
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
| .cases c => eraseAlts c.alts lctx
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx
| _ => lctx
end

View file

@ -20,6 +20,7 @@ public import Lean.Compiler.LCNF.ExtractClosed
public import Lean.Compiler.LCNF.Visibility
public import Lean.Compiler.LCNF.Simp
public import Lean.Compiler.LCNF.ToImpure
public import Lean.Compiler.LCNF.PushProj
public section
@ -141,6 +142,7 @@ def builtinPassManager : PassManager := {
]
impurePasses := #[
saveImpure, -- End of impure phase
pushProj (occurrence := 0),
inferVisibility (phase := .impure),
]
}

View file

@ -0,0 +1,158 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
module
prelude
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.Internalize
namespace Lean.Compiler.LCNF
/-!
This pass pushes projections into directly neighboring nested case statements.
Suppose we have an LCNF pure input that looks as follows:
```
cases a with
| alt1 p1 p2 =>
cases b with
| alt2 p3 p4 =>
...
| alt3 p5 p6 =>
...
| ...
```
ToImpure will convert this into:
```
cases a with
| alt1 p1 p2 =>
let p1 := proj[0] a;
let p2 := proj[1] a;
cases b with
| alt2 p3 p4 =>
let p3 := proj[0] b;
let p4 := proj[1] b;
...
| alt3 p5 p6 =>
let p5 := proj[0] b;
let p6 := proj[1] b;
...
| ...
```
Let's assume `p1` is used in both `alt2` and `alt3` and `p2` is used only in `alt3` then this pass
will convert the code into:
```
cases a with
| alt1 p1 p2 =>
cases b with
| alt2 p3 p4 =>
let p1 := proj[0] a;
let p3 := proj[0] b;
let p4 := proj[1] b;
...
| alt3 p5 p6 =>
let p1 := proj[0] a;
let p2 := proj[1] a;
let p5 := proj[0] b;
let p6 := proj[1] b;
...
| ...
```
This helps to avoid loading memory that is not actually required in all branches.
Note that unlike `floatLetIn`, this pass is willing to duplicate projections that are being pushed
around.
TODO: we suspect it might also help with reuse analysis, check this. This pass was ported from IR to
LCNF.
-/
mutual
partial def Cases.pushProjs (c : Cases .impure) (decls : Array (CodeDecl .impure)) :
CompilerM (Code .impure) := do
let altsUsed := c.alts.map (·.getCode.collectUsed)
let ctxUsed := ({} : FVarIdHashSet) |>.insert c.discr
let (bs, alts) ← go decls c.alts altsUsed #[] ctxUsed
let alts ← alts.mapM (·.mapCodeM Code.pushProj)
let c := c.updateAlts alts
return attachCodeDecls bs (.cases c)
where
/-
Here:
- `decls` are the declarations that are still under consideration for being pushed into `alts`
- `alts` are the alternatives of the current case arms,
- `altsUsed` contains the used fvars per arm, both these sets and `alts` will be updated as we push
things into them
- `ctx` is the set of declarations that we decided not to push into any alt already
- `ctxUsed` fulfills the same purpose as `altsUsed` for `ctx`
-/
go (decls : Array (CodeDecl .impure)) (alts : Array (Alt .impure)) (altsUsed : Array FVarIdHashSet)
(ctx : Array (CodeDecl .impure)) (ctxUsed : FVarIdHashSet) :
CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) :=
if decls.isEmpty then
return (ctx.reverse, alts)
else
let b := decls.back!
let bs := decls.pop
let done := return (bs.push b ++ ctx.reverse, alts)
let skip := go bs alts altsUsed (ctx.push b) (b.collectUsed ctxUsed)
let push (fvar : FVarId) : CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) := do
if !ctxUsed.contains fvar then
let alts ← alts.mapIdxM fun i alt => alt.mapCodeM fun k => do
if altsUsed[i]!.contains fvar then
return attachCodeDecls #[b] k
else
return k
let altsUsed := altsUsed.map fun used =>
if used.contains fvar then
b.collectUsed used
else
used
go bs alts altsUsed ctx ctxUsed
else
skip
match b with
| .let decl =>
match decl.value with
| .uproj .. | .oproj .. | .sproj .. => push decl.fvarId
-- TODO | .isShared .. => skip
| _ => done
| _ => done
partial def Code.pushProj (code : Code .impure) : CompilerM (Code .impure) := do
go code #[]
where
go (c : Code .impure) (decls : Array (CodeDecl .impure)) : CompilerM (Code .impure) := do
match c with
| .let decl k => go k (decls.push (.let decl))
| .jp decl k =>
let decl ← decl.updateValue (← decl.value.pushProj)
go k (decls.push (.jp decl))
| .uset var i y k _ =>
go k (decls.push (.uset var i y))
| .sset var i offset y ty k _ =>
go k (decls.push (.sset var i offset y ty))
| .cases c => c.pushProjs decls
| .jmp .. | .return .. | .unreach .. =>
return attachCodeDecls decls c
end
def Decl.pushProj (decl : Decl .impure) : CompilerM (Decl .impure) := do
let value ← decl.value.mapCodeM (·.pushProj)
let decl := { decl with value }
decl.internalize
public def pushProj (occurrence : Nat) : Pass :=
Pass.mkPerDeclaration `pushProj .impure Decl.pushProj occurrence
builtin_initialize
registerTraceClass `Compiler.pushProj (inherited := true)
end Lean.Compiler.LCNF

View file

@ -0,0 +1,123 @@
/-! This does some basic unit tests for the pushProj pass in LCNF -/
/--
trace: [Compiler.pushProj] size: 5
def test1 a : tobj :=
cases a : tobj
| Option.none =>
let _x.1 : tagged := 0;
return _x.1
| Option.some =>
let val.2 : tobj := proj[0] a;
return val.2
-/
#guard_msgs in
set_option pp.letVarTypes true in
set_option trace.Compiler.pushProj true in
def test1 (a : Option Nat) : Nat :=
match a with
| some a => a
| none => 0
/--
trace: [Compiler.pushProj] size: 10
def test2 a b : tobj :=
cases a : tobj
| Option.none =>
return a
| Option.some =>
cases b : tobj
| Option.none =>
return a
| Option.some =>
let val.1 : tobj := proj[0] a;
let val.2 : tobj := proj[0] b;
let _x.3 : tobj := Nat.add val.1 val.2;
let _x.4 : obj := ctor_1[Option.some] _x.3;
return _x.4
-/
#guard_msgs in
set_option pp.letVarTypes true in
set_option trace.Compiler.pushProj true in
def test2 (a b : Option Nat) : Option Nat :=
match a with
| some a =>
match b with
| some b => some (a + b)
| none => some a
| none => none
/--
trace: [Compiler.pushProj] size: 14
def test3 a b : tobj :=
cases a : tobj
| Option.none =>
return a
| Option.some =>
cases b : tobj
| Option.none =>
let val.1 : tobj := proj[0] a;
let _x.2 : tagged := 1;
let _x.3 : tobj := Nat.add val.1 _x.2;
let _x.4 : obj := ctor_1[Option.some] _x.3;
return _x.4
| Option.some =>
let val.5 : tobj := proj[0] a;
let val.6 : tobj := proj[0] b;
let _x.7 : tobj := Nat.add val.5 val.6;
let _x.8 : obj := ctor_1[Option.some] _x.7;
return _x.8
-/
#guard_msgs in
set_option pp.letVarTypes true in
set_option trace.Compiler.pushProj true in
def test3 (a b : Option Nat) : Option Nat :=
match a with
| some a =>
match b with
| some b => some (a + b)
| none => some (a + 1)
| none => none
/--
trace: [Compiler.pushProj] size: 18
def test4 a b c : tobj :=
cases a : tobj
| Option.none =>
return a
| Option.some =>
cases b : tobj
| Option.none =>
let val.1 : tobj := proj[0] a;
let _x.2 : tagged := 1;
let _x.3 : tobj := Nat.add val.1 _x.2;
let _x.4 : obj := ctor_1[Option.some] _x.3;
return _x.4
| Option.some =>
cases c : tobj
| Bool.false =>
let _x.5 : tagged := ctor_0[Option.none] ;
return _x.5
| Bool.true =>
let val.6 : tobj := proj[0] a;
let val.7 : tobj := proj[0] b;
let _x.8 : tobj := Nat.add val.6 val.7;
let _x.9 : obj := ctor_1[Option.some] _x.8;
return _x.9
-/
#guard_msgs in
set_option pp.letVarTypes true in
set_option trace.Compiler.pushProj true in
def test4 (a b : Option Nat) (c : Bool) : Option Nat :=
match a with
| some a =>
match b with
| some b =>
match c with
| true => some (a + b)
| false => none
| none => some (a + 1)
| none => none