feat: add simprocs for cond and dependent if-then-else in Sym.simp (#12040)
This PR adds simprocs for simplifying `cond` and dependent `if-then-else` in `Sym.simp`.
This commit is contained in:
parent
c3726bdf05
commit
ea9c7cf2ae
3 changed files with 154 additions and 4 deletions
|
|
@ -19,6 +19,20 @@ theorem ite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a b : α)
|
|||
(c' : Prop) {inst' : Decidable c'} (h : c = c') : @ite α c inst a b = @ite α c' inst' a b := by
|
||||
simp [*]
|
||||
|
||||
theorem dite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a : c → α) (b : ¬ c → α)
|
||||
(c' : Prop) {inst' : Decidable c'} (h : c = c')
|
||||
: @dite α c inst a b = @dite α c' inst' (fun h' => a (h.mpr_prop h')) (fun h' => b (h.mpr_not h')) := by
|
||||
simp [*]
|
||||
|
||||
theorem cond_cond_eq_true {α : Sort u} (c : Bool) (a b : α) (h : c = true) : cond c a b = a := by
|
||||
simp [*]
|
||||
|
||||
theorem cond_cond_eq_false {α : Sort u} (c : Bool) (a b : α) (h : c = false) : cond c a b = b := by
|
||||
simp [*]
|
||||
|
||||
theorem cond_cond_congr {α : Sort u} (c : Bool) (a b : α) (c' : Bool) (h : c = c') : cond c a b = cond c' a b := by
|
||||
simp [*]
|
||||
|
||||
theorem Nat.lt_eq_true (a b : Nat) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
theorem Int.lt_eq_true (a b : Int) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
theorem Rat.lt_eq_true (a b : Rat) (h : decide (a < b) = true) : (a < b) = True := by simp_all
|
||||
|
|
|
|||
|
|
@ -7,10 +7,12 @@ module
|
|||
prelude
|
||||
public import Lean.Meta.Sym.Simp.SimpM
|
||||
import Lean.Meta.Sym.AlphaShareBuilder
|
||||
import Lean.Meta.Sym.InstantiateS
|
||||
import Lean.Meta.Sym.InferType
|
||||
import Lean.Meta.Sym.Simp.App
|
||||
import Lean.Meta.SynthInstance
|
||||
import Lean.Meta.WHNF
|
||||
import Lean.Meta.AppBuilder
|
||||
import Init.Sym.Lemmas
|
||||
namespace Lean.Meta.Sym.Simp
|
||||
open Internal
|
||||
|
|
@ -22,9 +24,15 @@ def simpIte : Simproc := fun e => do
|
|||
let numArgs := e.getAppNumArgs
|
||||
if numArgs < 5 then return .rfl (done := true)
|
||||
propagateOverApplied e (numArgs - 5) fun e => do
|
||||
let_expr ite _ c _ a b := e | return .rfl
|
||||
let_expr f@ite α c _ a b := e | return .rfl
|
||||
match (← simp c) with
|
||||
| .rfl _ => return .rfl (done := true)
|
||||
| .rfl _ =>
|
||||
if c.isTrue then
|
||||
return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b
|
||||
else if c.isFalse then
|
||||
return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b
|
||||
else
|
||||
return .rfl (done := true)
|
||||
| .step c' h _ =>
|
||||
if c'.isTrue then
|
||||
return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h
|
||||
|
|
@ -32,9 +40,75 @@ def simpIte : Simproc := fun e => do
|
|||
return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h
|
||||
else
|
||||
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
|
||||
let inst' ← shareCommon inst'
|
||||
let e' := e.getBoundedAppFn 4
|
||||
let e' ← mkAppS₄ e' c' inst' a b
|
||||
let h' := mkApp3 (e.replaceFn ``Lean.Sym.ite_cond_congr) c' inst' h
|
||||
let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h
|
||||
return .step e' h' (done := true)
|
||||
|
||||
/--
|
||||
Simplifies a dependent `if-then-else` expression.
|
||||
-/
|
||||
def simpDIte : Simproc := fun e => do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs < 5 then return .rfl (done := true)
|
||||
propagateOverApplied e (numArgs - 5) fun e => do
|
||||
let_expr f@dite α c _ a b := e | return .rfl
|
||||
match (← simp c) with
|
||||
| .rfl _ =>
|
||||
if c.isTrue then
|
||||
let a' ← share <| a.betaRev #[mkConst ``True.intro]
|
||||
return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b
|
||||
else if c.isFalse then
|
||||
let b' ← share <| b.betaRev #[mkConst ``not_false]
|
||||
return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b
|
||||
else
|
||||
return .rfl (done := true)
|
||||
| .step c' h _ =>
|
||||
if c'.isTrue then
|
||||
let h' ← shareCommon <| mkOfEqTrueCore c h
|
||||
let a ← share <| a.betaRev #[h']
|
||||
return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h
|
||||
else if c'.isFalse then
|
||||
let h' ← shareCommon <| mkOfEqFalseCore c h
|
||||
let b ← share <| b.betaRev #[h']
|
||||
return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h
|
||||
else
|
||||
let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl
|
||||
let inst' ← shareCommon inst'
|
||||
let e' := e.getBoundedAppFn 4
|
||||
let h ← shareCommon h
|
||||
let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)])
|
||||
let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)])
|
||||
let e' ← mkAppS₄ e' c' inst' a b
|
||||
let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h
|
||||
return .step e' h' (done := true)
|
||||
|
||||
/--
|
||||
Simplifies a `cond` expression (aka Boolean `if-then-else`).
|
||||
-/
|
||||
def simpCond : Simproc := fun e => do
|
||||
let numArgs := e.getAppNumArgs
|
||||
if numArgs < 4 then return .rfl (done := true)
|
||||
propagateOverApplied e (numArgs - 4) fun e => do
|
||||
let_expr f@cond α c a b := e | return .rfl
|
||||
match (← simp c) with
|
||||
| .rfl _ =>
|
||||
if c.isConstOf ``true then
|
||||
return .step a <| mkApp3 (mkConst ``cond_true f.constLevels!) α a b
|
||||
else if c.isConstOf ``false then
|
||||
return .step b <| mkApp3 (mkConst ``cond_false f.constLevels!) α a b
|
||||
else
|
||||
return .rfl (done := true)
|
||||
| .step c' h _ =>
|
||||
if c'.isConstOf ``true then
|
||||
return .step a <| mkApp (e.replaceFn ``Sym.cond_cond_eq_true) h
|
||||
else if c'.isConstOf ``false then
|
||||
return .step b <| mkApp (e.replaceFn ``Sym.cond_cond_eq_false) h
|
||||
else
|
||||
let e' := e.getBoundedAppFn 3
|
||||
let e' ← mkAppS₃ e' c' a b
|
||||
let h' := mkApp2 (e.replaceFn ``Sym.cond_cond_congr) c' h
|
||||
return .step e' h' (done := true)
|
||||
|
||||
/--
|
||||
|
|
@ -62,8 +136,11 @@ public def simpControl : Simproc := fun e => do
|
|||
let .const declName _ := e.getAppFn | return .rfl
|
||||
if declName == ``ite then
|
||||
simpIte e
|
||||
else if declName == ``cond then
|
||||
simpCond e
|
||||
else if declName == ``dite then
|
||||
simpDIte e
|
||||
else
|
||||
-- **TODO**: Add more cases
|
||||
simpMatch declName e
|
||||
|
||||
end Lean.Meta.Sym.Simp
|
||||
|
|
|
|||
|
|
@ -21,6 +21,12 @@ example : f 12 = 0 := by
|
|||
example : (if true then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example : (if True then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example : (if False then a else b) = b := by
|
||||
sym_simp []
|
||||
|
||||
example (f g : Nat → Nat) : (if a + 0 = a then f else g) a = f a := by
|
||||
sym_simp [Nat.add_zero]
|
||||
|
||||
|
|
@ -110,3 +116,56 @@ example : (match Foo.mk3 c, Foo.mk2 b with | .mk1 _, _ => 1+0 | _, .mk2 _ => 2+1
|
|||
|
||||
example : (match (true, false, true) with | (false, _, _) => 1 | (_, false, _) => 2 | _ => 3) = 2 := by
|
||||
sym_simp []
|
||||
|
||||
example : (if _ : true then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example : (if _ : True then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example : (if _ : False then a else b) = b := by
|
||||
sym_simp []
|
||||
|
||||
example (f g : Nat → Nat) : (if _ : a + 0 = a then f else g) a = f a := by
|
||||
sym_simp [Nat.add_zero]
|
||||
|
||||
example (f g : Nat → Nat → Nat) : (if _ : a + 0 ≠ a then f else g) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero]
|
||||
|
||||
/--
|
||||
trace: a b : Nat
|
||||
f g : Nat → Nat → Nat
|
||||
h : a = b
|
||||
⊢ (if h : a ≠ b then id f else id (id g)) a (b + 0) = g a b
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (f g : Nat → Nat → Nat) (h : a = b) : (if _ : a + 0 ≠ b then id f else id (id g)) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero, id_eq]
|
||||
trace_state -- `if-then-else` branches should not have been simplified
|
||||
subst h
|
||||
sym_simp [Nat.add_zero, id_eq]
|
||||
|
||||
example : (bif true then a else b) = a := by
|
||||
sym_simp []
|
||||
|
||||
example : (bif false then a else b) = b := by
|
||||
sym_simp []
|
||||
|
||||
example (f g : Nat → Nat) : (bif a + 0 == a then f else g) a = f a := by
|
||||
sym_simp [Nat.add_zero, beq_self_eq_true]
|
||||
|
||||
example (f g : Nat → Nat → Nat) : (bif a + 0 != a then f else g) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero, bne_self_eq_false]
|
||||
|
||||
/--
|
||||
trace: a b : Nat
|
||||
f g : Nat → Nat → Nat
|
||||
h : a = b
|
||||
⊢ (bif a != b then id f else id (id g)) a (b + 0) = g a b
|
||||
-/
|
||||
#guard_msgs in
|
||||
example (f g : Nat → Nat → Nat) (h : a = b) : (bif a + 0 != b then id f else id (id g)) a (b + 0) = g a b := by
|
||||
sym_simp [Nat.add_zero, id_eq]
|
||||
trace_state -- `cond` branches should not have been simplified
|
||||
subst h
|
||||
sym_simp [Nat.add_zero, bne_self_eq_false, id_eq]
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue