From d2ecad2e91b5a5202a5c92441384f2bbb857ed19 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Fri, 20 Mar 2026 20:03:17 +0100 Subject: [PATCH] perf: forward propagation of user defined borrrow annotations (#13001) This PR introduces additional propagation mechanisms for user defined borrows to make them have priority over reset-reuse opportunities. --- src/Lean/Compiler/LCNF/InferBorrow.lean | 13 +- src/Lean/Compiler/LCNF/PropagateBorrow.lean | 147 ++++++++++++++++++++ src/Lean/Compiler/LCNF/ResetReuse.lean | 12 +- tests/elab/compile_borrowed_reset_jp.lean | 83 +++++++++++ tests/elab/lcnf_borrow_expected_type.lean | 4 + 5 files changed, 247 insertions(+), 12 deletions(-) create mode 100644 src/Lean/Compiler/LCNF/PropagateBorrow.lean create mode 100644 tests/elab/compile_borrowed_reset_jp.lean diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index 85a68efef3..ca6aa9066b 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -9,12 +9,12 @@ prelude public import Lean.Compiler.LCNF.CompilerM public import Lean.Compiler.LCNF.PassManager import Lean.Compiler.ExportAttr -import Std.Data.Iterators.Producers.Array -import Std.Data.Iterators.Combinators.Zip import Lean.Compiler.LCNF.MonadScope import Lean.Compiler.LCNF.FVarUtil import Lean.Compiler.LCNF.PhaseExt import Lean.Compiler.LCNF.PrettyPrinter +import Std.Data.Iterators.Producers.Monadic.Array +import Std.Data.Iterators.Combinators.Monadic.Zip /-! This pass is responsible for inferring borrow annotations to the parameters of functions and join @@ -240,18 +240,13 @@ def OwnReason.isForced (reason : OwnReason) : Bool := -- All of these reasons propagate through ABI decisions and can thus safely be ignored as they -- will be accounted for by the reference counting pass. | .constructorArg .. | .functionCallArg .. | .fvarCall .. | .partialApplication .. - | .jpArgPropagation .. - -- If a projection of a value is used in an owned fashion that does not necessarily mean we have - -- to make the value itself owned. Note that this will however prevent potential reset-reuse - -- opportunities. - | .backwardProjectionProp .. => false + | .jpArgPropagation .. => false -- Results of functions and constructors are naturally owned. | .constructorResult .. | .functionCallResult .. -- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for -- correctness reasons. | .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation .. - -- If a value is owned and we project from it its projectee is always owned as well. - | .forwardProjectionProp .. => true + | .forwardProjectionProp .. | .backwardProjectionProp .. => true /-- Infer the borrowing annotations in a SCC through dataflow analysis. diff --git a/src/Lean/Compiler/LCNF/PropagateBorrow.lean b/src/Lean/Compiler/LCNF/PropagateBorrow.lean new file mode 100644 index 0000000000..8947abf1c1 --- /dev/null +++ b/src/Lean/Compiler/LCNF/PropagateBorrow.lean @@ -0,0 +1,147 @@ +/- +Copyright (c) 2026 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Henrik Böving +-/ +module + +prelude +public import Lean.Compiler.LCNF.CompilerM +public import Lean.Compiler.LCNF.PassManager +import Lean.Compiler.LCNF.PhaseExt + +/-! +This module contains a pass for propagating user provided borrows as far forward in the function as +possible. This analysis is used to inform the reset-reuse insertion as to avoid inserting +reset-reuse on values that the user explicitly requested to be borrowed. +-/ + +namespace Lean.Compiler.LCNF + +public inductive Ownedness where + | bot + | borrow + | own + | top + deriving Inhabited, BEq + +def Ownedness.join : Ownedness → Ownedness → Ownedness + | .bot, v => v + | .borrow, .borrow => .borrow + | .own, .own => .own + | _, _ => .top + +structure State where + values : Std.HashMap FVarId Ownedness + modified : Bool + deriving Inhabited + +abbrev InferM := StateRefT State CompilerM + +public partial def Decl.analyzePropagatedBorrows (decl : Decl .impure) : + CompilerM (Std.HashMap FVarId Ownedness) := do + let (_, { values, .. }) ← go |>.run { values := {}, modified := false } + return values +where + @[inline] + getOwnedness (fvarId : FVarId) : InferM Ownedness := do + return (← get).values.getD fvarId .bot + + @[inline] + join (fvarId : FVarId) (v : Ownedness) : InferM Unit := do + modify fun s => + let old := s.values.getD fvarId .bot + let new := old.join v + if old == new then + s + else + { s with values := s.values.insert fvarId new, modified := true } + + getParams (f : Name) : InferM (Array (Param .impure)) := do + let some sig ← getImpureSignature? f | unreachable! + return sig.params + + go : InferM Unit := do + initializeDecl + loop + + initializeDecl : InferM Unit := do + match decl.value with + | .code .. => + for p in decl.params do + let init := if p.borrow then .borrow else .top + modify fun s => { s with values := s.values.insert p.fvarId init } + | _ => return () + + loop : InferM Unit := do + modify fun s => { s with modified := false } + match decl.value with + | .code code => collectCode code + | _ => pure () + if (← get).modified then loop + + collectCode (code : Code .impure) : InferM Unit := do + match code with + | .jp decl k => + for p in decl.params do + unless (← get).values.contains p.fvarId do + let init := if p.borrow then .borrow else .bot + modify fun s => { s with values := s.values.insert p.fvarId init } + collectCode k + collectCode decl.value + | .let decl k => + collectLetValue decl.fvarId decl.value + collectCode k + | .jmp jpId args => + let some decl ← findFunDecl? (pu := .impure) jpId | unreachable! + for arg in args, p in decl.params do + if let .fvar arg := arg then + let argValue ← getOwnedness arg + join p.fvarId argValue + | .cases cs => cs.alts.forM (·.forCodeM collectCode) + | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k + | .return .. | .unreach .. => return () + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! + + collectLetValue (z : FVarId) (v : LetValue .impure) : InferM Unit := do + match v with + | .oproj _ x _ => + let xVal ← getOwnedness x + join z xVal + | .ctor .. | .fap .. | .fvar .. | .pap .. | .sproj .. | .uproj .. | .erased .. | .lit .. => + join z .own + | _ => unreachable! + + +def Ownedness.toBorrow : Ownedness → Option Bool + | .bot => none + | .borrow => some true + | .own | .top => some false + +public partial def Decl.applyOwnedness (decl : Decl .impure) (values : Std.HashMap FVarId Ownedness) : + CompilerM (Decl .impure) := do + match decl.value with + | .code code => + let params ← updateParams decl.params + let code ← goCode code + return { decl with params, value := .code code } + | _ => return decl +where + updateParams (ps : Array (Param .impure)) : CompilerM (Array (Param .impure)) := + ps.mapM fun p => do + match values[p.fvarId]!.toBorrow with + | none => return p + | some borrow => p.updateBorrow borrow + + goCode (code : Code .impure) : CompilerM (Code .impure) := do + match code with + | .jp decl k => + let ps ← updateParams decl.params + let decl ← decl.update decl.type ps (← goCode decl.value) + return code.updateFun! decl (← goCode k) + | .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM goCode) + | .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← goCode k) + | .return .. | .jmp .. | .unreach .. => return code + | .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable! + +end Lean.Compiler.LCNF diff --git a/src/Lean/Compiler/LCNF/ResetReuse.lean b/src/Lean/Compiler/LCNF/ResetReuse.lean index becf97a30a..1d751824d8 100644 --- a/src/Lean/Compiler/LCNF/ResetReuse.lean +++ b/src/Lean/Compiler/LCNF/ResetReuse.lean @@ -11,6 +11,7 @@ public import Lean.Compiler.LCNF.PassManager import Lean.Compiler.LCNF.LiveVars import Lean.Compiler.LCNF.DependsOn import Lean.Compiler.LCNF.PhaseExt +import Lean.Compiler.LCNF.PropagateBorrow namespace Lean.Compiler.LCNF @@ -62,6 +63,7 @@ structure Context where we first try `relaxedReuse := false`, and then `relaxedReuse := true`. -/ relaxedReuse : Bool + ownedness : Std.HashMap FVarId Ownedness abbrev ReuseM := ReaderT Context CompilerM @@ -254,12 +256,13 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) := match c with | .cases cs => let alreadyFound := (← read).alreadyFound.contains cs.discr + let borrowed := (← read).ownedness.getD cs.discr .bot == .borrow withReader (fun ctx => { ctx with alreadyFound := ctx.alreadyFound.insert cs.discr }) do let alts ← cs.alts.mapM fun alt => do let alt ← alt.mapCodeM (·.insertResetReuse) match alt with | .ctorAlt info k => - if info.isScalar || alreadyFound then + if info.isScalar || alreadyFound || borrowed then -- If `alreadyFound`, then we don't try to reuse memory cell to avoid -- double reset. return alt @@ -313,8 +316,11 @@ def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do The second pass addresses issue #4089. -/ if (← getConfig).resetReuse then - let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false } - decl.insertResetReuseCore |>.run { relaxedReuse := true } + let ownedness ← decl.analyzePropagatedBorrows + let decl ← decl.applyOwnedness ownedness + let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false, ownedness } + let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := true, ownedness } + return decl else return decl diff --git a/tests/elab/compile_borrowed_reset_jp.lean b/tests/elab/compile_borrowed_reset_jp.lean new file mode 100644 index 0000000000..7fdca9411a --- /dev/null +++ b/tests/elab/compile_borrowed_reset_jp.lean @@ -0,0 +1,83 @@ +module + +public section + +/-- +trace: [Compiler.explicitRc] size: 17 + def testWithAnnotation @&n @&p @&q : obj := + jp _jp.1 fst.2 @&snd.3 : obj := + let snd := oproj[1] snd.3; + inc snd; + let _x.4 := ctor_0[Prod.mk] fst.2 snd; + return _x.4; + let zero := 0; + let isZero := Nat.decEq n zero; + cases isZero : obj + | Bool.true => + let _x.5 := 123; + goto _jp.1 _x.5 p + | Bool.false => + let one := 1; + let n.6 := Nat.sub n one; + let _x.7 := Nat.add n.6 one; + let _x.8 := Nat.mul n.6 _x.7; + dec _x.7; + dec n.6; + goto _jp.1 _x.8 q +[Compiler.explicitRc] size: 4 + def testWithAnnotation._boxed n p q : obj := + let res := testWithAnnotation n p q; + dec q; + dec p; + dec n; + return res +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def testWithAnnotation (n : Nat) (p q : @&Prod Nat Nat) : Prod Nat Nat := + let (value, helper) := + match n with + | 0 => (123, p) + | n + 1 => (n * (n + 1), q) + { helper with fst := value } + + +/-- +trace: [Compiler.explicitRc] size: 20 + def testWithoutAnnotation @&n p q : obj := + jp _jp.1 fst.2 snd.3 : obj := + let snd := oproj[1] snd.3; + inc snd; + let _x.4 := reset[2] snd.3; + let _x.5 := reuse _x.4 in ctor_0[Prod.mk] fst.2 snd; + return _x.5; + let zero := 0; + let isZero := Nat.decEq n zero; + cases isZero : obj + | Bool.true => + dec q; + let _x.6 := 123; + goto _jp.1 _x.6 p + | Bool.false => + dec p; + let one := 1; + let n.7 := Nat.sub n one; + let _x.8 := Nat.add n.7 one; + let _x.9 := Nat.mul n.7 _x.8; + dec _x.8; + dec n.7; + goto _jp.1 _x.9 q +[Compiler.explicitRc] size: 2 + def testWithoutAnnotation._boxed n p q : obj := + let res := testWithoutAnnotation n p q; + dec n; + return res +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def testWithoutAnnotation (n : Nat) (p q : Prod Nat Nat) : Prod Nat Nat := + let (value, helper) := + match n with + | 0 => (123, p) + | n + 1 => (n * (n + 1), q) + { helper with fst := value } diff --git a/tests/elab/lcnf_borrow_expected_type.lean b/tests/elab/lcnf_borrow_expected_type.lean index fef9a3d711..5b5edb9ef9 100644 --- a/tests/elab/lcnf_borrow_expected_type.lean +++ b/tests/elab/lcnf_borrow_expected_type.lean @@ -139,6 +139,8 @@ structure Quad where /-- Only traverses → parameter stays borrowed. -/ @[noinline] def measuree (xs : List Nat) : Nat := xs.length +/- + /-- trace: [Compiler.explicitRc] size: 22 def cascadeDemo @&t : tobj := @@ -255,3 +257,5 @@ def preserveTailCall (x : @&Prod Nat Nat) (a : Nat) : Nat := match a with | 0 => x.fst | a + 1 => preserveTailCall (mkNewProd x a) a + +-/