diff --git a/src/Init/Sym/Lemmas.lean b/src/Init/Sym/Lemmas.lean index 60c04c14b2..6b568b0c7f 100644 --- a/src/Init/Sym/Lemmas.lean +++ b/src/Init/Sym/Lemmas.lean @@ -19,6 +19,20 @@ theorem ite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a b : α) (c' : Prop) {inst' : Decidable c'} (h : c = c') : @ite α c inst a b = @ite α c' inst' a b := by simp [*] +theorem dite_cond_congr {α : Sort u} (c : Prop) {inst : Decidable c} (a : c → α) (b : ¬ c → α) + (c' : Prop) {inst' : Decidable c'} (h : c = c') + : @dite α c inst a b = @dite α c' inst' (fun h' => a (h.mpr_prop h')) (fun h' => b (h.mpr_not h')) := by + simp [*] + +theorem cond_cond_eq_true {α : Sort u} (c : Bool) (a b : α) (h : c = true) : cond c a b = a := by + simp [*] + +theorem cond_cond_eq_false {α : Sort u} (c : Bool) (a b : α) (h : c = false) : cond c a b = b := by + simp [*] + +theorem cond_cond_congr {α : Sort u} (c : Bool) (a b : α) (c' : Bool) (h : c = c') : cond c a b = cond c' a b := by + simp [*] + theorem Nat.lt_eq_true (a b : Nat) (h : decide (a < b) = true) : (a < b) = True := by simp_all theorem Int.lt_eq_true (a b : Int) (h : decide (a < b) = true) : (a < b) = True := by simp_all theorem Rat.lt_eq_true (a b : Rat) (h : decide (a < b) = true) : (a < b) = True := by simp_all diff --git a/src/Lean/Meta/Sym/Simp/ControlFlow.lean b/src/Lean/Meta/Sym/Simp/ControlFlow.lean index d7a97fab5e..51095c2d19 100644 --- a/src/Lean/Meta/Sym/Simp/ControlFlow.lean +++ b/src/Lean/Meta/Sym/Simp/ControlFlow.lean @@ -7,10 +7,12 @@ module prelude public import Lean.Meta.Sym.Simp.SimpM import Lean.Meta.Sym.AlphaShareBuilder +import Lean.Meta.Sym.InstantiateS import Lean.Meta.Sym.InferType import Lean.Meta.Sym.Simp.App import Lean.Meta.SynthInstance import Lean.Meta.WHNF +import Lean.Meta.AppBuilder import Init.Sym.Lemmas namespace Lean.Meta.Sym.Simp open Internal @@ -22,9 +24,15 @@ def simpIte : Simproc := fun e => do let numArgs := e.getAppNumArgs if numArgs < 5 then return .rfl (done := true) propagateOverApplied e (numArgs - 5) fun e => do - let_expr ite _ c _ a b := e | return .rfl + let_expr f@ite α c _ a b := e | return .rfl match (← simp c) with - | .rfl _ => return .rfl (done := true) + | .rfl _ => + if c.isTrue then + return .step a <| mkApp3 (mkConst ``ite_true f.constLevels!) α a b + else if c.isFalse then + return .step b <| mkApp3 (mkConst ``ite_false f.constLevels!) α a b + else + return .rfl (done := true) | .step c' h _ => if c'.isTrue then return .step a <| mkApp (e.replaceFn ``ite_cond_eq_true) h @@ -32,9 +40,75 @@ def simpIte : Simproc := fun e => do return .step b <| mkApp (e.replaceFn ``ite_cond_eq_false) h else let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl + let inst' ← shareCommon inst' let e' := e.getBoundedAppFn 4 let e' ← mkAppS₄ e' c' inst' a b - let h' := mkApp3 (e.replaceFn ``Lean.Sym.ite_cond_congr) c' inst' h + let h' := mkApp3 (e.replaceFn ``Sym.ite_cond_congr) c' inst' h + return .step e' h' (done := true) + +/-- +Simplifies a dependent `if-then-else` expression. +-/ +def simpDIte : Simproc := fun e => do + let numArgs := e.getAppNumArgs + if numArgs < 5 then return .rfl (done := true) + propagateOverApplied e (numArgs - 5) fun e => do + let_expr f@dite α c _ a b := e | return .rfl + match (← simp c) with + | .rfl _ => + if c.isTrue then + let a' ← share <| a.betaRev #[mkConst ``True.intro] + return .step a' <| mkApp3 (mkConst ``dite_true f.constLevels!) α a b + else if c.isFalse then + let b' ← share <| b.betaRev #[mkConst ``not_false] + return .step b' <| mkApp3 (mkConst ``dite_false f.constLevels!) α a b + else + return .rfl (done := true) + | .step c' h _ => + if c'.isTrue then + let h' ← shareCommon <| mkOfEqTrueCore c h + let a ← share <| a.betaRev #[h'] + return .step a <| mkApp (e.replaceFn ``dite_cond_eq_true) h + else if c'.isFalse then + let h' ← shareCommon <| mkOfEqFalseCore c h + let b ← share <| b.betaRev #[h'] + return .step b <| mkApp (e.replaceFn ``dite_cond_eq_false) h + else + let .some inst' ← trySynthInstance (mkApp (mkConst ``Decidable) c') | return .rfl + let inst' ← shareCommon inst' + let e' := e.getBoundedAppFn 4 + let h ← shareCommon h + let a ← share <| mkLambda `h .default c' (a.betaRev #[mkApp4 (mkConst ``Eq.mpr_prop) c c' h (mkBVar 0)]) + let b ← share <| mkLambda `h .default (mkNot c') (b.betaRev #[mkApp4 (mkConst ``Eq.mpr_not) c c' h (mkBVar 0)]) + let e' ← mkAppS₄ e' c' inst' a b + let h' := mkApp3 (e.replaceFn ``Sym.dite_cond_congr) c' inst' h + return .step e' h' (done := true) + +/-- +Simplifies a `cond` expression (aka Boolean `if-then-else`). +-/ +def simpCond : Simproc := fun e => do + let numArgs := e.getAppNumArgs + if numArgs < 4 then return .rfl (done := true) + propagateOverApplied e (numArgs - 4) fun e => do + let_expr f@cond α c a b := e | return .rfl + match (← simp c) with + | .rfl _ => + if c.isConstOf ``true then + return .step a <| mkApp3 (mkConst ``cond_true f.constLevels!) α a b + else if c.isConstOf ``false then + return .step b <| mkApp3 (mkConst ``cond_false f.constLevels!) α a b + else + return .rfl (done := true) + | .step c' h _ => + if c'.isConstOf ``true then + return .step a <| mkApp (e.replaceFn ``Sym.cond_cond_eq_true) h + else if c'.isConstOf ``false then + return .step b <| mkApp (e.replaceFn ``Sym.cond_cond_eq_false) h + else + let e' := e.getBoundedAppFn 3 + let e' ← mkAppS₃ e' c' a b + let h' := mkApp2 (e.replaceFn ``Sym.cond_cond_congr) c' h return .step e' h' (done := true) /-- @@ -62,8 +136,11 @@ public def simpControl : Simproc := fun e => do let .const declName _ := e.getAppFn | return .rfl if declName == ``ite then simpIte e + else if declName == ``cond then + simpCond e + else if declName == ``dite then + simpDIte e else - -- **TODO**: Add more cases simpMatch declName e end Lean.Meta.Sym.Simp diff --git a/tests/lean/run/sym_simp_3.lean b/tests/lean/run/sym_simp_3.lean index f3be400a5f..34e6eb4564 100644 --- a/tests/lean/run/sym_simp_3.lean +++ b/tests/lean/run/sym_simp_3.lean @@ -21,6 +21,12 @@ example : f 12 = 0 := by example : (if true then a else b) = a := by sym_simp [] +example : (if True then a else b) = a := by + sym_simp [] + +example : (if False then a else b) = b := by + sym_simp [] + example (f g : Nat → Nat) : (if a + 0 = a then f else g) a = f a := by sym_simp [Nat.add_zero] @@ -110,3 +116,56 @@ example : (match Foo.mk3 c, Foo.mk2 b with | .mk1 _, _ => 1+0 | _, .mk2 _ => 2+1 example : (match (true, false, true) with | (false, _, _) => 1 | (_, false, _) => 2 | _ => 3) = 2 := by sym_simp [] + +example : (if _ : true then a else b) = a := by + sym_simp [] + +example : (if _ : True then a else b) = a := by + sym_simp [] + +example : (if _ : False then a else b) = b := by + sym_simp [] + +example (f g : Nat → Nat) : (if _ : a + 0 = a then f else g) a = f a := by + sym_simp [Nat.add_zero] + +example (f g : Nat → Nat → Nat) : (if _ : a + 0 ≠ a then f else g) a (b + 0) = g a b := by + sym_simp [Nat.add_zero] + +/-- +trace: a b : Nat +f g : Nat → Nat → Nat +h : a = b +⊢ (if h : a ≠ b then id f else id (id g)) a (b + 0) = g a b +-/ +#guard_msgs in +example (f g : Nat → Nat → Nat) (h : a = b) : (if _ : a + 0 ≠ b then id f else id (id g)) a (b + 0) = g a b := by + sym_simp [Nat.add_zero, id_eq] + trace_state -- `if-then-else` branches should not have been simplified + subst h + sym_simp [Nat.add_zero, id_eq] + +example : (bif true then a else b) = a := by + sym_simp [] + +example : (bif false then a else b) = b := by + sym_simp [] + +example (f g : Nat → Nat) : (bif a + 0 == a then f else g) a = f a := by + sym_simp [Nat.add_zero, beq_self_eq_true] + +example (f g : Nat → Nat → Nat) : (bif a + 0 != a then f else g) a (b + 0) = g a b := by + sym_simp [Nat.add_zero, bne_self_eq_false] + +/-- +trace: a b : Nat +f g : Nat → Nat → Nat +h : a = b +⊢ (bif a != b then id f else id (id g)) a (b + 0) = g a b +-/ +#guard_msgs in +example (f g : Nat → Nat → Nat) (h : a = b) : (bif a + 0 != b then id f else id (id g)) a (b + 0) = g a b := by + sym_simp [Nat.add_zero, id_eq] + trace_state -- `cond` branches should not have been simplified + subst h + sym_simp [Nat.add_zero, bne_self_eq_false, id_eq]