feat: add simprocs for cond and dependent if-then-else in Sym.simp (#12040)

This PR adds simprocs for simplifying `cond` and dependent
`if-then-else` in `Sym.simp`.
This commit is contained in:
Leonardo de Moura 2026-01-18 17:35:09 -08:00 committed by GitHub
parent c3726bdf05
commit ea9c7cf2ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 154 additions and 4 deletions

View file

@ -19,6 +19,20 @@ theorem ite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a b : α)
(c' : Prop) {inst' : Decidable c'} (h : c = c') : @ite α c inst a b = @ite α c' inst' a b := by
simp [*]
theorem dite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a : c → α) (b : ¬ c → α)
(c' : Prop) {inst' : Decidable c'} (h : c = c')
: @dite α c inst a b = @dite α c' inst' (fun h' => a (h.mpr_prop h')) (fun h' => b (h.mpr_not h')) := by
simp [*]
theorem cond_cond_eq_true {α : Sort u} (c : Bool) (a b : α) (h : c = true) : cond c a b = a := by
simp [*]
theorem cond_cond_eq_false {α : Sort u} (c : Bool) (a b : α) (h : c = false) : cond c a b = b := by
simp [*]
theorem cond_cond_congr {α : Sort u} (c : Bool) (a b : α) (c' : Bool) (h : c = c') : cond c a b = cond c' a b := by
simp [*]
theorem Nat.lt_eq_true (a b : Nat) (h : decide (a < b) = true) : (a < b) = True := by simp_all
theorem Int.lt_eq_true (a b : Int) (h : decide (a < b) = true) : (a < b) = True := by simp_all
theorem Rat.lt_eq_true (a b : Rat) (h : decide (a < b) = true) : (a < b) = True := by simp_all

View file

@ -7,10 +7,12 @@ module
prelude
public import Lean.Meta.Sym.Simp.SimpM
import Lean.Meta.Sym.AlphaShareBuilder
import Lean.Meta.Sym.InstantiateS
import Lean.Meta.Sym.InferType
import Lean.Meta.Sym.Simp.App
import Lean.Meta.SynthInstance
import Lean.Meta.WHNF
import Lean.Meta.AppBuilder
import Init.Sym.Lemmas
namespace Lean.Meta.Sym.Simp
open Internal
@ -22,9 +24,15 @@ def simpIte : Simproc := fun e => do
let numArgs := e.getAppNumArgs
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr ite _ c _ a b := e | return .rfl
let_expr f@ite α c _ a b := e | return .rfl
match (← simp c) with
| .rfl _ => return .rfl (done := true)
| .rfl _ =>
if c.isTrue then
return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b
else if c.isFalse then
return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b
else
return .rfl (done := true)
| .step c' h _ =>
if c'.isTrue then
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
@ -32,9 +40,75 @@ def simpIte : Simproc := fun e => do
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' ← shareCommon inst'
let e' := e.getBoundedAppFn 4
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Lean.Sym.ite_cond_congr) c' inst' h
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h
return .step e' h' (done := true)
/--
Simplifies a dependent `if-then-else` expression.
-/
def simpDIte : Simproc := fun e => do
let numArgs := e.getAppNumArgs
if numArgs < 5 then return .rfl (done := true)
propagateOverApplied e (numArgs - 5) fun e => do
let_expr f@dite α c _ a b := e | return .rfl
match (← simp c) with
| .rfl _ =>
if c.isTrue then
let a' ← share <| a.betaRev #[mkConst ``True.intro]
return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b
else if c.isFalse then
let b' ← share <| b.betaRev #[mkConst ``not_false]
return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b
else
return .rfl (done := true)
| .step c' h _ =>
if c'.isTrue then
let h' ← shareCommon <| mkOfEqTrueCore c h
let a ← share <| a.betaRev #[h']
return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h
else if c'.isFalse then
let h' ← shareCommon <| mkOfEqFalseCore c h
let b ← share <| b.betaRev #[h']
return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h
else
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
let inst' ← shareCommon inst'
let e' := e.getBoundedAppFn 4
let h ← shareCommon h
let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)])
let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
let e' ← mkAppS₄ e' c' inst' a b
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
return .step e' h' (done := true)
/--
Simplifies a `cond` expression (aka Boolean `if-then-else`).
-/
def simpCond : Simproc := fun e => do
let numArgs := e.getAppNumArgs
if numArgs < 4 then return .rfl (done := true)
propagateOverApplied e (numArgs - 4) fun e => do
let_expr f@cond α c a b := e | return .rfl
match (← simp c) with
| .rfl _ =>
if c.isConstOf ``true then
return .step a <| mkApp3 (mkConst ``cond_true f.constLevels!) α a b
else if c.isConstOf ``false then
return .step b <| mkApp3 (mkConst ``cond_false f.constLevels!) α a b
else
return .rfl (done := true)
| .step c' h _ =>
if c'.isConstOf ``true then
return .step a <| mkApp (e.replaceFn ``Sym.cond_cond_eq_true) h
else if c'.isConstOf ``false then
return .step b <| mkApp (e.replaceFn ``Sym.cond_cond_eq_false) h
else
let e' := e.getBoundedAppFn 3
let e' ← mkAppS₃ e' c' a b
let h' := mkApp2 (e.replaceFn ``Sym.cond_cond_congr) c' h
return .step e' h' (done := true)
/--
@ -62,8 +136,11 @@ public def simpControl : Simproc := fun e => do
let .const declName _ := e.getAppFn | return .rfl
if declName == ``ite then
simpIte e
else if declName == ``cond then
simpCond e
else if declName == ``dite then
simpDIte e
else
-- **TODO**: Add more cases
simpMatch declName e
end Lean.Meta.Sym.Simp

