From 9f4db470c4f6c85e932fadd37b5e7913a98a4cf1 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 22 Mar 2026 17:22:36 -0700 Subject: [PATCH] feat: add permutation theorem support to `Sym.simp` (#13046) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR prevents `Sym.simp` from looping on permutation theorems like `∀ x y, x + y = y + x`. - Add `perm : Bool` field to `Theorem` - Add `isPerm` that checks if LHS and RHS have the same structure with pattern variables (de Bruijn indices) rearranged via a consistent bijection. Uses `ReaderT` (offset for binder entry), `StateT` (forward/backward maps), `ExceptT` (failure). - Compute `perm` in `mkTheoremFromDecl` / `mkTheoremFromExpr` - In `Theorem.rewrite`, when `perm` is true, only apply the rewrite if the result is strictly less than the input (using `acLt`) - Tests include the classic AC normalization stress test with `add_comm`, `add_assoc`, `add_left_comm` --------- Co-authored-by: Claude Opus 4.6 (1M context) --- src/Lean/Meta/Sym/Simp/Rewrite.lean | 7 ++ src/Lean/Meta/Sym/Simp/Theorems.lean | 54 +++++++++- tests/elab/sym_simp_cd.lean | 153 +++++++++++---------------- tests/elab/sym_simp_perm1.lean | 37 +++++++ 4 files changed, 156 insertions(+), 95 deletions(-) create mode 100644 tests/elab/sym_simp_perm1.lean diff --git a/src/Lean/Meta/Sym/Simp/Rewrite.lean b/src/Lean/Meta/Sym/Simp/Rewrite.lean index c8e0626141..dfff5a047a 100644 --- a/src/Lean/Meta/Sym/Simp/Rewrite.lean +++ b/src/Lean/Meta/Sym/Simp/Rewrite.lean @@ -9,6 +9,7 @@ public import Lean.Meta.Sym.Simp.Simproc public import Lean.Meta.Sym.Simp.Theorems public import Lean.Meta.Sym.Simp.App public import Lean.Meta.Sym.Simp.Discharger +import Lean.Meta.ACLt import Lean.Meta.Sym.InstantiateS import Lean.Meta.Sym.InstantiateMVarsS import Init.Data.Range.Polymorphic.Iterators @@ -71,10 +72,16 @@ public def Theorem.rewrite (thm : Theorem) (e : Expr) (d : Discharger := dischar let expr ← instantiateRevBetaS rhs args.toArray if isSameExpr e expr then return mkRflResultCD isCD + else if !(← checkPerm thm.perm e expr) then + return mkRflResultCD isCD else return .step expr proof (contextDependent := isCD) else return .rfl +where + checkPerm (perm : Bool) (e result : Expr) : MetaM Bool := do + if !perm then return true + acLt result e public def Theorems.rewrite (thms : Theorems) (d : Discharger := dischargeNone) : Simproc := fun e => do -- Track `cd` across all attempted theorems. If theorem A fails with cd=true diff --git a/src/Lean/Meta/Sym/Simp/Theorems.lean b/src/Lean/Meta/Sym/Simp/Theorems.lean index c31e404ba2..b78cb20443 100644 --- a/src/Lean/Meta/Sym/Simp/Theorems.lean +++ b/src/Lean/Meta/Sym/Simp/Theorems.lean @@ -10,6 +10,7 @@ public import Lean.Meta.DiscrTree import Lean.Meta.Sym.Simp.DiscrTree import Lean.Meta.AppBuilder import Lean.ExtraModUses +import Init.Omega public section namespace Lean.Meta.Sym.Simp @@ -26,6 +27,10 @@ structure Theorem where pattern : Pattern /-- Right-hand side of the equation. -/ rhs : Expr + /-- If `true`, the theorem is a permutation rule (e.g., `x + y = y + x`). + Rewriting is only applied when the result is strictly less than the input + (using `acLt`), preventing infinite loops. -/ + perm : Bool := false deriving Inhabited instance : BEq Theorem where @@ -45,6 +50,49 @@ def Theorems.getMatch (thms : Theorems) (e : Expr) : Array Theorem := def Theorems.getMatchWithExtra (thms : Theorems) (e : Expr) : Array (Theorem × Nat) := Sym.getMatchWithExtra thms.thms e +/-- +Check whether `lhs` and `rhs` (with `numVars` pattern variables represented as `.bvar` indices +`≥ 0` before any binder entry) are permutations of each other — same structure with only +pattern variable indices rearranged via a consistent bijection. + +Bvars with index `< offset` are "local" (introduced by binders inside the pattern) and must +match exactly. Bvars with index `≥ offset` are pattern variables and may be permuted, +but the mapping must be a bijection. + +Simplified compared to `Meta.simp`'s `isPerm`: +- Uses de Bruijn indices instead of metavariables +- No `.proj` (folded into applications) or `.letE` (zeta-expanded) cases +-/ +private abbrev IsPermM := ReaderT Nat $ StateT (Array (Option Nat)) $ Except Unit + +private partial def isPermAux (a b : Expr) : IsPermM Unit := do + match a, b with + | .bvar i, .bvar j => + let offset ← read + if i < offset && j < offset then + unless i == j do throw () + else if i >= offset && j >= offset then + let pi := i - offset + let pj := j - offset + let fwd ← get + if h : pi >= fwd.size then throw () else + match fwd[pi] with + | none => + -- Check injectivity: pj must not already be a target of another mapping + if fwd.contains (some pj) then throw () + set (fwd.set pi (some pj)) + | some pj' => unless pj == pj' do throw () + else throw () + | .app f₁ a₁, .app f₂ a₂ => isPermAux f₁ f₂; isPermAux a₁ a₂ + | .mdata _ s, t => isPermAux s t + | s, .mdata _ t => isPermAux s t + | .forallE _ d₁ b₁ _, .forallE _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂) + | .lam _ d₁ b₁ _, .lam _ d₂ b₂ _ => isPermAux d₁ d₂; withReader (· + 1) (isPermAux b₁ b₂) + | s, t => unless s == t do throw () + +def isPerm (numVars : Nat) (lhs rhs : Expr) : Bool := + ((isPermAux lhs rhs).run 0 |>.run (Array.replicate numVars none)) matches .ok _ + /-- Describes how a theorem's conclusion was adapted to an equality for use in `Sym.simp`. -/ private inductive EqAdaptation where /-- Already an equality `lhs = rhs`. Proof is used as-is. -/ @@ -99,13 +147,15 @@ where def mkTheoremFromDecl (declName : Name) : MetaM Theorem := do let (pattern, (rhs, adaptation)) ← mkPatternFromDeclWithKey declName selectEqKey let expr ← wrapProof pattern.varTypes.size (mkConst declName) adaptation - return { expr, pattern, rhs } + let perm := isPerm pattern.varTypes.size pattern.pattern rhs + return { expr, pattern, rhs, perm } /-- Create a `Theorem` from a proof expression. Handles equalities, `¬`, `↔`, and propositions. -/ def mkTheoremFromExpr (e : Expr) : MetaM Theorem := do let (pattern, (rhs, adaptation)) ← mkPatternFromExprWithKey e [] selectEqKey let expr ← wrapProof pattern.varTypes.size e adaptation - return { expr, pattern, rhs } + let perm := isPerm pattern.varTypes.size pattern.pattern rhs + return { expr, pattern, rhs, perm } /-- Environment extension storing a set of `Sym.Simp` theorems. diff --git a/tests/elab/sym_simp_cd.lean b/tests/elab/sym_simp_cd.lean index 15049eae15..cb0d552e0a 100644 --- a/tests/elab/sym_simp_cd.lean +++ b/tests/elab/sym_simp_cd.lean @@ -42,153 +42,120 @@ example : 2 + 3 = 5 := by -- and lands in the transient cache. On the second invocation, the transient cache is -- cleared, so there should be NO persistent cache hit for the overall expression. -- Only context-independent sub-expressions (literals, fvars) get persistent cache hits. -theorem Nat.add_comm_of_pos (a b : Nat) (_h : 0 < a) : a + b = b + a := Nat.add_comm a b +opaque f : Nat → Nat +axiom f_idem (a : Nat) (_h : 0 < a) : f (f a) = f a set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] transient cache hit: f (f n) +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] transient cache hit: f (f n) +[sym.simp.debug.cache] persistent cache hit: f n -/ #guard_msgs in -example (n : Nat) (h : 0 < n) : n + 2 = 2 + n := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (h : 0 < n) : f (f n) = f (f (f n)) := by + sym_simp_twice [f_idem] -- Test 3: Congruence — cd propagates through function application. --- `n + 2` rewrites context-dependently (cd=true), `3 + 4` evaluates ground (cd=false). --- The congruence combines both, so the overall result is cd=true. --- On second traversal: ground sub-expressions (`3 + 4`, `7`) hit persistent cache, --- but cd-tainted expressions (`2 + n`, `2 + n + 7`) are only in transient. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n -[sym.simp.debug.cache] transient cache hit: (2 + n) * 7 +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n + 7 [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] persistent cache hit: 3 + 4 -[sym.simp.debug.cache] transient cache hit: 2 + n -[sym.simp.debug.cache] persistent cache hit: 7 -[sym.simp.debug.cache] transient cache hit: (2 + n) * 7 +[sym.simp.debug.cache] persistent cache hit: f n + 7 +[sym.simp.debug.cache] persistent cache hit: f n + 7 -/ #guard_msgs in -example (n : Nat) (h : 0 < n) : (n + 2) * (3 + 4) = (2 + n) * 7 := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (h : 0 < n) : f (f n) + (3 + 4) = f n + 7 := by + sym_simp_twice [f_idem] --- Similar to previous test, but `Nat.add_comm_of_pos` is not applicable, but discharger must return `cd := true`. +-- Similar to previous test, but `f_idem` is not applicable (no hypothesis), but discharger must return `cd := true`. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] transient cache hit: n + 2 -[sym.simp.debug.cache] transient cache hit: (n + 2) * 7 +trace: [sym.simp.debug.cache] transient cache hit: f (f n) +[sym.simp.debug.cache] transient cache hit: f (f n) + 7 [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] persistent cache hit: 3 + 4 -[sym.simp.debug.cache] transient cache hit: n + 2 +[sym.simp.debug.cache] transient cache hit: f (f n) [sym.simp.debug.cache] persistent cache hit: 7 -[sym.simp.debug.cache] transient cache hit: (n + 2) * 7 +[sym.simp.debug.cache] transient cache hit: f (f n) + 7 -/ #guard_msgs in -example (n : Nat) : (n + 2) * (3 + 4) = (n + 2) * 7 := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) : f (f n) + (3 + 4) = f (f n) + 7 := by + sym_simp_twice [f_idem] -- Test 4: Arrow — cd propagates through implication. --- The hypothesis `n + 2 = 2 + n` is simplified context-dependently to `True`. --- `True → True` simplifies to `True`. The whole result is cd=true. --- `True` hits persistent cache; `2 + n` is only in transient. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] persistent cache hit: True -/ #guard_msgs in set_option linter.unusedVariables false in -example (n : Nat) (h : 0 < n) : (n + 2 = 2 + n) → True := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (h : 0 < n) : (f (f n) = f n) → True := by + sym_simp_twice [f_idem] -- Test 5: Lambda — cd propagates through funext. --- Body `n + 2` is simplified context-dependently inside the binder. --- `withFreshTransientCache` clears the transient cache on binder entry. --- The lambda result `fun x => 2 + n` is only in transient. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: fun x => 2 + n +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: fun x => f n [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: fun x => 2 + n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: fun x => f n +[sym.simp.debug.cache] persistent cache hit: fun x => f n -/ #guard_msgs in -example (n : Nat) (_h : 0 < n) : (fun _ : Nat => n + 2) = (fun _ : Nat => 2 + n) := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (_h : 0 < n) : (fun _ : Nat => f (f n)) = (fun _ : Nat => f n) := by + sym_simp_twice [f_idem] -- Test 6: Control flow — cd propagates through `ite` condition. --- The condition `n + 2 = 2 + n` is simplified context-dependently. --- The `ite` result inherits cd, and `1` (ground) is in persistent cache. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] persistent cache hit: 1 [sym.simp.debug.cache] second traversal -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] persistent cache hit: 1 [sym.simp.debug.cache] persistent cache hit: 1 -/ #guard_msgs in -example (n : Nat) (h : 0 < n) : (if n + 2 = 2 + n then 1 else 0) = 1 := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (h : 0 < n) : (if f (f n) = f n then 1 else 0) = 1 := by + sym_simp_twice [f_idem] -- Test 7: Dependent forall — body cd under binder with `withFreshTransientCache`. --- Simplifying `∀ (m : Nat), n + 2 = 2 + n` enters a binder (for `m`). --- The transient cache is cleared on binder entry (`withFreshTransientCache`). --- The body uses a cd rewrite, so the overall result is cd=true. --- After "second traversal": `Nat` (the binder type) hits persistent cache. set_option trace.sym.simp.debug.cache true in /-- -trace: [sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +trace: [sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] transient cache hit: f (f n) +[sym.simp.debug.cache] persistent cache hit: f n [sym.simp.debug.cache] second traversal [sym.simp.debug.cache] persistent cache hit: Nat -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: 2 -[sym.simp.debug.cache] persistent cache hit: n -[sym.simp.debug.cache] transient cache hit: 2 + n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] persistent cache hit: f n +[sym.simp.debug.cache] transient cache hit: f (f n) +[sym.simp.debug.cache] persistent cache hit: f n -/ #guard_msgs in set_option linter.unusedVariables false in -example (n : Nat) (h : 0 < n) : ∀ (_ : Nat), n + 2 = 2 + n := by - sym_simp_twice [Nat.add_comm_of_pos] +example (n : Nat) (h : 0 < n) : ∀ (_ : Nat), f (f n) = f (f (f n)) := by + sym_simp_twice [f_idem] diff --git a/tests/elab/sym_simp_perm1.lean b/tests/elab/sym_simp_perm1.lean new file mode 100644 index 0000000000..883d3ac6b9 --- /dev/null +++ b/tests/elab/sym_simp_perm1.lean @@ -0,0 +1,37 @@ +import Lean + +/-! Tests for permutation theorem support in `Sym.simp` -/ + +-- Nat.add_comm is a permutation theorem: x + y = y + x +-- Without perm support, `simp` with this theorem would loop. + +register_sym_simp commSimp where + post := ground >> rewrite [Nat.add_comm] + +-- This should terminate: Nat.add_comm is detected as perm, +-- and only applied when result < input. +example (x y : Nat) : x + y = y + x := by + sym => + simp commSimp + +-- Combining perm with non-perm theorems +register_sym_simp commZeroSimp where + post := ground >> rewrite [Nat.add_comm, Nat.zero_add, Nat.add_zero] + +example (x y : Nat) : 0 + (x + y) = y + x := by + sym => + simp commZeroSimp + +-- Verify perm doesn't interfere with non-perm theorems +register_sym_simp nonPermSimp where + post := ground >> rewrite [Nat.zero_add] + +example (x : Nat) : 0 + x = x := by + sym => + simp nonPermSimp + +register_sym_simp simple where + post := ground + +example (x y z w : Nat) : x + y + z + w = w + (z + y) + x := by + sym => simp simple [Nat.add_comm, Nat.add_assoc, Nat.add_left_comm]