feat: handle over/under-applied functions in Sym.simp (#11999)
This PR adds support for simplifying the arguments of over-applied and under-applied function application terms in `Sym.simp`, completing the implementation for all three congruence strategies (fixed prefix, interlaced, and congruence theorems).
This commit is contained in:
parent
c24df9e8d6
commit
3dfd125337
2 changed files with 253 additions and 112 deletions
|
|
@ -38,106 +38,6 @@ by their argument structure, allowing us to choose the most efficient proof stra
|
|||
inference on proof terms, which can be arbitrarily complex, and often destroys sharing.
|
||||
-/
|
||||
|
||||
/--
|
||||
Reduces `type` to weak head normal form and verifies it is a `forall` expression.
|
||||
If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work).
|
||||
The result is shared via `share` to maintain maximal sharing invariants.
|
||||
-/
|
||||
def whnfToForall (type : Expr) : SymM Expr := do
|
||||
if type.isForall then return type
|
||||
let type ← whnfD type
|
||||
unless type.isForall do throwError "function type expected{indentD type}"
|
||||
share type
|
||||
|
||||
/--
|
||||
Returns the type of an expression `e`. If `n > 0`, then `e` is an application
|
||||
with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent.
|
||||
Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by
|
||||
inferring the type of `f` and traversing the forall telescope.
|
||||
|
||||
We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached.
|
||||
This function tries to maximize the likelihood of a cache hit. For example,
|
||||
suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that
|
||||
`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than
|
||||
`@HAdd.hAdd Nat Nat Nat instAdd 5`.
|
||||
-/
|
||||
def getFnType (e : Expr) (n : Nat) : SymM Expr := do
|
||||
match n with
|
||||
| 0 => inferType e
|
||||
| n+1 =>
|
||||
let type ← getFnType e.appFn! n
|
||||
let .forallE _ _ β _ ← whnfToForall type | unreachable!
|
||||
return β
|
||||
|
||||
/--
|
||||
Simplify arguments of a function application with a fixed prefix structure.
|
||||
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
|
||||
`prefixSize` arguments unchanged.
|
||||
|
||||
For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments
|
||||
are structured as:
|
||||
```
|
||||
f a₁ ... aₚ b₁ ... bₛ
|
||||
└───────┘ └───────┘
|
||||
prefix suffix (rewritable)
|
||||
```
|
||||
|
||||
The prefix arguments (types, instances) should
|
||||
not be rewritten directly. Only the suffix arguments are recursively simplified.
|
||||
|
||||
**Performance optimization**: We avoid calling `inferType` on applied expressions
|
||||
like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates.
|
||||
Instead, we infer the type of the function prefix `f a₁ ... aₚ`
|
||||
(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications,
|
||||
then traverse the forall telescope to extract argument and result types as needed.
|
||||
|
||||
The helper `go` returns `Result × Expr` where the `Expr` is the function type at that
|
||||
position. However, the type is only meaningful (non-`default`) when `Result` is
|
||||
`.step`, since we only need types for constructing congruence proofs. This avoids
|
||||
unnecessary type inference when no rewriting occurs.
|
||||
-/
|
||||
def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs ≤ prefixSize then
|
||||
-- Nothing to be done
|
||||
return .rfl
|
||||
else if numArgs > prefixSize + suffixSize then
|
||||
-- **TODO**: over-applied case
|
||||
return .rfl
|
||||
else
|
||||
return (← go suffixSize e).1
|
||||
where
|
||||
go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do
|
||||
if i == 0 then
|
||||
return (.rfl, default)
|
||||
else
|
||||
let .app f a := e | unreachable!
|
||||
let (hf, fType) ← go (i-1) f
|
||||
match hf, (← simp a) with
|
||||
| .rfl _, .rfl _ => return (.rfl, default)
|
||||
| .step f' hf _, .rfl _ =>
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f' a
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a
|
||||
return (.step e' h, β)
|
||||
| .rfl _, .step a' ha _ =>
|
||||
let fType ← getFnType f (i-1)
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f a'
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha
|
||||
return (.step e' h, β)
|
||||
| .step f' hf _, .step a' ha _ =>
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f' a'
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha
|
||||
return (.step e' h, β)
|
||||
|
||||
/--
|
||||
Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`.
|
||||
For the dependent case, use `mkCongrFun`
|
||||
|
|
@ -182,19 +82,168 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a
|
|||
return .step e' h
|
||||
|
||||
/--
|
||||
Simplify arguments of a function application with interlaced rewritable/fixed arguments.
|
||||
Handles simplification of over-applied function terms.
|
||||
|
||||
When a function has more arguments than expected by its `CongrInfo`, we need to handle
|
||||
the "extra" arguments separately. This function peels off `numArgs` trailing applications,
|
||||
simplifies the remaining function using `simpFn`, then rebuilds the term by simplifying
|
||||
and re-applying the trailing arguments.
|
||||
|
||||
**Over-application** occurs when:
|
||||
- A function with `fixedPrefix prefixSize suffixSize` is applied to more than `prefixSize + suffixSize` arguments
|
||||
- A function with `interlaced` rewritable mask is applied to more than `mask.size` arguments
|
||||
- A function with a congruence theorem is applied to more than the theorem expects
|
||||
|
||||
**Example**: If `f` has `CongrInfo.fixedPrefix 2 3` (expects 5 arguments) but we see `f a₁ a₂ a₃ a₄ a₅ b₁ b₂`,
|
||||
then `numArgs = 2` (the extra arguments) and we:
|
||||
1. Recursively simplify `f a₁ a₂ a₃ a₄ a₅` using the fixed prefix strategy (via `simpFn`)
|
||||
2. Simplify each extra argument `b₁` and `b₂`
|
||||
3. Rebuild the term using either `mkCongr` (for non-dependent arrows) or `mkCongrFun` (for dependent functions)
|
||||
|
||||
**Parameters**:
|
||||
- `e`: The over-applied expression to simplify
|
||||
- `numArgs`: Number of excess arguments to peel off
|
||||
- `simpFn`: Strategy for simplifying the function after peeling (e.g., `simpFixedPrefix`, `simpInterlaced`, or `simpUsingCongrThm`)
|
||||
|
||||
**Note**: This is a fallback path without specialized optimizations. The common case (correct number of arguments)
|
||||
is handled more efficiently by the specific strategies.
|
||||
-/
|
||||
def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM Result) : SimpM Result := do
|
||||
let rec visit (e : Expr) (i : Nat) : SimpM Result := do
|
||||
if i == 0 then
|
||||
simpFn e
|
||||
else
|
||||
let i := i - 1
|
||||
match h : e with
|
||||
| .app f a =>
|
||||
let fr ← visit f i
|
||||
let .forallE _ α β _ ← whnfD (← inferType f) | unreachable!
|
||||
if !β.hasLooseBVars then
|
||||
if (← isProp α) then
|
||||
mkCongr e f a fr .rfl h
|
||||
else
|
||||
mkCongr e f a fr (← simp a) h
|
||||
else match fr with
|
||||
| .rfl _ => return .rfl
|
||||
| .step f' hf _ => mkCongrFun e f a f' hf h
|
||||
| _ => unreachable!
|
||||
visit e numArgs
|
||||
|
||||
/--
|
||||
Reduces `type` to weak head normal form and verifies it is a `forall` expression.
|
||||
If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work).
|
||||
The result is shared via `share` to maintain maximal sharing invariants.
|
||||
-/
|
||||
def whnfToForall (type : Expr) : SymM Expr := do
|
||||
if type.isForall then return type
|
||||
let type ← whnfD type
|
||||
unless type.isForall do throwError "function type expected{indentD type}"
|
||||
share type
|
||||
|
||||
/--
|
||||
Returns the type of an expression `e`. If `n > 0`, then `e` is an application
|
||||
with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent.
|
||||
Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by
|
||||
inferring the type of `f` and traversing the forall telescope.
|
||||
|
||||
We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached.
|
||||
This function tries to maximize the likelihood of a cache hit. For example,
|
||||
suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that
|
||||
`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than
|
||||
`@HAdd.hAdd Nat Nat Nat instAdd 5`.
|
||||
-/
|
||||
def getFnType (e : Expr) (n : Nat) : SymM Expr := do
|
||||
match n with
|
||||
| 0 => inferType e
|
||||
| n+1 =>
|
||||
let type ← getFnType e.appFn! n
|
||||
let .forallE _ _ β _ ← whnfToForall type | unreachable!
|
||||
return β
|
||||
|
||||
/--
|
||||
Simplifies arguments of a function application with a fixed prefix structure.
|
||||
Recursively simplifies the trailing `suffixSize` arguments, leaving the first
|
||||
`prefixSize` arguments unchanged.
|
||||
|
||||
For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments
|
||||
are structured as:
|
||||
```
|
||||
f a₁ ... aₚ b₁ ... bₛ
|
||||
└───────┘ └───────┘
|
||||
prefix suffix (rewritable)
|
||||
```
|
||||
|
||||
The prefix arguments (types, instances) should
|
||||
not be rewritten directly. Only the suffix arguments are recursively simplified.
|
||||
|
||||
**Performance optimization**: We avoid calling `inferType` on applied expressions
|
||||
like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates.
|
||||
Instead, we infer the type of the function prefix `f a₁ ... aₚ`
|
||||
(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications,
|
||||
then traverse the forall telescope to extract argument and result types as needed.
|
||||
|
||||
The helper `go` returns `Result × Expr` where the `Expr` is the function type at that
|
||||
position. However, the type is only meaningful (non-`default`) when `Result` is
|
||||
`.step`, since we only need types for constructing congruence proofs. This avoids
|
||||
unnecessary type inference when no rewriting occurs.
|
||||
-/
|
||||
def simpFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs ≤ prefixSize then
|
||||
-- Nothing to be done
|
||||
return .rfl
|
||||
else if numArgs > prefixSize + suffixSize then
|
||||
simpOverApplied e (numArgs - prefixSize - suffixSize) (main suffixSize)
|
||||
else
|
||||
main (numArgs - prefixSize) e
|
||||
where
|
||||
main (n : Nat) (e : Expr) : SimpM Result := do
|
||||
return (← go n e).1
|
||||
|
||||
go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do
|
||||
if i == 0 then
|
||||
return (.rfl, default)
|
||||
else
|
||||
let .app f a := e | unreachable!
|
||||
let (hf, fType) ← go (i-1) f
|
||||
match hf, (← simp a) with
|
||||
| .rfl _, .rfl _ => return (.rfl, default)
|
||||
| .step f' hf _, .rfl _ =>
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f' a
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a
|
||||
return (.step e' h, β)
|
||||
| .rfl _, .step a' ha _ =>
|
||||
let fType ← getFnType f (i-1)
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f a'
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha
|
||||
return (.step e' h, β)
|
||||
| .step f' hf _, .step a' ha _ =>
|
||||
let .forallE _ α β _ ← whnfToForall fType | unreachable!
|
||||
let e' ← mkAppS f' a'
|
||||
let u ← getLevel α
|
||||
let v ← getLevel β
|
||||
let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha
|
||||
return (.step e' h, β)
|
||||
|
||||
/--
|
||||
Simplifies arguments of a function application with interlaced rewritable/fixed arguments.
|
||||
Uses `rewritable[i]` to determine whether argument `i` should be simplified.
|
||||
For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments,
|
||||
uses `congrFun` to propagate changes from earlier arguments.
|
||||
-/
|
||||
def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
|
||||
def simpInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if h : numArgs = 0 then
|
||||
-- Nothing to be done
|
||||
return .rfl
|
||||
else if h : numArgs > rewritable.size then
|
||||
-- **TODO**: over-applied case
|
||||
return .rfl
|
||||
simpOverApplied e (numArgs - rewritable.size) (go rewritable.size · (Nat.le_refl _))
|
||||
else
|
||||
go numArgs e (by omega)
|
||||
where
|
||||
|
|
@ -262,11 +311,8 @@ See type `CongrArgKind`.
|
|||
- When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated
|
||||
theorem.
|
||||
-/
|
||||
def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
|
||||
def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
|
||||
let argKinds := thm.argKinds
|
||||
if e.getAppNumArgs != argKinds.size then
|
||||
-- **TODO**: over/under-applied
|
||||
return .rfl
|
||||
/-
|
||||
Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`.
|
||||
There is at least one non-`rfl` result in `argResults`.
|
||||
|
|
@ -332,7 +378,17 @@ def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do
|
|||
return .rfl
|
||||
else
|
||||
mkNonRflResult argResults.reverse
|
||||
simpEqArgs e (argKinds.size - 1) 0 #[]
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs > argKinds.size then
|
||||
simpOverApplied e (numArgs - argKinds.size) (simpEqArgs · (argKinds.size - 1) 0 #[])
|
||||
else if numArgs < argKinds.size then
|
||||
/-
|
||||
**Note**: under-applied case. This can be optimized, but this case is so
|
||||
rare that it is not worth doing it. We just reuse `simpOverApplied`
|
||||
-/
|
||||
simpOverApplied e e.getAppNumArgs (fun _ => return .rfl)
|
||||
else
|
||||
simpEqArgs e (argKinds.size - 1) 0 #[]
|
||||
|
||||
/--
|
||||
Main entry point for simplifying function application arguments.
|
||||
|
|
@ -342,8 +398,8 @@ public def simpAppArgs (e : Expr) : SimpM Result := do
|
|||
let f := e.getAppFn
|
||||
match (← getCongrInfo f) with
|
||||
| .none => return .rfl
|
||||
| .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize
|
||||
| .interlaced rewritable => congrInterlaced e rewritable
|
||||
| .congrTheorem thm => congrThm e thm
|
||||
| .fixedPrefix prefixSize suffixSize => simpFixedPrefix e prefixSize suffixSize
|
||||
| .interlaced rewritable => simpInterlaced e rewritable
|
||||
| .congrTheorem thm => simpUsingCongrThm e thm
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -38,3 +38,88 @@ example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by
|
|||
|
||||
example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by
|
||||
sym_simp [Nat.zero_add, eq_self]
|
||||
|
||||
/-- trace: ⊢ Nat.add 0 = id -/
|
||||
#guard_msgs in
|
||||
example : Nat.add (0 + 0) = id := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
funext
|
||||
simp
|
||||
|
||||
/--
|
||||
trace: a : Nat
|
||||
β✝ : Type
|
||||
f : β✝ → Prop
|
||||
h : HEq a = f
|
||||
⊢ HEq a = f
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (h : HEq a = f) : HEq (α := Nat) (0 + a) = f := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
exact h
|
||||
|
||||
/--
|
||||
trace: a b : Nat
|
||||
f : Nat → Nat
|
||||
h : f a = b
|
||||
⊢ id f a = b
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (f : Nat → Nat) (h : f a = b) : id f (0 + a) = b := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
exact h
|
||||
|
||||
def f (_ : α) {β : Type} (b : β) : β := b
|
||||
|
||||
/--
|
||||
trace: a : Nat
|
||||
g : Nat → Nat
|
||||
⊢ f 0 g a = g a
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (g : Nat → Nat) : f (0 + 0) g (0 + a) = g a := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
rfl
|
||||
|
||||
def f' (_ : α) (b : β) := b
|
||||
|
||||
/--
|
||||
trace: a : Nat
|
||||
g : Nat → Nat
|
||||
⊢ f' 0 g a = g a
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (g : Nat → Nat) : f' (0 + 0) g (0 + a) = g a := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
rfl
|
||||
|
||||
/--
|
||||
trace: a b : Nat
|
||||
as : Array (Nat → Nat)
|
||||
i : Nat
|
||||
x✝ : i < as.size
|
||||
h : as[i] a = b
|
||||
⊢ as[i] a = b
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b) : as[0 + i] (0 + a) = b := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
exact h
|
||||
|
||||
/--
|
||||
trace: c a : Nat
|
||||
g : Nat → Nat
|
||||
h : ite (c > 0) a = g
|
||||
⊢ ite (c > 0) a = g
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (h : ite (c > 0) a = g) : ite (c > 0) (0 + a) = g := by
|
||||
sym_simp [Nat.zero_add]
|
||||
trace_state
|
||||
exact h
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue