perf: improve simp cache behavior for well-behaved dischargers (#4044)
See comment at `Methods.wellBehavedDischarge`. The default discharger is now well-behaved.
This commit is contained in:
parent
d9ea092585
commit
e1b7984836
7 changed files with 128 additions and 68 deletions
|
|
@ -146,7 +146,9 @@ It tries to rewrite an expression using the elim and move lemmas.
|
|||
On failure, it calls the splitting procedure heuristic.
|
||||
-/
|
||||
partial def upwardAndElim (up : SimpTheorems) (e : Expr) : SimpM Simp.Step := do
|
||||
let r ← withDischarger prove do
|
||||
-- Remark: we set `wellBehavedDischarge := false` because `prove` may access arbitrary elements in the local context.
|
||||
-- See comment at `Methods.wellBehavedDischarge`
|
||||
let r ← withDischarger prove (wellBehavedDischarge := false) do
|
||||
Simp.rewrite? e up.post up.erased (tag := "squash") (rflOnly := false)
|
||||
let r := r.getD { expr := e }
|
||||
let r ← r.mkEqTrans (← splittingProcedure r.expr)
|
||||
|
|
|
|||
|
|
@ -244,28 +244,26 @@ def getSimpLetCase (n : Name) (t : Expr) (b : Expr) : MetaM SimpLetCase := do
|
|||
|
||||
/--
|
||||
We use `withNewlemmas` whenever updating the local context.
|
||||
We use `withFreshCache` because the local context affects `simp` rewrites
|
||||
even when `contextual := false`.
|
||||
For example, the `discharger` may inspect the current local context. The default
|
||||
discharger does that when applying equational theorems, and the user may
|
||||
use `(discharger := assumption)` or `(discharger := omega)`.
|
||||
If the `wishFreshCache` introduces performance issues, we can design a better solution
|
||||
for the default discharger which is used most of the time.
|
||||
-/
|
||||
def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := withFreshCache do
|
||||
def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do
|
||||
if (← getConfig).contextual then
|
||||
let mut s ← getSimpTheorems
|
||||
let mut updated := false
|
||||
for x in xs do
|
||||
if (← isProof x) then
|
||||
s ← s.addTheorem (.fvar x.fvarId!) x
|
||||
updated := true
|
||||
if updated then
|
||||
withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f
|
||||
else
|
||||
f
|
||||
else
|
||||
withFreshCache do
|
||||
let mut s ← getSimpTheorems
|
||||
let mut updated := false
|
||||
for x in xs do
|
||||
if (← isProof x) then
|
||||
s ← s.addTheorem (.fvar x.fvarId!) x
|
||||
updated := true
|
||||
if updated then
|
||||
withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f
|
||||
else
|
||||
f
|
||||
else if (← getMethods).wellBehavedDischarge then
|
||||
-- See comment at `Methods.wellBehavedDischarge` to understand why
|
||||
-- we don't have to reset the cache
|
||||
f
|
||||
else
|
||||
withFreshCache do f
|
||||
|
||||
def simpProj (e : Expr) : SimpM Result := do
|
||||
match (← reduceProj? e) with
|
||||
|
|
@ -654,12 +652,12 @@ where
|
|||
trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}"
|
||||
simpLoop e
|
||||
|
||||
@[inline] def withSimpConfig (ctx : Context) (x : MetaM α) : MetaM α :=
|
||||
@[inline] def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α :=
|
||||
withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| withReducible x
|
||||
|
||||
def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do
|
||||
let ctx := { ctx with config := (← ctx.config.updateArith) }
|
||||
withSimpConfig ctx do
|
||||
let ctx := { ctx with config := (← ctx.config.updateArith), lctxInitIndices := (← getLCtx).numIndices }
|
||||
withSimpContext ctx do
|
||||
let (r, s) ← simpMain e methods.toMethodsRef ctx |>.run { stats with }
|
||||
trace[Meta.Tactic.simp.numSteps] "{s.numSteps}"
|
||||
return (r, { s with })
|
||||
|
|
@ -676,7 +674,7 @@ where
|
|||
throw ex
|
||||
|
||||
def dsimpMain (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Expr × Stats) := do
|
||||
withSimpConfig ctx do
|
||||
withSimpContext ctx do
|
||||
let (r, s) ← dsimpMain e methods.toMethodsRef ctx |>.run { stats with }
|
||||
pure (r, { s with })
|
||||
where
|
||||
|
|
@ -698,7 +696,7 @@ def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (disc
|
|||
(stats : Stats := {}) : MetaM (Simp.Result × Stats) := do profileitM Exception "simp" (← getOptions) do
|
||||
match discharge? with
|
||||
| none => Simp.main e ctx stats (methods := Simp.mkDefaultMethodsCore simprocs)
|
||||
| some d => Simp.main e ctx stats (methods := Simp.mkMethods simprocs d)
|
||||
| some d => Simp.main e ctx stats (methods := Simp.mkMethods simprocs d (wellBehavedDischarge := false))
|
||||
|
||||
def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[])
|
||||
(stats : Stats := {}) : MetaM (Expr × Stats) := do profileitM Exception "dsimp" (← getOptions) do
|
||||
|
|
|
|||
|
|
@ -407,6 +407,7 @@ def mkSEvalMethods : CoreM Methods := do
|
|||
dpre := dpreDefault #[s]
|
||||
dpost := dpostDefault #[s]
|
||||
discharge? := dischargeGround
|
||||
wellBehavedDischarge := true
|
||||
}
|
||||
|
||||
def mkSEvalContext : CoreM Context := do
|
||||
|
|
@ -494,10 +495,16 @@ where
|
|||
| .forallE _ d b _ => (d.isEq || d.isHEq || b.hasLooseBVar 0) && go b
|
||||
| _ => e.isFalse
|
||||
|
||||
def dischargeUsingAssumption? (e : Expr) : SimpM (Option Expr) := do
|
||||
private def dischargeUsingAssumption? (e : Expr) : SimpM (Option Expr) := do
|
||||
let lctxInitIndices := (← readThe Simp.Context).lctxInitIndices
|
||||
let contextual := (← getConfig).contextual
|
||||
(← getLCtx).findDeclRevM? fun localDecl => do
|
||||
if localDecl.isImplementationDetail then
|
||||
return none
|
||||
-- The following test is needed to ensure `dischargeUsingAssumption?` is a
|
||||
-- well-behaved discharger. See comment at `Methods.wellBehavedDischarge`
|
||||
else if !contextual && localDecl.index >= lctxInitIndices then
|
||||
return none
|
||||
else if (← isDefEq e localDecl.type) then
|
||||
return some localDecl.toExpr
|
||||
else
|
||||
|
|
@ -546,16 +553,17 @@ def dischargeDefault? (e : Expr) : SimpM (Option Expr) := do
|
|||
|
||||
abbrev Discharge := Expr → SimpM (Option Expr)
|
||||
|
||||
def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := {
|
||||
def mkMethods (s : SimprocsArray) (discharge? : Discharge) (wellBehavedDischarge : Bool) : Methods := {
|
||||
pre := preDefault s
|
||||
post := postDefault s
|
||||
dpre := dpreDefault s
|
||||
dpost := dpostDefault s
|
||||
discharge? := discharge?
|
||||
discharge?
|
||||
wellBehavedDischarge
|
||||
}
|
||||
|
||||
def mkDefaultMethodsCore (simprocs : SimprocsArray) : Methods :=
|
||||
mkMethods simprocs dischargeDefault?
|
||||
mkMethods simprocs dischargeDefault? (wellBehavedDischarge := true)
|
||||
|
||||
def mkDefaultMethods : CoreM Methods := do
|
||||
if simprocs.get (← getOptions) then
|
||||
|
|
|
|||
|
|
@ -90,7 +90,12 @@ structure Context where
|
|||
-/
|
||||
parent? : Option Expr := none
|
||||
dischargeDepth : UInt32 := 0
|
||||
|
||||
/--
|
||||
Number of indices in the local context when starting `simp`.
|
||||
We use this information to decide which assumptions we can use without
|
||||
invalidating the cache.
|
||||
-/
|
||||
lctxInitIndices : Nat := 0
|
||||
deriving Inhabited
|
||||
|
||||
def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool :=
|
||||
|
|
@ -249,11 +254,18 @@ structure Simprocs where
|
|||
deriving Inhabited
|
||||
|
||||
structure Methods where
|
||||
pre : Simproc := fun _ => return .continue
|
||||
post : Simproc := fun e => return .done { expr := e }
|
||||
dpre : DSimproc := fun _ => return .continue
|
||||
dpost : DSimproc := fun e => return .done e
|
||||
pre : Simproc := fun _ => return .continue
|
||||
post : Simproc := fun e => return .done { expr := e }
|
||||
dpre : DSimproc := fun _ => return .continue
|
||||
dpost : DSimproc := fun e => return .done e
|
||||
discharge? : Expr → SimpM (Option Expr) := fun _ => return none
|
||||
/--
|
||||
`wellBehavedDischarge` must **not** be set to `true` IF `discharge?`
|
||||
access local declarations with index >= `Context.lctxInitIndices` when
|
||||
`contextual := false`.
|
||||
Reason: it would prevent us from aggressively caching `simp` results.
|
||||
-/
|
||||
wellBehavedDischarge : Bool := true
|
||||
deriving Inhabited
|
||||
|
||||
unsafe def Methods.toMethodsRefImpl (m : Methods) : MethodsRef :=
|
||||
|
|
@ -307,8 +319,9 @@ Save current cache, reset it, execute `x`, and then restore original cache.
|
|||
modify fun s => { s with cache := {} }
|
||||
try x finally modify fun s => { s with cache := cacheSaved }
|
||||
|
||||
@[inline] def withDischarger (discharge? : Expr → SimpM (Option Expr)) (x : SimpM α) : SimpM α :=
|
||||
withFreshCache <| withReader (fun r => { MethodsRef.toMethods r with discharge? }.toMethodsRef) x
|
||||
@[inline] def withDischarger (discharge? : Expr → SimpM (Option Expr)) (wellBehavedDischarge : Bool) (x : SimpM α) : SimpM α :=
|
||||
withFreshCache <|
|
||||
withReader (fun r => { MethodsRef.toMethods r with discharge?, wellBehavedDischarge }.toMethodsRef) x
|
||||
|
||||
def recordTriedSimpTheorem (thmId : Origin) : SimpM Unit := do
|
||||
modifyDiag fun { usedThmCounter, triedThmCounter, congrThmCounter } =>
|
||||
|
|
|
|||
|
|
@ -19,7 +19,20 @@ def overlap : Nat → Nat
|
|||
| n+1 => overlap n
|
||||
|
||||
example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by
|
||||
simp only [overlap]
|
||||
simp (config := { contextual := true }) only [overlap]
|
||||
guard_target =ₛ (if (n = 0 → False) then overlap n else overlap (n+1)) = overlap n
|
||||
sorry
|
||||
|
||||
example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by
|
||||
-- The following tactic should because the default discharger only uses assumptions available
|
||||
-- when `simp` was invoked unless `contextual := true`
|
||||
fail_if_success simp only [overlap]
|
||||
guard_target =ₛ (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n
|
||||
sorry
|
||||
|
||||
example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by
|
||||
-- assumption is not a well-behaved discharger, and the following should still work as expected
|
||||
simp (discharger := assumption) only [overlap]
|
||||
guard_target =ₛ (if (n = 0 → False) then overlap n else overlap (n+1)) = overlap n
|
||||
sorry
|
||||
|
||||
|
|
|
|||
|
|
@ -50,37 +50,42 @@ example : ack 4 4 = x := by
|
|||
set_option diagnostics true in
|
||||
simp [ack.eq_2, ack.eq_1, ack.eq_3]
|
||||
|
||||
/--
|
||||
info: [simp] used theorems (max: 22, num: 5):
|
||||
ack.eq_3 ↦ 22
|
||||
⏎
|
||||
Nat.reduceAdd (builtin simproc) ↦ 14
|
||||
⏎
|
||||
ack.eq_1 ↦ 11
|
||||
⏎
|
||||
ack.eq_2 ↦ 4
|
||||
⏎
|
||||
Nat.zero_add ↦ 1[simp] tried theorems (max: 38, num: 4):
|
||||
ack.eq_3 ↦ 38, succeeded: 22
|
||||
⏎
|
||||
ack.eq_1 ↦ 11, succeeded: 11
|
||||
⏎
|
||||
ack.eq_2 ↦ 4, succeeded: 4
|
||||
⏎
|
||||
Nat.zero_add ↦ 1, succeeded: 1[reduction] unfolded reducible declarations (max: 7, num: 1):
|
||||
outParam ↦ 7use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
---
|
||||
error: tactic 'simp' failed, nested error:
|
||||
(deterministic) timeout at `whnf`, maximum number of heartbeats (500) has been reached
|
||||
use `set_option maxHeartbeats <num>` to set the limit
|
||||
use `set_option diagnostics true` to get diagnostic information
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option maxHeartbeats 500 in
|
||||
example : ack 4 4 = x := by
|
||||
set_option diagnostics true in
|
||||
set_option diagnostics.threshold 0 in
|
||||
simp [ack.eq_2, ack.eq_1, ack.eq_3]
|
||||
-- TODO: In the following test we just want to check whether we
|
||||
-- diagnostics for `simp` when there is a failure. However, the
|
||||
-- actual counters make the test very unstable since small
|
||||
-- changes to Lean affect heartbeat consumption, and consequently
|
||||
-- the number of rewrites tried.
|
||||
-- /--
|
||||
-- info: [simp] used theorems (max: 22, num: 5):
|
||||
-- ack.eq_3 ↦ 22
|
||||
-- ⏎
|
||||
-- Nat.reduceAdd (builtin simproc) ↦ 14
|
||||
-- ⏎
|
||||
-- ack.eq_1 ↦ 11
|
||||
-- ⏎
|
||||
-- ack.eq_2 ↦ 4
|
||||
-- ⏎
|
||||
-- Nat.zero_add ↦ 1[simp] tried theorems (max: 38, num: 4):
|
||||
-- ack.eq_3 ↦ 38, succeeded: 22
|
||||
-- ⏎
|
||||
-- ack.eq_1 ↦ 11, succeeded: 11
|
||||
-- ⏎
|
||||
-- ack.eq_2 ↦ 4, succeeded: 4
|
||||
-- ⏎
|
||||
-- Nat.zero_add ↦ 1, succeeded: 1[reduction] unfolded reducible declarations (max: 7, num: 1):
|
||||
-- outParam ↦ 7use `set_option diagnostics.threshold <num>` to control threshold for reporting counters
|
||||
-- ---
|
||||
-- error: tactic 'simp' failed, nested error:
|
||||
-- (deterministic) timeout at `whnf`, maximum number of heartbeats (500) has been reached
|
||||
-- use `set_option maxHeartbeats <num>` to set the limit
|
||||
-- use `set_option diagnostics true` to get diagnostic information
|
||||
-- -/
|
||||
-- #guard_msgs in
|
||||
-- set_option maxHeartbeats 500 in
|
||||
-- example : ack 4 4 = x := by
|
||||
-- set_option diagnostics true in
|
||||
-- set_option diagnostics.threshold 0 in
|
||||
-- simp [ack.eq_2, ack.eq_1, ack.eq_3]
|
||||
|
||||
@[reducible] def h (x : Nat) :=
|
||||
match x with
|
||||
|
|
|
|||
21
tests/lean/run/simp_cache_perf_issue.lean
Normal file
21
tests/lean/run/simp_cache_perf_issue.lean
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
@[congr]
|
||||
theorem exists_prop_congr {p p' : Prop} {q q' : p → Prop} (hq : ∀ h, q h ↔ q' h) (hp : p ↔ p') :
|
||||
Exists q ↔ ∃ h : p', q' (hp.2 h) := sorry
|
||||
|
||||
set_option maxHeartbeats 1000 in
|
||||
example (x : Nat) :
|
||||
∃ (h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x)
|
||||
(h : x = x), True := by
|
||||
simp only
|
||||
sorry
|
||||
Loading…
Add table
Reference in a new issue