perf: forward propagation of user defined borrrow annotations (#13001)

This PR introduces additional propagation mechanisms for user defined
borrows to make them have priority over reset-reuse opportunities.
This commit is contained in:
Henrik Böving 2026-03-20 20:03:17 +01:00 committed by GitHub
parent 7097e37a1c
commit d2ecad2e91
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 247 additions and 12 deletions

View file

@ -9,12 +9,12 @@ prelude
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.ExportAttr
import Std.Data.Iterators.Producers.Array
import Std.Data.Iterators.Combinators.Zip
import Lean.Compiler.LCNF.MonadScope
import Lean.Compiler.LCNF.FVarUtil
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.PrettyPrinter
import Std.Data.Iterators.Producers.Monadic.Array
import Std.Data.Iterators.Combinators.Monadic.Zip
/-!
This pass is responsible for inferring borrow annotations to the parameters of functions and join
@ -240,18 +240,13 @@ def OwnReason.isForced (reason : OwnReason) : Bool :=
-- All of these reasons propagate through ABI decisions and can thus safely be ignored as they
-- will be accounted for by the reference counting pass.
| .constructorArg .. | .functionCallArg .. | .fvarCall .. | .partialApplication ..
| .jpArgPropagation ..
-- If a projection of a value is used in an owned fashion that does not necessarily mean we have
-- to make the value itself owned. Note that this will however prevent potential reset-reuse
-- opportunities.
| .backwardProjectionProp .. => false
| .jpArgPropagation .. => false
-- Results of functions and constructors are naturally owned.
| .constructorResult .. | .functionCallResult ..
-- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for
-- correctness reasons.
| .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation ..
-- If a value is owned and we project from it its projectee is always owned as well.
| .forwardProjectionProp .. => true
| .forwardProjectionProp .. | .backwardProjectionProp .. => true
/--
Infer the borrowing annotations in a SCC through dataflow analysis.

View file

@ -0,0 +1,147 @@
/-
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
module
prelude
public import Lean.Compiler.LCNF.CompilerM
public import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.PhaseExt
/-!
This module contains a pass for propagating user provided borrows as far forward in the function as
possible. This analysis is used to inform the reset-reuse insertion as to avoid inserting
reset-reuse on values that the user explicitly requested to be borrowed.
-/
namespace Lean.Compiler.LCNF
public inductive Ownedness where
| bot
| borrow
| own
| top
deriving Inhabited, BEq
def Ownedness.join : Ownedness → Ownedness → Ownedness
| .bot, v => v
| .borrow, .borrow => .borrow
| .own, .own => .own
| _, _ => .top
structure State where
values : Std.HashMap FVarId Ownedness
modified : Bool
deriving Inhabited
abbrev InferM := StateRefT State CompilerM
public partial def Decl.analyzePropagatedBorrows (decl : Decl .impure) :
CompilerM (Std.HashMap FVarId Ownedness) := do
let (_, { values, .. }) ← go |>.run { values := {}, modified := false }
return values
where
@[inline]
getOwnedness (fvarId : FVarId) : InferM Ownedness := do
return (← get).values.getD fvarId .bot
@[inline]
join (fvarId : FVarId) (v : Ownedness) : InferM Unit := do
modify fun s =>
let old := s.values.getD fvarId .bot
let new := old.join v
if old == new then
s
else
{ s with values := s.values.insert fvarId new, modified := true }
getParams (f : Name) : InferM (Array (Param .impure)) := do
let some sig ← getImpureSignature? f | unreachable!
return sig.params
go : InferM Unit := do
initializeDecl
loop
initializeDecl : InferM Unit := do
match decl.value with
| .code .. =>
for p in decl.params do
let init := if p.borrow then .borrow else .top
modify fun s => { s with values := s.values.insert p.fvarId init }
| _ => return ()
loop : InferM Unit := do
modify fun s => { s with modified := false }
match decl.value with
| .code code => collectCode code
| _ => pure ()
if (← get).modified then loop
collectCode (code : Code .impure) : InferM Unit := do
match code with
| .jp decl k =>
for p in decl.params do
unless (← get).values.contains p.fvarId do
let init := if p.borrow then .borrow else .bot
modify fun s => { s with values := s.values.insert p.fvarId init }
collectCode k
collectCode decl.value
| .let decl k =>
collectLetValue decl.fvarId decl.value
collectCode k
| .jmp jpId args =>
let some decl ← findFunDecl? (pu := .impure) jpId | unreachable!
for arg in args, p in decl.params do
if let .fvar arg := arg then
let argValue ← getOwnedness arg
join p.fvarId argValue
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
| .return .. | .unreach .. => return ()
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
collectLetValue (z : FVarId) (v : LetValue .impure) : InferM Unit := do
match v with
| .oproj _ x _ =>
let xVal ← getOwnedness x
join z xVal
| .ctor .. | .fap .. | .fvar .. | .pap .. | .sproj .. | .uproj .. | .erased .. | .lit .. =>
join z .own
| _ => unreachable!
def Ownedness.toBorrow : Ownedness → Option Bool
| .bot => none
| .borrow => some true
| .own | .top => some false
public partial def Decl.applyOwnedness (decl : Decl .impure) (values : Std.HashMap FVarId Ownedness) :
CompilerM (Decl .impure) := do
match decl.value with
| .code code =>
let params ← updateParams decl.params
let code ← goCode code
return { decl with params, value := .code code }
| _ => return decl
where
updateParams (ps : Array (Param .impure)) : CompilerM (Array (Param .impure)) :=
ps.mapM fun p => do
match values[p.fvarId]!.toBorrow with
| none => return p
| some borrow => p.updateBorrow borrow
goCode (code : Code .impure) : CompilerM (Code .impure) := do
match code with
| .jp decl k =>
let ps ← updateParams decl.params
let decl ← decl.update decl.type ps (← goCode decl.value)
return code.updateFun! decl (← goCode k)
| .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM goCode)
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← goCode k)
| .return .. | .jmp .. | .unreach .. => return code
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
end Lean.Compiler.LCNF

