From 86175bea00f2d963a8c1367ac8ae67dd70ab32ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Mon, 23 Mar 2026 19:10:50 +0100 Subject: [PATCH] perf: teach borrow inference about arrays (#13064) This PR informs the borrow inference that if an `Array` is borrowed and we index into it, the value we obtain is effectively a borrowed value as well. This helps improve the ABI of operations that recurse on linked structures containing arrays such as tries or persistent hash maps. --- src/Lean/Compiler/LCNF/ExplicitRC.lean | 1 + src/Lean/Compiler/LCNF/InferBorrow.lean | 10 +++ src/Lean/Compiler/LCNF/PropagateBorrow.lean | 19 ++++- tests/elab/compile_borrowed_reset_jp.lean | 83 +++++++++++++++++++ .../elab/compile_recursive_array_access.lean | 51 ++++++++++++ 5 files changed, 161 insertions(+), 3 deletions(-) create mode 100644 tests/elab/compile_recursive_array_access.lean diff --git a/src/Lean/Compiler/LCNF/ExplicitRC.lean b/src/Lean/Compiler/LCNF/ExplicitRC.lean index ff01955d58..d4863b2ec2 100644 --- a/src/Lean/Compiler/LCNF/ExplicitRC.lean +++ b/src/Lean/Compiler/LCNF/ExplicitRC.lean @@ -97,6 +97,7 @@ partial def collectCode (code : Code .impure) : M Unit := do match decl.value with | .oproj _ parent => addDerivedValue parent decl.fvarId + -- Keep in sync with PropagateBorrow, InferBorrow | .fap ``Array.getInternal args => if let .fvar parent := args[1]! then addDerivedValue parent decl.fvarId diff --git a/src/Lean/Compiler/LCNF/InferBorrow.lean b/src/Lean/Compiler/LCNF/InferBorrow.lean index e7eafa5d0d..a95055d2da 100644 --- a/src/Lean/Compiler/LCNF/InferBorrow.lean +++ b/src/Lean/Compiler/LCNF/InferBorrow.lean @@ -373,6 +373,16 @@ where | .oproj _ x _ => if ← isOwned x then ownFVar z (.forwardProjectionProp z) if ← isOwned z then ownFVar x (.backwardProjectionProp z) + -- Keep in sync with ExplicitRC, PropagateBorrow + | .fap ``Array.getInternal args => + if let .fvar parent := args[1]! then + if ← isOwned parent then ownFVar z (.forwardProjectionProp z) + | .fap ``Array.get!Internal args => + if let .fvar parent := args[2]! then + if ← isOwned parent then ownFVar z (.forwardProjectionProp z) + | .fap ``Array.uget args => + if let .fvar parent := args[1]! then + if ← isOwned parent then ownFVar z (.forwardProjectionProp z) | .fap f args => let ps ← getParamInfo (.decl f) ownFVar z (.functionCallResult z) diff --git a/src/Lean/Compiler/LCNF/PropagateBorrow.lean b/src/Lean/Compiler/LCNF/PropagateBorrow.lean index 8947abf1c1..99d6d07970 100644 --- a/src/Lean/Compiler/LCNF/PropagateBorrow.lean +++ b/src/Lean/Compiler/LCNF/PropagateBorrow.lean @@ -105,9 +105,22 @@ where collectLetValue (z : FVarId) (v : LetValue .impure) : InferM Unit := do match v with - | .oproj _ x _ => - let xVal ← getOwnedness x - join z xVal + | .oproj _ parent _ => + let parentVal ← getOwnedness parent + join z parentVal + -- Keep in sync with ExplicitRC, InferBorrow + | .fap ``Array.getInternal args => + if let .fvar parent := args[1]! then + let parentVal ← getOwnedness parent + join z parentVal + | .fap ``Array.get!Internal args => + if let .fvar parent := args[2]! then + let parentVal ← getOwnedness parent + join z parentVal + | .fap ``Array.uget args => + if let .fvar parent := args[1]! then + let parentVal ← getOwnedness parent + join z parentVal | .ctor .. | .fap .. | .fvar .. | .pap .. | .sproj .. | .uproj .. | .erased .. | .lit .. => join z .own | _ => unreachable! diff --git a/tests/elab/compile_borrowed_reset_jp.lean b/tests/elab/compile_borrowed_reset_jp.lean index 7fdca9411a..a930fb304e 100644 --- a/tests/elab/compile_borrowed_reset_jp.lean +++ b/tests/elab/compile_borrowed_reset_jp.lean @@ -81,3 +81,86 @@ def testWithoutAnnotation (n : Nat) (p q : Prod Nat Nat) : Prod Nat Nat := | 0 => (123, p) | n + 1 => (n * (n + 1), q) { helper with fst := value } + +/-- +trace: [Compiler.inferBorrow] own _x.28: result of ctor call _x.28 +[Compiler.inferBorrow] own _x.30: result of ctor call _x.30 +[Compiler.inferBorrow] own n: argument to constructor call _x.30 +[Compiler.inferBorrow] own _x.29: result of function call _x.29 +[Compiler.inferBorrow] size: 2 + def testArrayWithAnnotation._closed_0 : obj := + let _x.1 := 0; + let _x.2 := ctor_0[Prod.mk] _x.1 _x.1; + return _x.2 +[Compiler.inferBorrow] size: 4 + def testArrayWithAnnotation n @&ps : obj := + let _x.1 := testArrayWithAnnotation._closed_0; + let pair := Array.get!Internal ◾ _x.1 ps n; + let snd := oproj[1] pair; + let _x.2 := ctor_0[Prod.mk] n snd; + return _x.2 +-/ +#guard_msgs in +set_option trace.Compiler.inferBorrow true in +def testArrayWithAnnotation (n : Nat) (ps : @&Array (Nat × Nat)) : Nat × Nat := + let pair := ps[n]! + { pair with fst := n } + +/-- +trace: [Compiler.inferBorrow] own _x.28: used in reset reuse _x.28 +[Compiler.inferBorrow] own _x.29: used in reset reuse _x.28 +[Compiler.inferBorrow] own n: argument to constructor call _x.28 +[Compiler.inferBorrow] own pair: used in reset reuse _x.29 +[Compiler.inferBorrow] own snd: fwd projection propagation snd +[Compiler.inferBorrow] own _x.27: result of function call _x.27 +[Compiler.inferBorrow] size: 5 + def testArrayWithoutAnnotation n @&ps : obj := + let _x.1 := testArrayWithAnnotation._closed_0; + let pair := Array.get!Internal ◾ _x.1 ps n; + let snd := oproj[1] pair; + let _x.2 := reset[2] pair; + let _x.3 := reuse _x.2 in ctor_0[Prod.mk] n snd; + return _x.3 +-/ +#guard_msgs in +set_option trace.Compiler.inferBorrow true in +def testArrayWithoutAnnotation (n : Nat) (ps : Array (Nat × Nat)) : Nat × Nat := + let pair := ps[n]! + { pair with fst := n } + +/-- +warning: declaration uses `sorry` +--- +trace: [Compiler.inferBorrow] own _x.11: result of ctor call _x.11 +[Compiler.inferBorrow] own n: argument to constructor call _x.11 +[Compiler.inferBorrow] size: 3 + def testArrayWithAnnotation' n @&ps : obj := + let pair := Array.getInternal ◾ ps n ◾; + let snd := oproj[1] pair; + let _x.1 := ctor_0[Prod.mk] n snd; + return _x.1 +-/ +#guard_msgs in +set_option trace.Compiler.inferBorrow true in +def testArrayWithAnnotation' (n : Nat) (ps : @&Array (Nat × Nat)) : Nat × Nat := + let pair := ps[n]'sorry + { pair with fst := n } + +/-- +warning: declaration uses `sorry` +--- +trace: [Compiler.inferBorrow] own _x.13: result of ctor call _x.13 +[Compiler.inferBorrow] own _x.12: result of function call _x.12 +[Compiler.inferBorrow] size: 4 + def testArrayWithAnnotation'' n @&ps : obj := + let pair := Array.uget ◾ ps n ◾; + let snd := oproj[1] pair; + let _x.1 := USize.toNat n; + let _x.2 := ctor_0[Prod.mk] _x.1 snd; + return _x.2 +-/ +#guard_msgs in +set_option trace.Compiler.inferBorrow true in +def testArrayWithAnnotation'' (n : USize) (ps : @&Array (Nat × Nat)) : Nat × Nat := + let pair := ps[n]'sorry + { pair with fst := n.toNat } diff --git a/tests/elab/compile_recursive_array_access.lean b/tests/elab/compile_recursive_array_access.lean new file mode 100644 index 0000000000..cb9c0b09b8 --- /dev/null +++ b/tests/elab/compile_recursive_array_access.lean @@ -0,0 +1,51 @@ +/-! +This test asserts that the compiler is able to handle compilation of functions that recurse through +nested arrays in a way that does not unnecessarily remove borrow annotations. +-/ + + +inductive NAryTree where + | tip (x : String) + | node (ys : Array NAryTree) + deriving Inhabited + +/-- +trace: [Compiler.explicitRc] size: 19 + def followPath @&tree @&path : obj := + cases tree : obj + | NAryTree.tip => + cases path : obj + | List.nil => + let x.1 := oproj[0] tree; + inc x.1; + return x.1 + | _ => + let _x.2 := instInhabitedNAryTree.default._closed_0; + return _x.2 + | NAryTree.node => + cases path : obj + | List.cons => + let ys.3 := oproj[0] tree; + let head.4 := oproj[0] path; + let tail.5 := oproj[1] path; + let _x.6 := instInhabitedNAryTree.default; + let _x.7 := Array.get!InternalBorrowed ◾ _x.6 ys.3 head.4; + let _x.8 := followPath _x.7 tail.5; + return _x.8 + | _ => + let _x.9 := instInhabitedNAryTree.default._closed_0; + return _x.9 +[Compiler.explicitRc] size: 3 + def followPath._boxed tree path : obj := + let res := followPath tree path; + dec path; + dec tree; + return res +-/ +#guard_msgs in +set_option trace.Compiler.explicitRc true in +def followPath (tree : NAryTree) (path : List Nat) : String := + match tree, path with + | .tip x, [] => x + | .node ys, idx :: path => followPath ys[idx]! path + | _, _ => default