feat: add simpControl simproc for if-then-else simplification (#12035)
This PR adds `simpControl`, a simproc that handles control-flow expressions such as `if-then-else`. It simplifies conditions while avoiding unnecessary work on branches that won't be taken. The key behavior of `simpControl`: - Simplifies the condition of `if-then-else` expressions - If the condition reduces to `True` or `False`, returns the appropriate branch, and continue simplifying. - If the condition simplifies to a new expression, rebuilds the `if-then-else` with the simplified condition (synthesizing a new `Decidable` instance), and mark it as "done". That is, simplifier main loop will not visit branches. - Does **not** visit branches unless the condition becomes `True` or `False` This is useful for symbolic simplification where we want to avoid wasting effort simplifying branches that may be eliminated after the condition is resolved. This PR also fixes a bug in `Sym/Simp/EvalGround.lean`, and adds some helper functions.
This commit is contained in:
parent
5457a227ba
commit
f63ddd67a2
8 changed files with 158 additions and 5 deletions
|
|
@ -15,6 +15,10 @@ namespace Lean.Sym
|
|||
|
||||
theorem ne_self (a : α) : (a ≠ a) = False := by simp
|
||||
|
||||
theorem ite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a b : α)
|
||||
(c' : Prop) {inst' : Decidable c'} (h : c = c') : @ite α c inst a b = @ite α c' inst' a b := by
|
||||
simp [*]
|
||||
|
||||
theorem Nat.lt_eq_true (a b : Nat) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
theorem Int.lt_eq_true (a b : Int) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
theorem Rat.lt_eq_true (a b : Rat) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
|
|
|
|||
|
|
@ -2386,4 +2386,27 @@ def eagerReflBoolTrue : Expr :=
|
|||
def eagerReflBoolFalse : Expr :=
|
||||
mkApp2 (mkConst ``eagerReduce [0]) (mkApp3 (mkConst ``Eq [1]) (mkConst ``Bool) (mkConst ``Bool.false) (mkConst ``Bool.false)) reflBoolFalse
|
||||
|
||||
/--
|
||||
Replaces the head constant in a function application chain with a different constant.
|
||||
|
||||
Given an expression that is either a constant or a function application chain,
|
||||
replaces the head constant with `declName` while preserving all arguments and universe levels.
|
||||
|
||||
**Examples**:
|
||||
- `f.replaceFn g` → `g` (where `f` is a constant)
|
||||
- `(f a b c).replaceFn g` → `g a b c`
|
||||
- `(@f.{u, v} a b).replaceFn g` → `@g.{u, v} a b`
|
||||
|
||||
**Panics**: If the expression is neither a constant nor a function application.
|
||||
|
||||
**Use case**: Useful for substituting one function for another while maintaining
|
||||
the same application structure, such as replacing a theorem with a related theorem
|
||||
that has the same type and universe parameters.
|
||||
-/
|
||||
def Expr.replaceFn (e : Expr) (declName : Name) : Expr :=
|
||||
match e with
|
||||
| .app f a => mkApp (f.replaceFn declName) a
|
||||
| .const _ us => mkConst declName us
|
||||
| _ => panic! "function application or constant expected"
|
||||
|
||||
end Lean
|
||||
|
|
|
|||
|
|
@ -177,4 +177,16 @@ def mkHaveS (x : Name) (t : Expr) (v : Expr) (b : Expr) : m Expr := do
|
|||
else
|
||||
mkLetS n newType newVal newBody nondep
|
||||
|
||||
def mkAppS₂ (f a₁ a₂ : Expr) : m Expr := do
|
||||
mkAppS (← mkAppS f a₁) a₂
|
||||
|
||||
def mkAppS₃ (f a₁ a₂ a₃ : Expr) : m Expr := do
|
||||
mkAppS (← mkAppS₂ f a₁ a₂) a₃
|
||||
|
||||
def mkAppS₄ (f a₁ a₂ a₃ a₄ : Expr) : m Expr := do
|
||||
mkAppS (← mkAppS₃ f a₁ a₂ a₃) a₄
|
||||
|
||||
def mkAppS₅ (f a₁ a₂ a₃ a₄ a₅ : Expr) : m Expr := do
|
||||
mkAppS (← mkAppS₄ f a₁ a₂ a₃ a₄) a₅
|
||||
|
||||
end Lean.Meta.Sym.Internal
|
||||
|
|
|
|||
|
|
@ -20,3 +20,4 @@ public import Lean.Meta.Sym.Simp.Forall
|
|||
public import Lean.Meta.Sym.Simp.Debug
|
||||
public import Lean.Meta.Sym.Simp.EvalGround
|
||||
public import Lean.Meta.Sym.Simp.Discharger
|
||||
public import Lean.Meta.Sym.Simp.ControlFlow
|
||||
|
|
|
|||
|
|
@ -70,7 +70,7 @@ Returns a proof using `congrFun`
|
|||
congrFun.{u, v} {α : Sort u} {β : α → Sort v} {f g : (x : α) → β x} (h : f = g) (a : α) : f a = g a
|
||||
```
|
||||
-/
|
||||
def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a) : SymM Result := do
|
||||
def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a) (done := false) : SymM Result := do
|
||||
let .forallE x _ βx _ ← whnfD (← inferType f)
|
||||
| throwError "failed to build congruence proof, function expected{indentExpr f}"
|
||||
let α ← inferType a
|
||||
|
|
@ -79,7 +79,7 @@ def mkCongrFun (e : Expr) (f a : Expr) (f' : Expr) (hf : Expr) (_ : e = .app f a
|
|||
let β := Lean.mkLambda x .default α βx
|
||||
let e' ← mkAppS f' a
|
||||
let h := mkApp6 (mkConst ``congrFun [u, v]) α β f f' hf a
|
||||
return .step e' h
|
||||
return .step e' h done
|
||||
|
||||
/--
|
||||
Handles simplification of over-applied function terms.
|
||||
|
|
@ -129,6 +129,43 @@ public def simpOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM R
|
|||
| _ => unreachable!
|
||||
visit e numArgs
|
||||
|
||||
/--
|
||||
Handles over-applied function expressions by simplifying only the base function and
|
||||
propagating changes through extra arguments WITHOUT simplifying them.
|
||||
|
||||
Unlike `simpOverApplied`, this function does not simplify the extra arguments themselves.
|
||||
It only uses congruence (`mkCongrFun`) to propagate changes when the base function is simplified.
|
||||
|
||||
**Algorithm**:
|
||||
1. Peel off `numArgs` extra arguments from `e`
|
||||
2. Apply `simpFn` to simplify the base function
|
||||
3. If the base changed, propagate the change through each extra argument using `mkCongrFun`
|
||||
4. Return `.rfl` if the base function was not simplified
|
||||
|
||||
**Parameters**:
|
||||
- `e`: The over-applied expression
|
||||
- `numArgs`: Number of excess arguments to peel off
|
||||
- `simpFn`: Strategy for simplifying the base function after peeling
|
||||
|
||||
**Contrast with `simpOverApplied`**:
|
||||
- `simpOverApplied`: Fully simplifies both base and extra arguments
|
||||
- `propagateOverApplied`: Only simplifies base, appends extra arguments unchanged
|
||||
-/
|
||||
public def propagateOverApplied (e : Expr) (numArgs : Nat) (simpFn : Expr → SimpM Result) : SimpM Result := do
|
||||
let rec visit (e : Expr) (i : Nat) : SimpM Result := do
|
||||
if i == 0 then
|
||||
simpFn e
|
||||
else
|
||||
let i := i - 1
|
||||
match h : e with
|
||||
| .app f a =>
|
||||
let r ← visit f i
|
||||
match r with
|
||||
| .rfl _ => return r
|
||||
| .step f' hf done => mkCongrFun e f a f' hf h done
|
||||
| _ => unreachable!
|
||||
visit e numArgs
|
||||
|
||||
/--
|
||||
Reduces `type` to weak head normal form and verifies it is a `forall` expression.
|
||||
If `type` is already a `forall`, returns it unchanged (avoiding unnecessary work).
|
||||
|
|
|
|||
49
src/Lean/Meta/Sym/Simp/ControlFlow.lean
Normal file
49
src/Lean/Meta/Sym/Simp/ControlFlow.lean
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
/-
|
||||
Copyright (c) 2026 Amazon.com, Inc. or its affiliates. All Rights Reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Leonardo de Moura
|
||||
-/
|
||||
module
|
||||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.Sym.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.Simp.App
|
||||
import Lean.Meta.SynthInstance
|
||||
import Lean.Expr
|
||||
import Init.Sym.Lemmas
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Internal
|
||||
|
||||
def simpIte : Simproc := fun e => do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs < 5 then return .rfl (done := true)
|
||||
propagateOverApplied e (numArgs - 5) fun e => do
|
||||
let_expr ite _ c _ a b := e | return .rfl
|
||||
match (← simp c) with
|
||||
| .rfl _ => return .rfl (done := true)
|
||||
| .step c' h _ =>
|
||||
if c'.isTrue then
|
||||
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
|
||||
else if c'.isFalse then
|
||||
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
|
||||
else
|
||||
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
|
||||
let e' := e.getBoundedAppFn 4
|
||||
let e' ← mkAppS₄ e' c' inst' a b
|
||||
let h' := mkApp3 (e.replaceFn ``Lean.Sym.ite_cond_congr) c' inst' h
|
||||
return .step e' h' (done := true)
|
||||
|
||||
/--
|
||||
Simplifies control-flow expressions such as `if-then-else` and `match` expressions.
|
||||
It visits only the conditions and discriminants.
|
||||
-/
|
||||
public def simpControl : Simproc := fun e => do
|
||||
if !e.isApp then return .rfl
|
||||
let .const declName _ := e.getAppFn | return .rfl
|
||||
if declName == ``ite then
|
||||
simpIte e
|
||||
else
|
||||
-- **TODO**: Add more cases
|
||||
return .rfl
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
@ -508,7 +508,8 @@ def evalGE (α : Expr) (a b : Expr) : SimpM Result :=
|
|||
def evalEq (α : Expr) (a b : Expr) : SimpM Result :=
|
||||
if isSameExpr a b then do
|
||||
let e ← share <| mkConst ``True
|
||||
return .step e (mkApp2 (mkConst ``eq_self [1]) α a) (done := true)
|
||||
let u ← getLevel α
|
||||
return .step e (mkApp2 (mkConst ``eq_self [u]) α a) (done := true)
|
||||
else match_expr α with
|
||||
| Nat => evalBinPred getNatValue? (mkConst ``Nat.eq_eq_true) (mkConst ``Nat.eq_eq_false) (. = .) a b
|
||||
| Int => evalBinPred getIntValue? (mkConst ``Int.eq_eq_true) (mkConst ``Int.eq_eq_false) (. = .) a b
|
||||
|
|
@ -528,7 +529,8 @@ def evalEq (α : Expr) (a b : Expr) : SimpM Result :=
|
|||
def evalNe (α : Expr) (a b : Expr) : SimpM Result :=
|
||||
if isSameExpr a b then do
|
||||
let e ← share <| mkConst ``False
|
||||
return .step e (mkApp2 (mkConst ``ne_self [1]) α a) (done := true)
|
||||
let u ← getLevel α
|
||||
return .step e (mkApp2 (mkConst ``ne_self [u]) α a) (done := true)
|
||||
else match_expr α with
|
||||
| Nat => evalBinPred getNatValue? (mkConst ``Nat.ne_eq_true) (mkConst ``Nat.ne_eq_false) (. ≠ .) a b
|
||||
| Int => evalBinPred getIntValue? (mkConst ``Int.ne_eq_true) (mkConst ``Int.ne_eq_false) (. ≠ .) a b
|
||||
|
|
|
|||
|
|
@ -3,7 +3,10 @@ open Lean Meta Elab Tactic
|
|||
|
||||
elab "sym_simp" "[" declNames:ident,* "]" : tactic => do
|
||||
let rewrite ← Sym.mkSimprocFor (← declNames.getElems.mapM fun s => realizeGlobalConstNoOverload s.raw) Sym.Simp.dischargeSimpSelf
|
||||
let methods : Sym.Simp.Methods := { post := Sym.Simp.evalGround.andThen rewrite }
|
||||
let methods : Sym.Simp.Methods := {
|
||||
pre := Sym.Simp.simpControl
|
||||
post := Sym.Simp.evalGround.andThen rewrite
|
||||
}
|
||||
liftMetaTactic1 <| Sym.simpWith (Sym.simp · methods)
|
||||
|
||||
example : (1-1) + x*1 + (2-1)*0 = x := by
|
||||
|
|
@ -14,3 +17,25 @@ axiom fax : x > 10 → f x = 0
|
|||
|
||||
example : f 12 = 0 := by
|
||||
sym_simp [fax]
|
||||
|
||||
example : (if true then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example (f g : Nat → Nat) : (if a + 0 = a then f else g) a = f a := by
|
||||
sym_simp [Nat.add_zero]
|
||||
|
||||
example (f g : Nat → Nat → Nat) : (if a + 0 ≠ a then f else g) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero]
|
||||
|
||||
/--
|
||||
trace: a b : Nat
|
||||
f g : Nat → Nat → Nat
|
||||
h : a = b
|
||||
⊢ (if a ≠ b then id f else id (id g)) a (b + 0) = g a b
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (f g : Nat → Nat → Nat) (h : a = b) : (if a + 0 ≠ b then id f else id (id g)) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero, id_eq]
|
||||
trace_state -- `if-then-else` branches should not have been simplified
|
||||
subst h
|
||||
sym_simp [Nat.add_zero, id_eq]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue