feat: add permutation theorem support to Sym.simp (#13046)
This PR prevents `Sym.simp` from looping on permutation theorems like `∀ x y, x + y = y + x`. - Add `perm : Bool` field to `Theorem` - Add `isPerm` that checks if LHS and RHS have the same structure with pattern variables (de Bruijn indices) rearranged via a consistent bijection. Uses `ReaderT` (offset for binder entry), `StateT` (forward/backward maps), `ExceptT` (failure). - Compute `perm` in `mkTheoremFromDecl` / `mkTheoremFromExpr` - In `Theorem.rewrite`, when `perm` is true, only apply the rewrite if the result is strictly less than the input (using `acLt`) - Tests include the classic AC normalization stress test with `add_comm`, `add_assoc`, `add_left_comm` --------- Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
8ae39633d1
commit
9f4db470c4
4 changed files with 156 additions and 95 deletions
|
|
@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.Simproc
|
|||
public import Lean.Meta.Sym.Simp.Theorems
|
||||
public import Lean.Meta.Sym.Simp.App
|
||||
public import Lean.Meta.Sym.Simp.Discharger
|
||||
import Lean.Meta.ACLt
|
||||
import Lean.Meta.Sym.InstantiateS
|
||||
import Lean.Meta.Sym.InstantiateMVarsS
|
||||
import Init.Data.Range.Polymorphic.Iterators
|
||||
|
|
@ -71,10 +72,16 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar
|
|||
let expr ← instantiateRevBetaS rhs args.toArray
|
||||
if isSameExpr e expr then
|
||||
return mkRflResultCD isCD
|
||||
else if !(← checkPerm thm.perm e expr) then
|
||||
return mkRflResultCD isCD
|
||||
else
|
||||
return .step expr proof (contextDependent := isCD)
|
||||
else
|
||||
return .rfl
|
||||
where
|
||||
checkPerm (perm : Bool) (e result : Expr) : MetaM Bool := do
|
||||
if !perm then return true
|
||||
acLt result e
|
||||
|
||||
public def Theorems.rewrite (thms : Theorems) (d : Discharger := dischargeNone) : Simproc := fun e => do
|
||||
-- Track `cd` across all attempted theorems. If theorem A fails with cd=true
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ public import Lean.Meta.DiscrTree
|
|||
import Lean.Meta.Sym.Simp.DiscrTree
|
||||
import Lean.Meta.AppBuilder
|
||||
import Lean.ExtraModUses
|
||||
import Init.Omega
|
||||
public section
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
|
||||
|
|
@ -26,6 +27,10 @@ structure Theorem where
|
|||
pattern : Pattern
|
||||
/-- Right-hand side of the equation. -/
|
||||
rhs : Expr
|
||||
/-- If `true`, the theorem is a permutation rule (e.g., `x + y = y + x`).
|
||||
Rewriting is only applied when the result is strictly less than the input
|
||||
(using `acLt`), preventing infinite loops. -/
|
||||
perm : Bool := false
|
||||
deriving Inhabited
|
||||
|
||||
instance : BEq Theorem where
|
||||
|
|
@ -45,6 +50,49 @@ def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem :=
|
|||
def Theorems.getMatchWithExtra (thms : Theorems) (e : Expr) : Array (Theorem × Nat) :=
|
||||
Sym.getMatchWithExtra thms.thms e
|
||||
|
||||
/--
|
||||
Check whether `lhs` and `rhs` (with `numVars` pattern variables represented as `.bvar` indices
|
||||
`≥ 0` before any binder entry) are permutations of each other — same structure with only
|
||||
pattern variable indices rearranged via a consistent bijection.
|
||||
|
||||
Bvars with index `< offset` are "local" (introduced by binders inside the pattern) and must
|
||||
match exactly. Bvars with index `≥ offset` are pattern variables and may be permuted,
|
||||
but the mapping must be a bijection.
|
||||
|
||||
Simplified compared to `Meta.simp`'s `isPerm`:
|
||||
- Uses de Bruijn indices instead of metavariables
|
||||
- No `.proj` (folded into applications) or `.letE` (zeta-expanded) cases
|
||||
-/
|
||||
private abbrev IsPermM := ReaderT Nat $ StateT (Array (Option Nat)) $ Except Unit
|
||||
|
||||
private partial def isPermAux (a b : Expr) : IsPermM Unit := do
|
||||
match a, b with
|
||||
| .bvar i, .bvar j =>
|
||||
let offset ← read
|
||||
if i < offset && j < offset then
|
||||
unless i == j do throw ()
|
||||
else if i >= offset && j >= offset then
|
||||
let pi := i - offset
|
||||
let pj := j - offset
|
||||
let fwd ← get
|
||||
if h : pi >= fwd.size then throw () else
|
||||
match fwd[pi] with
|
||||
| none =>
|
||||
-- Check injectivity: pj must not already be a target of another mapping
|
||||
if fwd.contains (some pj) then throw ()
|
||||
set (fwd.set pi (some pj))
|
||||
| some pj' => unless pj == pj' do throw ()
|
||||
else throw ()
|
||||
| .app f₁ a₁, .app f₂ a₂ => isPermAux f₁ f₂; isPermAux a₁ a₂
|
||||
| .mdata _ s, t => isPermAux s t
|
||||
| s, .mdata _ t => isPermAux s t
|
||||
| .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂)
|
||||
| .lam _ d₁ b₁ _, .lam _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂)
|
||||
| s, t => unless s == t do throw ()
|
||||
|
||||
def isPerm (numVars : Nat) (lhs rhs : Expr) : Bool :=
|
||||
((isPermAux lhs rhs).run 0 |>.run (Array.replicate numVars none)) matches .ok _
|
||||
|
||||
/-- Describes how a theorem's conclusion was adapted to an equality for use in `Sym.simp`. -/
|
||||
private inductive EqAdaptation where
|
||||
/-- Already an equality `lhs = rhs`. Proof is used as-is. -/
|
||||
|
|
@ -99,13 +147,15 @@ where
|
|||
def mkTheoremFromDecl (declName : Name) : MetaM Theorem := do
|
||||
let (pattern, (rhs, adaptation)) ← mkPatternFromDeclWithKey declName selectEqKey
|
||||
let expr ← wrapProof pattern.varTypes.size (mkConst declName) adaptation
|
||||
return { expr, pattern, rhs }
|
||||
let perm := isPerm pattern.varTypes.size pattern.pattern rhs
|
||||
return { expr, pattern, rhs, perm }
|
||||
|
||||
/-- Create a `Theorem` from a proof expression. Handles equalities, `¬`, `↔`, and propositions. -/
|
||||
def mkTheoremFromExpr (e : Expr) : MetaM Theorem := do
|
||||
let (pattern, (rhs, adaptation)) ← mkPatternFromExprWithKey e [] selectEqKey
|
||||
let expr ← wrapProof pattern.varTypes.size e adaptation
|
||||
return { expr, pattern, rhs }
|
||||
let perm := isPerm pattern.varTypes.size pattern.pattern rhs
|
||||
return { expr, pattern, rhs, perm }
|
||||
|
||||
/--
|
||||
Environment extension storing a set of `Sym.Simp` theorems.
|
||||
|
|
|
|||
|
|
@ -42,153 +42,120 @@ example : 2 + 3 = 5 := by
|
|||
-- and lands in the transient cache. On the second invocation, the transient cache is
|
||||
-- cleared, so there should be NO persistent cache hit for the overall expression.
|
||||
-- Only context-independent sub-expressions (literals, fvars) get persistent cache hits.
|
||||
theorem Nat.add_comm_of_pos (a b : Nat) (_h : 0 < a) : a + b = b + a := Nat.add_comm a b
|
||||
opaque f : Nat → Nat
|
||||
axiom f_idem (a : Nat) (_h : 0 < a) : f (f a) = f a
|
||||
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (n : Nat) (h : 0 < n) : n + 2 = 2 + n := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (h : 0 < n) : f (f n) = f (f (f n)) := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Test 3: Congruence — cd propagates through function application.
|
||||
-- `n + 2` rewrites context-dependently (cd=true), `3 + 4` evaluates ground (cd=false).
|
||||
-- The congruence combines both, so the overall result is cd=true.
|
||||
-- On second traversal: ground sub-expressions (`3 + 4`, `7`) hit persistent cache,
|
||||
-- but cd-tainted expressions (`2 + n`, `2 + n + 7`) are only in transient.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n + 7
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: 3 + 4
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: 7
|
||||
[sym.simp.debug.cache] transient cache hit: (2 + n) * 7
|
||||
[sym.simp.debug.cache] persistent cache hit: f n + 7
|
||||
[sym.simp.debug.cache] persistent cache hit: f n + 7
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (n : Nat) (h : 0 < n) : (n + 2) * (3 + 4) = (2 + n) * 7 := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (h : 0 < n) : f (f n) + (3 + 4) = f n + 7 := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Similar to previous test, but `Nat.add_comm_of_pos` is not applicable, but discharger must return `cd := true`.
|
||||
-- Similar to previous test, but `f_idem` is not applicable (no hypothesis), but discharger must return `cd := true`.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] transient cache hit: n + 2
|
||||
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
|
||||
trace: [sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n) + 7
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: 3 + 4
|
||||
[sym.simp.debug.cache] transient cache hit: n + 2
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] persistent cache hit: 7
|
||||
[sym.simp.debug.cache] transient cache hit: (n + 2) * 7
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n) + 7
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (n : Nat) : (n + 2) * (3 + 4) = (n + 2) * 7 := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) : f (f n) + (3 + 4) = f (f n) + 7 := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Test 4: Arrow — cd propagates through implication.
|
||||
-- The hypothesis `n + 2 = 2 + n` is simplified context-dependently to `True`.
|
||||
-- `True → True` simplifies to `True`. The whole result is cd=true.
|
||||
-- `True` hits persistent cache; `2 + n` is only in transient.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: True
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option linter.unusedVariables false in
|
||||
example (n : Nat) (h : 0 < n) : (n + 2 = 2 + n) → True := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (h : 0 < n) : (f (f n) = f n) → True := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Test 5: Lambda — cd propagates through funext.
|
||||
-- Body `n + 2` is simplified context-dependently inside the binder.
|
||||
-- `withFreshTransientCache` clears the transient cache on binder entry.
|
||||
-- The lambda result `fun x => 2 + n` is only in transient.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: fun x => f n
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: fun x => 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: fun x => f n
|
||||
[sym.simp.debug.cache] persistent cache hit: fun x => f n
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => n + 2) = (fun _ : Nat => 2 + n) := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (_h : 0 < n) : (fun _ : Nat => f (f n)) = (fun _ : Nat => f n) := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Test 6: Control flow — cd propagates through `ite` condition.
|
||||
-- The condition `n + 2 = 2 + n` is simplified context-dependently.
|
||||
-- The `ite` result inherits cd, and `1` (ground) is in persistent cache.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: 1
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: 1
|
||||
[sym.simp.debug.cache] persistent cache hit: 1
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (n : Nat) (h : 0 < n) : (if n + 2 = 2 + n then 1 else 0) = 1 := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (h : 0 < n) : (if f (f n) = f n then 1 else 0) = 1 := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
||||
-- Test 7: Dependent forall — body cd under binder with `withFreshTransientCache`.
|
||||
-- Simplifying `∀ (m : Nat), n + 2 = 2 + n` enters a binder (for `m`).
|
||||
-- The transient cache is cleared on binder entry (`withFreshTransientCache`).
|
||||
-- The body uses a cd rewrite, so the overall result is cd=true.
|
||||
-- After "second traversal": `Nat` (the binder type) hits persistent cache.
|
||||
set_option trace.sym.simp.debug.cache true in
|
||||
/--
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
trace: [sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] second traversal
|
||||
[sym.simp.debug.cache] persistent cache hit: Nat
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: 2
|
||||
[sym.simp.debug.cache] persistent cache hit: n
|
||||
[sym.simp.debug.cache] transient cache hit: 2 + n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
[sym.simp.debug.cache] transient cache hit: f (f n)
|
||||
[sym.simp.debug.cache] persistent cache hit: f n
|
||||
-/
|
||||
#guard_msgs in
|
||||
set_option linter.unusedVariables false in
|
||||
example (n : Nat) (h : 0 < n) : ∀ (_ : Nat), n + 2 = 2 + n := by
|
||||
sym_simp_twice [Nat.add_comm_of_pos]
|
||||
example (n : Nat) (h : 0 < n) : ∀ (_ : Nat), f (f n) = f (f (f n)) := by
|
||||
sym_simp_twice [f_idem]
|
||||
|
|
|
|||
37
tests/elab/sym_simp_perm1.lean
Normal file
37
tests/elab/sym_simp_perm1.lean
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
import Lean
|
||||
|
||||
/-! Tests for permutation theorem support in `Sym.simp` -/
|
||||
|
||||
-- Nat.add_comm is a permutation theorem: x + y = y + x
|
||||
-- Without perm support, `simp` with this theorem would loop.
|
||||
|
||||
register_sym_simp commSimp where
|
||||
post := ground >> rewrite [Nat.add_comm]
|
||||
|
||||
-- This should terminate: Nat.add_comm is detected as perm,
|
||||
-- and only applied when result < input.
|
||||
example (x y : Nat) : x + y = y + x := by
|
||||
sym =>
|
||||
simp commSimp
|
||||
|
||||
-- Combining perm with non-perm theorems
|
||||
register_sym_simp commZeroSimp where
|
||||
post := ground >> rewrite [Nat.add_comm, Nat.zero_add, Nat.add_zero]
|
||||
|
||||
example (x y : Nat) : 0 + (x + y) = y + x := by
|
||||
sym =>
|
||||
simp commZeroSimp
|
||||
|
||||
-- Verify perm doesn't interfere with non-perm theorems
|
||||
register_sym_simp nonPermSimp where
|
||||
post := ground >> rewrite [Nat.zero_add]
|
||||
|
||||
example (x : Nat) : 0 + x = x := by
|
||||
sym =>
|
||||
simp nonPermSimp
|
||||
|
||||
register_sym_simp simple where
|
||||
post := ground
|
||||
|
||||
example (x y z w : Nat) : x + y + z + w = w + (z + y) + x := by
|
||||
sym => simp simple [Nat.add_comm, Nat.add_assoc, Nat.add_left_comm]
|
||||
Loading…
Add table
Reference in a new issue