perf: check simp cache in simpLoop (#8880)
This PR makes `simp` consult its own cache more often, to avoid replicating work. Before, the simp cache was checked upon entry of `simpImpl` only, which then calls `simpLoop`, which recursively iterates the `pre`-lemmas, without checking the cache again. Now, `simpLoop` itself checks the cache. This seems more principled, given that `simpLoop` is actually putting entries into the cache for each of its calls, so it’s more uniform if it checks the cache itself. This avoids repeated rewrites. For example given ``` theorem ab : a = b := testSorry theorem bc : b = c := testSorry example (h : P c) : P b ∧ P a := by simp [ab, bc, h] ``` simp would rewrite `b ==> c` twice (once as part of `b ==> c` and then again as part of `a ==> b ==> c`). And it’d be order dependent: With ``` example (h : P c) : P a ∧ P b := by simp [ab, bc, h] ``` the `a ==> b ==> c` chain would insert `b ==> c` into the cache, and picked up by `simpImpl` when rewriting `P b`. With this change, `b ==> c` is performed only once in both examples. Instruction counts on stdlib and mathlib both show a mild improvement across the board (0.5%), with individual modules improving by up to 4% in stdlib and even more in mathlib. (This does not check the cache before applying `post`, which explains where there are still some repeated rewrites in the trace logs. But I’m less sure about inserting a cache check here and so I am treading carefully here. It’s also going to be at most one `post` application that’s duplicated, because if `post` returns `.visit`, we go back to `pre` and thus a cache check.)
This commit is contained in:
parent
4d697874b7
commit
2441bf1f76
6 changed files with 97 additions and 29 deletions
|
|
@ -748,6 +748,10 @@ def cacheResult (e : Expr) (cfg : Config) (r : Result) : SimpM Result := do
|
|||
|
||||
partial def simpLoop (e : Expr) : SimpM Result := withIncRecDepth do
|
||||
let cfg ← getConfig
|
||||
if cfg.memoize then
|
||||
let cache := (← get).cache
|
||||
if let some result := cache.find? e then
|
||||
return result
|
||||
if (← get).numSteps > cfg.maxSteps then
|
||||
throwError "simp failed, maximum number of steps exceeded"
|
||||
else
|
||||
|
|
@ -784,16 +788,8 @@ def simpImpl (e : Expr) : SimpM Result := withIncRecDepth do
|
|||
checkSystem "simp"
|
||||
if (← isProof e) then
|
||||
return { expr := e }
|
||||
go
|
||||
where
|
||||
go : SimpM Result := do
|
||||
let cfg ← getConfig
|
||||
if cfg.memoize then
|
||||
let cache := (← get).cache
|
||||
if let some result := cache.find? e then
|
||||
return result
|
||||
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
|
||||
simpLoop e
|
||||
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
|
||||
simpLoop e
|
||||
|
||||
@[inline] private def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := do
|
||||
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <|
|
||||
|
|
|
|||
|
|
@ -15,17 +15,17 @@ Please use `termination_by` to specify a decreasing measure.
|
|||
decreasing_by.lean:75:13-77:3: error: unexpected token 'end'; expected '{' or tactic
|
||||
decreasing_by.lean:75:0-75:13: error: unsolved goals
|
||||
n m : Nat
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
|
||||
|
||||
n m : Nat
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
|
||||
decreasing_by.lean:85:0-85:22: error: unsolved goals
|
||||
case a
|
||||
n m : Nat
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (n, dec2 m) (n, m)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (n, dec2 m) (n, m)
|
||||
|
||||
n m : Nat
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
|
||||
decreasing_by.lean:93:0-94:22: error: Could not find a decreasing measure.
|
||||
The basic measures relate at each recursive call as follows:
|
||||
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
|
||||
|
|
@ -35,7 +35,7 @@ The basic measures relate at each recursive call as follows:
|
|||
Please use `termination_by` to specify a decreasing measure.
|
||||
decreasing_by.lean:104:0-106:17: error: unsolved goals
|
||||
n m : Nat
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => a₁ < a₂) (dec1 n, 100) (n, m)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun x1 x2 => x1 < x2) (dec1 n, 100) (n, m)
|
||||
decreasing_by.lean:114:0-117:17: error: Could not find a decreasing measure.
|
||||
The basic measures relate at each recursive call as follows:
|
||||
(<, ≤, =: relation proved, ? all proofs failed, _: no proof attempted)
|
||||
|
|
|
|||
|
|
@ -13,10 +13,10 @@ trace: [diag] Diagnostics
|
|||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [simp] Diagnostics
|
||||
[simp] used theorems (max: 59, num: 1):
|
||||
[simp] ack.eq_3 ↦ 59
|
||||
[simp] tried theorems (max: 59, num: 1):
|
||||
[simp] ack.eq_3 ↦ 59, succeeded: 59
|
||||
[simp] used theorems (max: 57, num: 1):
|
||||
[simp] ack.eq_3 ↦ 57
|
||||
[simp] tried theorems (max: 57, num: 1):
|
||||
[simp] ack.eq_3 ↦ 57, succeeded: 57
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
trace: [diag] Diagnostics
|
||||
|
|
|
|||
72
tests/lean/run/simpCacheTest.lean
Normal file
72
tests/lean/run/simpCacheTest.lean
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
/-!
|
||||
Checks that the simp cache is consulted within `simpLoop`, not just in `simpMain`
|
||||
-/
|
||||
|
||||
|
||||
axiom testSorry : α
|
||||
opaque a : Nat
|
||||
opaque b : Nat
|
||||
opaque c : Nat
|
||||
opaque f : Nat → Nat
|
||||
opaque P : Nat → Prop
|
||||
theorem ab : a = b := testSorry
|
||||
theorem bc : b = c := testSorry
|
||||
theorem ba : b = a := testSorry
|
||||
theorem fafb : f a = f b := testSorry
|
||||
|
||||
set_option trace.Meta.Tactic.simp.rewrite true
|
||||
|
||||
|
||||
-- This trace should only mention one `bc` rewrite, not two.
|
||||
|
||||
/--
|
||||
trace: [Meta.Tactic.simp.rewrite] bc:1000:
|
||||
b
|
||||
==>
|
||||
c
|
||||
[Meta.Tactic.simp.rewrite] h:1000:
|
||||
P c
|
||||
==>
|
||||
True
|
||||
[Meta.Tactic.simp.rewrite] ab:1000:
|
||||
a
|
||||
==>
|
||||
b
|
||||
[Meta.Tactic.simp.rewrite] h:1000:
|
||||
P c
|
||||
==>
|
||||
True
|
||||
[Meta.Tactic.simp.rewrite] and_self:1000:
|
||||
True ∧ True
|
||||
==>
|
||||
True
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (h : P c) : P b ∧ P a := by simp [ab, bc, h]
|
||||
|
||||
-- Almost the same goal, but ordered differently.
|
||||
|
||||
/--
|
||||
trace: [Meta.Tactic.simp.rewrite] ab:1000:
|
||||
a
|
||||
==>
|
||||
b
|
||||
[Meta.Tactic.simp.rewrite] bc:1000:
|
||||
b
|
||||
==>
|
||||
c
|
||||
[Meta.Tactic.simp.rewrite] h:1000:
|
||||
P c
|
||||
==>
|
||||
True
|
||||
[Meta.Tactic.simp.rewrite] h:1000:
|
||||
P c
|
||||
==>
|
||||
True
|
||||
[Meta.Tactic.simp.rewrite] and_self:1000:
|
||||
True ∧ True
|
||||
==>
|
||||
True
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (h : P c) : P a ∧ P b := by simp [ab, bc, h]
|
||||
|
|
@ -12,8 +12,8 @@ trace: [simp] Diagnostics
|
|||
[simp] used theorems (max: 50, num: 2):
|
||||
[simp] f_eq ↦ 50
|
||||
[simp] q_eq ↦ 50
|
||||
[simp] tried theorems (max: 101, num: 2):
|
||||
[simp] f_eq ↦ 101, succeeded: 50
|
||||
[simp] tried theorems (max: 51, num: 2):
|
||||
[simp] f_eq ↦ 51, succeeded: 50
|
||||
[simp] q_eq ↦ 50, succeeded: 50
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-/
|
||||
|
|
@ -33,13 +33,13 @@ def ack : Nat → Nat → Nat
|
|||
|
||||
/--
|
||||
trace: [simp] Diagnostics
|
||||
[simp] used theorems (max: 1201, num: 3):
|
||||
[simp] ack.eq_3 ↦ 1201
|
||||
[simp] Nat.reduceAdd (builtin simproc) ↦ 771
|
||||
[simp] ack.eq_1 ↦ 768
|
||||
[simp] tried theorems (max: 1973, num: 2):
|
||||
[simp] ack.eq_3 ↦ 1973, succeeded: 1201
|
||||
[simp] ack.eq_1 ↦ 768, succeeded: 768
|
||||
[simp] used theorems (max: 1193, num: 3):
|
||||
[simp] ack.eq_3 ↦ 1193
|
||||
[simp] Nat.reduceAdd (builtin simproc) ↦ 508
|
||||
[simp] ack.eq_1 ↦ 508
|
||||
[simp] tried theorems (max: 1705, num: 2):
|
||||
[simp] ack.eq_3 ↦ 1705, succeeded: 1193
|
||||
[simp] ack.eq_1 ↦ 508, succeeded: 508
|
||||
use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
error: tactic 'simp' failed, nested error:
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ x✝ :
|
|||
(y : (_ : Nat) ×' Tree α) →
|
||||
(invImage (fun x => PSigma.casesOn x fun n t => (n, t)) Prod.instWellFoundedRelation).1 y ⟨n.succ, { cs := cs }⟩ →
|
||||
Tree α
|
||||
⊢ Prod.Lex (fun a₁ a₂ => a₁ < a₂) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
|
||||
⊢ Prod.Lex (fun x1 x2 => x1 < x2) (fun a₁ a₂ => sizeOf a₁ < sizeOf a₂)
|
||||
(n, { cs := List.map (fun x => x✝ ⟨n + 1, x.val⟩ ⋯) cs.attach }) (n.succ, { cs := cs })
|
||||
-/
|
||||
#guard_msgs(trace) in
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue