feat: auto-generated congruence theorems for Sym.simp (#11985)

This PR implements support for auto-generated congruence theorems in
`Sym.simp`, enabling simplification of functions with complex argument
dependencies such as proof arguments and `Decidable` instances.

Previously, `Sym.simp` used basic congruence lemmas (`congrArg`,
`congrFun`, `congrFun'`, `congr`) to construct proofs when simplifying
function arguments. This approach is efficient for simple cases but
cannot handle functions with dependent proof arguments or `Decidable`
instances that depend on earlier arguments.

The new `congrThm` function applies pre-generated congruence theorems
(similar to the main simplifier) to handle these complex cases.
This commit is contained in:
Leonardo de Moura 2026-01-12 19:00:39 -08:00 committed by GitHub
parent d68de2e018
commit 1cd6db1579
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 129 additions and 5 deletions

View file

@ -6,6 +6,8 @@ Authors: Leonardo de Moura
module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.SynthInstance
import Lean.Meta.Tactic.Simp.Types
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.InferType
import Lean.Meta.Sym.Simp.Result
@ -211,12 +213,126 @@ where
| _ => unreachable!
/--
Simplify arguments using a pre-generated congruence theorem.
Used for functions with proof or `Decidable` arguments.
Helper function used at `congrThm`. The idea is to initialize `argResults` lazily
when we get the first non-`.rfl` result.
-/
def congrThm (_e : Expr) (_ : CongrTheorem) : SimpM Result := do
-- **TODO**
return .rfl
def pushResult (argResults : Array Result) (numEqs : Nat) (result : Result) : Array Result :=
match result with
| .rfl .. => if argResults.size > 0 then argResults.push result else argResults
| .step .. =>
if argResults.size < numEqs then
Array.replicate numEqs .rfl |>.push result
else
argResults.push result
/--
Simplifies arguments of a function application using a pre-generated congruence theorem.
This strategy is used for functions that have complex argument dependencies, particularly
those with proof arguments or `Decidable` instances. Unlike `congrFixedPrefix` and
`congrInterlaced`, which construct proofs on-the-fly using basic congruence lemmas
(`congrArg`, `congrFun`, `congrFun'`, `congr`), this function applies a specialized congruence theorem
that was pre-generated for the specific function being simplified.
See type `CongrArgKind`.
**Algorithm**:
1. Recursively simplify all `.eq` arguments (via `simpEqArgs`)
2. If all simplifications return `.rfl`, the overall result is `.rfl`
3. Otherwise, construct the final proof by:
- Starting with the congruence theorem's proof term
- Applying original arguments and their simplification results
- Re-synthesizing subsingleton instances when their dependencies change
- Removing unnecessary casts from the result
**Key examples**:
1. `ite`: Has type `{α : Sort u} → (c : Prop) → [Decidable c] → ααα`
- Argument kinds: `[.fixed, .eq, .subsingletonInst, .eq, .eq]`
- When simplifying `ite (x > 0) a b`, if `x > 0` simplifies to `true`, we must
re-synthesize `[Decidable true]` because the original `[Decidable (x > 0)]`
instance is no longer type-correct
2. `GetElem.getElem`: Has type
```
{coll : Type u} → {idx : Type v} → {elem : Type w} → {valid : coll → idx → Prop} →
[GetElem coll idx elem valid] → (xs : coll) → (i : idx) → valid xs i → elem
```
- The proof argument `valid xs i` depends on earlier arguments `xs` and `i`
- When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated
theorem.
-/
def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
let argKinds := thm.argKinds
if e.getAppNumArgs != argKinds.size then
-- **TODO**: over/under-applied
return .rfl
/-
Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`.
There is at least one non-`rfl` result in `argResults`.
-/
let mkNonRflResult (argResults : Array Result) : SimpM Result := do
let mut proof := thm.proof
let mut type := thm.type
let mut j := 0 -- index at argResults
let mut subst := #[]
let args := e.getAppArgs
for arg in args, kind in argKinds do
proof := mkApp proof arg
type := type.bindingBody!
match kind with
| .fixed => subst := subst.push arg
| .cast => subst := subst.push arg
| .subsingletonInst =>
subst := subst.push arg
let clsNew := type.bindingDomain!.instantiateRev subst
let instNew ← if (← isDefEqI (← inferType arg) clsNew) then
pure arg
else
let .some val ← trySynthInstance clsNew | return .rfl
pure val
proof := mkApp proof instNew
subst := subst.push instNew
type := type.bindingBody!
| .eq =>
subst := subst.push arg
match argResults[j]! with
| .rfl _ =>
let h ← mkEqRefl arg
proof := mkApp2 proof arg h
subst := subst.push arg |>.push h
| .step arg' h _ =>
proof := mkApp2 proof arg' h
subst := subst.push arg' |>.push h
type := type.bindingBody!.bindingBody!
j := j + 1
| _ => unreachable!
let_expr Eq _ _ rhs := type | unreachable!
let rhs := rhs.instantiateRev subst
let hasCast := argKinds.any (· matches .cast)
let rhs ← if hasCast then Simp.removeUnnecessaryCasts rhs else pure rhs
let rhs ← share rhs
return .step rhs proof
/-
Recursively simplifies arguments of kind `.eq`. The array `argResults` is initialized lazily
as soon as the simplifier returns a non-`rfl` result for some arguments.
`numEqs` is the number of `.eq` arguments found so far.
-/
let rec simpEqArgs (e : Expr) (i : Nat) (numEqs : Nat) (argResults : Array Result) : SimpM Result := do
match e with
| .app f a =>
match argKinds[i]! with
| .subsingletonInst
| .fixed => simpEqArgs f (i-1) numEqs argResults
| .cast => simpEqArgs f (i-1) numEqs argResults
| .eq => simpEqArgs f (i-1) (numEqs+1) (pushResult argResults numEqs (← simp a))
| _ => unreachable!
| _ =>
if argResults.isEmpty then
return .rfl
else
mkNonRflResult argResults.reverse
simpEqArgs e (argKinds.size - 1) 0 #[]
/--
Main entry point for simplifying function application arguments.

View file

@ -149,6 +149,7 @@ inductive Result where
Simplified to `e'` with proof `proof : e = e'`.
If `done = true`, skip recursive simplification of `e'`. -/
| step (e' : Expr) (proof : Expr) (done : Bool := false)
deriving Inhabited
private opaque MethodsRefPointed : NonemptyType.{0}
def MethodsRef : Type := MethodsRefPointed.type

View file

@ -31,3 +31,10 @@ example : ∀ x, 0 + x + 0 = x := by
example : ∀ x, 0 + x + 0 = x := by
sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true]
example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by
sym_simp [Nat.add_zero, eq_self, if_true]
exact hp
example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by
sym_simp [Nat.zero_add, eq_self]