diff --git a/src/Lean/Meta/Sym/Simp/App.lean b/src/Lean/Meta/Sym/Simp/App.lean index 12c2b09c55..f285e79205 100644 --- a/src/Lean/Meta/Sym/Simp/App.lean +++ b/src/Lean/Meta/Sym/Simp/App.lean @@ -38,106 +38,6 @@ by their argument structure, allowing us to choose the most efficient proof stra inference on proof terms, which can be arbitrarily complex, and often destroys sharing. -/ -/-- -Reduces `type` to weak head normal form and verifies it is a `forall` expression. -If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work). -The result is shared via `share` to maintain maximal sharing invariants. --/ -def whnfToForall (type : Expr) : SymM Expr := do - if type.isForall then return type - let type ← whnfD type - unless type.isForall do throwError "function type expected{indentD type}" - share type - -/-- -Returns the type of an expression `e`. If `n > 0`, then `e` is an application -with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent. -Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by -inferring the type of `f` and traversing the forall telescope. - -We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached. -This function tries to maximize the likelihood of a cache hit. For example, -suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that -`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than -`@HAdd.hAdd Nat Nat Nat instAdd 5`. --/ -def getFnType (e : Expr) (n : Nat) : SymM Expr := do - match n with - | 0 => inferType e - | n+1 => - let type ← getFnType e.appFn! n - let .forallE _ _ β _ ← whnfToForall type | unreachable! - return β - -/-- -Simplify arguments of a function application with a fixed prefix structure. -Recursively simplifies the trailing `suffixSize` arguments, leaving the first -`prefixSize` arguments unchanged. - -For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments -are structured as: -``` -f a₁ ... aₚ b₁ ... bₛ - └───────┘ └───────┘ - prefix suffix (rewritable) -``` - -The prefix arguments (types, instances) should -not be rewritten directly. Only the suffix arguments are recursively simplified. - -**Performance optimization**: We avoid calling `inferType` on applied expressions -like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates. -Instead, we infer the type of the function prefix `f a₁ ... aₚ` -(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications, -then traverse the forall telescope to extract argument and result types as needed. - -The helper `go` returns `Result × Expr` where the `Expr` is the function type at that -position. However, the type is only meaningful (non-`default`) when `Result` is -`.step`, since we only need types for constructing congruence proofs. This avoids -unnecessary type inference when no rewriting occurs. --/ -def congrFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do - let numArgs := e.getAppNumArgs - if numArgs ≤ prefixSize then - -- Nothing to be done - return .rfl - else if numArgs > prefixSize + suffixSize then - -- **TODO**: over-applied case - return .rfl - else - return (← go suffixSize e).1 -where - go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do - if i == 0 then - return (.rfl, default) - else - let .app f a := e | unreachable! - let (hf, fType) ← go (i-1) f - match hf, (← simp a) with - | .rfl _, .rfl _ => return (.rfl, default) - | .step f' hf _, .rfl _ => - let .forallE _ α β _ ← whnfToForall fType | unreachable! - let e' ← mkAppS f' a - let u ← getLevel α - let v ← getLevel β - let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a - return (.step e' h, β) - | .rfl _, .step a' ha _ => - let fType ← getFnType f (i-1) - let .forallE _ α β _ ← whnfToForall fType | unreachable! - let e' ← mkAppS f a' - let u ← getLevel α - let v ← getLevel β - let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha - return (.step e' h, β) - | .step f' hf _, .step a' ha _ => - let .forallE _ α β _ ← whnfToForall fType | unreachable! - let e' ← mkAppS f' a' - let u ← getLevel α - let v ← getLevel β - let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha - return (.step e' h, β) - /-- Helper function for constructing a congruence proof using `congrFun'`, `congrArg`, `congr`. For the dependent case, use `mkCongrFun` @@ -182,19 +82,168 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a return .step e' h /-- -Simplify arguments of a function application with interlaced rewritable/fixed arguments. +Handles simplification of over-applied function terms. + +When a function has more arguments than expected by its `CongrInfo`, we need to handle +the "extra" arguments separately. This function peels off `numArgs` trailing applications, +simplifies the remaining function using `simpFn`, then rebuilds the term by simplifying +and re-applying the trailing arguments. + +**Over-application** occurs when: +- A function with `fixedPrefix prefixSize suffixSize` is applied to more than `prefixSize + suffixSize` arguments +- A function with `interlaced` rewritable mask is applied to more than `mask.size` arguments +- A function with a congruence theorem is applied to more than the theorem expects + +**Example**: If `f` has `CongrInfo.fixedPrefix 2 3` (expects 5 arguments) but we see `f a₁ a₂ a₃ a₄ a₅ b₁ b₂`, +then `numArgs = 2` (the extra arguments) and we: +1. Recursively simplify `f a₁ a₂ a₃ a₄ a₅` using the fixed prefix strategy (via `simpFn`) +2. Simplify each extra argument `b₁` and `b₂` +3. Rebuild the term using either `mkCongr` (for non-dependent arrows) or `mkCongrFun` (for dependent functions) + +**Parameters**: +- `e`: The over-applied expression to simplify +- `numArgs`: Number of excess arguments to peel off +- `simpFn`: Strategy for simplifying the function after peeling (e.g., `simpFixedPrefix`, `simpInterlaced`, or `simpUsingCongrThm`) + +**Note**: This is a fallback path without specialized optimizations. The common case (correct number of arguments) +is handled more efficiently by the specific strategies. +-/ +def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM Result) : SimpM Result := do + let rec visit (e : Expr) (i : Nat) : SimpM Result := do + if i == 0 then + simpFn e + else + let i := i - 1 + match h : e with + | .app f a => + let fr ← visit f i + let .forallE _ α β _ ← whnfD (← inferType f) | unreachable! + if !β.hasLooseBVars then + if (← isProp α) then + mkCongr e f a fr .rfl h + else + mkCongr e f a fr (← simp a) h + else match fr with + | .rfl _ => return .rfl + | .step f' hf _ => mkCongrFun e f a f' hf h + | _ => unreachable! + visit e numArgs + +/-- +Reduces `type` to weak head normal form and verifies it is a `forall` expression. +If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work). +The result is shared via `share` to maintain maximal sharing invariants. +-/ +def whnfToForall (type : Expr) : SymM Expr := do + if type.isForall then return type + let type ← whnfD type + unless type.isForall do throwError "function type expected{indentD type}" + share type + +/-- +Returns the type of an expression `e`. If `n > 0`, then `e` is an application +with at least `n` arguments. This function assumes the `n` trailing arguments are non-dependent. +Given `e` of the form `f a₁ a₂ ... aₙ`, the type of `e` is computed by +inferring the type of `f` and traversing the forall telescope. + +We use this function to implement `congrFixedPrefix`. Recall that `inferType` is cached. +This function tries to maximize the likelihood of a cache hit. For example, +suppose `e` is `@HAdd.hAdd Nat Nat Nat instAdd 5` and `n = 1`. It is much more likely that +`@HAdd.hAdd Nat Nat Nat instAdd` is already in the cache than +`@HAdd.hAdd Nat Nat Nat instAdd 5`. +-/ +def getFnType (e : Expr) (n : Nat) : SymM Expr := do + match n with + | 0 => inferType e + | n+1 => + let type ← getFnType e.appFn! n + let .forallE _ _ β _ ← whnfToForall type | unreachable! + return β + +/-- +Simplifies arguments of a function application with a fixed prefix structure. +Recursively simplifies the trailing `suffixSize` arguments, leaving the first +`prefixSize` arguments unchanged. + +For a function with `CongrInfo.fixedPrefix prefixSize suffixSize`, the arguments +are structured as: +``` +f a₁ ... aₚ b₁ ... bₛ + └───────┘ └───────┘ + prefix suffix (rewritable) +``` + +The prefix arguments (types, instances) should +not be rewritten directly. Only the suffix arguments are recursively simplified. + +**Performance optimization**: We avoid calling `inferType` on applied expressions +like `f a₁ ... aₚ b₁` or `f a₁ ... aₚ b₁ ... bₛ`, which would have poor cache hit rates. +Instead, we infer the type of the function prefix `f a₁ ... aₚ` +(e.g., `@HAdd.hAdd Nat Nat Nat instAdd`) which is probably shared across many applications, +then traverse the forall telescope to extract argument and result types as needed. + +The helper `go` returns `Result × Expr` where the `Expr` is the function type at that +position. However, the type is only meaningful (non-`default`) when `Result` is +`.step`, since we only need types for constructing congruence proofs. This avoids +unnecessary type inference when no rewriting occurs. +-/ +def simpFixedPrefix (e : Expr) (prefixSize : Nat) (suffixSize : Nat) : SimpM Result := do + let numArgs := e.getAppNumArgs + if numArgs ≤ prefixSize then + -- Nothing to be done + return .rfl + else if numArgs > prefixSize + suffixSize then + simpOverApplied e (numArgs - prefixSize - suffixSize) (main suffixSize) + else + main (numArgs - prefixSize) e +where + main (n : Nat) (e : Expr) : SimpM Result := do + return (← go n e).1 + + go (i : Nat) (e : Expr) : SimpM (Result × Expr) := do + if i == 0 then + return (.rfl, default) + else + let .app f a := e | unreachable! + let (hf, fType) ← go (i-1) f + match hf, (← simp a) with + | .rfl _, .rfl _ => return (.rfl, default) + | .step f' hf _, .rfl _ => + let .forallE _ α β _ ← whnfToForall fType | unreachable! + let e' ← mkAppS f' a + let u ← getLevel α + let v ← getLevel β + let h := mkApp6 (mkConst ``congrFun' [u, v]) α β f f' hf a + return (.step e' h, β) + | .rfl _, .step a' ha _ => + let fType ← getFnType f (i-1) + let .forallE _ α β _ ← whnfToForall fType | unreachable! + let e' ← mkAppS f a' + let u ← getLevel α + let v ← getLevel β + let h := mkApp6 (mkConst ``congrArg [u, v]) α β a a' f ha + return (.step e' h, β) + | .step f' hf _, .step a' ha _ => + let .forallE _ α β _ ← whnfToForall fType | unreachable! + let e' ← mkAppS f' a' + let u ← getLevel α + let v ← getLevel β + let h := mkApp8 (mkConst ``congr [u, v]) α β f f' a a' hf ha + return (.step e' h, β) + +/-- +Simplifies arguments of a function application with interlaced rewritable/fixed arguments. Uses `rewritable[i]` to determine whether argument `i` should be simplified. For rewritable arguments, calls `simp` and uses `congrFun'`, `congrArg`, and `congr`; for fixed arguments, uses `congrFun` to propagate changes from earlier arguments. -/ -def congrInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do +def simpInterlaced (e : Expr) (rewritable : Array Bool) : SimpM Result := do let numArgs := e.getAppNumArgs if h : numArgs = 0 then -- Nothing to be done return .rfl else if h : numArgs > rewritable.size then - -- **TODO**: over-applied case - return .rfl + simpOverApplied e (numArgs - rewritable.size) (go rewritable.size · (Nat.le_refl _)) else go numArgs e (by omega) where @@ -262,11 +311,8 @@ See type `CongrArgKind`. - When `xs` or `i` are simplified, the proof is adjusted in the `rhs` of the auto-generated theorem. -/ -def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do +def simpUsingCongrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do let argKinds := thm.argKinds - if e.getAppNumArgs != argKinds.size then - -- **TODO**: over/under-applied - return .rfl /- Constructs the non-`rfl` result. `argResults` contains the result for arguments with kind `.eq`. There is at least one non-`rfl` result in `argResults`. @@ -332,7 +378,17 @@ def congrThm (e : Expr) (thm : CongrTheorem) : SimpM Result := do return .rfl else mkNonRflResult argResults.reverse - simpEqArgs e (argKinds.size - 1) 0 #[] + let numArgs := e.getAppNumArgs + if numArgs > argKinds.size then + simpOverApplied e (numArgs - argKinds.size) (simpEqArgs · (argKinds.size - 1) 0 #[]) + else if numArgs < argKinds.size then + /- + **Note**: under-applied case. This can be optimized, but this case is so + rare that it is not worth doing it. We just reuse `simpOverApplied` + -/ + simpOverApplied e e.getAppNumArgs (fun _ => return .rfl) + else + simpEqArgs e (argKinds.size - 1) 0 #[] /-- Main entry point for simplifying function application arguments. @@ -342,8 +398,8 @@ public def simpAppArgs (e : Expr) : SimpM Result := do let f := e.getAppFn match (← getCongrInfo f) with | .none => return .rfl - | .fixedPrefix prefixSize suffixSize => congrFixedPrefix e prefixSize suffixSize - | .interlaced rewritable => congrInterlaced e rewritable - | .congrTheorem thm => congrThm e thm + | .fixedPrefix prefixSize suffixSize => simpFixedPrefix e prefixSize suffixSize + | .interlaced rewritable => simpInterlaced e rewritable + | .congrTheorem thm => simpUsingCongrThm e thm end Lean.Meta.Sym.Simp diff --git a/tests/lean/run/sym_simp_1.lean b/tests/lean/run/sym_simp_1.lean index e153f03e4e..e010fd33a9 100644 --- a/tests/lean/run/sym_simp_1.lean +++ b/tests/lean/run/sym_simp_1.lean @@ -38,3 +38,88 @@ example (p q : Prop) (hp : p) : if x + 0 = x then p else q := by example (as : Array Int) (i : Nat) (h : 0 + i < as.size) : as[0 + i] = as[i] := by sym_simp [Nat.zero_add, eq_self] + +/-- trace: ⊢ Nat.add 0 = id -/ +#guard_msgs in +example : Nat.add (0 + 0) = id := by + sym_simp [Nat.zero_add] + trace_state + funext + simp + +/-- +trace: a : Nat +β✝ : Type +f : β✝ → Prop +h : HEq a = f +⊢ HEq a = f +-/ +#guard_msgs in +example (h : HEq a = f) : HEq (α := Nat) (0 + a) = f := by + sym_simp [Nat.zero_add] + trace_state + exact h + +/-- +trace: a b : Nat +f : Nat → Nat +h : f a = b +⊢ id f a = b +-/ +#guard_msgs in +example (f : Nat → Nat) (h : f a = b) : id f (0 + a) = b := by + sym_simp [Nat.zero_add] + trace_state + exact h + +def f (_ : α) {β : Type} (b : β) : β := b + +/-- +trace: a : Nat +g : Nat → Nat +⊢ f 0 g a = g a +-/ +#guard_msgs in +example (g : Nat → Nat) : f (0 + 0) g (0 + a) = g a := by + sym_simp [Nat.zero_add] + trace_state + rfl + +def f' (_ : α) (b : β) := b + +/-- +trace: a : Nat +g : Nat → Nat +⊢ f' 0 g a = g a +-/ +#guard_msgs in +example (g : Nat → Nat) : f' (0 + 0) g (0 + a) = g a := by + sym_simp [Nat.zero_add] + trace_state + rfl + +/-- +trace: a b : Nat +as : Array (Nat → Nat) +i : Nat +x✝ : i < as.size +h : as[i] a = b +⊢ as[i] a = b +-/ +#guard_msgs in +example (as : Array (Nat → Nat)) (i : Nat) (_ : i < as.size) (h : as[i] a = b) : as[0 + i] (0 + a) = b := by + sym_simp [Nat.zero_add] + trace_state + exact h + +/-- +trace: c a : Nat +g : Nat → Nat +h : ite (c > 0) a = g +⊢ ite (c > 0) a = g +-/ +#guard_msgs in +example (h : ite (c > 0) a = g) : ite (c > 0) (0 + a) = g := by + sym_simp [Nat.zero_add] + trace_state + exact h