feat: auto-generated congruence theorems for Sym.simp (#11985)
This PR implements support for auto-generated congruence theorems in `Sym.simp`, enabling simplification of functions with complex argument dependencies such as proof arguments and `Decidable` instances. Previously, `Sym.simp` used basic congruence lemmas (`congrArg`, `congrFun`, `congrFun'`, `congr`) to construct proofs when simplifying function arguments. This approach is efficient for simple cases but cannot handle functions with dependent proof arguments or `Decidable` instances that depend on earlier arguments. The new `congrThm` function applies pre-generated congruence theorems (similar to the main simplifier) to handle these complex cases.
This commit is contained in:
parent
d68de2e018
commit
1cd6db1579
3 changed files with 129 additions and 5 deletions
|
|
@ -6,6 +6,8 @@ Authors: Leonardo de Moura
|
|||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.SynthInstance
|
||||
import Lean.Meta.Tactic.Simp.Types
|
||||
import Lean.Meta.Sym.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.InferType
|
||||
import Lean.Meta.Sym.Simp.Result
|
||||
|
|
@ -211,12 +213,126 @@ where
|
|||
| _ => unreachable!
|
||||
|
||||
/--
|
||||
Simplify arguments using a pre-generated congruence theorem.
|
||||
Used for functions with proof or `Decidable` arguments.
|
||||
Helper function used at `congrThm`. The idea is to initialize `argResults` lazily
|
||||
when we get the first non-`.rfl` result.
|
||||
-/
|
||||
def congrThm (_e : Expr) (_ : CongrTheorem) : SimpM Result := do
|
||||
-- **TODO**
|
||||
return .rfl
|
||||
def pushResult (argResults : Array Result) (numEqs : Nat) (result : Result) : Array Result :=
|
||||
match result with
|
||||
| .rfl .. => if argResults.size > 0 then argResults.push result else argResults
|
||||
| .step .. =>
|
||||
if argResults.size < numEqs then
|
||||
Array.replicate numEqs .rfl |>.push result
|
||||
else
|
||||
argResults.push result
|
||||
|
||||
/--
|
||||
Simplifies arguments of a function application using a pre-generated congruence theorem.
|
||||
|
||||
This strategy is used for functions that have complex argument dependencies, particularly
|
||||
those with proof arguments or `Decidable` instances. Unlike `congrFixedPrefix` and
|
||||
`congrInterlaced`, which construct proofs on-the-fly using basic congruence lemmas
|
||||
(`congrArg`, `congrFun`, `congrFun'`, `congr`), this function applies a specialized congruence theorem
|
||||
that was pre-generated for the specific function being simplified.
|
||||
|
||||
See type `CongrArgKind`.
|
||||
|
||||
**Algorithm**:
|
||||
1. Recursively simplify all `.eq` arguments (via `simpEqArgs`)
|
||||
2. If all simplifications return `.rfl`, the overall result is `.rfl`
|
||||
3. Otherwise, construct the final proof by:
|
||||
- Starting with the congruence theorem's proof term
|
||||
- Applying original arguments and their simplification results
|
||||
- Re-synthesizing subsingleton instances when their dependencies change
|
||||
- Removing unnecessary casts from the result
|
||||
|
||||
**Key examples**:
|
||||
|
||||
1. `ite`: Has type `{α : Sort u} → (c : Prop) → [Decidable c] → α → α → α`
|
||||
- Argument kinds: `[.fixed, .eq, .subsingletonInst, .eq, .eq]`
|
||||
- When simplifying `ite (x > 0) a b`, if `x > 0` simplifies to `true`, we must
|
||||
re-synthesize `[Decidable true]` because the original `[Decidable (x > 0)]`
|
||||
instance is no longer type-correct
|
||||
|
||||
2. `GetElem.getElem`: Has type
|
||||
```
|
||||
{coll : Type u} → {idx : Type v} → {elem : Type w} → {valid : coll → idx → Prop} →
|
||||
[GetElem coll idx elem valid] → (xs : coll) → (i : idx) → valid xs i → elem
|
||||
```
|
||||
- The proof argument `valid xs i` depends on earlier arguments `xs` and `i`
|
||||
- When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated
|
||||
theorem.
|
||||
-/
|
||||
def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
|
||||
let argKinds := thm.argKinds
|
||||
if e.getAppNumArgs != argKinds.size then
|
||||
-- **TODO**: over/under-applied
|
||||
return .rfl
|
||||
/-
|
||||
Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`.
|
||||
There is at least one non-`rfl` result in `argResults`.
|
||||
-/
|
||||
let mkNonRflResult (argResults : Array Result) : SimpM Result := do
|
||||
let mut proof := thm.proof
|
||||
let mut type := thm.type
|
||||
let mut j := 0 -- index at argResults
|
||||
let mut subst := #[]
|
||||
let args := e.getAppArgs
|
||||
for arg in args, kind in argKinds do
|
||||
proof := mkApp proof arg
|
||||
type := type.bindingBody!
|
||||
match kind with
|
||||
| .fixed => subst := subst.push arg
|
||||
| .cast => subst := subst.push arg
|
||||
| .subsingletonInst =>
|
||||
subst := subst.push arg
|
||||
let clsNew := type.bindingDomain!.instantiateRev subst
|
||||
let instNew ← if (← isDefEqI (← inferType arg) clsNew) then
|
||||
pure arg
|
||||
else
|
||||
let .some val ← trySynthInstance clsNew | return .rfl
|
||||
pure val
|
||||
proof := mkApp proof instNew
|
||||
subst := subst.push instNew
|
||||
type := type.bindingBody!
|
||||
| .eq =>
|
||||
subst := subst.push arg
|
||||
match argResults[j]! with
|
||||
| .rfl _ =>
|
||||
let h ← mkEqRefl arg
|
||||
proof := mkApp2 proof arg h
|
||||
subst := subst.push arg |>.push h
|
||||
| .step arg' h _ =>
|
||||
proof := mkApp2 proof arg' h
|
||||
subst := subst.push arg' |>.push h
|
||||
type := type.bindingBody!.bindingBody!
|
||||
j := j + 1
|
||||
| _ => unreachable!
|
||||
let_expr Eq _ _ rhs := type | unreachable!
|
||||
let rhs := rhs.instantiateRev subst
|
||||
let hasCast := argKinds.any (· matches .cast)
|
||||
let rhs ← if hasCast then Simp.removeUnnecessaryCasts rhs else pure rhs
|
||||
let rhs ← share rhs
|
||||
return .step rhs proof
|
||||
/-
|
||||
Recursively simplifies arguments of kind `.eq`. The array `argResults` is initialized lazily
|
||||
as soon as the simplifier returns a non-`rfl` result for some arguments.
|
||||
`numEqs` is the number of `.eq` arguments found so far.
|
||||
-/
|
||||
let rec simpEqArgs (e : Expr) (i : Nat) (numEqs : Nat) (argResults : Array Result) : SimpM Result := do
|
||||
match e with
|
||||
| .app f a =>
|
||||
match argKinds[i]! with
|
||||
| .subsingletonInst
|
||||
| .fixed => simpEqArgs f (i-1) numEqs argResults
|
||||
| .cast => simpEqArgs f (i-1) numEqs argResults
|
||||
| .eq => simpEqArgs f (i-1) (numEqs+1) (pushResult argResults numEqs (← simp a))
|
||||
| _ => unreachable!
|
||||
| _ =>
|
||||
if argResults.isEmpty then
|
||||
return .rfl
|
||||
else
|
||||
mkNonRflResult argResults.reverse
|
||||
simpEqArgs e (argKinds.size - 1) 0 #[]
|
||||
|
||||
/--
|
||||
Main entry point for simplifying function application arguments.
|
||||
|
|
|
|||
|
|
@ -149,6 +149,7 @@ inductive Result where
|
|||
Simplified to `e'` with proof `proof : e = e'`.
|
||||
If `done = true`, skip recursive simplification of `e'`. -/
|
||||
| step (e' : Expr) (proof : Expr) (done : Bool := false)
|
||||
deriving Inhabited
|
||||
|
||||
private opaque MethodsRefPointed : NonemptyType.{0}
|
||||
def MethodsRef : Type := MethodsRefPointed.type
|
||||
|
|
|
|||
|
|
@ -31,3 +31,10 @@ example : ∀ x, 0 + x + 0 = x := by
|
|||
|
||||
example : ∀ x, 0 + x + 0 = x := by
|
||||
sym_simp [Nat.add_zero, Nat.zero_add, eq_self, forall_true]
|
||||
|
||||
example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by
|
||||
sym_simp [Nat.add_zero, eq_self, if_true]
|
||||
exact hp
|
||||
|
||||
example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by
|
||||
sym_simp [Nat.zero_add, eq_self]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue