From e1b798483608d61a5abd5f39cb60969bb35df0cb Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 1 May 2024 21:57:44 +0200 Subject: [PATCH] perf: improve `simp` cache behavior for well-behaved dischargers (#4044) See comment at `Methods.wellBehavedDischarge`. The default discharger is now well-behaved. --- src/Lean/Elab/Tactic/NormCast.lean | 4 +- src/Lean/Meta/Tactic/Simp/Main.lean | 46 ++++++++-------- src/Lean/Meta/Tactic/Simp/Rewrite.lean | 16 ++++-- src/Lean/Meta/Tactic/Simp/Types.lean | 27 ++++++--- tests/lean/run/3943.lean | 15 ++++- tests/lean/run/simpDiag.lean | 67 ++++++++++++----------- tests/lean/run/simp_cache_perf_issue.lean | 21 +++++++ 7 files changed, 128 insertions(+), 68 deletions(-) create mode 100644 tests/lean/run/simp_cache_perf_issue.lean diff --git a/src/Lean/Elab/Tactic/NormCast.lean b/src/Lean/Elab/Tactic/NormCast.lean index d9a630b4b7..aaf5373ca7 100644 --- a/src/Lean/Elab/Tactic/NormCast.lean +++ b/src/Lean/Elab/Tactic/NormCast.lean @@ -146,7 +146,9 @@ It tries to rewrite an expression using the elim and move lemmas. On failure, it calls the splitting procedure heuristic. -/ partial def upwardAndElim (up : SimpTheorems) (e : Expr) : SimpM Simp.Step := do - let r ← withDischarger prove do + -- Remark: we set `wellBehavedDischarge := false` because `prove` may access arbitrary elements in the local context. + -- See comment at `Methods.wellBehavedDischarge` + let r ← withDischarger prove (wellBehavedDischarge := false) do Simp.rewrite? e up.post up.erased (tag := "squash") (rflOnly := false) let r := r.getD { expr := e } let r ← r.mkEqTrans (← splittingProcedure r.expr) diff --git a/src/Lean/Meta/Tactic/Simp/Main.lean b/src/Lean/Meta/Tactic/Simp/Main.lean index 8ee496f282..99d5d5ba42 100644 --- a/src/Lean/Meta/Tactic/Simp/Main.lean +++ b/src/Lean/Meta/Tactic/Simp/Main.lean @@ -244,28 +244,26 @@ def getSimpLetCase (n : Name) (t : Expr) (b : Expr) : MetaM SimpLetCase := do /-- We use `withNewlemmas` whenever updating the local context. -We use `withFreshCache` because the local context affects `simp` rewrites -even when `contextual := false`. -For example, the `discharger` may inspect the current local context. The default -discharger does that when applying equational theorems, and the user may -use `(discharger := assumption)` or `(discharger := omega)`. -If the `wishFreshCache` introduces performance issues, we can design a better solution -for the default discharger which is used most of the time. -/ -def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := withFreshCache do +def withNewLemmas {α} (xs : Array Expr) (f : SimpM α) : SimpM α := do if (← getConfig).contextual then - let mut s ← getSimpTheorems - let mut updated := false - for x in xs do - if (← isProof x) then - s ← s.addTheorem (.fvar x.fvarId!) x - updated := true - if updated then - withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f - else - f - else + withFreshCache do + let mut s ← getSimpTheorems + let mut updated := false + for x in xs do + if (← isProof x) then + s ← s.addTheorem (.fvar x.fvarId!) x + updated := true + if updated then + withTheReader Context (fun ctx => { ctx with simpTheorems := s }) f + else + f + else if (← getMethods).wellBehavedDischarge then + -- See comment at `Methods.wellBehavedDischarge` to understand why + -- we don't have to reset the cache f + else + withFreshCache do f def simpProj (e : Expr) : SimpM Result := do match (← reduceProj? e) with @@ -654,12 +652,12 @@ where trace[Meta.Tactic.simp.heads] "{repr e.toHeadIndex}" simpLoop e -@[inline] def withSimpConfig (ctx : Context) (x : MetaM α) : MetaM α := +@[inline] def withSimpContext (ctx : Context) (x : MetaM α) : MetaM α := withConfig (fun c => { c with etaStruct := ctx.config.etaStruct }) <| withReducible x def main (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Result × Stats) := do - let ctx := { ctx with config := (← ctx.config.updateArith) } - withSimpConfig ctx do + let ctx := { ctx with config := (← ctx.config.updateArith), lctxInitIndices := (← getLCtx).numIndices } + withSimpContext ctx do let (r, s) ← simpMain e methods.toMethodsRef ctx |>.run { stats with } trace[Meta.Tactic.simp.numSteps] "{s.numSteps}" return (r, { s with }) @@ -676,7 +674,7 @@ where throw ex def dsimpMain (e : Expr) (ctx : Context) (stats : Stats := {}) (methods : Methods := {}) : MetaM (Expr × Stats) := do - withSimpConfig ctx do + withSimpContext ctx do let (r, s) ← dsimpMain e methods.toMethodsRef ctx |>.run { stats with } pure (r, { s with }) where @@ -698,7 +696,7 @@ def simp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (disc (stats : Stats := {}) : MetaM (Simp.Result × Stats) := do profileitM Exception "simp" (← getOptions) do match discharge? with | none => Simp.main e ctx stats (methods := Simp.mkDefaultMethodsCore simprocs) - | some d => Simp.main e ctx stats (methods := Simp.mkMethods simprocs d) + | some d => Simp.main e ctx stats (methods := Simp.mkMethods simprocs d (wellBehavedDischarge := false)) def dsimp (e : Expr) (ctx : Simp.Context) (simprocs : SimprocsArray := #[]) (stats : Stats := {}) : MetaM (Expr × Stats) := do profileitM Exception "dsimp" (← getOptions) do diff --git a/src/Lean/Meta/Tactic/Simp/Rewrite.lean b/src/Lean/Meta/Tactic/Simp/Rewrite.lean index 6f484502f0..bb8d9045e1 100644 --- a/src/Lean/Meta/Tactic/Simp/Rewrite.lean +++ b/src/Lean/Meta/Tactic/Simp/Rewrite.lean @@ -407,6 +407,7 @@ def mkSEvalMethods : CoreM Methods := do dpre := dpreDefault #[s] dpost := dpostDefault #[s] discharge? := dischargeGround + wellBehavedDischarge := true } def mkSEvalContext : CoreM Context := do @@ -494,10 +495,16 @@ where | .forallE _ d b _ => (d.isEq || d.isHEq || b.hasLooseBVar 0) && go b | _ => e.isFalse -def dischargeUsingAssumption? (e : Expr) : SimpM (Option Expr) := do +private def dischargeUsingAssumption? (e : Expr) : SimpM (Option Expr) := do + let lctxInitIndices := (← readThe Simp.Context).lctxInitIndices + let contextual := (← getConfig).contextual (← getLCtx).findDeclRevM? fun localDecl => do if localDecl.isImplementationDetail then return none + -- The following test is needed to ensure `dischargeUsingAssumption?` is a + -- well-behaved discharger. See comment at `Methods.wellBehavedDischarge` + else if !contextual && localDecl.index >= lctxInitIndices then + return none else if (← isDefEq e localDecl.type) then return some localDecl.toExpr else @@ -546,16 +553,17 @@ def dischargeDefault? (e : Expr) : SimpM (Option Expr) := do abbrev Discharge := Expr → SimpM (Option Expr) -def mkMethods (s : SimprocsArray) (discharge? : Discharge) : Methods := { +def mkMethods (s : SimprocsArray) (discharge? : Discharge) (wellBehavedDischarge : Bool) : Methods := { pre := preDefault s post := postDefault s dpre := dpreDefault s dpost := dpostDefault s - discharge? := discharge? + discharge? + wellBehavedDischarge } def mkDefaultMethodsCore (simprocs : SimprocsArray) : Methods := - mkMethods simprocs dischargeDefault? + mkMethods simprocs dischargeDefault? (wellBehavedDischarge := true) def mkDefaultMethods : CoreM Methods := do if simprocs.get (← getOptions) then diff --git a/src/Lean/Meta/Tactic/Simp/Types.lean b/src/Lean/Meta/Tactic/Simp/Types.lean index b1325fd54b..e2c7981aef 100644 --- a/src/Lean/Meta/Tactic/Simp/Types.lean +++ b/src/Lean/Meta/Tactic/Simp/Types.lean @@ -90,7 +90,12 @@ structure Context where -/ parent? : Option Expr := none dischargeDepth : UInt32 := 0 - + /-- + Number of indices in the local context when starting `simp`. + We use this information to decide which assumptions we can use without + invalidating the cache. + -/ + lctxInitIndices : Nat := 0 deriving Inhabited def Context.isDeclToUnfold (ctx : Context) (declName : Name) : Bool := @@ -249,11 +254,18 @@ structure Simprocs where deriving Inhabited structure Methods where - pre : Simproc := fun _ => return .continue - post : Simproc := fun e => return .done { expr := e } - dpre : DSimproc := fun _ => return .continue - dpost : DSimproc := fun e => return .done e + pre : Simproc := fun _ => return .continue + post : Simproc := fun e => return .done { expr := e } + dpre : DSimproc := fun _ => return .continue + dpost : DSimproc := fun e => return .done e discharge? : Expr → SimpM (Option Expr) := fun _ => return none + /-- + `wellBehavedDischarge` must **not** be set to `true` IF `discharge?` + access local declarations with index >= `Context.lctxInitIndices` when + `contextual := false`. + Reason: it would prevent us from aggressively caching `simp` results. + -/ + wellBehavedDischarge : Bool := true deriving Inhabited unsafe def Methods.toMethodsRefImpl (m : Methods) : MethodsRef := @@ -307,8 +319,9 @@ Save current cache, reset it, execute `x`, and then restore original cache. modify fun s => { s with cache := {} } try x finally modify fun s => { s with cache := cacheSaved } -@[inline] def withDischarger (discharge? : Expr → SimpM (Option Expr)) (x : SimpM α) : SimpM α := - withFreshCache <| withReader (fun r => { MethodsRef.toMethods r with discharge? }.toMethodsRef) x +@[inline] def withDischarger (discharge? : Expr → SimpM (Option Expr)) (wellBehavedDischarge : Bool) (x : SimpM α) : SimpM α := + withFreshCache <| + withReader (fun r => { MethodsRef.toMethods r with discharge?, wellBehavedDischarge }.toMethodsRef) x def recordTriedSimpTheorem (thmId : Origin) : SimpM Unit := do modifyDiag fun { usedThmCounter, triedThmCounter, congrThmCounter } => diff --git a/tests/lean/run/3943.lean b/tests/lean/run/3943.lean index 62b4c63def..3d8fba43d1 100644 --- a/tests/lean/run/3943.lean +++ b/tests/lean/run/3943.lean @@ -19,7 +19,20 @@ def overlap : Nat → Nat | n+1 => overlap n example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by - simp only [overlap] + simp (config := { contextual := true }) only [overlap] + guard_target =ₛ (if (n = 0 → False) then overlap n else overlap (n+1)) = overlap n + sorry + +example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by + -- The following tactic should because the default discharger only uses assumptions available + -- when `simp` was invoked unless `contextual := true` + fail_if_success simp only [overlap] + guard_target =ₛ (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n + sorry + +example : (if (n = 0 → False) then overlap (n+1) else overlap (n+1)) = overlap n := by + -- assumption is not a well-behaved discharger, and the following should still work as expected + simp (discharger := assumption) only [overlap] guard_target =ₛ (if (n = 0 → False) then overlap n else overlap (n+1)) = overlap n sorry diff --git a/tests/lean/run/simpDiag.lean b/tests/lean/run/simpDiag.lean index e5f0e4b36a..1fb98d33c0 100644 --- a/tests/lean/run/simpDiag.lean +++ b/tests/lean/run/simpDiag.lean @@ -50,37 +50,42 @@ example : ack 4 4 = x := by set_option diagnostics true in simp [ack.eq_2, ack.eq_1, ack.eq_3] -/-- -info: [simp] used theorems (max: 22, num: 5): - ack.eq_3 ↦ 22 - ⏎ - Nat.reduceAdd (builtin simproc) ↦ 14 - ⏎ - ack.eq_1 ↦ 11 - ⏎ - ack.eq_2 ↦ 4 - ⏎ - Nat.zero_add ↦ 1[simp] tried theorems (max: 38, num: 4): - ack.eq_3 ↦ 38, succeeded: 22 - ⏎ - ack.eq_1 ↦ 11, succeeded: 11 - ⏎ - ack.eq_2 ↦ 4, succeeded: 4 - ⏎ - Nat.zero_add ↦ 1, succeeded: 1[reduction] unfolded reducible declarations (max: 7, num: 1): - outParam ↦ 7use `set_option diagnostics.threshold ` to control threshold for reporting counters ---- -error: tactic 'simp' failed, nested error: -(deterministic) timeout at `whnf`, maximum number of heartbeats (500) has been reached -use `set_option maxHeartbeats ` to set the limit -use `set_option diagnostics true` to get diagnostic information --/ -#guard_msgs in -set_option maxHeartbeats 500 in -example : ack 4 4 = x := by - set_option diagnostics true in - set_option diagnostics.threshold 0 in - simp [ack.eq_2, ack.eq_1, ack.eq_3] +-- TODO: In the following test we just want to check whether we +-- diagnostics for `simp` when there is a failure. However, the +-- actual counters make the test very unstable since small +-- changes to Lean affect heartbeat consumption, and consequently +-- the number of rewrites tried. +-- /-- +-- info: [simp] used theorems (max: 22, num: 5): +-- ack.eq_3 ↦ 22 +-- ⏎ +-- Nat.reduceAdd (builtin simproc) ↦ 14 +-- ⏎ +-- ack.eq_1 ↦ 11 +-- ⏎ +-- ack.eq_2 ↦ 4 +-- ⏎ +-- Nat.zero_add ↦ 1[simp] tried theorems (max: 38, num: 4): +-- ack.eq_3 ↦ 38, succeeded: 22 +-- ⏎ +-- ack.eq_1 ↦ 11, succeeded: 11 +-- ⏎ +-- ack.eq_2 ↦ 4, succeeded: 4 +-- ⏎ +-- Nat.zero_add ↦ 1, succeeded: 1[reduction] unfolded reducible declarations (max: 7, num: 1): +-- outParam ↦ 7use `set_option diagnostics.threshold ` to control threshold for reporting counters +-- --- +-- error: tactic 'simp' failed, nested error: +-- (deterministic) timeout at `whnf`, maximum number of heartbeats (500) has been reached +-- use `set_option maxHeartbeats ` to set the limit +-- use `set_option diagnostics true` to get diagnostic information +-- -/ +-- #guard_msgs in +-- set_option maxHeartbeats 500 in +-- example : ack 4 4 = x := by +-- set_option diagnostics true in +-- set_option diagnostics.threshold 0 in +-- simp [ack.eq_2, ack.eq_1, ack.eq_3] @[reducible] def h (x : Nat) := match x with diff --git a/tests/lean/run/simp_cache_perf_issue.lean b/tests/lean/run/simp_cache_perf_issue.lean new file mode 100644 index 0000000000..58923b6d3e --- /dev/null +++ b/tests/lean/run/simp_cache_perf_issue.lean @@ -0,0 +1,21 @@ +@[congr] +theorem exists_prop_congr {p p' : Prop} {q q' : p → Prop} (hq : ∀ h, q h ↔ q' h) (hp : p ↔ p') : + Exists q ↔ ∃ h : p', q' (hp.2 h) := sorry + +set_option maxHeartbeats 1000 in +example (x : Nat) : + ∃ (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x) + (h : x = x), True := by + simp only + sorry