perf: check simp cache in simpLoop (#8880)

This PR makes `simp` consult its own cache more often, to avoid
replicating work.

Before, the simp cache was checked upon entry of `simpImpl` only, which
then calls `simpLoop`, which recursively iterates the `pre`-lemmas,
without checking the cache again.

Now, `simpLoop` itself checks the cache. This seems more principled,
given that `simpLoop` is actually putting entries into the cache for
each of its calls, so it’s more uniform if it checks the cache itself.

This avoids repeated rewrites. For example given
```
theorem ab : a = b := testSorry
theorem bc : b = c := testSorry
example (h : P c) : P b ∧ P a := by simp [ab, bc, h]
```
simp would rewrite `b ==> c` twice (once as part of `b ==> c` and then
again as part of `a ==> b ==> c`). And it’d be order dependent: With
```
example (h : P c) : P a ∧ P b := by simp [ab, bc, h]
```
the `a ==> b ==> c` chain would insert `b ==> c` into the cache, and
picked up by `simpImpl` when rewriting `P b`.

With this change, `b ==> c` is performed only once in both examples.

Instruction counts on stdlib and mathlib both show a mild improvement
across the board (0.5%), with individual modules improving by up to 4%
in stdlib and even more in mathlib.


(This does not check the cache before applying `post`, which explains
where there are still some repeated rewrites in the trace logs. But I’m
less sure about inserting a cache check here and so I am treading
carefully here. It’s also going to be at most one `post` application
that’s duplicated, because if `post` returns `.visit`, we go back to
`pre` and thus a cache check.)
This commit is contained in:
Joachim Breitner 2025-06-21 19:58:05 +02:00 committed by GitHub
parent 4d697874b7
commit 2441bf1f76
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 97 additions and 29 deletions

View file

@ -748,6 +748,10 @@ def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do
partial def simpLoop (e : Expr) : SimpM Result := withIncRecDepth do
let cfg ← getConfig
if cfg.memoize then
let cache := (← get).cache
if let some result := cache.find? e then
return result
if (← get).numSteps > cfg.maxSteps then
throwError "simp failed, maximum number of steps exceeded"
else
@ -784,16 +788,8 @@ def simpImpl (e : Expr) : SimpM Result := withIncRecDepth do
checkSystem "simp"
if (← isProof e) then
return { expr := e }
go
where
go : SimpM Result := do
let cfg ← getConfig
if cfg.memoize then
let cache := (← get).cache
if let some result := cache.find? e then
return result
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
simpLoop e
@[inline] private def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := do
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <|

View file

@ -15,17 +15,17 @@ Please use `termination_by` to specify a decreasing measure.
decreasing_by.lean:75:13-77:3: error: unexpected token 'end'; expected '{' or tactic
decreasing_by.lean:75:0-75:13: error: unsolved goals
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:85:0-85:22: error: unsolved goals
case a
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:93:0-94:22: error: Could not find a decreasing measure.
The basic measures relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
@ -35,7 +35,7 @@ The basic measures relate at each recursive call as follows:
Please use `termination_by` to specify a decreasing measure.
decreasing_by.lean:104:0-106:17: error: unsolved goals
n m : Nat
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
decreasing_by.lean:114:0-117:17: error: Could not find a decreasing measure.
The basic measures relate at each recursive call as follows:
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)

View file

@ -13,10 +13,10 @@ trace: [diag] Diagnostics
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [simp] Diagnostics
[simp] used theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59
[simp] tried theorems (max: 59, num: 1):
[simp] ack.eq_3 ↦ 59, succeeded: 59
[simp] used theorems (max: 57, num: 1):
[simp] ack.eq_3 ↦ 57
[simp] tried theorems (max: 57, num: 1):
[simp] ack.eq_3 ↦ 57, succeeded: 57
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
trace: [diag] Diagnostics

View file

@ -0,0 +1,72 @@
/-!
Checks that the simp cache is consulted within `simpLoop`, not just in `simpMain`
-/
axiom testSorry : α
opaque a : Nat
opaque b : Nat
opaque c : Nat
opaque f : Nat → Nat
opaque P : Nat → Prop
theorem ab : a = b := testSorry
theorem bc : b = c := testSorry
theorem ba : b = a := testSorry
theorem fafb : f a = f b := testSorry
set_option trace.Meta.Tactic.simp.rewrite true
-- This trace should only mention one `bc` rewrite, not two.
/--
trace: [Meta.Tactic.simp.rewrite] bc:1000:
b
==>
c
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] ab:1000:
a
==>
b
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] and_self:1000:
True ∧ True
==>
True
-/
#guard_msgs in
example (h : P c) : P b ∧ P a := by simp [ab, bc, h]
-- Almost the same goal, but ordered differently.
/--
trace: [Meta.Tactic.simp.rewrite] ab:1000:
a
==>
b
[Meta.Tactic.simp.rewrite] bc:1000:
b
==>
c
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] h:1000:
P c
==>
True
[Meta.Tactic.simp.rewrite] and_self:1000:
True ∧ True
==>
True
-/
#guard_msgs in
example (h : P c) : P a ∧ P b := by simp [ab, bc, h]

View file

@ -12,8 +12,8 @@ trace: [simp] Diagnostics
[simp] used theorems (max: 50, num: 2):
[simp] f_eq ↦ 50
[simp] q_eq ↦ 50
[simp] tried theorems (max: 101, num: 2):
[simp] f_eq ↦ 101, succeeded: 50
[simp] tried theorems (max: 51, num: 2):
[simp] f_eq ↦ 51, succeeded: 50
[simp] q_eq ↦ 50, succeeded: 50
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
-/
@ -33,13 +33,13 @@ def ack : Nat → Nat → Nat
/--
trace: [simp] Diagnostics
[simp] used theorems (max: 1201, num: 3):
[simp] ack.eq_3 ↦ 1201
[simp] Nat.reduceAdd (builtin simproc) ↦ 771
[simp] ack.eq_1 ↦ 768
[simp] tried theorems (max: 1973, num: 2):
[simp] ack.eq_3 ↦ 1973, succeeded: 1201
[simp] ack.eq_1 ↦ 768, succeeded: 768
[simp] used theorems (max: 1193, num: 3):
[simp] ack.eq_3 ↦ 1193
[simp] Nat.reduceAdd (builtin simproc) ↦ 508
[simp] ack.eq_1 ↦ 508
[simp] tried theorems (max: 1705, num: 2):
[simp] ack.eq_3 ↦ 1705, succeeded: 1193
[simp] ack.eq_1 ↦ 508, succeeded: 508
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
---
error: tactic 'simp' failed, nested error:

View file

@ -16,7 +16,7 @@ x✝ :
(y : (_ : Nat) ×' Tree α) →
(invImage (fun x => PSigma.casesOn x fun n t => (n, t)) Prod.instWellFoundedRelation).1 y ⟨n.succ, { cs := cs }⟩ →
Tree α
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
(n, { cs := List.map (fun x => x✝ ⟨n + 1, x.val⟩ ⋯) cs.attach }) (n.succ, { cs := cs })
-/
#guard_msgs(trace) in