diff --git a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean index 143f226246..92df535ef2 100644 --- a/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean +++ b/src/Lean/Meta/Tactic/Simp/SimpTheorems.lean @@ -422,7 +422,20 @@ def SimpTheorems.addDeclToUnfold (d : SimpTheorems) (declName : Name) : MetaM Si let mut d := d for eqn in eqns do d ← SimpTheorems.addConst d eqn - if hasSmartUnfoldingDecl (← getEnv) declName then + /- + Even if a function has equation theorems, + we also store it in the `toUnfold` set in the following two cases: + 1- It was defined by structural recursion and has a smart-unfolding associated declaration. + 2- It is non-recursive. + + Reason: `unfoldPartialApp := true` or conditional equations may not apply. + + Remark: In the future, we are planning to disable this + behavior unless `unfoldPartialApp := true`. + Moreover, users will have to use `f.eq_def` if they want to force the definition to be + unfolded. + -/ + if hasSmartUnfoldingDecl (← getEnv) declName || !(← isRecursiveDefinition declName) then d := d.addDeclToUnfoldCore declName return d else diff --git a/tests/lean/run/unfoldPartialRegression.lean b/tests/lean/run/unfoldPartialRegression.lean new file mode 100644 index 0000000000..4e0c0d3bde --- /dev/null +++ b/tests/lean/run/unfoldPartialRegression.lean @@ -0,0 +1,47 @@ +universe u + +class Zero (α : Type u) where + zero : α + +instance (priority := 300) Zero.toOfNat0 {α} [Zero α] : OfNat α (nat_lit 0) where + ofNat := ‹Zero α›.1 + +class One (α : Type u) where + one : α + +instance (priority := 300) One.toOfNat1 {α} [One α] : OfNat α (nat_lit 1) where + ofNat := ‹One α›.1 +instance (priority := 200) One.ofOfNat1 {α} [OfNat α (nat_lit 1)] : One α where + one := 1 + +@[match_pattern] def bit0 {α : Type u} [Add α] (a : α) : α := a + a + +@[match_pattern] def bit1 {α : Type u} [One α] [Add α] (a : α) : α := bit0 a + 1 + +class AddZeroClass (M : Type u) extends Zero M, Add M where + zero_add : ∀ a : M, 0 + a = a + add_zero : ∀ a : M, a + 0 = a + +open AddZeroClass + +theorem bit0_zero {M} [AddZeroClass M] : bit0 (0 : M) = 0 := + add_zero _ + +def bit (b : Bool) : Nat → Nat := + cond b bit1 bit0 + +-- This is `Nat.bit_mod_two` from `Mathlib.Data.Nat.Bitwise`. +-- Here it works fine: +example (a : Bool) (x : Nat) : + bit a x % 2 = if a then 1 else 0 := by + simp (config := { unfoldPartialApp := true }) only [bit, bit1, bit0, ← Nat.mul_two, Bool.cond_eq_ite] + split <;> simp [Nat.add_mod] + +-- Now prove one more theorem +theorem bit1_zero {M} [AddZeroClass M] [One M] : bit1 (0 : M) = 1 := by rw [bit1, bit0_zero, zero_add] + +-- Now try again: +example (a : Bool) (x : Nat) : + bit a x % 2 = if a then 1 else 0 := by + simp (config := { unfoldPartialApp := true }) only [bit, bit1, bit0, ← Nat.mul_two, Bool.cond_eq_ite] + split <;> simp [Nat.add_mod] -- fails