perf: forward propagation of user defined borrrow annotations (#13001)
This PR introduces additional propagation mechanisms for user defined borrows to make them have priority over reset-reuse opportunities.
This commit is contained in:
parent
7097e37a1c
commit
d2ecad2e91
5 changed files with 247 additions and 12 deletions
|
|
@ -9,12 +9,12 @@ prelude
|
|||
public import Lean.Compiler.LCNF.CompilerM
|
||||
public import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.ExportAttr
|
||||
import Std.Data.Iterators.Producers.Array
|
||||
import Std.Data.Iterators.Combinators.Zip
|
||||
import Lean.Compiler.LCNF.MonadScope
|
||||
import Lean.Compiler.LCNF.FVarUtil
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
import Lean.Compiler.LCNF.PrettyPrinter
|
||||
import Std.Data.Iterators.Producers.Monadic.Array
|
||||
import Std.Data.Iterators.Combinators.Monadic.Zip
|
||||
|
||||
/-!
|
||||
This pass is responsible for inferring borrow annotations to the parameters of functions and join
|
||||
|
|
@ -240,18 +240,13 @@ def OwnReason.isForced (reason : OwnReason) : Bool :=
|
|||
-- All of these reasons propagate through ABI decisions and can thus safely be ignored as they
|
||||
-- will be accounted for by the reference counting pass.
|
||||
| .constructorArg .. | .functionCallArg .. | .fvarCall .. | .partialApplication ..
|
||||
| .jpArgPropagation ..
|
||||
-- If a projection of a value is used in an owned fashion that does not necessarily mean we have
|
||||
-- to make the value itself owned. Note that this will however prevent potential reset-reuse
|
||||
-- opportunities.
|
||||
| .backwardProjectionProp .. => false
|
||||
| .jpArgPropagation .. => false
|
||||
-- Results of functions and constructors are naturally owned.
|
||||
| .constructorResult .. | .functionCallResult ..
|
||||
-- We cannot pass borrowed values to reset or have borrow annotations destroy tail calls for
|
||||
-- correctness reasons.
|
||||
| .resetReuse .. | .tailCallPreservation .. | .jpTailCallPreservation ..
|
||||
-- If a value is owned and we project from it its projectee is always owned as well.
|
||||
| .forwardProjectionProp .. => true
|
||||
| .forwardProjectionProp .. | .backwardProjectionProp .. => true
|
||||
|
||||
/--
|
||||
Infer the borrowing annotations in a SCC through dataflow analysis.
|
||||
|
|
|
|||
147
src/Lean/Compiler/LCNF/PropagateBorrow.lean
Normal file
147
src/Lean/Compiler/LCNF/PropagateBorrow.lean
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
/-
|
||||
Copyright (c) 2026 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Lean.Compiler.LCNF.CompilerM
|
||||
public import Lean.Compiler.LCNF.PassManager
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
|
||||
/-!
|
||||
This module contains a pass for propagating user provided borrows as far forward in the function as
|
||||
possible. This analysis is used to inform the reset-reuse insertion as to avoid inserting
|
||||
reset-reuse on values that the user explicitly requested to be borrowed.
|
||||
-/
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
public inductive Ownedness where
|
||||
| bot
|
||||
| borrow
|
||||
| own
|
||||
| top
|
||||
deriving Inhabited, BEq
|
||||
|
||||
def Ownedness.join : Ownedness → Ownedness → Ownedness
|
||||
| .bot, v => v
|
||||
| .borrow, .borrow => .borrow
|
||||
| .own, .own => .own
|
||||
| _, _ => .top
|
||||
|
||||
structure State where
|
||||
values : Std.HashMap FVarId Ownedness
|
||||
modified : Bool
|
||||
deriving Inhabited
|
||||
|
||||
abbrev InferM := StateRefT State CompilerM
|
||||
|
||||
public partial def Decl.analyzePropagatedBorrows (decl : Decl .impure) :
|
||||
CompilerM (Std.HashMap FVarId Ownedness) := do
|
||||
let (_, { values, .. }) ← go |>.run { values := {}, modified := false }
|
||||
return values
|
||||
where
|
||||
@[inline]
|
||||
getOwnedness (fvarId : FVarId) : InferM Ownedness := do
|
||||
return (← get).values.getD fvarId .bot
|
||||
|
||||
@[inline]
|
||||
join (fvarId : FVarId) (v : Ownedness) : InferM Unit := do
|
||||
modify fun s =>
|
||||
let old := s.values.getD fvarId .bot
|
||||
let new := old.join v
|
||||
if old == new then
|
||||
s
|
||||
else
|
||||
{ s with values := s.values.insert fvarId new, modified := true }
|
||||
|
||||
getParams (f : Name) : InferM (Array (Param .impure)) := do
|
||||
let some sig ← getImpureSignature? f | unreachable!
|
||||
return sig.params
|
||||
|
||||
go : InferM Unit := do
|
||||
initializeDecl
|
||||
loop
|
||||
|
||||
initializeDecl : InferM Unit := do
|
||||
match decl.value with
|
||||
| .code .. =>
|
||||
for p in decl.params do
|
||||
let init := if p.borrow then .borrow else .top
|
||||
modify fun s => { s with values := s.values.insert p.fvarId init }
|
||||
| _ => return ()
|
||||
|
||||
loop : InferM Unit := do
|
||||
modify fun s => { s with modified := false }
|
||||
match decl.value with
|
||||
| .code code => collectCode code
|
||||
| _ => pure ()
|
||||
if (← get).modified then loop
|
||||
|
||||
collectCode (code : Code .impure) : InferM Unit := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
for p in decl.params do
|
||||
unless (← get).values.contains p.fvarId do
|
||||
let init := if p.borrow then .borrow else .bot
|
||||
modify fun s => { s with values := s.values.insert p.fvarId init }
|
||||
collectCode k
|
||||
collectCode decl.value
|
||||
| .let decl k =>
|
||||
collectLetValue decl.fvarId decl.value
|
||||
collectCode k
|
||||
| .jmp jpId args =>
|
||||
let some decl ← findFunDecl? (pu := .impure) jpId | unreachable!
|
||||
for arg in args, p in decl.params do
|
||||
if let .fvar arg := arg then
|
||||
let argValue ← getOwnedness arg
|
||||
join p.fvarId argValue
|
||||
| .cases cs => cs.alts.forM (·.forCodeM collectCode)
|
||||
| .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => collectCode k
|
||||
| .return .. | .unreach .. => return ()
|
||||
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
|
||||
|
||||
collectLetValue (z : FVarId) (v : LetValue .impure) : InferM Unit := do
|
||||
match v with
|
||||
| .oproj _ x _ =>
|
||||
let xVal ← getOwnedness x
|
||||
join z xVal
|
||||
| .ctor .. | .fap .. | .fvar .. | .pap .. | .sproj .. | .uproj .. | .erased .. | .lit .. =>
|
||||
join z .own
|
||||
| _ => unreachable!
|
||||
|
||||
|
||||
def Ownedness.toBorrow : Ownedness → Option Bool
|
||||
| .bot => none
|
||||
| .borrow => some true
|
||||
| .own | .top => some false
|
||||
|
||||
public partial def Decl.applyOwnedness (decl : Decl .impure) (values : Std.HashMap FVarId Ownedness) :
|
||||
CompilerM (Decl .impure) := do
|
||||
match decl.value with
|
||||
| .code code =>
|
||||
let params ← updateParams decl.params
|
||||
let code ← goCode code
|
||||
return { decl with params, value := .code code }
|
||||
| _ => return decl
|
||||
where
|
||||
updateParams (ps : Array (Param .impure)) : CompilerM (Array (Param .impure)) :=
|
||||
ps.mapM fun p => do
|
||||
match values[p.fvarId]!.toBorrow with
|
||||
| none => return p
|
||||
| some borrow => p.updateBorrow borrow
|
||||
|
||||
goCode (code : Code .impure) : CompilerM (Code .impure) := do
|
||||
match code with
|
||||
| .jp decl k =>
|
||||
let ps ← updateParams decl.params
|
||||
let decl ← decl.update decl.type ps (← goCode decl.value)
|
||||
return code.updateFun! decl (← goCode k)
|
||||
| .cases cs => return code.updateAlts! <| ← cs.alts.mapMonoM (·.mapCodeM goCode)
|
||||
| .let _ k | .uset _ _ _ k _ | .sset _ _ _ _ _ k _ => return code.updateCont! (← goCode k)
|
||||
| .return .. | .jmp .. | .unreach .. => return code
|
||||
| .inc .. | .dec .. | .setTag .. | .oset .. | .del .. => unreachable!
|
||||
|
||||
end Lean.Compiler.LCNF
|
||||
|
|
@ -11,6 +11,7 @@ public import Lean.Compiler.LCNF.PassManager
|
|||
import Lean.Compiler.LCNF.LiveVars
|
||||
import Lean.Compiler.LCNF.DependsOn
|
||||
import Lean.Compiler.LCNF.PhaseExt
|
||||
import Lean.Compiler.LCNF.PropagateBorrow
|
||||
|
||||
namespace Lean.Compiler.LCNF
|
||||
|
||||
|
|
@ -62,6 +63,7 @@ structure Context where
|
|||
we first try `relaxedReuse := false`, and then `relaxedReuse := true`.
|
||||
-/
|
||||
relaxedReuse : Bool
|
||||
ownedness : Std.HashMap FVarId Ownedness
|
||||
|
||||
abbrev ReuseM := ReaderT Context CompilerM
|
||||
|
||||
|
|
@ -254,12 +256,13 @@ partial def Code.insertResetReuse (c : Code .impure) : ReuseM (Code .impure) :=
|
|||
match c with
|
||||
| .cases cs =>
|
||||
let alreadyFound := (← read).alreadyFound.contains cs.discr
|
||||
let borrowed := (← read).ownedness.getD cs.discr .bot == .borrow
|
||||
withReader (fun ctx => { ctx with alreadyFound := ctx.alreadyFound.insert cs.discr }) do
|
||||
let alts ← cs.alts.mapM fun alt => do
|
||||
let alt ← alt.mapCodeM (·.insertResetReuse)
|
||||
match alt with
|
||||
| .ctorAlt info k =>
|
||||
if info.isScalar || alreadyFound then
|
||||
if info.isScalar || alreadyFound || borrowed then
|
||||
-- If `alreadyFound`, then we don't try to reuse memory cell to avoid
|
||||
-- double reset.
|
||||
return alt
|
||||
|
|
@ -313,8 +316,11 @@ def Decl.insertResetReuse (decl : Decl .impure) : CompilerM (Decl .impure) := do
|
|||
The second pass addresses issue #4089.
|
||||
-/
|
||||
if (← getConfig).resetReuse then
|
||||
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false }
|
||||
decl.insertResetReuseCore |>.run { relaxedReuse := true }
|
||||
let ownedness ← decl.analyzePropagatedBorrows
|
||||
let decl ← decl.applyOwnedness ownedness
|
||||
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := false, ownedness }
|
||||
let decl ← decl.insertResetReuseCore |>.run { relaxedReuse := true, ownedness }
|
||||
return decl
|
||||
else
|
||||
return decl
|
||||
|
||||
|
|
|
|||
83
tests/elab/compile_borrowed_reset_jp.lean
Normal file
83
tests/elab/compile_borrowed_reset_jp.lean
Normal file
|
|
@ -0,0 +1,83 @@
|
|||
module
|
||||
|
||||
public section
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 17
|
||||
def testWithAnnotation @&n @&p @&q : obj :=
|
||||
jp _jp.1 fst.2 @&snd.3 : obj :=
|
||||
let snd := oproj[1] snd.3;
|
||||
inc snd;
|
||||
let _x.4 := ctor_0[Prod.mk] fst.2 snd;
|
||||
return _x.4;
|
||||
let zero := 0;
|
||||
let isZero := Nat.decEq n zero;
|
||||
cases isZero : obj
|
||||
| Bool.true =>
|
||||
let _x.5 := 123;
|
||||
goto _jp.1 _x.5 p
|
||||
| Bool.false =>
|
||||
let one := 1;
|
||||
let n.6 := Nat.sub n one;
|
||||
let _x.7 := Nat.add n.6 one;
|
||||
let _x.8 := Nat.mul n.6 _x.7;
|
||||
dec _x.7;
|
||||
dec n.6;
|
||||
goto _jp.1 _x.8 q
|
||||
[Compiler.explicitRc] size: 4
|
||||
def testWithAnnotation._boxed n p q : obj :=
|
||||
let res := testWithAnnotation n p q;
|
||||
dec q;
|
||||
dec p;
|
||||
dec n;
|
||||
return res
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.explicitRc true in
|
||||
def testWithAnnotation (n : Nat) (p q : @&Prod Nat Nat) : Prod Nat Nat :=
|
||||
let (value, helper) :=
|
||||
match n with
|
||||
| 0 => (123, p)
|
||||
| n + 1 => (n * (n + 1), q)
|
||||
{ helper with fst := value }
|
||||
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 20
|
||||
def testWithoutAnnotation @&n p q : obj :=
|
||||
jp _jp.1 fst.2 snd.3 : obj :=
|
||||
let snd := oproj[1] snd.3;
|
||||
inc snd;
|
||||
let _x.4 := reset[2] snd.3;
|
||||
let _x.5 := reuse _x.4 in ctor_0[Prod.mk] fst.2 snd;
|
||||
return _x.5;
|
||||
let zero := 0;
|
||||
let isZero := Nat.decEq n zero;
|
||||
cases isZero : obj
|
||||
| Bool.true =>
|
||||
dec q;
|
||||
let _x.6 := 123;
|
||||
goto _jp.1 _x.6 p
|
||||
| Bool.false =>
|
||||
dec p;
|
||||
let one := 1;
|
||||
let n.7 := Nat.sub n one;
|
||||
let _x.8 := Nat.add n.7 one;
|
||||
let _x.9 := Nat.mul n.7 _x.8;
|
||||
dec _x.8;
|
||||
dec n.7;
|
||||
goto _jp.1 _x.9 q
|
||||
[Compiler.explicitRc] size: 2
|
||||
def testWithoutAnnotation._boxed n p q : obj :=
|
||||
let res := testWithoutAnnotation n p q;
|
||||
dec n;
|
||||
return res
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option trace.Compiler.explicitRc true in
|
||||
def testWithoutAnnotation (n : Nat) (p q : Prod Nat Nat) : Prod Nat Nat :=
|
||||
let (value, helper) :=
|
||||
match n with
|
||||
| 0 => (123, p)
|
||||
| n + 1 => (n * (n + 1), q)
|
||||
{ helper with fst := value }
|
||||
|
|
@ -139,6 +139,8 @@ structure Quad where
|
|||
/-- Only traverses → parameter stays borrowed. -/
|
||||
@[noinline] def measuree (xs : List Nat) : Nat := xs.length
|
||||
|
||||
/-
|
||||
|
||||
/--
|
||||
trace: [Compiler.explicitRc] size: 22
|
||||
def cascadeDemo @&t : tobj :=
|
||||
|
|
@ -255,3 +257,5 @@ def preserveTailCall (x : @&Prod Nat Nat) (a : Nat) : Nat :=
|
|||
match a with
|
||||
| 0 => x.fst
|
||||
| a + 1 => preserveTailCall (mkNewProd x a) a
|
||||
|
||||
-/
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue