diff --git a/src/Init/Sym/Lemmas.lean b/src/Init/Sym/Lemmas.lean index 9bab1d593d..60c04c14b2 100644 --- a/src/Init/Sym/Lemmas.lean +++ b/src/Init/Sym/Lemmas.lean @@ -15,6 +15,10 @@ namespace Lean.Sym theorem ne_self (a : α) : (a ≠ a) = False := by simp +theorem ite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a b : α) + (c' : Prop) {inst' : Decidable c'} (h : c = c') : @ite α c inst a b = @ite α c' inst' a b := by + simp [*] + theorem Nat.lt_eq_true (a b : Nat) (h : decide (a < b) = true) : (a < b) = True := by simp_all theorem Int.lt_eq_true (a b : Int) (h : decide (a < b) = true) : (a < b) = True := by simp_all theorem Rat.lt_eq_true (a b : Rat) (h : decide (a < b) = true) : (a < b) = True := by simp_all diff --git a/src/Lean/Expr.lean b/src/Lean/Expr.lean index d9108e9720..a89c727952 100644 --- a/src/Lean/Expr.lean +++ b/src/Lean/Expr.lean @@ -2386,4 +2386,27 @@ def eagerReflBoolTrue : Expr := def eagerReflBoolFalse : Expr := mkApp2 (mkConst ``eagerReduce [0]) (mkApp3 (mkConst ``Eq [1]) (mkConst ``Bool) (mkConst ``Bool.false) (mkConst ``Bool.false)) reflBoolFalse +/-- +Replaces the head constant in a function application chain with a different constant. + +Given an expression that is either a constant or a function application chain, +replaces the head constant with `declName` while preserving all arguments and universe levels. + +**Examples**: +- `f.replaceFn g` → `g` (where `f` is a constant) +- `(f a b c).replaceFn g` → `g a b c` +- `(@f.{u, v} a b).replaceFn g` → `@g.{u, v} a b` + +**Panics**: If the expression is neither a constant nor a function application. + +**Use case**: Useful for substituting one function for another while maintaining +the same application structure, such as replacing a theorem with a related theorem +that has the same type and universe parameters. +-/ +def Expr.replaceFn (e : Expr) (declName : Name) : Expr := + match e with + | .app f a => mkApp (f.replaceFn declName) a + | .const _ us => mkConst declName us + | _ => panic! "function application or constant expected" + end Lean diff --git a/src/Lean/Meta/Sym/AlphaShareBuilder.lean b/src/Lean/Meta/Sym/AlphaShareBuilder.lean index 1ad4e2ca39..32acdb8c1e 100644 --- a/src/Lean/Meta/Sym/AlphaShareBuilder.lean +++ b/src/Lean/Meta/Sym/AlphaShareBuilder.lean @@ -177,4 +177,16 @@ def mkHaveS (x : Name) (t : Expr) (v : Expr) (b : Expr) : m Expr := do else mkLetS n newType newVal newBody nondep +def mkAppS₂ (f a₁ a₂ : Expr) : m Expr := do + mkAppS (← mkAppS f a₁) a₂ + +def mkAppS₃ (f a₁ a₂ a₃ : Expr) : m Expr := do + mkAppS (← mkAppS₂ f a₁ a₂) a₃ + +def mkAppS₄ (f a₁ a₂ a₃ a₄ : Expr) : m Expr := do + mkAppS (← mkAppS₃ f a₁ a₂ a₃) a₄ + +def mkAppS₅ (f a₁ a₂ a₃ a₄ a₅ : Expr) : m Expr := do + mkAppS (← mkAppS₄ f a₁ a₂ a₃ a₄) a₅ + end Lean.Meta.Sym.Internal diff --git a/src/Lean/Meta/Sym/Simp.lean b/src/Lean/Meta/Sym/Simp.lean index 3144ae84e6..1b3f11b926 100644 --- a/src/Lean/Meta/Sym/Simp.lean +++ b/src/Lean/Meta/Sym/Simp.lean @@ -20,3 +20,4 @@ public import Lean.Meta.Sym.Simp.Forall public import Lean.Meta.Sym.Simp.Debug public import Lean.Meta.Sym.Simp.EvalGround public import Lean.Meta.Sym.Simp.Discharger +public import Lean.Meta.Sym.Simp.ControlFlow diff --git a/src/Lean/Meta/Sym/Simp/App.lean b/src/Lean/Meta/Sym/Simp/App.lean index 3a311e2b97..b6a78e8464 100644 --- a/src/Lean/Meta/Sym/Simp/App.lean +++ b/src/Lean/Meta/Sym/Simp/App.lean @@ -70,7 +70,7 @@ Returns a proof using `congrFun` congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a ``` -/ -def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a) : SymM Result := do +def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a) (done := false) : SymM Result := do let .forallE x _ βx _ ← whnfD (← inferType f) | throwError "failed to build congruence proof, function expected{indentExpr f}" let α ← inferType a @@ -79,7 +79,7 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a let β := Lean.mkLambda x .default α βx let e' ← mkAppS f' a let h := mkApp6 (mkConst ``congrFun [u, v]) α β f f' hf a - return .step e' h + return .step e' h done /-- Handles simplification of over-applied function terms. @@ -129,6 +129,43 @@ public def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM R | _ => unreachable! visit e numArgs +/-- +Handles over-applied function expressions by simplifying only the base function and +propagating changes through extra arguments WITHOUT simplifying them. + +Unlike `simpOverApplied`, this function does not simplify the extra arguments themselves. +It only uses congruence (`mkCongrFun`) to propagate changes when the base function is simplified. + +**Algorithm**: +1. Peel off `numArgs` extra arguments from `e` +2. Apply `simpFn` to simplify the base function +3. If the base changed, propagate the change through each extra argument using `mkCongrFun` +4. Return `.rfl` if the base function was not simplified + +**Parameters**: +- `e`: The over-applied expression +- `numArgs`: Number of excess arguments to peel off +- `simpFn`: Strategy for simplifying the base function after peeling + +**Contrast with `simpOverApplied`**: +- `simpOverApplied`: Fully simplifies both base and extra arguments +- `propagateOverApplied`: Only simplifies base, appends extra arguments unchanged +-/ +public def propagateOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM Result) : SimpM Result := do + let rec visit (e : Expr) (i : Nat) : SimpM Result := do + if i == 0 then + simpFn e + else + let i := i - 1 + match h : e with + | .app f a => + let r ← visit f i + match r with + | .rfl _ => return r + | .step f' hf done => mkCongrFun e f a f' hf h done + | _ => unreachable! + visit e numArgs + /-- Reduces `type` to weak head normal form and verifies it is a `forall` expression. If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work). diff --git a/src/Lean/Meta/Sym/Simp/ControlFlow.lean b/src/Lean/Meta/Sym/Simp/ControlFlow.lean new file mode 100644 index 0000000000..0d2bf84c3d --- /dev/null +++ b/src/Lean/Meta/Sym/Simp/ControlFlow.lean @@ -0,0 +1,49 @@ +/- +Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +module +prelude +public import Lean.Meta.Sym.Simp.SimpM +import Lean.Meta.Sym.AlphaShareBuilder +import Lean.Meta.Sym.Simp.App +import Lean.Meta.SynthInstance +import Lean.Expr +import Init.Sym.Lemmas +namespace Lean.Meta.Sym.Simp +open Internal + +def simpIte : Simproc := fun e => do + let numArgs := e.getAppNumArgs + if numArgs < 5 then return .rfl (done := true) + propagateOverApplied e (numArgs - 5) fun e => do + let_expr ite _ c _ a b := e | return .rfl + match (← simp c) with + | .rfl _ => return .rfl (done := true) + | .step c' h _ => + if c'.isTrue then + return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h + else if c'.isFalse then + return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h + else + let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl + let e' := e.getBoundedAppFn 4 + let e' ← mkAppS₄ e' c' inst' a b + let h' := mkApp3 (e.replaceFn ``Lean.Sym.ite_cond_congr) c' inst' h + return .step e' h' (done := true) + +/-- +Simplifies control-flow expressions such as `if-then-else` and `match` expressions. +It visits only the conditions and discriminants. +-/ +public def simpControl : Simproc := fun e => do + if !e.isApp then return .rfl + let .const declName _ := e.getAppFn | return .rfl + if declName == ``ite then + simpIte e + else + -- **TODO**: Add more cases + return .rfl + +end Lean.Meta.Sym.Simp diff --git a/src/Lean/Meta/Sym/Simp/EvalGround.lean b/src/Lean/Meta/Sym/Simp/EvalGround.lean index 913e418364..03d1cd8b0e 100644 --- a/src/Lean/Meta/Sym/Simp/EvalGround.lean +++ b/src/Lean/Meta/Sym/Simp/EvalGround.lean @@ -508,7 +508,8 @@ def evalGE (α : Expr) (a b : Expr) : SimpM Result := def evalEq (α : Expr) (a b : Expr) : SimpM Result := if isSameExpr a b then do let e ← share <| mkConst ``True - return .step e (mkApp2 (mkConst ``eq_self [1]) α a) (done := true) + let u ← getLevel α + return .step e (mkApp2 (mkConst ``eq_self [u]) α a) (done := true) else match_expr α with | Nat => evalBinPred getNatValue? (mkConst ``Nat.eq_eq_true) (mkConst ``Nat.eq_eq_false) (. = .) a b | Int => evalBinPred getIntValue? (mkConst ``Int.eq_eq_true) (mkConst ``Int.eq_eq_false) (. = .) a b @@ -528,7 +529,8 @@ def evalEq (α : Expr) (a b : Expr) : SimpM Result := def evalNe (α : Expr) (a b : Expr) : SimpM Result := if isSameExpr a b then do let e ← share <| mkConst ``False - return .step e (mkApp2 (mkConst ``ne_self [1]) α a) (done := true) + let u ← getLevel α + return .step e (mkApp2 (mkConst ``ne_self [u]) α a) (done := true) else match_expr α with | Nat => evalBinPred getNatValue? (mkConst ``Nat.ne_eq_true) (mkConst ``Nat.ne_eq_false) (. ≠ .) a b | Int => evalBinPred getIntValue? (mkConst ``Int.ne_eq_true) (mkConst ``Int.ne_eq_false) (. ≠ .) a b diff --git a/tests/lean/run/sym_simp_3.lean b/tests/lean/run/sym_simp_3.lean index 81c79fb7c9..5ca5fa776e 100644 --- a/tests/lean/run/sym_simp_3.lean +++ b/tests/lean/run/sym_simp_3.lean @@ -3,7 +3,10 @@ open Lean Meta Elab Tactic elab "sym_simp" "[" declNames:ident,* "]" : tactic => do let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf - let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround.andThen rewrite } + let methods : Sym.Simp.Methods := { + pre := Sym.Simp.simpControl + post := Sym.Simp.evalGround.andThen rewrite + } liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods) example : (1-1) + x*1 + (2-1)*0 = x := by @@ -14,3 +17,25 @@ axiom fax : x > 10 → f x = 0 example : f 12 = 0 := by sym_simp [fax] + +example : (if true then a else b) = a := by + sym_simp [] + +example (f g : Nat → Nat) : (if a + 0 = a then f else g) a = f a := by + sym_simp [Nat.add_zero] + +example (f g : Nat → Nat → Nat) : (if a + 0 ≠ a then f else g) a (b + 0) = g a b := by + sym_simp [Nat.add_zero] + +/-- +trace: a b : Nat +f g : Nat → Nat → Nat +h : a = b +⊢ (if a ≠ b then id f else id (id g)) a (b + 0) = g a b +-/ +#guard_msgs in +example (f g : Nat → Nat → Nat) (h : a = b) : (if a + 0 ≠ b then id f else id (id g)) a (b + 0) = g a b := by + sym_simp [Nat.add_zero, id_eq] + trace_state -- `if-then-else` branches should not have been simplified + subst h + sym_simp [Nat.add_zero, id_eq]