refactor: port push_proj to LCNF (#12294)
This PR ports the `push_proj` pass from IR to LCNF. Notably it cannot delete it from IR yet as the pass is still used later on.
This commit is contained in:
parent
00c1f0d3a9
commit
7ba76bd33e
10 changed files with 315 additions and 3 deletions
|
|
@ -47,8 +47,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do
|
|||
logDecls `init decls
|
||||
checkDecls decls
|
||||
let mut decls := decls
|
||||
decls := decls.map Decl.pushProj
|
||||
logDecls `push_proj decls
|
||||
if compiler.reuse.get (← getOptions) then
|
||||
decls := decls.map (Decl.insertResetReuse (← getEnv))
|
||||
logDecls `reset_reuse decls
|
||||
|
|
|
|||
|
|
@ -421,7 +421,7 @@ inductive Code (pu : Purity) where
|
|||
| return (fvarId : FVarId)
|
||||
| unreach (type : Expr)
|
||||
| uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac)
|
||||
deriving Inhabited
|
||||
|
||||
end
|
||||
|
|
@ -497,10 +497,13 @@ inductive CodeDecl (pu : Purity) where
|
|||
| let (decl : LetDecl pu)
|
||||
| fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac)
|
||||
| jp (decl : FunDecl pu)
|
||||
| uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac)
|
||||
| sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac)
|
||||
deriving Inhabited
|
||||
|
||||
def CodeDecl.fvarId : CodeDecl pu → FVarId
|
||||
| .let decl | .fun decl _ | .jp decl => decl.fvarId
|
||||
| .uset var .. | .sset var .. => var
|
||||
|
||||
def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu :=
|
||||
go decls.size code
|
||||
|
|
@ -511,6 +514,8 @@ where
|
|||
| .let decl => go (i-1) (.let decl code)
|
||||
| .fun decl _ => go (i-1) (.fun decl code)
|
||||
| .jp decl => go (i-1) (.jp decl code)
|
||||
| .uset var idx y _ => go (i-1) (.uset var idx y code)
|
||||
| .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code)
|
||||
else
|
||||
code
|
||||
|
||||
|
|
@ -1072,6 +1077,13 @@ end
|
|||
@[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet :=
|
||||
collectType e s
|
||||
|
||||
def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : FVarIdHashSet :=
|
||||
match codeDecl with
|
||||
| .let decl => collectLetValue decl.value <| collectType decl.type s
|
||||
| .jp decl | .fun decl _ => decl.collectUsed s
|
||||
| .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty
|
||||
| .uset var _ y _ => s.insert var |>.insert y
|
||||
|
||||
/--
|
||||
Traverse the given block of potentially mutually recursive functions
|
||||
and mark a declaration `f` as recursive if there is an application
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do
|
|||
match decl with
|
||||
| .let decl => eraseLetDecl decl
|
||||
| .jp decl | .fun decl _ => eraseFunDecl decl
|
||||
| .sset .. | .uset .. => return ()
|
||||
|
||||
/--
|
||||
Erase all free variables occurring in `decls` from the local context.
|
||||
|
|
|
|||
|
|
@ -57,6 +57,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool :=
|
|||
match decl with
|
||||
| .let decl => decl.dependsOn s
|
||||
| .jp decl | .fun decl _ => decl.dependsOn s
|
||||
| .uset var _ y _ => s.contains var || s.contains y
|
||||
| .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s)
|
||||
|
||||
/--
|
||||
Return `true` is `c` depends on a free variable in `s`.
|
||||
|
|
|
|||
|
|
@ -190,11 +190,15 @@ instance : TraverseFVar (CodeDecl pu) where
|
|||
| .fun decl _ => return .fun (← mapFVarM f decl)
|
||||
| .jp decl => return .jp (← mapFVarM f decl)
|
||||
| .let decl => return .let (← mapFVarM f decl)
|
||||
| .uset var i y _ => return .uset (← f var) i (← f y)
|
||||
| .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty)
|
||||
forFVarM f decl :=
|
||||
match decl with
|
||||
| .fun decl _ => forFVarM f decl
|
||||
| .jp decl => forFVarM f decl
|
||||
| .let decl => forFVarM f decl
|
||||
| .uset var i y _ => do f var; f y
|
||||
| .sset var i offset y ty _ => do f var; f y; forFVarM f ty
|
||||
|
||||
instance : TraverseFVar (Alt pu) where
|
||||
mapFVarM f alt := do
|
||||
|
|
|
|||
|
|
@ -158,6 +158,17 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl
|
|||
| .let decl => return .let (← internalizeLetDecl decl)
|
||||
| .fun decl _ => return .fun (← internalizeFunDecl decl)
|
||||
| .jp decl => return .jp (← internalizeFunDecl decl)
|
||||
| .uset var i y _ =>
|
||||
-- Something weird should be happening if these become erased...
|
||||
let .fvar var ← normFVar var | unreachable!
|
||||
let .fvar y ← normFVar y | unreachable!
|
||||
return .uset var i y
|
||||
| .sset var i offset y ty _ =>
|
||||
let .fvar var ← normFVar var | unreachable!
|
||||
let .fvar y ← normFVar y | unreachable!
|
||||
let ty ← normExpr ty
|
||||
return .sset var i offset y ty
|
||||
|
||||
|
||||
end Internalize
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ mutual
|
|||
| .let decl k => eraseCode k <| lctx.eraseLetDecl decl
|
||||
| .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl
|
||||
| .cases c => eraseAlts c.alts lctx
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx
|
||||
| _ => lctx
|
||||
end
|
||||
|
||||
|
|
|
|||
|
|
@ -20,6 +20,7 @@ public import Lean.Compiler.LCNF.ExtractClosed
|
|||
public import Lean.Compiler.LCNF.Visibility
|
||||
public import Lean.Compiler.LCNF.Simp
|
||||
public import Lean.Compiler.LCNF.ToImpure
|
||||
public import Lean.Compiler.LCNF.PushProj
|
||||
|
||||
public section
|
||||
|
||||
|
|
@ -141,6 +142,7 @@ def builtinPassManager : PassManager := {
|
|||
]
|
||||
impurePasses := #[
|
||||
saveImpure, -- End of impure phase
|
||||
pushProj (occurrence := 0),
|
||||
inferVisibility (phase := .impure),
|
||||
]
|
||||
}
|
||||
|
|
|
|||
158
src/Lean/Compiler/LCNF/PushProj.lean
Normal file
158
src/Lean/Compiler/LCNF/PushProj.lean
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Compiler.LCNF.CompilerM
|
||||
public import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.LCNF.Internalize
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
/-!
|
||||
This pass pushes projections into directly neighboring nested case statements.
|
||||
|
||||
Suppose we have an LCNF pure input that looks as follows:
|
||||
```
|
||||
cases a with
|
||||
| alt1 p1 p2 =>
|
||||
cases b with
|
||||
| alt2 p3 p4 =>
|
||||
...
|
||||
| alt3 p5 p6 =>
|
||||
...
|
||||
| ...
|
||||
```
|
||||
ToImpure will convert this into:
|
||||
```
|
||||
cases a with
|
||||
| alt1 p1 p2 =>
|
||||
let p1 := proj[0] a;
|
||||
let p2 := proj[1] a;
|
||||
cases b with
|
||||
| alt2 p3 p4 =>
|
||||
let p3 := proj[0] b;
|
||||
let p4 := proj[1] b;
|
||||
...
|
||||
| alt3 p5 p6 =>
|
||||
let p5 := proj[0] b;
|
||||
let p6 := proj[1] b;
|
||||
...
|
||||
| ...
|
||||
```
|
||||
Let's assume `p1` is used in both `alt2` and `alt3` and `p2` is used only in `alt3` then this pass
|
||||
will convert the code into:
|
||||
```
|
||||
cases a with
|
||||
| alt1 p1 p2 =>
|
||||
cases b with
|
||||
| alt2 p3 p4 =>
|
||||
let p1 := proj[0] a;
|
||||
let p3 := proj[0] b;
|
||||
let p4 := proj[1] b;
|
||||
...
|
||||
| alt3 p5 p6 =>
|
||||
let p1 := proj[0] a;
|
||||
let p2 := proj[1] a;
|
||||
let p5 := proj[0] b;
|
||||
let p6 := proj[1] b;
|
||||
...
|
||||
| ...
|
||||
```
|
||||
This helps to avoid loading memory that is not actually required in all branches.
|
||||
|
||||
Note that unlike `floatLetIn`, this pass is willing to duplicate projections that are being pushed
|
||||
around.
|
||||
|
||||
|
||||
TODO: we suspect it might also help with reuse analysis, check this. This pass was ported from IR to
|
||||
LCNF.
|
||||
-/
|
||||
|
||||
mutual
|
||||
|
||||
partial def Cases.pushProjs (c : Cases .impure) (decls : Array (CodeDecl .impure)) :
|
||||
CompilerM (Code .impure) := do
|
||||
let altsUsed := c.alts.map (·.getCode.collectUsed)
|
||||
let ctxUsed := ({} : FVarIdHashSet) |>.insert c.discr
|
||||
let (bs, alts) ← go decls c.alts altsUsed #[] ctxUsed
|
||||
let alts ← alts.mapM (·.mapCodeM Code.pushProj)
|
||||
let c := c.updateAlts alts
|
||||
return attachCodeDecls bs (.cases c)
|
||||
where
|
||||
/-
|
||||
Here:
|
||||
- `decls` are the declarations that are still under consideration for being pushed into `alts`
|
||||
- `alts` are the alternatives of the current case arms,
|
||||
- `altsUsed` contains the used fvars per arm, both these sets and `alts` will be updated as we push
|
||||
things into them
|
||||
- `ctx` is the set of declarations that we decided not to push into any alt already
|
||||
- `ctxUsed` fulfills the same purpose as `altsUsed` for `ctx`
|
||||
-/
|
||||
go (decls : Array (CodeDecl .impure)) (alts : Array (Alt .impure)) (altsUsed : Array FVarIdHashSet)
|
||||
(ctx : Array (CodeDecl .impure)) (ctxUsed : FVarIdHashSet) :
|
||||
CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) :=
|
||||
if decls.isEmpty then
|
||||
return (ctx.reverse, alts)
|
||||
else
|
||||
let b := decls.back!
|
||||
let bs := decls.pop
|
||||
let done := return (bs.push b ++ ctx.reverse, alts)
|
||||
let skip := go bs alts altsUsed (ctx.push b) (b.collectUsed ctxUsed)
|
||||
let push (fvar : FVarId) : CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) := do
|
||||
if !ctxUsed.contains fvar then
|
||||
let alts ← alts.mapIdxM fun i alt => alt.mapCodeM fun k => do
|
||||
if altsUsed[i]!.contains fvar then
|
||||
return attachCodeDecls #[b] k
|
||||
else
|
||||
return k
|
||||
let altsUsed := altsUsed.map fun used =>
|
||||
if used.contains fvar then
|
||||
b.collectUsed used
|
||||
else
|
||||
used
|
||||
go bs alts altsUsed ctx ctxUsed
|
||||
else
|
||||
skip
|
||||
match b with
|
||||
| .let decl =>
|
||||
match decl.value with
|
||||
| .uproj .. | .oproj .. | .sproj .. => push decl.fvarId
|
||||
-- TODO | .isShared .. => skip
|
||||
| _ => done
|
||||
| _ => done
|
||||
|
||||
partial def Code.pushProj (code : Code .impure) : CompilerM (Code .impure) := do
|
||||
go code #[]
|
||||
where
|
||||
go (c : Code .impure) (decls : Array (CodeDecl .impure)) : CompilerM (Code .impure) := do
|
||||
match c with
|
||||
| .let decl k => go k (decls.push (.let decl))
|
||||
| .jp decl k =>
|
||||
let decl ← decl.updateValue (← decl.value.pushProj)
|
||||
go k (decls.push (.jp decl))
|
||||
| .uset var i y k _ =>
|
||||
go k (decls.push (.uset var i y))
|
||||
| .sset var i offset y ty k _ =>
|
||||
go k (decls.push (.sset var i offset y ty))
|
||||
| .cases c => c.pushProjs decls
|
||||
| .jmp .. | .return .. | .unreach .. =>
|
||||
return attachCodeDecls decls c
|
||||
|
||||
end
|
||||
|
||||
def Decl.pushProj (decl : Decl .impure) : CompilerM (Decl .impure) := do
|
||||
let value ← decl.value.mapCodeM (·.pushProj)
|
||||
let decl := { decl with value }
|
||||
decl.internalize
|
||||
|
||||
public def pushProj (occurrence : Nat) : Pass :=
|
||||
Pass.mkPerDeclaration `pushProj .impure Decl.pushProj occurrence
|
||||
|
||||
builtin_initialize
|
||||
registerTraceClass `Compiler.pushProj (inherited := true)
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
123
tests/lean/run/compiler_push_proj.lean
Normal file
123
tests/lean/run/compiler_push_proj.lean
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
/-! This does some basic unit tests for the pushProj pass in LCNF -/
|
||||
|
||||
|
||||
/--
|
||||
trace: [Compiler.pushProj] size: 5
|
||||
def test1 a : tobj :=
|
||||
cases a : tobj
|
||||
| Option.none =>
|
||||
let _x.1 : tagged := 0;
|
||||
return _x.1
|
||||
| Option.some =>
|
||||
let val.2 : tobj := proj[0] a;
|
||||
return val.2
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.letVarTypes true in
|
||||
set_option trace.Compiler.pushProj true in
|
||||
def test1 (a : Option Nat) : Nat :=
|
||||
match a with
|
||||
| some a => a
|
||||
| none => 0
|
||||
|
||||
|
||||
/--
|
||||
trace: [Compiler.pushProj] size: 10
|
||||
def test2 a b : tobj :=
|
||||
cases a : tobj
|
||||
| Option.none =>
|
||||
return a
|
||||
| Option.some =>
|
||||
cases b : tobj
|
||||
| Option.none =>
|
||||
return a
|
||||
| Option.some =>
|
||||
let val.1 : tobj := proj[0] a;
|
||||
let val.2 : tobj := proj[0] b;
|
||||
let _x.3 : tobj := Nat.add val.1 val.2;
|
||||
let _x.4 : obj := ctor_1[Option.some] _x.3;
|
||||
return _x.4
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.letVarTypes true in
|
||||
set_option trace.Compiler.pushProj true in
|
||||
def test2 (a b : Option Nat) : Option Nat :=
|
||||
match a with
|
||||
| some a =>
|
||||
match b with
|
||||
| some b => some (a + b)
|
||||
| none => some a
|
||||
| none => none
|
||||
|
||||
/--
|
||||
trace: [Compiler.pushProj] size: 14
|
||||
def test3 a b : tobj :=
|
||||
cases a : tobj
|
||||
| Option.none =>
|
||||
return a
|
||||
| Option.some =>
|
||||
cases b : tobj
|
||||
| Option.none =>
|
||||
let val.1 : tobj := proj[0] a;
|
||||
let _x.2 : tagged := 1;
|
||||
let _x.3 : tobj := Nat.add val.1 _x.2;
|
||||
let _x.4 : obj := ctor_1[Option.some] _x.3;
|
||||
return _x.4
|
||||
| Option.some =>
|
||||
let val.5 : tobj := proj[0] a;
|
||||
let val.6 : tobj := proj[0] b;
|
||||
let _x.7 : tobj := Nat.add val.5 val.6;
|
||||
let _x.8 : obj := ctor_1[Option.some] _x.7;
|
||||
return _x.8
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.letVarTypes true in
|
||||
set_option trace.Compiler.pushProj true in
|
||||
def test3 (a b : Option Nat) : Option Nat :=
|
||||
match a with
|
||||
| some a =>
|
||||
match b with
|
||||
| some b => some (a + b)
|
||||
| none => some (a + 1)
|
||||
| none => none
|
||||
|
||||
/--
|
||||
trace: [Compiler.pushProj] size: 18
|
||||
def test4 a b c : tobj :=
|
||||
cases a : tobj
|
||||
| Option.none =>
|
||||
return a
|
||||
| Option.some =>
|
||||
cases b : tobj
|
||||
| Option.none =>
|
||||
let val.1 : tobj := proj[0] a;
|
||||
let _x.2 : tagged := 1;
|
||||
let _x.3 : tobj := Nat.add val.1 _x.2;
|
||||
let _x.4 : obj := ctor_1[Option.some] _x.3;
|
||||
return _x.4
|
||||
| Option.some =>
|
||||
cases c : tobj
|
||||
| Bool.false =>
|
||||
let _x.5 : tagged := ctor_0[Option.none] ;
|
||||
return _x.5
|
||||
| Bool.true =>
|
||||
let val.6 : tobj := proj[0] a;
|
||||
let val.7 : tobj := proj[0] b;
|
||||
let _x.8 : tobj := Nat.add val.6 val.7;
|
||||
let _x.9 : obj := ctor_1[Option.some] _x.8;
|
||||
return _x.9
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option pp.letVarTypes true in
|
||||
set_option trace.Compiler.pushProj true in
|
||||
def test4 (a b : Option Nat) (c : Bool) : Option Nat :=
|
||||
match a with
|
||||
| some a =>
|
||||
match b with
|
||||
| some b =>
|
||||
match c with
|
||||
| true => some (a + b)
|
||||
| false => none
|
||||
| none => some (a + 1)
|
||||
| none => none
|
||||
|
||||
Loading…
Add table
Reference in a new issue