lean4-htt/tests/elab/sym_simp_cd.lean
Leonardo de Moura d2907b5c96
feat: add contextDependent to Sym.simp Result with two-tier cache (#12996)
This PR adds per-result `contextDependent` tracking to `Sym.Simp.Result`
and splits the simplifier cache into persistent (context-independent)
and transient (context-dependent, cleared on binder entry). This
replaces the coarse `wellBehavedMethods` flag.

Key changes:
- Add `contextDependent : Bool := false` to `Result.rfl` and
`Result.step`
- Split `State.cache` into `persistentCache` and `transientCache`
- Remove `wellBehavedMethods` from `Methods`
- Replace `withoutModifyingCacheIfNotWellBehaved` with
`withFreshTransientCache`
- Change `DischargeResult` to an inductive (`.failed`/`.solved`)
- Add `dischargeAssumption` (context-dependent discharger for testing)
- Add `sym.simp.debug.cache` trace class
- Propagate `contextDependent` through all combinators (congruence,
transitivity, control flow, arrows, rewriting)
- Add `mkRflResult`/`mkRflResultCD` to avoid dynamic allocation of rfl
results
- Fix `isRfl` to ignore `contextDependent` (was silently broken by the
extra field)

Propagation invariant: when combining sub-results, `cd` is the
disjunction of ALL sub-results' flags — including `.rfl` results. If
`simp` returned `.rfl (contextDependent := true)`, it means `simp` might
take a completely different code path in another local context, so all
downstream results must be marked context-dependent.

---------

Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-20 00:22:08 +00:00

194 lines
8.6 KiB
Text

/-
Test for `contextDependent` two-tier cache in Sym.simp.
Uses `dischargeAssumption` (context-dependent) to verify:
- Context-independent results land in persistent cache and survive across invocations.
- Context-dependent results land in transient cache and are re-computed on second invocation.
-/
import Lean
open Lean Elab Tactic Meta
/-- Invoke simp twice on the same goal, threading the persistent cache. -/
elab "sym_simp_twice" "[" declNames:ident,* "]" : tactic => do
let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeAssumption
let methods : Sym.Simp.Methods := {
pre := Sym.Simp.simpControl.andThen Sym.Simp.simpArrowTelescope
post := Sym.Simp.evalGround.andThen rewrite
}
liftMetaTactic1 fun mvarId => Sym.SymM.run do
let mvarId ← Sym.preprocessMVar mvarId
let target := (← mvarId.getDecl).type
-- First invocation: builds the cache from scratch
let (_, state) ← Sym.Simp.SimpM.run (Sym.Simp.simp target) methods
trace[sym.simp.debug.cache] "second traversal"
-- Second invocation: persistent cache carries over, transient cache is cleared
let (result, _) ← Sym.Simp.SimpM.run (Sym.Simp.simp target) methods (s := state)
(← result.toSimpGoalResult mvarId).toOption
-- Test 1: Ground evaluation is context-independent.
-- The second invocation should hit the persistent cache for the whole expression.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: 2 + 3 = 5
-/
#guard_msgs in
example : 2 + 3 = 5 := by
sym_simp_twice []
-- Test 2: Conditional rewrite using a hypothesis is context-dependent.
-- `dischargeAssumption` uses local hypothesis `h : 0 < n`, so the result is context-dependent
-- and lands in the transient cache. On the second invocation, the transient cache is
-- cleared, so there should be NO persistent cache hit for the overall expression.
-- Only context-independent sub-expressions (literals, fvars) get persistent cache hits.
theorem Nat.add_comm_of_pos (a b : Nat) (_h : 0 < a) : a + b = b + a := Nat.add_comm a b
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 3: Congruence — cd propagates through function application.
-- `n + 2` rewrites context-dependently (cd=true), `3 + 4` evaluates ground (cd=false).
-- The congruence combines both, so the overall result is cd=true.
-- On second traversal: ground sub-expressions (`3 + 4`, `7`) hit persistent cache,
-- but cd-tainted expressions (`2 + n`, `2 + n + 7`) are only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (n + 2) * (3 + 4) = (2 + n) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Similar to previous test, but `Nat.add_comm_of_pos` is not applicable, but discharger must return `cd := true`.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 3 + 4
[sym.simp.debug.cache] transient cache hit: n + 2
[sym.simp.debug.cache] persistent cache hit: 7
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
-/
#guard_msgs in
example (n : Nat) : (n + 2) * (3 + 4) = (n + 2) * 7 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 4: Arrow — cd propagates through implication.
-- The hypothesis `n + 2 = 2 + n` is simplified context-dependently to `True`.
-- `True → True` simplifies to `True`. The whole result is cd=true.
-- `True` hits persistent cache; `2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: True
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : (n + 2 = 2 + n) → True := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 5: Lambda — cd propagates through funext.
-- Body `n + 2` is simplified context-dependently inside the binder.
-- `withFreshTransientCache` clears the transient cache on binder entry.
-- The lambda result `fun x => 2 + n` is only in transient.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
-/
#guard_msgs in
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => n + 2) = (fun _ : Nat => 2 + n) := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 6: Control flow — cd propagates through `ite` condition.
-- The condition `n + 2 = 2 + n` is simplified context-dependently.
-- The `ite` result inherits cd, and `1` (ground) is in persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] persistent cache hit: 1
[sym.simp.debug.cache] persistent cache hit: 1
-/
#guard_msgs in
example (n : Nat) (h : 0 < n) : (if n + 2 = 2 + n then 1 else 0) = 1 := by
sym_simp_twice [Nat.add_comm_of_pos]
-- Test 7: Dependent forall — body cd under binder with `withFreshTransientCache`.
-- Simplifying `∀ (m : Nat), n + 2 = 2 + n` enters a binder (for `m`).
-- The transient cache is cleared on binder entry (`withFreshTransientCache`).
-- The body uses a cd rewrite, so the overall result is cd=true.
-- After "second traversal": `Nat` (the binder type) hits persistent cache.
set_option trace.sym.simp.debug.cache true in
/--
trace: [sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
[sym.simp.debug.cache] second traversal
[sym.simp.debug.cache] persistent cache hit: Nat
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: 2
[sym.simp.debug.cache] persistent cache hit: n
[sym.simp.debug.cache] transient cache hit: 2 + n
-/
#guard_msgs in
set_option linter.unusedVariables false in
example (n : Nat) (h : 0 < n) : ∀ (_ : Nat), n + 2 = 2 + n := by
sym_simp_twice [Nat.add_comm_of_pos]