From 2441bf1f76471592fefd434c96189df191cbd90a Mon Sep 17 00:00:00 2001 From: Joachim Breitner Date: Sat, 21 Jun 2025 19:58:05 +0200 Subject: [PATCH] perf: check simp cache in simpLoop (#8880) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR makes `simp` consult its own cache more often, to avoid replicating work. Before, the simp cache was checked upon entry of `simpImpl` only, which then calls `simpLoop`, which recursively iterates the `pre`-lemmas, without checking the cache again. Now, `simpLoop` itself checks the cache. This seems more principled, given that `simpLoop` is actually putting entries into the cache for each of its calls, so it’s more uniform if it checks the cache itself. This avoids repeated rewrites. For example given ``` theorem ab : a = b := testSorry theorem bc : b = c := testSorry example (h : P c) : P b ∧ P a := by simp [ab, bc, h] ``` simp would rewrite `b ==> c` twice (once as part of `b ==> c` and then again as part of `a ==> b ==> c`). And it’d be order dependent: With ``` example (h : P c) : P a ∧ P b := by simp [ab, bc, h] ``` the `a ==> b ==> c` chain would insert `b ==> c` into the cache, and picked up by `simpImpl` when rewriting `P b`. With this change, `b ==> c` is performed only once in both examples. Instruction counts on stdlib and mathlib both show a mild improvement across the board (0.5%), with individual modules improving by up to 4% in stdlib and even more in mathlib. (This does not check the cache before applying `post`, which explains where there are still some repeated rewrites in the trace logs. But I’m less sure about inserting a cache check here and so I am treading carefully here. It’s also going to be at most one `post` application that’s duplicated, because if `post` returns `.visit`, we go back to `pre` and thus a cache check.) --- src/Lean/Meta/Tactic/Simp/Main.lean | 16 ++--- tests/lean/decreasing_by.lean.expected.out | 10 +-- tests/lean/run/ack.lean | 8 +-- tests/lean/run/simpCacheTest.lean | 72 ++++++++++++++++++++++ tests/lean/run/simpDiag.lean | 18 +++--- tests/lean/run/wf_preprocess_leak.lean | 2 +- 6 files changed, 97 insertions(+), 29 deletions(-) create mode 100644 tests/lean/run/simpCacheTest.lean diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 83533a58b4..6348afa9b9 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -748,6 +748,10 @@ def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do partial def simpLoop (e : Expr) : SimpM Result := withIncRecDepth do let cfg ← getConfig + if cfg.memoize then + let cache := (← get).cache + if let some result := cache.find? e then + return result if (← get).numSteps > cfg.maxSteps then throwError "simp failed, maximum number of steps exceeded" else @@ -784,16 +788,8 @@ def simpImpl (e : Expr) : SimpM Result := withIncRecDepth do checkSystem "simp" if (← isProof e) then return { expr := e } - go -where - go : SimpM Result := do - let cfg ← getConfig - if cfg.memoize then - let cache := (← get).cache - if let some result := cache.find? e then - return result - trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}" - simpLoop e + trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}" + simpLoop e @[inline] private def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := do withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| diff --git a/tests/lean/decreasing_by.lean.expected.out b/tests/lean/decreasing_by.lean.expected.out index 474a642aa8..5dce42f8f7 100644 --- a/tests/lean/decreasing_by.lean.expected.out +++ b/tests/lean/decreasing_by.lean.expected.out @@ -15,17 +15,17 @@ Please use `termination_by` to specify a decreasing measure. decreasing_by.lean:75:13-77:3: error: unexpected token 'end'; expected '{' or tactic decreasing_by.lean:75:0-75:13: error: unsolved goals n m : Nat -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m) n m : Nat -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m) decreasing_by.lean:85:0-85:22: error: unsolved goals case a n m : Nat -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m) n m : Nat -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m) decreasing_by.lean:93:0-94:22: error: Could not find a decreasing measure. The basic measures relate at each recursive call as follows: (<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted) @@ -35,7 +35,7 @@ The basic measures relate at each recursive call as follows: Please use `termination_by` to specify a decreasing measure. decreasing_by.lean:104:0-106:17: error: unsolved goals n m : Nat -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m) decreasing_by.lean:114:0-117:17: error: Could not find a decreasing measure. The basic measures relate at each recursive call as follows: (<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted) diff --git a/tests/lean/run/ack.lean b/tests/lean/run/ack.lean index fdbfde718b..27d21a2d5e 100644 --- a/tests/lean/run/ack.lean +++ b/tests/lean/run/ack.lean @@ -13,10 +13,10 @@ trace: [diag] Diagnostics use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [simp] Diagnostics - [simp] used theorems (max: 59, num: 1): - [simp] ack.eq_3 ↦ 59 - [simp] tried theorems (max: 59, num: 1): - [simp] ack.eq_3 ↦ 59, succeeded: 59 + [simp] used theorems (max: 57, num: 1): + [simp] ack.eq_3 ↦ 57 + [simp] tried theorems (max: 57, num: 1): + [simp] ack.eq_3 ↦ 57, succeeded: 57 use `set_option diagnostics.threshold ` to control threshold for reporting counters --- trace: [diag] Diagnostics diff --git a/tests/lean/run/simpCacheTest.lean b/tests/lean/run/simpCacheTest.lean new file mode 100644 index 0000000000..cf1f1018e1 --- /dev/null +++ b/tests/lean/run/simpCacheTest.lean @@ -0,0 +1,72 @@ +/-! +Checks that the simp cache is consulted within `simpLoop`, not just in `simpMain` +-/ + + +axiom testSorry : α +opaque a : Nat +opaque b : Nat +opaque c : Nat +opaque f : Nat → Nat +opaque P : Nat → Prop +theorem ab : a = b := testSorry +theorem bc : b = c := testSorry +theorem ba : b = a := testSorry +theorem fafb : f a = f b := testSorry + +set_option trace.Meta.Tactic.simp.rewrite true + + +-- This trace should only mention one `bc` rewrite, not two. + +/-- +trace: [Meta.Tactic.simp.rewrite] bc:1000: + b + ==> + c +[Meta.Tactic.simp.rewrite] h:1000: + P c + ==> + True +[Meta.Tactic.simp.rewrite] ab:1000: + a + ==> + b +[Meta.Tactic.simp.rewrite] h:1000: + P c + ==> + True +[Meta.Tactic.simp.rewrite] and_self:1000: + True ∧ True + ==> + True +-/ +#guard_msgs in +example (h : P c) : P b ∧ P a := by simp [ab, bc, h] + +-- Almost the same goal, but ordered differently. + +/-- +trace: [Meta.Tactic.simp.rewrite] ab:1000: + a + ==> + b +[Meta.Tactic.simp.rewrite] bc:1000: + b + ==> + c +[Meta.Tactic.simp.rewrite] h:1000: + P c + ==> + True +[Meta.Tactic.simp.rewrite] h:1000: + P c + ==> + True +[Meta.Tactic.simp.rewrite] and_self:1000: + True ∧ True + ==> + True +-/ +#guard_msgs in +example (h : P c) : P a ∧ P b := by simp [ab, bc, h] diff --git a/tests/lean/run/simpDiag.lean b/tests/lean/run/simpDiag.lean index 23260790fc..faf8f0a830 100644 --- a/tests/lean/run/simpDiag.lean +++ b/tests/lean/run/simpDiag.lean @@ -12,8 +12,8 @@ trace: [simp] Diagnostics [simp] used theorems (max: 50, num: 2): [simp] f_eq ↦ 50 [simp] q_eq ↦ 50 - [simp] tried theorems (max: 101, num: 2): - [simp] f_eq ↦ 101, succeeded: 50 + [simp] tried theorems (max: 51, num: 2): + [simp] f_eq ↦ 51, succeeded: 50 [simp] q_eq ↦ 50, succeeded: 50 use `set_option diagnostics.threshold ` to control threshold for reporting counters -/ @@ -33,13 +33,13 @@ def ack : Nat → Nat → Nat /-- trace: [simp] Diagnostics - [simp] used theorems (max: 1201, num: 3): - [simp] ack.eq_3 ↦ 1201 - [simp] Nat.reduceAdd (builtin simproc) ↦ 771 - [simp] ack.eq_1 ↦ 768 - [simp] tried theorems (max: 1973, num: 2): - [simp] ack.eq_3 ↦ 1973, succeeded: 1201 - [simp] ack.eq_1 ↦ 768, succeeded: 768 + [simp] used theorems (max: 1193, num: 3): + [simp] ack.eq_3 ↦ 1193 + [simp] Nat.reduceAdd (builtin simproc) ↦ 508 + [simp] ack.eq_1 ↦ 508 + [simp] tried theorems (max: 1705, num: 2): + [simp] ack.eq_3 ↦ 1705, succeeded: 1193 + [simp] ack.eq_1 ↦ 508, succeeded: 508 use `set_option diagnostics.threshold ` to control threshold for reporting counters --- error: tactic 'simp' failed, nested error: diff --git a/tests/lean/run/wf_preprocess_leak.lean b/tests/lean/run/wf_preprocess_leak.lean index e6e5854753..999a5c852d 100644 --- a/tests/lean/run/wf_preprocess_leak.lean +++ b/tests/lean/run/wf_preprocess_leak.lean @@ -16,7 +16,7 @@ x✝ : (y : (_ : Nat) ×' Tree α) → (invImage (fun x => PSigma.casesOn x fun n t => (n, t)) Prod.instWellFoundedRelation).1 y ⟨n.succ, { cs := cs }⟩ → Tree α -⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂) +⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂) (n, { cs := List.map (fun x => x✝ ⟨n + 1, x.val⟩ ⋯) cs.attach }) (n.succ, { cs := cs }) -/ #guard_msgs(trace) in