feat: handle over/under-applied functions in Sym.simp (#11999)

This PR adds support for simplifying the arguments of over-applied and
under-applied function application terms in `Sym.simp`, completing the
implementation for all three congruence strategies (fixed prefix,
interlaced, and congruence theorems).
This commit is contained in:
Leonardo de Moura 2026-01-13 17:40:42 -08:00 committed by GitHub
parent c24df9e8d6
commit 3dfd125337
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 253 additions and 112 deletions

View file

@ -38,106 +38,6 @@ by their argument structure, allowing us to choose the most efficient proof stra
inference on proof terms, which can be arbitrarily complex, and often destroys sharing.
-/
/--
Reduces `type` to weak head normal form and verifies it is a `forall` expression.
If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work).
The result is shared via `share` to maintain maximal sharing invariants.
-/
def whnfToForall (type : Expr) : SymM Expr := do
if type.isForall then return type
let type ← whnfD type
unless type.isForall do throwError "function type expected{indentD type}"
share type
/--
Returns the type of an expression `e`. If `n > 0`, then `e` is an application
with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent.
Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by
inferring the type of `f` and traversing the forall telescope.
We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached.
This function tries to maximize the likelihood of a cache hit. For example,
suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that
`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than
`@HAdd.hAdd Nat Nat Nat instAdd 5`.
-/
def getFnType (e : Expr) (n : Nat) : SymM Expr := do
match n with
| 0 => inferType e
| n+1 =>
let type ← getFnType e.appFn! n
let .forallE _ _ β _ ← whnfToForall type | unreachable!
return β
/--
Simplify arguments of a function application with a fixed prefix structure.
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
`prefixSize` arguments unchanged.
For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments
are structured as:
```
f a₁ ... aₚ b₁ ... bₛ
└───────┘ └───────┘
prefix suffix (rewritable)
```
The prefix arguments (types, instances) should
not be rewritten directly. Only the suffix arguments are recursively simplified.
**Performance optimization**: We avoid calling `inferType` on applied expressions
like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates.
Instead, we infer the type of the function prefix `f a₁ ... aₚ`
(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications,
then traverse the forall telescope to extract argument and result types as needed.
The helper `go` returns `Result × Expr` where the `Expr` is the function type at that
position. However, the type is only meaningful (non-`default`) when `Result` is
`.step`, since we only need types for constructing congruence proofs. This avoids
unnecessary type inference when no rewriting occurs.
-/
def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
let numArgs := e.getAppNumArgs
if numArgs ≤ prefixSize then
-- Nothing to be done
return .rfl
else if numArgs > prefixSize + suffixSize then
-- **TODO**: over-applied case
return .rfl
else
return (← go suffixSize e).1
where
go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do
if i == 0 then
return (.rfl, default)
else
let .app f a := e | unreachable!
let (hf, fType) ← go (i-1) f
match hf, (← simp a) with
| .rfl _, .rfl _ => return (.rfl, default)
| .step f' hf _, .rfl _ =>
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f' a
let u ← getLevel α
let v ← getLevel β
let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a
return (.step e' h, β)
| .rfl _, .step a' ha _ =>
let fType ← getFnType f (i-1)
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f a'
let u ← getLevel α
let v ← getLevel β
let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha
return (.step e' h, β)
| .step f' hf _, .step a' ha _ =>
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f' a'
let u ← getLevel α
let v ← getLevel β
let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha
return (.step e' h, β)
/--
Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`.
For the dependent case, use `mkCongrFun`
@ -182,19 +82,168 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a
return .step e' h
/--
Simplify arguments of a function application with interlaced rewritable/fixed arguments.
Handles simplification of over-applied function terms.
When a function has more arguments than expected by its `CongrInfo`, we need to handle
the "extra" arguments separately. This function peels off `numArgs` trailing applications,
simplifies the remaining function using `simpFn`, then rebuilds the term by simplifying
and re-applying the trailing arguments.
**Over-application** occurs when:
- A function with `fixedPrefix prefixSize suffixSize` is applied to more than `prefixSize + suffixSize` arguments
- A function with `interlaced` rewritable mask is applied to more than `mask.size` arguments
- A function with a congruence theorem is applied to more than the theorem expects
**Example**: If `f` has `CongrInfo.fixedPrefix 2 3` (expects 5 arguments) but we see `f a₁ a₂ a₃ a₄ a₅ b₁ b₂`,
then `numArgs = 2` (the extra arguments) and we:
1. Recursively simplify `f a₁ a₂ a₃ a₄ a₅` using the fixed prefix strategy (via `simpFn`)
2. Simplify each extra argument `b₁` and `b₂`
3. Rebuild the term using either `mkCongr` (for non-dependent arrows) or `mkCongrFun` (for dependent functions)
**Parameters**:
- `e`: The over-applied expression to simplify
- `numArgs`: Number of excess arguments to peel off
- `simpFn`: Strategy for simplifying the function after peeling (e.g., `simpFixedPrefix`, `simpInterlaced`, or `simpUsingCongrThm`)
**Note**: This is a fallback path without specialized optimizations. The common case (correct number of arguments)
is handled more efficiently by the specific strategies.
-/
def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM Result) : SimpM Result := do
let rec visit (e : Expr) (i : Nat) : SimpM Result := do
if i == 0 then
simpFn e
else
let i := i - 1
match h : e with
| .app f a =>
let fr ← visit f i
let .forallE _ α β _ ← whnfD (← inferType f) | unreachable!
if !β.hasLooseBVars then
if (← isProp α) then
mkCongr e f a fr .rfl h
else
mkCongr e f a fr (← simp a) h
else match fr with
| .rfl _ => return .rfl
| .step f' hf _ => mkCongrFun e f a f' hf h
| _ => unreachable!
visit e numArgs
/--
Reduces `type` to weak head normal form and verifies it is a `forall` expression.
If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work).
The result is shared via `share` to maintain maximal sharing invariants.
-/
def whnfToForall (type : Expr) : SymM Expr := do
if type.isForall then return type
let type ← whnfD type
unless type.isForall do throwError "function type expected{indentD type}"
share type
/--
Returns the type of an expression `e`. If `n > 0`, then `e` is an application
with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent.
Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by
inferring the type of `f` and traversing the forall telescope.
We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached.
This function tries to maximize the likelihood of a cache hit. For example,
suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that
`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than
`@HAdd.hAdd Nat Nat Nat instAdd 5`.
-/
def getFnType (e : Expr) (n : Nat) : SymM Expr := do
match n with
| 0 => inferType e
| n+1 =>
let type ← getFnType e.appFn! n
let .forallE _ _ β _ ← whnfToForall type | unreachable!
return β
/--
Simplifies arguments of a function application with a fixed prefix structure.
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
`prefixSize` arguments unchanged.
For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments
are structured as:
```
f a₁ ... aₚ b₁ ... bₛ
└───────┘ └───────┘
prefix suffix (rewritable)
```
The prefix arguments (types, instances) should
not be rewritten directly. Only the suffix arguments are recursively simplified.
**Performance optimization**: We avoid calling `inferType` on applied expressions
like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates.
Instead, we infer the type of the function prefix `f a₁ ... aₚ`
(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications,
then traverse the forall telescope to extract argument and result types as needed.
The helper `go` returns `Result × Expr` where the `Expr` is the function type at that
position. However, the type is only meaningful (non-`default`) when `Result` is
`.step`, since we only need types for constructing congruence proofs. This avoids
unnecessary type inference when no rewriting occurs.
-/
def simpFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
let numArgs := e.getAppNumArgs
if numArgs ≤ prefixSize then
-- Nothing to be done
return .rfl
else if numArgs > prefixSize + suffixSize then
simpOverApplied e (numArgs - prefixSize - suffixSize) (main suffixSize)
else
main (numArgs - prefixSize) e
where
main (n : Nat) (e : Expr) : SimpM Result := do
return (← go n e).1
go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do
if i == 0 then
return (.rfl, default)
else
let .app f a := e | unreachable!
let (hf, fType) ← go (i-1) f
match hf, (← simp a) with
| .rfl _, .rfl _ => return (.rfl, default)
| .step f' hf _, .rfl _ =>
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f' a
let u ← getLevel α
let v ← getLevel β
let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a
return (.step e' h, β)
| .rfl _, .step a' ha _ =>
let fType ← getFnType f (i-1)
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f a'
let u ← getLevel α
let v ← getLevel β
let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha
return (.step e' h, β)
| .step f' hf _, .step a' ha _ =>
let .forallE _ α β _ ← whnfToForall fType | unreachable!
let e' ← mkAppS f' a'
let u ← getLevel α
let v ← getLevel β
let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha
return (.step e' h, β)
/--
Simplifies arguments of a function application with interlaced rewritable/fixed arguments.
Uses `rewritable[i]` to determine whether argument `i` should be simplified.
For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments,
uses `congrFun` to propagate changes from earlier arguments.
-/
def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
def simpInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
let numArgs := e.getAppNumArgs
if h : numArgs = 0 then
-- Nothing to be done
return .rfl
else if h : numArgs > rewritable.size then
-- **TODO**: over-applied case
return .rfl
simpOverApplied e (numArgs - rewritable.size) (go rewritable.size · (Nat.le_refl _))
else
go numArgs e (by omega)
where
@ -262,11 +311,8 @@ See type `CongrArgKind`.
- When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated
theorem.
-/
def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
let argKinds := thm.argKinds
if e.getAppNumArgs != argKinds.size then
-- **TODO**: over/under-applied
return .rfl
/-
Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`.
There is at least one non-`rfl` result in `argResults`.
@ -332,7 +378,17 @@ def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
return .rfl
else
mkNonRflResult argResults.reverse
simpEqArgs e (argKinds.size - 1) 0 #[]
let numArgs := e.getAppNumArgs
if numArgs > argKinds.size then
simpOverApplied e (numArgs - argKinds.size) (simpEqArgs · (argKinds.size - 1) 0 #[])
else if numArgs < argKinds.size then
/-
**Note**: under-applied case. This can be optimized, but this case is so
rare that it is not worth doing it. We just reuse `simpOverApplied`
-/
simpOverApplied e e.getAppNumArgs (fun _ => return .rfl)
else
simpEqArgs e (argKinds.size - 1) 0 #[]
/--
Main entry point for simplifying function application arguments.
@ -342,8 +398,8 @@ public def simpAppArgs (e : Expr) : SimpM Result := do
let f := e.getAppFn
match (← getCongrInfo f) with
| .none => return .rfl
| .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize
| .interlaced rewritable => congrInterlaced e rewritable
| .congrTheorem thm => congrThm e thm
| .fixedPrefix prefixSize suffixSize => simpFixedPrefix e prefixSize suffixSize
| .interlaced rewritable => simpInterlaced e rewritable
| .congrTheorem thm => simpUsingCongrThm e thm
end Lean.Meta.Sym.Simp

View file

@ -38,3 +38,88 @@ example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by
example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by
sym_simp [Nat.zero_add, eq_self]
/-- trace: ⊢ Nat.add 0 = id -/
#guard_msgs in
example : Nat.add (0 + 0) = id := by
sym_simp [Nat.zero_add]
trace_state
funext
simp
/--
trace: a : Nat
β✝ : Type
f : β✝ → Prop
h : HEq a = f
⊢ HEq a = f
-/
#guard_msgs in
example (h : HEq a = f) : HEq (α := Nat) (0 + a) = f := by
sym_simp [Nat.zero_add]
trace_state
exact h
/--
trace: a b : Nat
f : Nat → Nat
h : f a = b
⊢ id f a = b
-/
#guard_msgs in
example (f : Nat → Nat) (h : f a = b) : id f (0 + a) = b := by
sym_simp [Nat.zero_add]
trace_state
exact h
def f (_ : α) {β : Type} (b : β) : β := b
/--
trace: a : Nat
g : Nat → Nat
⊢ f 0 g a = g a
-/
#guard_msgs in
example (g : Nat → Nat) : f (0 + 0) g (0 + a) = g a := by
sym_simp [Nat.zero_add]
trace_state
rfl
def f' (_ : α) (b : β) := b
/--
trace: a : Nat
g : Nat → Nat
⊢ f' 0 g a = g a
-/
#guard_msgs in
example (g : Nat → Nat) : f' (0 + 0) g (0 + a) = g a := by
sym_simp [Nat.zero_add]
trace_state
rfl
/--
trace: a b : Nat
as : Array (Nat → Nat)
i : Nat
x✝ : i < as.size
h : as[i] a = b
⊢ as[i] a = b
-/
#guard_msgs in
example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b) : as[0 + i] (0 + a) = b := by
sym_simp [Nat.zero_add]
trace_state
exact h
/--
trace: c a : Nat
g : Nat → Nat
h : ite (c > 0) a = g
⊢ ite (c > 0) a = g
-/
#guard_msgs in
example (h : ite (c > 0) a = g) : ite (c > 0) (0 + a) = g := by
sym_simp [Nat.zero_add]
trace_state
exact h