From 1cd6db15792cf7f7b5dba4f9f7b3a9cd8a1002cd Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 12 Jan 2026 19:00:39 -0800 Subject: [PATCH] feat: auto-generated congruence theorems for `Sym.simp` (#11985) This PR implements support for auto-generated congruence theorems in `Sym.simp`, enabling simplification of functions with complex argument dependencies such as proof arguments and `Decidable` instances. Previously, `Sym.simp` used basic congruence lemmas (`congrArg`, `congrFun`, `congrFun'`, `congr`) to construct proofs when simplifying function arguments. This approach is efficient for simple cases but cannot handle functions with dependent proof arguments or `Decidable` instances that depend on earlier arguments. The new `congrThm` function applies pre-generated congruence theorems (similar to the main simplifier) to handle these complex cases. --- src/Lean/Meta/Sym/Simp/App.lean | 126 ++++++++++++++++++++++++++++-- src/Lean/Meta/Sym/Simp/SimpM.lean | 1 + tests/lean/run/sym_simp_1.lean | 7 ++ 3 files changed, 129 insertions(+), 5 deletions(-) diff --git a/src/Lean/Meta/Sym/Simp/App.lean b/src/Lean/Meta/Sym/Simp/App.lean index 0756c7e056..12c2b09c55 100644 --- a/src/Lean/Meta/Sym/Simp/App.lean +++ b/src/Lean/Meta/Sym/Simp/App.lean @@ -6,6 +6,8 @@ Authors: Leonardo de Moura module prelude public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.SynthInstance +import Lean.Meta.Tactic.Simp.Types import Lean.Meta.Sym.AlphaShareBuilder import Lean.Meta.Sym.InferType import Lean.Meta.Sym.Simp.Result @@ -211,12 +213,126 @@ where | _ => unreachable! /-- -Simplify arguments using a pre-generated congruence theorem. -Used for functions with proof or `Decidable` arguments. +Helper function used at `congrThm`. The idea is to initialize `argResults` lazily +when we get the first non-`.rfl` result. -/ -def congrThm (_e : Expr) (_ : CongrTheorem) : SimpM Result := do - -- **TODO** - return .rfl +def pushResult (argResults : Array Result) (numEqs : Nat) (result : Result) : Array Result := + match result with + | .rfl .. => if argResults.size > 0 then argResults.push result else argResults + | .step .. => + if argResults.size < numEqs then + Array.replicate numEqs .rfl |>.push result + else + argResults.push result + +/-- +Simplifies arguments of a function application using a pre-generated congruence theorem. + +This strategy is used for functions that have complex argument dependencies, particularly +those with proof arguments or `Decidable` instances. Unlike `congrFixedPrefix` and +`congrInterlaced`, which construct proofs on-the-fly using basic congruence lemmas +(`congrArg`, `congrFun`, `congrFun'`, `congr`), this function applies a specialized congruence theorem +that was pre-generated for the specific function being simplified. + +See type `CongrArgKind`. + +**Algorithm**: +1. Recursively simplify all `.eq` arguments (via `simpEqArgs`) +2. If all simplifications return `.rfl`, the overall result is `.rfl` +3. Otherwise, construct the final proof by: + - Starting with the congruence theorem's proof term + - Applying original arguments and their simplification results + - Re-synthesizing subsingleton instances when their dependencies change + - Removing unnecessary casts from the result + +**Key examples**: + +1. `ite`: Has type `{α : Sort u} → (c : Prop) → [Decidable c] → α → α → α` + - Argument kinds: `[.fixed, .eq, .subsingletonInst, .eq, .eq]` + - When simplifying `ite (x > 0) a b`, if `x > 0` simplifies to `true`, we must + re-synthesize `[Decidable true]` because the original `[Decidable (x > 0)]` + instance is no longer type-correct + +2. `GetElem.getElem`: Has type + ``` + {coll : Type u} → {idx : Type v} → {elem : Type w} → {valid : coll → idx → Prop} → + [GetElem coll idx elem valid] → (xs : coll) → (i : idx) → valid xs i → elem + ``` + - The proof argument `valid xs i` depends on earlier arguments `xs` and `i` + - When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated + theorem. +-/ +def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do + let argKinds := thm.argKinds + if e.getAppNumArgs != argKinds.size then + -- **TODO**: over/under-applied + return .rfl + /- + Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`. + There is at least one non-`rfl` result in `argResults`. + -/ + let mkNonRflResult (argResults : Array Result) : SimpM Result := do + let mut proof := thm.proof + let mut type := thm.type + let mut j := 0 -- index at argResults + let mut subst := #[] + let args := e.getAppArgs + for arg in args, kind in argKinds do + proof := mkApp proof arg + type := type.bindingBody! + match kind with + | .fixed => subst := subst.push arg + | .cast => subst := subst.push arg + | .subsingletonInst => + subst := subst.push arg + let clsNew := type.bindingDomain!.instantiateRev subst + let instNew ← if (← isDefEqI (← inferType arg) clsNew) then + pure arg + else + let .some val ← trySynthInstance clsNew | return .rfl + pure val + proof := mkApp proof instNew + subst := subst.push instNew + type := type.bindingBody! + | .eq => + subst := subst.push arg + match argResults[j]! with + | .rfl _ => + let h ← mkEqRefl arg + proof := mkApp2 proof arg h + subst := subst.push arg |>.push h + | .step arg' h _ => + proof := mkApp2 proof arg' h + subst := subst.push arg' |>.push h + type := type.bindingBody!.bindingBody! + j := j + 1 + | _ => unreachable! + let_expr Eq _ _ rhs := type | unreachable! + let rhs := rhs.instantiateRev subst + let hasCast := argKinds.any (· matches .cast) + let rhs ← if hasCast then Simp.removeUnnecessaryCasts rhs else pure rhs + let rhs ← share rhs + return .step rhs proof + /- + Recursively simplifies arguments of kind `.eq`. The array `argResults` is initialized lazily + as soon as the simplifier returns a non-`rfl` result for some arguments. + `numEqs` is the number of `.eq` arguments found so far. + -/ + let rec simpEqArgs (e : Expr) (i : Nat) (numEqs : Nat) (argResults : Array Result) : SimpM Result := do + match e with + | .app f a => + match argKinds[i]! with + | .subsingletonInst + | .fixed => simpEqArgs f (i-1) numEqs argResults + | .cast => simpEqArgs f (i-1) numEqs argResults + | .eq => simpEqArgs f (i-1) (numEqs+1) (pushResult argResults numEqs (← simp a)) + | _ => unreachable! + | _ => + if argResults.isEmpty then + return .rfl + else + mkNonRflResult argResults.reverse + simpEqArgs e (argKinds.size - 1) 0 #[] /-- Main entry point for simplifying function application arguments. diff --git a/src/Lean/Meta/Sym/Simp/SimpM.lean b/src/Lean/Meta/Sym/Simp/SimpM.lean index f44b2dfbf9..4d365d7eb0 100644 --- a/src/Lean/Meta/Sym/Simp/SimpM.lean +++ b/src/Lean/Meta/Sym/Simp/SimpM.lean @@ -149,6 +149,7 @@ inductive Result where Simplified to `e'` with proof `proof : e = e'`. If `done = true`, skip recursive simplification of `e'`. -/ | step (e' : Expr) (proof : Expr) (done : Bool := false) + deriving Inhabited private opaque MethodsRefPointed : NonemptyType.{0} def MethodsRef : Type := MethodsRefPointed.type diff --git a/tests/lean/run/sym_simp_1.lean b/tests/lean/run/sym_simp_1.lean index 82a6ace8e1..e153f03e4e 100644 --- a/tests/lean/run/sym_simp_1.lean +++ b/tests/lean/run/sym_simp_1.lean @@ -31,3 +31,10 @@ example : ∀ x, 0 + x + 0 = x := by example : ∀ x, 0 + x + 0 = x := by sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true] + +example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by + sym_simp [Nat.add_zero, eq_self, if_true] + exact hp + +example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by + sym_simp [Nat.zero_add, eq_self]