diff --git a/src/Lean/Compiler/IR.lean b/src/Lean/Compiler/IR.lean index c65a6e6643..f70aa6b158 100644 --- a/src/Lean/Compiler/IR.lean +++ b/src/Lean/Compiler/IR.lean @@ -47,8 +47,6 @@ def compile (decls : Array Decl) : CompilerM (Array Decl) := do logDecls `init decls checkDecls decls let mut decls := decls - decls := decls.map Decl.pushProj - logDecls `push_proj decls if compiler.reuse.get (← getOptions) then decls := decls.map (Decl.insertResetReuse (← getEnv)) logDecls `reset_reuse decls diff --git a/src/Lean/Compiler/LCNF/Basic.lean b/src/Lean/Compiler/LCNF/Basic.lean index c7d242096c..e7f509686f 100644 --- a/src/Lean/Compiler/LCNF/Basic.lean +++ b/src/Lean/Compiler/LCNF/Basic.lean @@ -421,7 +421,7 @@ inductive Code (pu : Purity) where | return (fvarId : FVarId) | unreach (type : Expr) | uset (var : FVarId) (i : Nat) (y : FVarId) (k : Code pu) (h : pu = .impure := by purity_tac) - | sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac) + | sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (k : Code pu) (h : pu = .impure := by purity_tac) deriving Inhabited end @@ -497,10 +497,13 @@ inductive CodeDecl (pu : Purity) where | let (decl : LetDecl pu) | fun (decl : FunDecl pu) (h : pu = .pure := by purity_tac) | jp (decl : FunDecl pu) + | uset (var : FVarId) (i : Nat) (y : FVarId) (h : pu = .impure := by purity_tac) + | sset (var : FVarId) (i : Nat) (offset : Nat) (y : FVarId) (ty : Expr) (h : pu = .impure := by purity_tac) deriving Inhabited def CodeDecl.fvarId : CodeDecl pu → FVarId | .let decl | .fun decl _ | .jp decl => decl.fvarId + | .uset var .. | .sset var .. => var def attachCodeDecls (decls : Array (CodeDecl pu)) (code : Code pu) : Code pu := go decls.size code @@ -511,6 +514,8 @@ where | .let decl => go (i-1) (.let decl code) | .fun decl _ => go (i-1) (.fun decl code) | .jp decl => go (i-1) (.jp decl code) + | .uset var idx y _ => go (i-1) (.uset var idx y code) + | .sset var idx offset y ty _ => go (i-1) (.sset var idx offset y ty code) else code @@ -1072,6 +1077,13 @@ end @[inline] def collectUsedAtExpr (s : FVarIdHashSet) (e : Expr) : FVarIdHashSet := collectType e s +def CodeDecl.collectUsed (codeDecl : CodeDecl pu) (s : FVarIdHashSet := ∅) : FVarIdHashSet := + match codeDecl with + | .let decl => collectLetValue decl.value <| collectType decl.type s + | .jp decl | .fun decl _ => decl.collectUsed s + | .sset var _ _ y ty _ => s.insert var |>.insert y |> collectType ty + | .uset var _ y _ => s.insert var |>.insert y + /-- Traverse the given block of potentially mutually recursive functions and mark a declaration `f` as recursive if there is an application diff --git a/src/Lean/Compiler/LCNF/CompilerM.lean b/src/Lean/Compiler/LCNF/CompilerM.lean index c813afceaa..39a16c4a36 100644 --- a/src/Lean/Compiler/LCNF/CompilerM.lean +++ b/src/Lean/Compiler/LCNF/CompilerM.lean @@ -149,6 +149,7 @@ def eraseCodeDecl (decl : CodeDecl pu) : CompilerM Unit := do match decl with | .let decl => eraseLetDecl decl | .jp decl | .fun decl _ => eraseFunDecl decl + | .sset .. | .uset .. => return () /-- Erase all free variables occurring in `decls` from the local context. diff --git a/src/Lean/Compiler/LCNF/DependsOn.lean b/src/Lean/Compiler/LCNF/DependsOn.lean index d80231a5e4..3c95c93e1d 100644 --- a/src/Lean/Compiler/LCNF/DependsOn.lean +++ b/src/Lean/Compiler/LCNF/DependsOn.lean @@ -57,6 +57,8 @@ def CodeDecl.dependsOn (decl : CodeDecl pu) (s : FVarIdSet) : Bool := match decl with | .let decl => decl.dependsOn s | .jp decl | .fun decl _ => decl.dependsOn s + | .uset var _ y _ => s.contains var || s.contains y + | .sset var _ _ y ty _ => s.contains var || s.contains y || (typeDepOn ty s) /-- Return `true` is `c` depends on a free variable in `s`. diff --git a/src/Lean/Compiler/LCNF/FVarUtil.lean b/src/Lean/Compiler/LCNF/FVarUtil.lean index 2658df1c99..d5a031a0f6 100644 --- a/src/Lean/Compiler/LCNF/FVarUtil.lean +++ b/src/Lean/Compiler/LCNF/FVarUtil.lean @@ -190,11 +190,15 @@ instance : TraverseFVar (CodeDecl pu) where | .fun decl _ => return .fun (← mapFVarM f decl) | .jp decl => return .jp (← mapFVarM f decl) | .let decl => return .let (← mapFVarM f decl) + | .uset var i y _ => return .uset (← f var) i (← f y) + | .sset var i offset y ty _ => return .sset (← f var) i offset (← f y) (← mapFVarM f ty) forFVarM f decl := match decl with | .fun decl _ => forFVarM f decl | .jp decl => forFVarM f decl | .let decl => forFVarM f decl + | .uset var i y _ => do f var; f y + | .sset var i offset y ty _ => do f var; f y; forFVarM f ty instance : TraverseFVar (Alt pu) where mapFVarM f alt := do diff --git a/src/Lean/Compiler/LCNF/Internalize.lean b/src/Lean/Compiler/LCNF/Internalize.lean index 271b81943d..8ea9d06323 100644 --- a/src/Lean/Compiler/LCNF/Internalize.lean +++ b/src/Lean/Compiler/LCNF/Internalize.lean @@ -158,6 +158,17 @@ partial def internalizeCodeDecl (decl : CodeDecl pu) : InternalizeM pu (CodeDecl | .let decl => return .let (← internalizeLetDecl decl) | .fun decl _ => return .fun (← internalizeFunDecl decl) | .jp decl => return .jp (← internalizeFunDecl decl) + | .uset var i y _ => + -- Something weird should be happening if these become erased... + let .fvar var ← normFVar var | unreachable! + let .fvar y ← normFVar y | unreachable! + return .uset var i y + | .sset var i offset y ty _ => + let .fvar var ← normFVar var | unreachable! + let .fvar y ← normFVar y | unreachable! + let ty ← normExpr ty + return .sset var i offset y ty + end Internalize diff --git a/src/Lean/Compiler/LCNF/LCtx.lean b/src/Lean/Compiler/LCNF/LCtx.lean index dd05fd6489..86e6df70c2 100644 --- a/src/Lean/Compiler/LCNF/LCtx.lean +++ b/src/Lean/Compiler/LCNF/LCtx.lean @@ -77,6 +77,7 @@ mutual | .let decl k => eraseCode k <| lctx.eraseLetDecl decl | .jp decl k | .fun decl k _ => eraseCode k <| eraseFunDecl lctx decl | .cases c => eraseAlts c.alts lctx + | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => eraseCode k lctx | _ => lctx end diff --git a/src/Lean/Compiler/LCNF/Passes.lean b/src/Lean/Compiler/LCNF/Passes.lean index 7fd78cfc6d..441346d428 100644 --- a/src/Lean/Compiler/LCNF/Passes.lean +++ b/src/Lean/Compiler/LCNF/Passes.lean @@ -20,6 +20,7 @@ public import Lean.Compiler.LCNF.ExtractClosed public import Lean.Compiler.LCNF.Visibility public import Lean.Compiler.LCNF.Simp public import Lean.Compiler.LCNF.ToImpure +public import Lean.Compiler.LCNF.PushProj public section @@ -141,6 +142,7 @@ def builtinPassManager : PassManager := { ] impurePasses := #[ saveImpure, -- End of impure phase + pushProj (occurrence := 0), inferVisibility (phase := .impure), ] } diff --git a/src/Lean/Compiler/LCNF/PushProj.lean b/src/Lean/Compiler/LCNF/PushProj.lean new file mode 100644 index 0000000000..067300e51a --- /dev/null +++ b/src/Lean/Compiler/LCNF/PushProj.lean @@ -0,0 +1,158 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +module + +prelude +public import Lean.Compiler.LCNF.CompilerM +public import Lean.Compiler.LCNF.PassManager +import Lean.Compiler.LCNF.Internalize + +namespace Lean.Compiler.LCNF + +/-! +This pass pushes projections into directly neighboring nested case statements. + +Suppose we have an LCNF pure input that looks as follows: +``` +cases a with +| alt1 p1 p2 => + cases b with + | alt2 p3 p4 => + ... + | alt3 p5 p6 => + ... +| ... +``` +ToImpure will convert this into: +``` +cases a with +| alt1 p1 p2 => + let p1 := proj[0] a; + let p2 := proj[1] a; + cases b with + | alt2 p3 p4 => + let p3 := proj[0] b; + let p4 := proj[1] b; + ... + | alt3 p5 p6 => + let p5 := proj[0] b; + let p6 := proj[1] b; + ... +| ... +``` +Let's assume `p1` is used in both `alt2` and `alt3` and `p2` is used only in `alt3` then this pass +will convert the code into: +``` +cases a with +| alt1 p1 p2 => + cases b with + | alt2 p3 p4 => + let p1 := proj[0] a; + let p3 := proj[0] b; + let p4 := proj[1] b; + ... + | alt3 p5 p6 => + let p1 := proj[0] a; + let p2 := proj[1] a; + let p5 := proj[0] b; + let p6 := proj[1] b; + ... +| ... +``` +This helps to avoid loading memory that is not actually required in all branches. + +Note that unlike `floatLetIn`, this pass is willing to duplicate projections that are being pushed +around. + + +TODO: we suspect it might also help with reuse analysis, check this. This pass was ported from IR to +LCNF. +-/ + +mutual + +partial def Cases.pushProjs (c : Cases .impure) (decls : Array (CodeDecl .impure)) : + CompilerM (Code .impure) := do + let altsUsed := c.alts.map (·.getCode.collectUsed) + let ctxUsed := ({} : FVarIdHashSet) |>.insert c.discr + let (bs, alts) ← go decls c.alts altsUsed #[] ctxUsed + let alts ← alts.mapM (·.mapCodeM Code.pushProj) + let c := c.updateAlts alts + return attachCodeDecls bs (.cases c) +where + /- + Here: + - `decls` are the declarations that are still under consideration for being pushed into `alts` + - `alts` are the alternatives of the current case arms, + - `altsUsed` contains the used fvars per arm, both these sets and `alts` will be updated as we push + things into them + - `ctx` is the set of declarations that we decided not to push into any alt already + - `ctxUsed` fulfills the same purpose as `altsUsed` for `ctx` + -/ + go (decls : Array (CodeDecl .impure)) (alts : Array (Alt .impure)) (altsUsed : Array FVarIdHashSet) + (ctx : Array (CodeDecl .impure)) (ctxUsed : FVarIdHashSet) : + CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) := + if decls.isEmpty then + return (ctx.reverse, alts) + else + let b := decls.back! + let bs := decls.pop + let done := return (bs.push b ++ ctx.reverse, alts) + let skip := go bs alts altsUsed (ctx.push b) (b.collectUsed ctxUsed) + let push (fvar : FVarId) : CompilerM (Array (CodeDecl .impure) × Array (Alt .impure)) := do + if !ctxUsed.contains fvar then + let alts ← alts.mapIdxM fun i alt => alt.mapCodeM fun k => do + if altsUsed[i]!.contains fvar then + return attachCodeDecls #[b] k + else + return k + let altsUsed := altsUsed.map fun used => + if used.contains fvar then + b.collectUsed used + else + used + go bs alts altsUsed ctx ctxUsed + else + skip + match b with + | .let decl => + match decl.value with + | .uproj .. | .oproj .. | .sproj .. => push decl.fvarId + -- TODO | .isShared .. => skip + | _ => done + | _ => done + +partial def Code.pushProj (code : Code .impure) : CompilerM (Code .impure) := do + go code #[] +where + go (c : Code .impure) (decls : Array (CodeDecl .impure)) : CompilerM (Code .impure) := do + match c with + | .let decl k => go k (decls.push (.let decl)) + | .jp decl k => + let decl ← decl.updateValue (← decl.value.pushProj) + go k (decls.push (.jp decl)) + | .uset var i y k _ => + go k (decls.push (.uset var i y)) + | .sset var i offset y ty k _ => + go k (decls.push (.sset var i offset y ty)) + | .cases c => c.pushProjs decls + | .jmp .. | .return .. | .unreach .. => + return attachCodeDecls decls c + +end + +def Decl.pushProj (decl : Decl .impure) : CompilerM (Decl .impure) := do + let value ← decl.value.mapCodeM (·.pushProj) + let decl := { decl with value } + decl.internalize + +public def pushProj (occurrence : Nat) : Pass := + Pass.mkPerDeclaration `pushProj .impure Decl.pushProj occurrence + +builtin_initialize + registerTraceClass `Compiler.pushProj (inherited := true) + +end Lean.Compiler.LCNF diff --git a/tests/lean/run/compiler_push_proj.lean b/tests/lean/run/compiler_push_proj.lean new file mode 100644 index 0000000000..489bedaf7b --- /dev/null +++ b/tests/lean/run/compiler_push_proj.lean @@ -0,0 +1,123 @@ +/-! This does some basic unit tests for the pushProj pass in LCNF -/ + + +/-- +trace: [Compiler.pushProj] size: 5 + def test1 a : tobj := + cases a : tobj + | Option.none => + let _x.1 : tagged := 0; + return _x.1 + | Option.some => + let val.2 : tobj := proj[0] a; + return val.2 +-/ +#guard_msgs in +set_option pp.letVarTypes true in +set_option trace.Compiler.pushProj true in +def test1 (a : Option Nat) : Nat := + match a with + | some a => a + | none => 0 + + +/-- +trace: [Compiler.pushProj] size: 10 + def test2 a b : tobj := + cases a : tobj + | Option.none => + return a + | Option.some => + cases b : tobj + | Option.none => + return a + | Option.some => + let val.1 : tobj := proj[0] a; + let val.2 : tobj := proj[0] b; + let _x.3 : tobj := Nat.add val.1 val.2; + let _x.4 : obj := ctor_1[Option.some] _x.3; + return _x.4 +-/ +#guard_msgs in +set_option pp.letVarTypes true in +set_option trace.Compiler.pushProj true in +def test2 (a b : Option Nat) : Option Nat := + match a with + | some a => + match b with + | some b => some (a + b) + | none => some a + | none => none + +/-- +trace: [Compiler.pushProj] size: 14 + def test3 a b : tobj := + cases a : tobj + | Option.none => + return a + | Option.some => + cases b : tobj + | Option.none => + let val.1 : tobj := proj[0] a; + let _x.2 : tagged := 1; + let _x.3 : tobj := Nat.add val.1 _x.2; + let _x.4 : obj := ctor_1[Option.some] _x.3; + return _x.4 + | Option.some => + let val.5 : tobj := proj[0] a; + let val.6 : tobj := proj[0] b; + let _x.7 : tobj := Nat.add val.5 val.6; + let _x.8 : obj := ctor_1[Option.some] _x.7; + return _x.8 +-/ +#guard_msgs in +set_option pp.letVarTypes true in +set_option trace.Compiler.pushProj true in +def test3 (a b : Option Nat) : Option Nat := + match a with + | some a => + match b with + | some b => some (a + b) + | none => some (a + 1) + | none => none + +/-- +trace: [Compiler.pushProj] size: 18 + def test4 a b c : tobj := + cases a : tobj + | Option.none => + return a + | Option.some => + cases b : tobj + | Option.none => + let val.1 : tobj := proj[0] a; + let _x.2 : tagged := 1; + let _x.3 : tobj := Nat.add val.1 _x.2; + let _x.4 : obj := ctor_1[Option.some] _x.3; + return _x.4 + | Option.some => + cases c : tobj + | Bool.false => + let _x.5 : tagged := ctor_0[Option.none] ; + return _x.5 + | Bool.true => + let val.6 : tobj := proj[0] a; + let val.7 : tobj := proj[0] b; + let _x.8 : tobj := Nat.add val.6 val.7; + let _x.9 : obj := ctor_1[Option.some] _x.8; + return _x.9 +-/ +#guard_msgs in +set_option pp.letVarTypes true in +set_option trace.Compiler.pushProj true in +def test4 (a b : Option Nat) (c : Bool) : Option Nat := + match a with + | some a => + match b with + | some b => + match c with + | true => some (a + b) + | false => none + | none => some (a + 1) + | none => none +