View file

@ -21,6 +21,12 @@ example : f 12 = 0 := by
example : (if true then a else b) = a := by
sym_simp []
example : (if True then a else b) = a := by
sym_simp []
example : (if False then a else b) = b := by
sym_simp []
example (f g : Nat → Nat) : (if a + 0 = a then f else g) a = f a := by
sym_simp [Nat.add_zero]
@ -110,3 +116,56 @@ example : (match Foo.mk3 c, Foo.mk2 b with | .mk1 _, _ => 1+0 | _, .mk2 _ => 2+1
example : (match (true, false, true) with | (false, _, _) => 1 | (_, false, _) => 2 | _ => 3) = 2 := by
sym_simp []
example : (if _ : true then a else b) = a := by
sym_simp []
example : (if _ : True then a else b) = a := by
sym_simp []
example : (if _ : False then a else b) = b := by
sym_simp []
example (f g : Nat → Nat) : (if _ : a + 0 = a then f else g) a = f a := by
sym_simp [Nat.add_zero]
example (f g : Nat → Nat → Nat) : (if _ : a + 0 ≠ a then f else g) a (b + 0) = g a b := by
sym_simp [Nat.add_zero]
/--
trace: a b : Nat
f g : Nat → Nat → Nat
h : a = b
⊢ (if h : a ≠ b then id f else id (id g)) a (b + 0) = g a b
-/
#guard_msgs in
example (f g : Nat → Nat → Nat) (h : a = b) : (if _ : a + 0 ≠ b then id f else id (id g)) a (b + 0) = g a b := by
sym_simp [Nat.add_zero, id_eq]
trace_state -- `if-then-else` branches should not have been simplified
subst h
sym_simp [Nat.add_zero, id_eq]
example : (bif true then a else b) = a := by
sym_simp []
example : (bif false then a else b) = b := by
sym_simp []
example (f g : Nat → Nat) : (bif a + 0 == a then f else g) a = f a := by
sym_simp [Nat.add_zero, beq_self_eq_true]
example (f g : Nat → Nat → Nat) : (bif a + 0 != a then f else g) a (b + 0) = g a b := by
sym_simp [Nat.add_zero, bne_self_eq_false]
/--
trace: a b : Nat
f g : Nat → Nat → Nat
h : a = b
⊢ (bif a != b then id f else id (id g)) a (b + 0) = g a b
-/
#guard_msgs in
example (f g : Nat → Nat → Nat) (h : a = b) : (bif a + 0 != b then id f else id (id g)) a (b + 0) = g a b := by
sym_simp [Nat.add_zero, id_eq]
trace_state -- `cond` branches should not have been simplified
subst h
sym_simp [Nat.add_zero, bne_self_eq_false, id_eq]