View file

@ -11,6 +11,7 @@ public import Lean.Compiler.LCNF.PassManager
import Lean.Compiler.LCNF.LiveVars
import Lean.Compiler.LCNF.DependsOn
import Lean.Compiler.LCNF.PhaseExt
import Lean.Compiler.LCNF.PropagateBorrow
namespace Lean.Compiler.LCNF
@ -62,6 +63,7 @@ structure Context where
we first try `relaxedReuse := false`, and then `relaxedReuse := true`.
-/
relaxedReuse : Bool
ownedness : Std.HashMap FVarId Ownedness
abbrev ReuseM := ReaderT Context CompilerM
@ -254,12 +256,13 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
match c with
| .cases cs =>
let alreadyFound := (← read).alreadyFound.contains cs.discr
let borrowed := (← read).ownedness.getD cs.discr .bot == .borrow
withReader (fun ctx => { ctx with alreadyFound := ctx.alreadyFound.insert cs.discr }) do
let alts ← cs.alts.mapM fun alt => do
let alt ← alt.mapCodeM (·.insertResetReuse)
match alt with
| .ctorAlt info k =>
if info.isScalar || alreadyFound then
if info.isScalar || alreadyFound || borrowed then
-- If `alreadyFound`, then we don't try to reuse memory cell to avoid
-- double reset.
return alt
@ -313,8 +316,11 @@ def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do
The second pass addresses issue #4089.
-/
if (← getConfig).resetReuse then
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false }
decl.insertResetReuseCore |>.run { relaxedReuse := true }
let ownedness ← decl.analyzePropagatedBorrows
let decl ← decl.applyOwnedness ownedness
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false, ownedness }
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := true, ownedness }
return decl
else
return decl

View file

@ -0,0 +1,83 @@
module
public section
/--
trace: [Compiler.explicitRc] size: 17
def testWithAnnotation @&n @&p @&q : obj :=
jp _jp.1 fst.2 @&snd.3 : obj :=
let snd := oproj[1] snd.3;
inc snd;
let _x.4 := ctor_0[Prod.mk] fst.2 snd;
return _x.4;
let zero := 0;
let isZero := Nat.decEq n zero;
cases isZero : obj
| Bool.true =>
let _x.5 := 123;
goto _jp.1 _x.5 p
| Bool.false =>
let one := 1;
let n.6 := Nat.sub n one;
let _x.7 := Nat.add n.6 one;
let _x.8 := Nat.mul n.6 _x.7;
dec _x.7;
dec n.6;
goto _jp.1 _x.8 q
[Compiler.explicitRc] size: 4
def testWithAnnotation._boxed n p q : obj :=
let res := testWithAnnotation n p q;
dec q;
dec p;
dec n;
return res
-/
#guard_msgs in
set_option trace.Compiler.explicitRc true in
def testWithAnnotation (n : Nat) (p q : @&Prod Nat Nat) : Prod Nat Nat :=
let (value, helper) :=
match n with
| 0 => (123, p)
| n + 1 => (n * (n + 1), q)
{ helper with fst := value }
/--
trace: [Compiler.explicitRc] size: 20
def testWithoutAnnotation @&n p q : obj :=
jp _jp.1 fst.2 snd.3 : obj :=
let snd := oproj[1] snd.3;
inc snd;
let _x.4 := reset[2] snd.3;
let _x.5 := reuse _x.4 in ctor_0[Prod.mk] fst.2 snd;
return _x.5;
let zero := 0;
let isZero := Nat.decEq n zero;
cases isZero : obj
| Bool.true =>
dec q;
let _x.6 := 123;
goto _jp.1 _x.6 p
| Bool.false =>
dec p;
let one := 1;
let n.7 := Nat.sub n one;
let _x.8 := Nat.add n.7 one;
let _x.9 := Nat.mul n.7 _x.8;
dec _x.8;
dec n.7;
goto _jp.1 _x.9 q
[Compiler.explicitRc] size: 2
def testWithoutAnnotation._boxed n p q : obj :=
let res := testWithoutAnnotation n p q;
dec n;
return res
-/
#guard_msgs in
set_option trace.Compiler.explicitRc true in
def testWithoutAnnotation (n : Nat) (p q : Prod Nat Nat) : Prod Nat Nat :=
let (value, helper) :=
match n with
| 0 => (123, p)
| n + 1 => (n * (n + 1), q)
{ helper with fst := value }

View file

@ -139,6 +139,8 @@ structure Quad where
/-- Only traverses → parameter stays borrowed. -/
@[noinline] def measuree (xs : List Nat) : Nat := xs.length
/-
/--
trace: [Compiler.explicitRc] size: 22
def cascadeDemo @&t : tobj :=
@ -255,3 +257,5 @@ def preserveTailCall (x : @&Prod Nat Nat) (a : Nat) : Nat :=
match a with
| 0 => x.fst
| a + 1 => preserveTailCall (mkNewProd x a) a
-/