From a63bafcc5c4958d7482d5145ae199a333f3d263f Mon Sep 17 00:00:00 2001 From: Mario Carneiro Date: Thu, 22 Jun 2017 18:51:25 -0400 Subject: [PATCH] refactor(init/data/nat/bitwise): change definitions to avoid WF The type-correctness of binary_rec_eq (the statement, not the proof) depends on unfolding the embedded well-founded definition of mod. This definition avoids it by using two simpler functions bodd and div2 that reduce well in the kernel. --- library/init/data/nat/bitwise.lean | 150 ++++++++++++++++++++--------- 1 file changed, 103 insertions(+), 47 deletions(-) diff --git a/library/init/data/nat/bitwise.lean b/library/init/data/nat/bitwise.lean index 886c3e7a32..2a03148c0b 100644 --- a/library/init/data/nat/bitwise.lean +++ b/library/init/data/nat/bitwise.lean @@ -11,15 +11,78 @@ universe u namespace nat + def bodd_div2 : ℕ → bool × ℕ + | 0 := (ff, 0) + | (succ n) := + match bodd_div2 n with + | (ff, m) := (tt, m) + | (tt, m) := (ff, succ m) + end + + def div2 (n : ℕ) : ℕ := (bodd_div2 n).2 + + def bodd (n : ℕ) : bool := (bodd_div2 n).1 + + @[simp] lemma bodd_zero : bodd 0 = ff := rfl + @[simp] lemma bodd_one : bodd 1 = tt := rfl + @[simp] lemma bodd_two : bodd 2 = ff := rfl + + @[simp] def bodd_succ (n : ℕ) : bodd (succ n) = bnot (bodd n) := + by unfold bodd bodd_div2; cases bodd_div2 n; cases fst; refl + + @[simp] def bodd_add (m n : ℕ) : bodd (m + n) = bxor (bodd m) (bodd n) := + begin + induction n with n IH, + { simp, cases bodd m; refl }, + { simp [IH], cases bodd m; cases bodd n; refl } + end + + @[simp] def bodd_mul (m n : ℕ) : bodd (m * n) = bodd m && bodd n := + begin + induction n with n IH, + { simp, cases bodd m; refl }, + { simp [mul_succ, IH], cases bodd m; cases bodd n; refl } + end + + lemma mod_two_of_bodd (n : ℕ) : n % 2 = cond (bodd n) 1 0 := + begin + have := congr_arg bodd (mod_add_div n 2), + simp [bnot] at this, + rw [show ∀ b, ff && b = ff, by intros; cases b; refl, + show ∀ b, bxor b ff = b, by intros; cases b; refl] at this, + rw -this, + cases mod_two_eq_zero_or_one n; rw a; refl + end + + @[simp] lemma div2_zero : div2 0 = 0 := rfl + @[simp] lemma div2_one : div2 1 = 0 := rfl + @[simp] lemma div2_two : div2 2 = 1 := rfl + + @[simp] lemma div2_succ (n : ℕ) : div2 (succ n) = cond (bodd n) (succ (div2 n)) (div2 n) := + by unfold bodd div2 bodd_div2; cases bodd_div2 n; cases fst; refl + + theorem bodd_add_div2 : ∀ n, cond (bodd n) 1 0 + 2 * div2 n = n + | 0 := rfl + | (succ n) := begin + simp, + refine eq.trans _ (congr_arg succ (bodd_add_div2 n)), + cases bodd n; simp [cond, bnot], + { rw add_comm; refl }, + { rw [succ_mul, add_comm 1] } + end + + theorem div2_val (n) : div2 n = n / 2 := + by refine eq_of_mul_eq_mul_left dec_trivial + (nat.add_left_cancel (eq.trans _ (mod_add_div n 2).symm)); + rw [mod_two_of_bodd, bodd_add_div2] + def shiftl : ℕ → ℕ → ℕ | m 0 := m | m (n+1) := 2 * shiftl m n def shiftr : ℕ → ℕ → ℕ | m 0 := m - | m (n+1) := shiftr m n / 2 - - def bodd (n : ℕ) : bool := n % 2 = 1 + | m (n+1) := div2 (shiftr m n) def test_bit (m n : ℕ) : bool := bodd (shiftr m n) @@ -32,71 +95,63 @@ namespace nat lemma bit_val (b n) : bit b n = 2 * n + cond b 1 0 := by { cases b, apply bit0_val, apply bit1_val } - lemma mod_two_of_bodd (n : nat) : n % 2 = cond (bodd n) 1 0 := - match by apply_instance : ∀ d, n % 2 = cond (@to_bool (n % 2 = 1) d) 1 0 with - | is_true h := h - | is_false h := (mod_two_eq_zero_or_one _).resolve_right h - end - - lemma bit_decomp (n : nat) : bit (bodd n) (shiftr n 1) = n := - (bit_val _ _).trans $ (add_comm _ _).trans $ - eq.trans (by rw mod_two_of_bodd; refl) (mod_add_div n 2) + lemma bit_decomp (n : nat) : bit (bodd n) (div2 n) = n := + (bit_val _ _).trans $ (add_comm _ _).trans $ bodd_add_div2 _ lemma bit_cases_on {C : nat → Sort u} (n) (h : ∀ b n, C (bit b n)) : C n := by rw -bit_decomp n; apply h + @[simp] lemma bit_zero : bit ff 0 = 0 := rfl + lemma bodd_bit (b n) : bodd (bit b n) = b := - begin - rw bit_val, dsimp [bodd], - rw [add_comm, add_mul_mod_self_left, mod_eq_of_lt]; - cases b; exact dec_trivial - end + by rw bit_val; simp; cases b; cases bodd n; refl - lemma shiftr1_bit (b n) : shiftr (bit b n) 1 = n := - begin - rw bit_val, dsimp [shiftr], - rw [add_comm, add_mul_div_left, div_eq_of_lt, zero_add]; + lemma div2_bit (b n) : div2 (bit b n) = n := + by rw [bit_val, div2_val, add_comm, add_mul_div_left, div_eq_of_lt, zero_add]; cases b; exact dec_trivial - end - def shiftl_add (m n) : ∀ k, shiftl m (n + k) = shiftl (shiftl m n) k + lemma shiftl_add (m n) : ∀ k, shiftl m (n + k) = shiftl (shiftl m n) k | 0 := rfl | (k+1) := congr_arg ((*) 2) (shiftl_add k) - def shiftr_add (m n) : ∀ k, shiftr m (n + k) = shiftr (shiftr m n) k + lemma shiftr_add (m n) : ∀ k, shiftr m (n + k) = shiftr (shiftr m n) k | 0 := rfl - | (k+1) := congr_arg (/ 2) (shiftr_add k) + | (k+1) := congr_arg div2 (shiftr_add k) - def shiftl_eq_mul_pow (m) : ∀ n, shiftl m n = m * 2 ^ n + lemma shiftl_eq_mul_pow (m) : ∀ n, shiftl m n = m * 2 ^ n | 0 := (mul_one _).symm | (k+1) := (congr_arg ((*) 2) (shiftl_eq_mul_pow k)).trans $ by simp [pow_succ] - def one_shiftl (n) : shiftl 1 n = 2 ^ n := + lemma one_shiftl (n) : shiftl 1 n = 2 ^ n := (shiftl_eq_mul_pow _ _).trans (one_mul _) - def zero_shiftl (n) : shiftl 0 n = 0 := + @[simp] lemma zero_shiftl (n) : shiftl 0 n = 0 := (shiftl_eq_mul_pow _ _).trans (zero_mul _) - def shiftr_eq_div_pow (m) : ∀ n, shiftr m n = m / 2 ^ n + lemma shiftr_eq_div_pow (m) : ∀ n, shiftr m n = m / 2 ^ n | 0 := (nat.div_one _).symm - | (k+1) := (congr_arg (/ 2) (shiftr_eq_div_pow k)).trans $ - by dsimp; rw [nat.div_div_eq_div_mul]; refl + | (k+1) := (congr_arg div2 (shiftr_eq_div_pow k)).trans $ + by dsimp; rw [div2_val, nat.div_div_eq_div_mul]; refl - def zero_shiftr (n) : shiftr 0 n = 0 := + @[simp] lemma zero_shiftr (n) : shiftr 0 n = 0 := (shiftr_eq_div_pow _ _).trans (nat.zero_div _) - def test_bit_zero (b n) : test_bit (bit b n) 0 = b := bodd_bit _ _ + lemma test_bit_zero (b n) : test_bit (bit b n) 0 = b := bodd_bit _ _ - def test_bit_succ (m b n) : test_bit (bit b n) (succ m) = test_bit n m := - have bodd (shiftr (shiftr (bit b n) 1) m) = bodd (shiftr n m), by rw shiftr1_bit, + lemma test_bit_succ (m b n) : test_bit (bit b n) (succ m) = test_bit n m := + have bodd (shiftr (shiftr (bit b n) 1) m) = bodd (shiftr n m), + by dsimp [shiftr]; rw div2_bit, by rw [-shiftr_add, add_comm] at this; exact this def binary_rec {C : nat → Sort u} (f : ∀ b n, C n → C (bit b n)) (z : C 0) : Π n, C n - | n := if n0 : n = 0 then by rw n0; exact z else let n' := shiftr n 1 in - have n' < n, from (div_lt_iff_lt_mul _ _ dec_trivial).2 $ - by have := nat.mul_lt_mul_of_pos_left (dec_trivial : 1 < 2) - (lt_of_le_of_ne (zero_le _) (ne.symm n0)); - rwa mul_one at this, + | n := if n0 : n = 0 then by rw n0; exact z else let n' := div2 n in + have n' < n, begin + change div2 n < n, rw div2_val, + apply (div_lt_iff_lt_mul _ _ (succ_pos 1)).2, + have := nat.mul_lt_mul_of_pos_left (lt_succ_self 1) + (lt_of_le_of_ne (zero_le _) (ne.symm n0)), + rwa mul_one at this + end, by rw [-show bit (bodd n) n' = n, from bit_decomp n]; exact f (bodd n) n' (binary_rec n') @@ -116,6 +171,10 @@ namespace nat def ldiff : ℕ → ℕ → ℕ := bitwise (λ a b, a && bnot b) def lxor : ℕ → ℕ → ℕ := bitwise bxor + lemma binary_rec_zero {C : nat → Sort u} (f : ∀ b n, C n → C (bit b n)) (z) : + binary_rec f z 0 = z := + by {rw [binary_rec.equations._eqn_1], refl} + lemma binary_rec_eq {C : nat → Sort u} {f : ∀ b n, C n → C (bit b n)} {z} (h : f ff 0 z = z) (b n) : binary_rec f z (bit b n) = f b n (binary_rec f z n) := @@ -123,16 +182,13 @@ namespace nat rw [binary_rec.equations._eqn_1], cases (by apply_instance : decidable (bit b n = 0)) with b0 b0; dsimp [dite], { generalize (binary_rec._main._pack._proof_2 (bit b n)) e, - rw [bodd_bit, shiftr1_bit], intro e, refl }, + rw [bodd_bit, div2_bit], intro e, refl }, { generalize (binary_rec._main._pack._proof_1 (bit b n) b0) e, - have bf := bodd_bit b n, have n0 := shiftr1_bit b n, - rw b0 at bf n0, rw [-show ff = b, from bf, -show 0 = n, from n0], intro e, - exact h.symm }, + have bf := bodd_bit b n, have n0 := div2_bit b n, + rw b0 at bf n0, simp at bf n0, rw [-bf, -n0, binary_rec_zero], + exact λe, h.symm }, end - lemma binary_rec_zero {C : nat → Sort u} (f : ∀ b n, C n → C (bit b n)) (z) : - binary_rec f z 0 = z := by {rw [binary_rec.equations._eqn_1], refl} - lemma bitwise_bit_aux {f : bool → bool → bool} (h : f ff ff = ff) : @binary_rec (λ_, ℕ) (λ b n _, bit (f ff b) (cond (f ff tt) n 0))