From ca941249b98a2da3959e0b2b7d9311c647795778 Mon Sep 17 00:00:00 2001 From: Scott Morrison Date: Tue, 20 Feb 2024 06:59:52 +1100 Subject: [PATCH] chore: upstream Std.BitVec.* (#3400) Co-authored-by: Leonardo de Moura --- src/Init/Data.lean | 1 + src/Init/Data/Array/Lemmas.lean | 2 +- src/Init/Data/BitVec.lean | 5 + src/Init/Data/BitVec/Basic.lean | 535 +++++++++++ src/Init/Data/BitVec/Bitblast.lean | 173 ++++ src/Init/Data/BitVec/Folds.lean | 59 ++ src/Init/Data/BitVec/Lemmas.lean | 534 +++++++++++ src/Init/Data/Fin.lean | 1 + src/Init/Data/Fin/Basic.lean | 2 +- src/Init/Data/Fin/Lemmas.lean | 830 ++++++++++++++++++ src/Init/Data/Int/Bitwise.lean | 2 +- src/Init/Data/Nat/Bitwise.lean | 55 +- src/Init/Data/Nat/Bitwise/Basic.lean | 63 ++ src/Init/Data/Nat/Bitwise/Lemmas.lean | 503 +++++++++++ .../inWordCompletion.lean.expected.out | 6 +- tests/lean/run/etaStructProofIrrelIssue.lean | 3 - tests/lean/run/ext1.lean | 1 - 17 files changed, 2714 insertions(+), 61 deletions(-) create mode 100644 src/Init/Data/BitVec.lean create mode 100644 src/Init/Data/BitVec/Basic.lean create mode 100644 src/Init/Data/BitVec/Bitblast.lean create mode 100644 src/Init/Data/BitVec/Folds.lean create mode 100644 src/Init/Data/BitVec/Lemmas.lean create mode 100644 src/Init/Data/Fin/Lemmas.lean create mode 100644 src/Init/Data/Nat/Bitwise/Basic.lean create mode 100644 src/Init/Data/Nat/Bitwise/Lemmas.lean diff --git a/src/Init/Data.lean b/src/Init/Data.lean index 610686effd..754822953f 100644 --- a/src/Init/Data.lean +++ b/src/Init/Data.lean @@ -7,6 +7,7 @@ prelude import Init.Data.Basic import Init.Data.Nat import Init.Data.Bool +import Init.Data.BitVec import Init.Data.Cast import Init.Data.Char import Init.Data.String diff --git a/src/Init/Data/Array/Lemmas.lean b/src/Init/Data/Array/Lemmas.lean index 75423b6c0d..a16c78280c 100644 --- a/src/Init/Data/Array/Lemmas.lean +++ b/src/Init/Data/Array/Lemmas.lean @@ -4,7 +4,7 @@ Released under Apache 2.0 license as described in the file LICENSE. Authors: Mario Carneiro -/ prelude -import Init.Data.Nat +import Init.Data.Nat.MinMax import Init.Data.List.Lemmas import Init.Data.Fin.Basic import Init.Data.Array.Mem diff --git a/src/Init/Data/BitVec.lean b/src/Init/Data/BitVec.lean new file mode 100644 index 0000000000..76049b62d3 --- /dev/null +++ b/src/Init/Data/BitVec.lean @@ -0,0 +1,5 @@ +prelude +import Init.Data.BitVec.Basic +import Init.Data.BitVec.Bitblast +import Init.Data.BitVec.Folds +import Init.Data.BitVec.Lemmas diff --git a/src/Init/Data/BitVec/Basic.lean b/src/Init/Data/BitVec/Basic.lean new file mode 100644 index 0000000000..74adceacbb --- /dev/null +++ b/src/Init/Data/BitVec/Basic.lean @@ -0,0 +1,535 @@ +/- +Copyright (c) 2022 by the authors listed in the file AUTHORS and their +institutional affiliations. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joe Hendrix, Wojciech Nawrocki, Leonardo de Moura, Mario Carneiro, Alex Keizer +-/ +prelude +import Init.Data.Fin.Basic +import Init.Data.Nat.Bitwise.Lemmas +import Init.Data.Nat.Power2 + +namespace Std + +/-! +We define bitvectors. We choose the `Fin` representation over others for its relative efficiency +(Lean has special support for `Nat`), alignment with `UIntXY` types which are also represented +with `Fin`, and the fact that bitwise operations on `Fin` are already defined. Some other possible +representations are `List Bool`, `{ l : List Bool // l.length = w }`, `Fin w → Bool`. + +We define many of the bitvector operations from the +[`QF_BV` logic](https://smtlib.cs.uiowa.edu/logics-all.shtml#QF_BV). +of SMT-LIBv2. +-/ + +/-- +A bitvector of the specified width. This is represented as the underlying `Nat` number +in both the runtime and the kernel, inheriting all the special support for `Nat`. +-/ +structure BitVec (w : Nat) where + /-- Construct a `BitVec w` from a number less than `2^w`. + O(1), because we use `Fin` as the internal representation of a bitvector. -/ + ofFin :: + /-- Interpret a bitvector as a number less than `2^w`. + O(1), because we use `Fin` as the internal representation of a bitvector. -/ + toFin : Fin (2^w) + deriving DecidableEq + +namespace BitVec + +/-- `cast eq i` embeds `i` into an equal `BitVec` type. -/ +@[inline] def cast (eq : n = m) (i : BitVec n) : BitVec m := + .ofFin (Fin.cast (congrArg _ eq) i.toFin) + +/-- The `BitVec` with value `i mod 2^n`. Treated as an operation on bitvectors, +this is truncation of the high bits when downcasting and zero-extension when upcasting. -/ +protected def ofNat (n : Nat) (i : Nat) : BitVec n where + toFin := Fin.ofNat' i (Nat.two_pow_pos n) + +instance : NatCast (BitVec w) := ⟨BitVec.ofNat w⟩ + +/-- Given a bitvector `a`, return the underlying `Nat`. This is O(1) because `BitVec` is a +(zero-cost) wrapper around a `Nat`. -/ +protected def toNat (a : BitVec n) : Nat := a.toFin.val + +/-- Return the bound in terms of toNat. -/ +theorem isLt (x : BitVec w) : x.toNat < 2^w := x.toFin.isLt + +/-- Return the `i`-th least significant bit or `false` if `i ≥ w`. -/ +@[inline] def getLsb (x : BitVec w) (i : Nat) : Bool := x.toNat.testBit i + +/-- Return the `i`-th most significant bit or `false` if `i ≥ w`. -/ +@[inline] def getMsb (x : BitVec w) (i : Nat) : Bool := i < w && getLsb x (w-1-i) + +/-- Return most-significant bit in bitvector. -/ +@[inline] protected def msb (a : BitVec n) : Bool := getMsb a 0 + +/-- Interpret the bitvector as an integer stored in two's complement form. -/ +protected def toInt (a : BitVec n) : Int := + if a.msb then Int.ofNat a.toNat - Int.ofNat (2^n) else a.toNat + +/-- Return a bitvector `0` of size `n`. This is the bitvector with all zero bits. -/ +protected def zero (n : Nat) : BitVec n := ⟨0, Nat.two_pow_pos n⟩ + +instance : Inhabited (BitVec n) where default := .zero n + +instance instOfNat : OfNat (BitVec n) i where ofNat := .ofNat n i + +/-- Notation for bit vector literals. `i#n` is a shorthand for `BitVec.ofNat n i`. -/ +scoped syntax:max term:max noWs "#" noWs term:max : term +macro_rules | `($i#$n) => `(BitVec.ofNat $n $i) + +/- Support for `i#n` notation in patterns. -/ +attribute [match_pattern] BitVec.ofNat + +/-- Unexpander for bit vector literals. -/ +@[app_unexpander BitVec.ofNat] def unexpandBitVecOfNat : Lean.PrettyPrinter.Unexpander + | `($(_) $n $i) => `($i#$n) + | _ => throw () + +/-- Convert bitvector into a fixed-width hex number. -/ +protected def toHex {n : Nat} (x : BitVec n) : String := + let s := (Nat.toDigits 16 x.toNat).asString + let t := (List.replicate ((n+3) / 4 - s.length) '0').asString + t ++ s + +instance : Repr (BitVec n) where reprPrec a _ := "0x" ++ (a.toHex : Format) ++ "#" ++ repr n + +instance : ToString (BitVec n) where toString a := toString (repr a) + +/-- Theorem for normalizing the bit vector literal representation. -/ +-- TODO: This needs more usage data to assess which direction the simp should go. +@[simp] theorem ofNat_eq_ofNat : @OfNat.ofNat (BitVec n) i _ = BitVec.ofNat n i := rfl +@[simp] theorem natCast_eq_ofNat : Nat.cast x = x#w := rfl + +/-- +Addition for bit vectors. This can be interpreted as either signed or unsigned addition +modulo `2^n`. + +SMT-Lib name: `bvadd`. +-/ +protected def add (x y : BitVec n) : BitVec n where toFin := x.toFin + y.toFin +instance : Add (BitVec n) := ⟨BitVec.add⟩ + +/-- +Subtraction for bit vectors. This can be interpreted as either signed or unsigned subtraction +modulo `2^n`. +-/ +protected def sub (x y : BitVec n) : BitVec n where toFin := x.toFin - y.toFin +instance : Sub (BitVec n) := ⟨BitVec.sub⟩ + +/-- +Negation for bit vectors. This can be interpreted as either signed or unsigned negation +modulo `2^n`. + +SMT-Lib name: `bvneg`. +-/ +protected def neg (x : BitVec n) : BitVec n := .sub 0 x +instance : Neg (BitVec n) := ⟨.neg⟩ + +/-- Bit vector of size `n` where all bits are `1`s -/ +def allOnes (n : Nat) : BitVec n := -1 + +/-- +Return the absolute value of a signed bitvector. +-/ +protected def abs (s : BitVec n) : BitVec n := if s.msb then .neg s else s + +/-- +Multiplication for bit vectors. This can be interpreted as either signed or unsigned negation +modulo `2^n`. + +SMT-Lib name: `bvmul`. +-/ +protected def mul (x y : BitVec n) : BitVec n := ofFin <| x.toFin * y.toFin +instance : Mul (BitVec n) := ⟨.mul⟩ + +/-- +Unsigned division for bit vectors using the Lean convention where division by zero returns zero. +-/ +def udiv (x y : BitVec n) : BitVec n := ofFin <| x.toFin / y.toFin +instance : Div (BitVec n) := ⟨.udiv⟩ + +/-- +Unsigned modulo for bit vectors. + +SMT-Lib name: `bvurem`. +-/ +def umod (x y : BitVec n) : BitVec n := ofFin <| x.toFin % y.toFin +instance : Mod (BitVec n) := ⟨.umod⟩ + +/-- +Unsigned division for bit vectors using the +[SMT-Lib convention](http://smtlib.cs.uiowa.edu/theories-FixedSizeBitVectors.shtml) +where division by zero returns the `allOnes` bitvector. + +SMT-Lib name: `bvudiv`. +-/ +def smtUDiv (x y : BitVec n) : BitVec n := if y = 0 then -1 else .udiv x y + +/-- +Signed t-division for bit vectors using the Lean convention where division +by zero returns zero. + +```lean +sdiv 7#4 2 = 3#4 +sdiv (-9#4) 2 = -4#4 +sdiv 5#4 -2 = -2#4 +sdiv (-7#4) (-2) = 3#4 +``` +-/ +def sdiv (s t : BitVec n) : BitVec n := + match s.msb, t.msb with + | false, false => udiv s t + | false, true => .neg (udiv s (.neg t)) + | true, false => .neg (udiv (.neg s) t) + | true, true => udiv (.neg s) (.neg t) + +/-- +Signed division for bit vectors using SMTLIB rules for division by zero. + +Specifically, `smtSDiv x 0 = if x >= 0 then -1 else 1` + +SMT-Lib name: `bvsdiv`. +-/ +def smtSDiv (s t : BitVec n) : BitVec n := + match s.msb, t.msb with + | false, false => smtUDiv s t + | false, true => .neg (smtUDiv s (.neg t)) + | true, false => .neg (smtUDiv (.neg s) t) + | true, true => smtUDiv (.neg s) (.neg t) + +/-- +Remainder for signed division rounding to zero. + +SMT_Lib name: `bvsrem`. +-/ +def srem (s t : BitVec n) : BitVec n := + match s.msb, t.msb with + | false, false => umod s t + | false, true => umod s (.neg t) + | true, false => .neg (umod (.neg s) t) + | true, true => .neg (umod (.neg s) (.neg t)) + +/-- +Remainder for signed division rounded to negative infinity. + +SMT_Lib name: `bvsmod`. +-/ +def smod (s t : BitVec m) : BitVec m := + match s.msb, t.msb with + | false, false => .umod s t + | false, true => + let u := .umod s (.neg t) + (if u = BitVec.ofNat m 0 then u else .add u t) + | true, false => + let u := .umod (.neg s) t + (if u = BitVec.ofNat m 0 then u else .sub t u) + | true, true => .neg (.umod (.neg s) (.neg t)) + +/-- +Unsigned less-than for bit vectors. + +SMT-Lib name: `bvult`. +-/ +protected def ult (x y : BitVec n) : Bool := x.toFin < y.toFin +instance : LT (BitVec n) where lt x y := x.toFin < y.toFin +instance (x y : BitVec n) : Decidable (x < y) := + inferInstanceAs (Decidable (x.toFin < y.toFin)) + +/-- +Unsigned less-than-or-equal-to for bit vectors. + +SMT-Lib name: `bvule`. +-/ +protected def ule (x y : BitVec n) : Bool := x.toFin ≤ y.toFin + +instance : LE (BitVec n) where le x y := x.toFin ≤ y.toFin +instance (x y : BitVec n) : Decidable (x ≤ y) := + inferInstanceAs (Decidable (x.toFin ≤ y.toFin)) + +/-- +Signed less-than for bit vectors. + +```lean +BitVec.slt 6#4 7 = true +BitVec.slt 7#4 8 = false +``` +SMT-Lib name: `bvslt`. +-/ +protected def slt (x y : BitVec n) : Bool := x.toInt < y.toInt + +/-- +Signed less-than-or-equal-to for bit vectors. + +SMT-Lib name: `bvsle`. +-/ +protected def sle (x y : BitVec n) : Bool := x.toInt ≤ y.toInt + +/-- +Bitwise AND for bit vectors. + +```lean +0b1010#4 &&& 0b0110#4 = 0b0010#4 +``` + +SMT-Lib name: `bvand`. +-/ +protected def and (x y : BitVec n) : BitVec n where toFin := + ⟨x.toNat &&& y.toNat, Nat.and_lt_two_pow x.toNat y.isLt⟩ +instance : AndOp (BitVec w) := ⟨.and⟩ + +/-- +Bitwise OR for bit vectors. + +```lean +0b1010#4 ||| 0b0110#4 = 0b1110#4 +``` + +SMT-Lib name: `bvor`. +-/ +protected def or (x y : BitVec n) : BitVec n where toFin := + ⟨x.toNat ||| y.toNat, Nat.or_lt_two_pow x.isLt y.isLt⟩ +instance : OrOp (BitVec w) := ⟨.or⟩ + +/-- + Bitwise XOR for bit vectors. + +```lean +0b1010#4 ^^^ 0b0110#4 = 0b1100#4 +``` + +SMT-Lib name: `bvxor`. +-/ +protected def xor (x y : BitVec n) : BitVec n where toFin := + ⟨x.toNat ^^^ y.toNat, Nat.xor_lt_two_pow x.isLt y.isLt⟩ +instance : Xor (BitVec w) := ⟨.xor⟩ + +/-- +Bitwise NOT for bit vectors. + +```lean +~~~(0b0101#4) == 0b1010 +``` +SMT-Lib name: `bvnot`. +-/ +protected def not (x : BitVec n) : BitVec n := + allOnes n ^^^ x +instance : Complement (BitVec w) := ⟨.not⟩ + +/-- The `BitVec` with value `(2^n + (i mod 2^n)) mod 2^n`. -/ +protected def ofInt (n : Nat) (i : Int) : BitVec n := + match i with + | Int.ofNat a => .ofNat n a + | Int.negSucc a => ~~~.ofNat n a + +instance : IntCast (BitVec w) := ⟨BitVec.ofInt w⟩ + +/-- +Left shift for bit vectors. The low bits are filled with zeros. As a numeric operation, this is +equivalent to `a * 2^s`, modulo `2^n`. + +SMT-Lib name: `bvshl` except this operator uses a `Nat` shift value. +-/ +protected def shiftLeft (a : BitVec n) (s : Nat) : BitVec n := .ofNat n (a.toNat <<< s) +instance : HShiftLeft (BitVec w) Nat (BitVec w) := ⟨.shiftLeft⟩ + +/-- +(Logical) right shift for bit vectors. The high bits are filled with zeros. +As a numeric operation, this is equivalent to `a / 2^s`, rounding down. + +SMT-Lib name: `bvlshr` except this operator uses a `Nat` shift value. +-/ +def ushiftRight (a : BitVec n) (s : Nat) : BitVec n := + ⟨a.toNat >>> s, by + let ⟨a, lt⟩ := a + simp only [BitVec.toNat, Nat.shiftRight_eq_div_pow, Nat.div_lt_iff_lt_mul (Nat.two_pow_pos s)] + rw [←Nat.mul_one a] + exact Nat.mul_lt_mul_of_lt_of_le' lt (Nat.two_pow_pos s) (Nat.le_refl 1)⟩ + +instance : HShiftRight (BitVec w) Nat (BitVec w) := ⟨.ushiftRight⟩ + +/-- +Arithmetic right shift for bit vectors. The high bits are filled with the +most-significant bit. +As a numeric operation, this is equivalent to `a.toInt >>> s`. + +SMT-Lib name: `bvashr` except this operator uses a `Nat` shift value. +-/ +def sshiftRight (a : BitVec n) (s : Nat) : BitVec n := .ofInt n (a.toInt >>> s) + +instance {n} : HShiftLeft (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x <<< y.toNat⟩ +instance {n} : HShiftRight (BitVec m) (BitVec n) (BitVec m) := ⟨fun x y => x >>> y.toNat⟩ + +/-- +Rotate left for bit vectors. All the bits of `x` are shifted to higher positions, with the top `n` +bits wrapping around to fill the low bits. + +```lean +rotateLeft 0b0011#4 3 = 0b1001 +``` +SMT-Lib name: `rotate_left` except this operator uses a `Nat` shift amount. +-/ +def rotateLeft (x : BitVec w) (n : Nat) : BitVec w := x <<< n ||| x >>> (w - n) + +/-- +Rotate right for bit vectors. All the bits of `x` are shifted to lower positions, with the +bottom `n` bits wrapping around to fill the high bits. + +```lean +rotateRight 0b01001#5 1 = 0b10100 +``` +SMT-Lib name: `rotate_right` except this operator uses a `Nat` shift amount. +-/ +def rotateRight (x : BitVec w) (n : Nat) : BitVec w := x >>> n ||| x <<< (w - n) + +/-- +A version of `zeroExtend` that requires a proof, but is a noop. +-/ +def zeroExtend' {n w : Nat} (le : n ≤ w) (x : BitVec n) : BitVec w := + ⟨x.toNat, by + apply Nat.lt_of_lt_of_le x.isLt + exact Nat.pow_le_pow_of_le_right (by trivial) le⟩ + +/-- +`shiftLeftZeroExtend x n` returns `zeroExtend (w+n) x <<< n` without +needing to compute `x % 2^(2+n)`. +-/ +def shiftLeftZeroExtend (msbs : BitVec w) (m : Nat) : BitVec (w+m) := + let shiftLeftLt {x : Nat} (p : x < 2^w) (m : Nat) : x <<< m < 2^(w+m) := by + simp [Nat.shiftLeft_eq, Nat.pow_add] + apply Nat.mul_lt_mul_of_pos_right p + exact (Nat.two_pow_pos m) + ⟨msbs.toNat <<< m, shiftLeftLt msbs.isLt m⟩ + +/-- +Concatenation of bitvectors. This uses the "big endian" convention that the more significant +input is on the left, so `0xAB#8 ++ 0xCD#8 = 0xABCD#16`. + +SMT-Lib name: `concat`. +-/ +def append (msbs : BitVec n) (lsbs : BitVec m) : BitVec (n+m) := + shiftLeftZeroExtend msbs m ||| zeroExtend' (Nat.le_add_left m n) lsbs + +instance : HAppend (BitVec w) (BitVec v) (BitVec (w + v)) := ⟨.append⟩ + +/-- +Extraction of bits `start` to `start + len - 1` from a bit vector of size `n` to yield a +new bitvector of size `len`. If `start + len > n`, then the vector will be zero-padded in the +high bits. +-/ +def extractLsb' (start len : Nat) (a : BitVec n) : BitVec len := .ofNat _ (a.toNat >>> start) + +/-- +Extraction of bits `hi` (inclusive) down to `lo` (inclusive) from a bit vector of size `n` to +yield a new bitvector of size `hi - lo + 1`. + +SMT-Lib name: `extract`. +-/ +def extractLsb (hi lo : Nat) (a : BitVec n) : BitVec (hi - lo + 1) := extractLsb' lo _ a + +-- TODO: write this using multiplication +/-- `replicate i x` concatenates `i` copies of `x` into a new vector of length `w*i`. -/ +def replicate : (i : Nat) → BitVec w → BitVec (w*i) + | 0, _ => 0 + | n+1, x => + have hEq : w + w*n = w*(n + 1) := by + rw [Nat.mul_add, Nat.add_comm, Nat.mul_one] + hEq ▸ (x ++ replicate n x) + +/-- Fills a bitvector with `w` copies of the bit `b`. -/ +def fill (w : Nat) (b : Bool) : BitVec w := bif b then -1 else 0 + +/-- +Zero extend vector `x` of length `w` by adding zeros in the high bits until it has length `v`. +If `v < w` then it truncates the high bits instead. + +SMT-Lib name: `zero_extend`. +-/ +def zeroExtend (v : Nat) (x : BitVec w) : BitVec v := + if h : w ≤ v then + zeroExtend' h x + else + .ofNat v x.toNat + +/-- +Truncate the high bits of bitvector `x` of length `w`, resulting in a vector of length `v`. +If `v > w` then it zero-extends the vector instead. +-/ +abbrev truncate := @zeroExtend + +/-- +Sign extend a vector of length `w`, extending with `i` additional copies of the most significant +bit in `x`. If `x` is an empty vector, then the sign is treated as zero. + +SMT-Lib name: `sign_extend`. +-/ +def signExtend (v : Nat) (x : BitVec w) : BitVec v := .ofInt v x.toInt + +/-! We add simp-lemmas that rewrite bitvector operations into the equivalent notation -/ +@[simp] theorem append_eq (x : BitVec w) (y : BitVec v) : BitVec.append x y = x ++ y := rfl +@[simp] theorem shiftLeft_eq (x : BitVec w) (n : Nat) : BitVec.shiftLeft x n = x <<< n := rfl +@[simp] theorem ushiftRight_eq (x : BitVec w) (n : Nat) : BitVec.ushiftRight x n = x >>> n := rfl +@[simp] theorem not_eq (x : BitVec w) : BitVec.not x = ~~~x := rfl +@[simp] theorem and_eq (x y : BitVec w) : BitVec.and x y = x &&& y := rfl +@[simp] theorem or_eq (x y : BitVec w) : BitVec.or x y = x ||| y := rfl +@[simp] theorem xor_eq (x y : BitVec w) : BitVec.xor x y = x ^^^ y := rfl +@[simp] theorem neg_eq (x : BitVec w) : BitVec.neg x = -x := rfl +@[simp] theorem add_eq (x y : BitVec w) : BitVec.add x y = x + y := rfl +@[simp] theorem sub_eq (x y : BitVec w) : BitVec.sub x y = x - y := rfl +@[simp] theorem mul_eq (x y : BitVec w) : BitVec.mul x y = x * y := rfl +@[simp] theorem zero_eq : BitVec.zero n = 0#n := rfl + +@[simp] theorem cast_ofNat {n m : Nat} (h : n = m) (x : Nat) : + cast h (BitVec.ofNat n x) = BitVec.ofNat m x := by + subst h; rfl + +@[simp] theorem cast_cast {n m k : Nat} (h₁ : n = m) (h₂ : m = k) (x : BitVec n) : + cast h₂ (cast h₁ x) = cast (h₁ ▸ h₂) x := + rfl + +@[simp] theorem cast_eq {n : Nat} (h : n = n) (x : BitVec n) : + cast h x = x := + rfl + +/-- Turn a `Bool` into a bitvector of length `1` -/ +def ofBool (b : Bool) : BitVec 1 := cond b 1 0 + +@[simp] theorem ofBool_false : ofBool false = 0 := by trivial +@[simp] theorem ofBool_true : ofBool true = 1 := by trivial + +/-- The empty bitvector -/ +abbrev nil : BitVec 0 := 0 + +/-! +### Cons and Concat +We give special names to the operations of adding a single bit to either end of a bitvector. +We follow the precedent of `Vector.cons`/`Vector.concat` both for the name, and for the decision +to have the resulting size be `n + 1` for both operations (rather than `1 + n`, which would be the +result of appending a single bit to the front in the naive implementation). +-/ + +/-- Append a single bit to the end of a bitvector, using big endian order (see `append`). + That is, the new bit is the least significant bit. -/ +def concat {n} (msbs : BitVec n) (lsb : Bool) : BitVec (n+1) := msbs ++ (ofBool lsb) + +/-- Prepend a single bit to the front of a bitvector, using big endian order (see `append`). + That is, the new bit is the most significant bit. -/ +def cons {n} (msb : Bool) (lsbs : BitVec n) : BitVec (n+1) := + ((ofBool msb) ++ lsbs).cast (Nat.add_comm ..) + +/-- All empty bitvectors are equal -/ +instance : Subsingleton (BitVec 0) where + allEq := by intro ⟨0, _⟩ ⟨0, _⟩; rfl + +/-- Every bitvector of length 0 is equal to `nil`, i.e., there is only one empty bitvector -/ +theorem eq_nil : ∀ (x : BitVec 0), x = nil + | ofFin ⟨0, _⟩ => rfl + +theorem append_ofBool (msbs : BitVec w) (lsb : Bool) : + msbs ++ ofBool lsb = concat msbs lsb := + rfl + +theorem ofBool_append (msb : Bool) (lsbs : BitVec w) : + ofBool msb ++ lsbs = (cons msb lsbs).cast (Nat.add_comm ..) := + rfl diff --git a/src/Init/Data/BitVec/Bitblast.lean b/src/Init/Data/BitVec/Bitblast.lean new file mode 100644 index 0000000000..85583b5123 --- /dev/null +++ b/src/Init/Data/BitVec/Bitblast.lean @@ -0,0 +1,173 @@ +/- +Copyright (c) 2023 by the authors listed in the file AUTHORS and their +institutional affiliations. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Harun Khan, Abdalrhman M Mohamed, Joe Hendrix +-/ +prelude +import Init.Data.BitVec.Folds + +/-! +# Bitblasting of bitvectors + +This module provides theorems for showing the equivalence between BitVec operations using +the `Fin 2^n` representation and Boolean vectors. It is still under development, but +intended to provide a path for converting SAT and SMT solver proofs about BitVectors +as vectors of bits into proofs about Lean `BitVec` values. + +The module is named for the bit-blasting operation in an SMT solver that converts bitvector +expressions into expressions about individual bits in each vector. + +## Main results +* `x + y : BitVec w` is `(adc x y false).2`. + + +## Future work +All other operations are to be PR'ed later and are already proved in +https://github.com/mhk119/lean-smt/blob/bitvec/Smt/Data/Bitwise.lean. + +-/ + +open Nat Bool + +/-! ### Preliminaries -/ + +namespace Std.BitVec + +private theorem testBit_limit {x i : Nat} (x_lt_succ : x < 2^(i+1)) : + testBit x i = decide (x ≥ 2^i) := by + cases xi : testBit x i with + | true => + simp [testBit_implies_ge xi] + | false => + simp + cases Nat.lt_or_ge x (2^i) with + | inl x_lt => + exact x_lt + | inr x_ge => + have ⟨j, ⟨j_ge, jp⟩⟩ := ge_two_pow_implies_high_bit_true x_ge + cases Nat.lt_or_eq_of_le j_ge with + | inr x_eq => + simp [x_eq, jp] at xi + | inl x_lt => + exfalso + apply Nat.lt_irrefl + calc x < 2^(i+1) := x_lt_succ + _ ≤ 2 ^ j := Nat.pow_le_pow_of_le_right Nat.zero_lt_two x_lt + _ ≤ x := testBit_implies_ge jp + +private theorem mod_two_pow_succ (x i : Nat) : + x % 2^(i+1) = 2^i*(x.testBit i).toNat + x % (2 ^ i):= by + apply Nat.eq_of_testBit_eq + intro j + simp only [Nat.mul_add_lt_is_or, testBit_or, testBit_mod_two_pow, testBit_shiftLeft, + Nat.testBit_bool_to_nat, Nat.sub_eq_zero_iff_le, Nat.mod_lt, Nat.two_pow_pos, + testBit_mul_pow_two] + rcases Nat.lt_trichotomy i j with i_lt_j | i_eq_j | j_lt_i + · have i_le_j : i ≤ j := Nat.le_of_lt i_lt_j + have not_j_le_i : ¬(j ≤ i) := Nat.not_le_of_lt i_lt_j + have not_j_lt_i : ¬(j < i) := Nat.not_lt_of_le i_le_j + have not_j_lt_i_succ : ¬(j < i + 1) := + Nat.not_le_of_lt (Nat.succ_lt_succ i_lt_j) + simp [i_le_j, not_j_le_i, not_j_lt_i, not_j_lt_i_succ] + · simp [i_eq_j] + · have j_le_i : j ≤ i := Nat.le_of_lt j_lt_i + have j_le_i_succ : j < i + 1 := Nat.succ_le_succ j_le_i + have not_j_ge_i : ¬(j ≥ i) := Nat.not_le_of_lt j_lt_i + simp [j_lt_i, j_le_i, not_j_ge_i, j_le_i_succ] + +private theorem mod_two_pow_lt (x i : Nat) : x % 2 ^ i < 2^i := Nat.mod_lt _ (Nat.two_pow_pos _) + +/-! ### Addition -/ + +/-- carry w x y c returns true if the `w` carry bit is true when computing `x + y + c`. -/ +def carry (w x y : Nat) (c : Bool) : Bool := decide (x % 2^w + y % 2^w + c.toNat ≥ 2^w) + +@[simp] theorem carry_zero : carry 0 x y c = c := by + cases c <;> simp [carry, mod_one] + +/-- At least two out of three booleans are true. -/ +abbrev atLeastTwo (a b c : Bool) : Bool := a && b || a && c || b && c + +/-- Carry function for bitwise addition. -/ +def adcb (x y c : Bool) : Bool × Bool := (atLeastTwo x y c, Bool.xor x (Bool.xor y c)) + +/-- Bitwise addition implemented via a ripple carry adder. -/ +def adc (x y : BitVec w) : Bool → Bool × BitVec w := + iunfoldr fun (i : Fin w) c => adcb (x.getLsb i) (y.getLsb i) c + +theorem adc_overflow_limit (x y i : Nat) (c : Bool) : x % 2^i + (y % 2^i + c.toNat) < 2^(i+1) := by + have : c.toNat ≤ 1 := Bool.toNat_le_one c + rw [Nat.pow_succ] + omega + +theorem carry_succ (w x y : Nat) (c : Bool) : + carry (succ w) x y c = atLeastTwo (x.testBit w) (y.testBit w) (carry w x y c) := by + simp only [carry, mod_two_pow_succ, atLeastTwo] + simp only [Nat.pow_succ'] + generalize testBit x w = xh + generalize testBit y w = yh + have sum_bnd : x%2^w + (y%2^w + c.toNat) < 2*2^w := by + simp only [← Nat.pow_succ'] + exact adc_overflow_limit x y w c + cases xh <;> cases yh <;> (simp; omega) + +theorem getLsb_add_add_bool {i : Nat} (i_lt : i < w) (x y : BitVec w) (c : Bool) : + getLsb (x + y + zeroExtend w (ofBool c)) i = + Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat c)) := by + let ⟨x, x_lt⟩ := x + let ⟨y, y_lt⟩ := y + simp only [getLsb, toNat_add, toNat_zeroExtend, i_lt, toNat_ofFin, toNat_ofBool, + Nat.mod_add_mod, Nat.add_mod_mod] + apply Eq.trans + rw [← Nat.div_add_mod x (2^i), ← Nat.div_add_mod y (2^i)] + simp only + [ Nat.testBit_mod_two_pow, + Nat.testBit_mul_two_pow_add_eq, + i_lt, + decide_True, + Bool.true_and, + Nat.add_assoc, + Nat.add_left_comm (_%_) (_ * _) _, + testBit_limit (adc_overflow_limit x y i c) + ] + simp [testBit_to_div_mod, carry, Nat.add_assoc] + +theorem getLsb_add {i : Nat} (i_lt : i < w) (x y : BitVec w) : + getLsb (x + y) i = + Bool.xor (getLsb x i) (Bool.xor (getLsb y i) (carry i x.toNat y.toNat false)) := by + simpa using getLsb_add_add_bool i_lt x y false + +theorem adc_spec (x y : BitVec w) (c : Bool) : + adc x y c = (carry w x.toNat y.toNat c, x + y + zeroExtend w (ofBool c)) := by + simp only [adc] + apply iunfoldr_replace + (fun i => carry i x.toNat y.toNat c) + (x + y + zeroExtend w (ofBool c)) + c + case init => + simp [carry, Nat.mod_one] + cases c <;> rfl + case step => + intro ⟨i, lt⟩ + simp only [adcb, Prod.mk.injEq, carry_succ] + apply And.intro + case left => + rw [testBit_toNat, testBit_toNat] + case right => + simp [getLsb_add_add_bool lt] + +theorem add_eq_adc (w : Nat) (x y : BitVec w) : x + y = (adc x y false).snd := by + simp [adc_spec] + +/-! ### add -/ + +/-- Adding a bitvector to its own complement yields the all ones bitpattern -/ +@[simp] theorem add_not_self (x : BitVec w) : x + ~~~x = allOnes w := by + rw [add_eq_adc, adc, iunfoldr_replace (fun _ => false) (allOnes w)] + · rfl + · simp [adcb, atLeastTwo] + +/-- Subtracting `x` from the all ones bitvector is equivalent to taking its complement -/ +theorem allOnes_sub_eq_not (x : BitVec w) : allOnes w - x = ~~~x := by + rw [← add_not_self x, BitVec.add_comm, add_sub_cancel] diff --git a/src/Init/Data/BitVec/Folds.lean b/src/Init/Data/BitVec/Folds.lean new file mode 100644 index 0000000000..3dcf8a4b7e --- /dev/null +++ b/src/Init/Data/BitVec/Folds.lean @@ -0,0 +1,59 @@ +/- +Copyright (c) 2023 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joe Hendrix +-/ +prelude +import Init.Data.BitVec.Lemmas +import Init.Data.Nat.Lemmas +import Init.Data.Fin.Iterate + +namespace Std.BitVec + +/-- +iunfoldr is an iterative operation that applies a function `f` repeatedly. + +It produces a sequence of state values `[s_0, s_1 .. s_w]` and a bitvector +`v` where `f i s_i = (s_{i+1}, b_i)` and `b_i` is bit `i`th least-significant bit +in `v` (e.g., `getLsb v i = b_i`). + +Theorems involving `iunfoldr` can be eliminated using `iunfoldr_replace` below. +-/ +def iunfoldr (f : Fin w -> α → α × Bool) (s : α) : α × BitVec w := + Fin.hIterate (fun i => α × BitVec i) (s, nil) fun i q => + (fun p => ⟨p.fst, cons p.snd q.snd⟩) (f i q.fst) + +theorem iunfoldr.fst_eq + {f : Fin w → α → α × Bool} (state : Nat → α) (s : α) + (init : s = state 0) + (ind : ∀(i : Fin w), (f i (state i.val)).fst = state (i.val+1)) : + (iunfoldr f s).fst = state w := by + unfold iunfoldr + apply Fin.hIterate_elim (fun i (p : α × BitVec i) => p.fst = state i) + case init => + exact init + case step => + intro i ⟨s, v⟩ p + simp_all [ind i] + +private theorem iunfoldr.eq_test + {f : Fin w → α → α × Bool} (state : Nat → α) (value : BitVec w) (a : α) + (init : state 0 = a) + (step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value.getLsb i.val)) : + iunfoldr f a = (state w, BitVec.truncate w value) := by + apply Fin.hIterate_eq (fun i => ((state i, BitVec.truncate i value) : α × BitVec i)) + case init => + simp only [init, eq_nil] + case step => + intro i + simp_all [truncate_succ] + +/-- +Correctness theorem for `iunfoldr`. +-/ +theorem iunfoldr_replace + {f : Fin w → α → α × Bool} (state : Nat → α) (value : BitVec w) (a : α) + (init : state 0 = a) + (step : ∀(i : Fin w), f i (state i.val) = (state (i.val+1), value.getLsb i.val)) : + iunfoldr f a = (state w, value) := by + simp [iunfoldr.eq_test state value a init step] diff --git a/src/Init/Data/BitVec/Lemmas.lean b/src/Init/Data/BitVec/Lemmas.lean new file mode 100644 index 0000000000..44024d42f7 --- /dev/null +++ b/src/Init/Data/BitVec/Lemmas.lean @@ -0,0 +1,534 @@ +/- +Copyright (c) 2023 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joe Hendrix +-/ +prelude +import Init.Data.Bool +import Init.Data.BitVec.Basic +import Init.Data.Fin.Lemmas +import Init.Data.Nat.Lemmas + +namespace Std.BitVec + +/-- +This normalized a bitvec using `ofFin` to `ofNat`. +-/ +theorem ofFin_eq_ofNat : @BitVec.ofFin w (Fin.mk x lt) = BitVec.ofNat w x := by + simp only [BitVec.ofNat, Fin.ofNat', lt, Nat.mod_eq_of_lt] + +/-- Prove equality of bitvectors in terms of nat operations. -/ +theorem eq_of_toNat_eq {n} : ∀ {i j : BitVec n}, i.toNat = j.toNat → i = j + | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl + +@[simp] theorem val_toFin (x : BitVec w) : x.toFin.val = x.toNat := rfl + +theorem toNat_eq (x y : BitVec n) : x = y ↔ x.toNat = y.toNat := + Iff.intro (congrArg BitVec.toNat) eq_of_toNat_eq + +theorem toNat_lt (x : BitVec n) : x.toNat < 2^n := x.toFin.2 + +theorem testBit_toNat (x : BitVec w) : x.toNat.testBit i = x.getLsb i := rfl + +@[simp] theorem getLsb_ofFin (x : Fin (2^n)) (i : Nat) : + getLsb (BitVec.ofFin x) i = x.val.testBit i := rfl + +@[simp] theorem getLsb_ge (x : BitVec w) (i : Nat) (ge : i ≥ w) : getLsb x i = false := by + let ⟨x, x_lt⟩ := x + simp + apply Nat.testBit_lt_two_pow + have p : 2^w ≤ 2^i := Nat.pow_le_pow_of_le_right (by omega) ge + omega + +theorem lt_of_getLsb (x : BitVec w) (i : Nat) : getLsb x i = true → i < w := by + if h : i < w then + simp [h] + else + simp [Nat.ge_of_not_lt h] + +-- We choose `eq_of_getLsb_eq` as the `@[ext]` theorem for `BitVec` +-- somewhat arbitrarily over `eq_of_getMsg_eq`. +@[ext] theorem eq_of_getLsb_eq {x y : BitVec w} + (pred : ∀(i : Fin w), x.getLsb i.val = y.getLsb i.val) : x = y := by + apply eq_of_toNat_eq + apply Nat.eq_of_testBit_eq + intro i + if i_lt : i < w then + exact pred ⟨i, i_lt⟩ + else + have p : i ≥ w := Nat.le_of_not_gt i_lt + simp [testBit_toNat, getLsb_ge _ _ p] + +theorem eq_of_getMsb_eq {x y : BitVec w} + (pred : ∀(i : Fin w), x.getMsb i = y.getMsb i.val) : x = y := by + simp only [getMsb] at pred + apply eq_of_getLsb_eq + intro ⟨i, i_lt⟩ + if w_zero : w = 0 then + simp [w_zero] + else + have w_pos := Nat.pos_of_ne_zero w_zero + have r : i ≤ w - 1 := by + simp [Nat.le_sub_iff_add_le w_pos, Nat.add_succ] + exact i_lt + have q_lt : w - 1 - i < w := by + simp only [Nat.sub_sub] + apply Nat.sub_lt w_pos + simp [Nat.succ_add] + have q := pred ⟨w - 1 - i, q_lt⟩ + simpa [q_lt, Nat.sub_sub_self, r] using q + +theorem eq_of_toFin_eq : ∀ {x y : BitVec w}, x.toFin = y.toFin → x = y + | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl + +@[simp] theorem toNat_ofBool (b : Bool) : (ofBool b).toNat = b.toNat := by + cases b <;> rfl + +theorem ofNat_one (n : Nat) : BitVec.ofNat 1 n = BitVec.ofBool (n % 2 = 1) := by + rcases (Nat.mod_two_eq_zero_or_one n) with h | h <;> simp [h, BitVec.ofNat, Fin.ofNat'] + +theorem ofBool_eq_iff_eq : ∀(b b' : Bool), BitVec.ofBool b = BitVec.ofBool b' ↔ b = b' := by + decide + +@[simp] theorem toNat_ofFin (x : Fin (2^n)) : (BitVec.ofFin x).toNat = x.val := rfl + +@[simp] theorem toNat_ofNat (x w : Nat) : (x#w).toNat = x % 2^w := by + simp [BitVec.toNat, BitVec.ofNat, Fin.ofNat'] + +-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals. +-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea. +theorem getLsb_ofNat (n : Nat) (x : Nat) (i : Nat) : + getLsb (x#n) i = (i < n && x.testBit i) := by + simp [getLsb, BitVec.ofNat, Fin.val_ofNat'] + +@[deprecated toNat_ofNat] theorem toNat_zero (n : Nat) : (0#n).toNat = 0 := by trivial + +@[simp] theorem toNat_mod_cancel (x : BitVec n) : x.toNat % (2^n) = x.toNat := + Nat.mod_eq_of_lt x.isLt + +private theorem lt_two_pow_of_le {x m n : Nat} (lt : x < 2 ^ m) (le : m ≤ n) : x < 2 ^ n := + Nat.lt_of_lt_of_le lt (Nat.pow_le_pow_of_le_right (by trivial : 0 < 2) le) + +@[simp] theorem ofNat_toNat (m : Nat) (x : BitVec n) : x.toNat#m = truncate m x := by + let ⟨x, lt_n⟩ := x + unfold truncate + unfold zeroExtend + if h : n ≤ m then + unfold zeroExtend' + have lt_m : x < 2 ^ m := lt_two_pow_of_le lt_n h + simp [h, lt_m, Nat.mod_eq_of_lt, BitVec.toNat, BitVec.ofNat, Fin.ofNat'] + else + simp [h] + + +/-! ### msb -/ + +theorem msb_eq_decide (x : BitVec (Nat.succ w)) : BitVec.msb x = decide (2 ^ w ≤ x.toNat) := by + simp only [BitVec.msb, getMsb, Nat.zero_lt_succ, + decide_True, getLsb, Nat.testBit, Nat.succ_sub_succ_eq_sub, + Nat.sub_zero, Nat.and_one_is_mod, Bool.true_and, Nat.shiftRight_eq_div_pow] + rcases (Nat.lt_or_ge (BitVec.toNat x) (2 ^ w)) with h | h + · simp [Nat.div_eq_of_lt h, h] + · simp only [h] + rw [Nat.div_eq_sub_div (Nat.two_pow_pos w) h, Nat.div_eq_of_lt] + · decide + · have : BitVec.toNat x < 2^w + 2^w := by simpa [Nat.pow_succ, Nat.mul_two] using x.isLt + omega + +/-! ### cast -/ + +@[simp] theorem toNat_cast (h : w = v) (x : BitVec w) : (cast h x).toNat = x.toNat := rfl +@[simp] theorem toFin_cast (h : w = v) (x : BitVec w) : + (cast h x).toFin = x.toFin.cast (by rw [h]) := + rfl + +@[simp] theorem getLsb_cast (h : w = v) (x : BitVec w) : (cast h x).getLsb i = x.getLsb i := by + subst h; simp + +@[simp] theorem getMsb_cast (h : w = v) (x : BitVec w) : (cast h x).getMsb i = x.getMsb i := by + subst h; simp +@[simp] theorem msb_cast (h : w = v) (x : BitVec w) : (cast h x).msb = x.msb := by + simp [BitVec.msb] + +/-! ### zeroExtend and truncate -/ + +@[simp] theorem toNat_zeroExtend' {m n : Nat} (p : m ≤ n) (x : BitVec m) : + (zeroExtend' p x).toNat = x.toNat := by + unfold zeroExtend' + simp [p, x.isLt, Nat.mod_eq_of_lt] + +theorem toNat_zeroExtend (i : Nat) (x : BitVec n) : + BitVec.toNat (zeroExtend i x) = x.toNat % 2^i := by + let ⟨x, lt_n⟩ := x + simp only [zeroExtend] + if n_le_i : n ≤ i then + have x_lt_two_i : x < 2 ^ i := lt_two_pow_of_le lt_n n_le_i + simp [n_le_i, Nat.mod_eq_of_lt, x_lt_two_i] + else + simp [n_le_i, toNat_ofNat] + +@[simp] theorem zeroExtend_eq (x : BitVec n) : zeroExtend n x = x := by + apply eq_of_toNat_eq + let ⟨x, lt_n⟩ := x + simp [truncate, zeroExtend] + +@[simp] theorem zeroExtend_zero (m n : Nat) : zeroExtend m (0#n) = 0#m := by + apply eq_of_toNat_eq + simp [toNat_zeroExtend] + +@[simp] theorem truncate_eq (x : BitVec n) : truncate n x = x := zeroExtend_eq x + +@[simp] theorem toNat_truncate (x : BitVec n) : (truncate i x).toNat = x.toNat % 2^i := + toNat_zeroExtend i x + +@[simp] theorem getLsb_zeroExtend' (ge : m ≥ n) (x : BitVec n) (i : Nat) : + getLsb (zeroExtend' ge x) i = getLsb x i := by + simp [getLsb, toNat_zeroExtend'] + +@[simp] theorem getLsb_zeroExtend (m : Nat) (x : BitVec n) (i : Nat) : + getLsb (zeroExtend m x) i = (decide (i < m) && getLsb x i) := by + simp [getLsb, toNat_zeroExtend, Nat.testBit_mod_two_pow] + +@[simp] theorem getLsb_truncate (m : Nat) (x : BitVec n) (i : Nat) : + getLsb (truncate m x) i = (decide (i < m) && getLsb x i) := + getLsb_zeroExtend m x i + +/-! ## extractLsb -/ + +@[simp] +protected theorem extractLsb_ofFin {n} (x : Fin (2^n)) (hi lo : Nat) : + extractLsb hi lo (@BitVec.ofFin n x) = .ofNat (hi-lo+1) (x.val >>> lo) := rfl + +@[simp] +protected theorem extractLsb_ofNat (x n : Nat) (hi lo : Nat) : + extractLsb hi lo x#n = .ofNat (hi - lo + 1) ((x % 2^n) >>> lo) := by + apply eq_of_getLsb_eq + intro ⟨i, _lt⟩ + simp [BitVec.ofNat] + +@[simp] theorem extractLsb'_toNat (s m : Nat) (x : BitVec n) : + (extractLsb' s m x).toNat = (x.toNat >>> s) % 2^m := rfl + +@[simp] theorem extractLsb_toNat (hi lo : Nat) (x : BitVec n) : + (extractLsb hi lo x).toNat = (x.toNat >>> lo) % 2^(hi-lo+1) := rfl + +@[simp] theorem getLsb_extract (hi lo : Nat) (x : BitVec n) (i : Nat) : + getLsb (extractLsb hi lo x) i = (i ≤ (hi-lo) && getLsb x (lo+i)) := by + unfold getLsb + simp [Nat.lt_succ] + +/-! ### allOnes -/ + +private theorem allOnes_def : + allOnes v = .ofFin (⟨0, Nat.two_pow_pos v⟩ - ⟨1 % 2^v, Nat.mod_lt _ (Nat.two_pow_pos v)⟩) := by + rfl + +@[simp] theorem toNat_allOnes : (allOnes v).toNat = 2^v - 1 := by + simp only [allOnes_def, toNat_ofFin, Fin.coe_sub, Nat.zero_add] + by_cases h : v = 0 + · subst h + rfl + · rw [Nat.mod_eq_of_lt (Nat.one_lt_two_pow h), Nat.mod_eq_of_lt] + exact Nat.pred_lt_self (Nat.two_pow_pos v) + +@[simp] theorem getLsb_allOnes : (allOnes v).getLsb i = decide (i < v) := by + simp only [allOnes_def, getLsb_ofFin, Fin.coe_sub, Nat.zero_add, Nat.testBit_mod_two_pow] + if h : i < v then + simp only [h, decide_True, Bool.true_and] + match i, v, h with + | i, (v + 1), h => + rw [Nat.mod_eq_of_lt (by simp), Nat.testBit_two_pow_sub_one] + simp [h] + else + simp [h] + +@[simp] theorem negOne_eq_allOnes : -1#w = allOnes w := + rfl + +/-! ### or -/ + +@[simp] theorem toNat_or (x y : BitVec v) : + BitVec.toNat (x ||| y) = BitVec.toNat x ||| BitVec.toNat y := rfl + +@[simp] theorem toFin_or (x y : BitVec v) : + BitVec.toFin (x ||| y) = BitVec.toFin x ||| BitVec.toFin y := by + simp only [HOr.hOr, OrOp.or, BitVec.or, Fin.lor, val_toFin, Fin.mk.injEq] + exact (Nat.mod_eq_of_lt <| Nat.or_lt_two_pow x.isLt y.isLt).symm + + +@[simp] theorem getLsb_or {x y : BitVec v} : (x ||| y).getLsb i = (x.getLsb i || y.getLsb i) := by + rw [← testBit_toNat, getLsb, getLsb] + simp + +/-! ### and -/ + +@[simp] theorem toNat_and (x y : BitVec v) : + BitVec.toNat (x &&& y) = BitVec.toNat x &&& BitVec.toNat y := rfl + +@[simp] theorem toFin_and (x y : BitVec v) : + BitVec.toFin (x &&& y) = BitVec.toFin x &&& BitVec.toFin y := by + simp only [HAnd.hAnd, AndOp.and, BitVec.and, Fin.land, val_toFin, Fin.mk.injEq] + exact (Nat.mod_eq_of_lt <| Nat.and_lt_two_pow _ y.isLt).symm + +@[simp] theorem getLsb_and {x y : BitVec v} : (x &&& y).getLsb i = (x.getLsb i && y.getLsb i) := by + rw [← testBit_toNat, getLsb, getLsb] + simp + +/-! ### xor -/ + +@[simp] theorem toNat_xor (x y : BitVec v) : + BitVec.toNat (x ^^^ y) = BitVec.toNat x ^^^ BitVec.toNat y := rfl + +@[simp] theorem toFin_xor (x y : BitVec v) : + BitVec.toFin (x ^^^ y) = BitVec.toFin x ^^^ BitVec.toFin y := by + simp only [HXor.hXor, Xor.xor, BitVec.xor, Fin.xor, val_toFin, Fin.mk.injEq] + exact (Nat.mod_eq_of_lt <| Nat.xor_lt_two_pow x.isLt y.isLt).symm + +@[simp] theorem getLsb_xor {x y : BitVec v} : + (x ^^^ y).getLsb i = (xor (x.getLsb i) (y.getLsb i)) := by + rw [← testBit_toNat, getLsb, getLsb] + simp + +/-! ### not -/ + +theorem not_def {x : BitVec v} : ~~~x = allOnes v ^^^ x := rfl + +@[simp] theorem toNat_not {x : BitVec v} : (~~~x).toNat = 2^v - 1 - x.toNat := by + rw [Nat.sub_sub, Nat.add_comm, not_def, toNat_xor] + apply Nat.eq_of_testBit_eq + intro i + simp only [toNat_allOnes, Nat.testBit_xor, Nat.testBit_two_pow_sub_one] + match h : BitVec.toNat x with + | 0 => simp + | y+1 => + rw [Nat.succ_eq_add_one] at h + rw [← h] + rw [Nat.testBit_two_pow_sub_succ (toNat_lt _)] + · cases w : decide (i < v) + · simp at w + simp [w] + rw [Nat.testBit_lt_two_pow] + calc BitVec.toNat x < 2 ^ v := toNat_lt _ + _ ≤ 2 ^ i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two w + · simp + +@[simp] theorem toFin_not (x : BitVec w) : + (~~~x).toFin = x.toFin.rev := by + apply Fin.val_inj.mp + simp only [val_toFin, toNat_not, Fin.val_rev] + omega + +@[simp] theorem getLsb_not {x : BitVec v} : (~~~x).getLsb i = (decide (i < v) && ! x.getLsb i) := by + by_cases h' : i < v <;> simp_all [not_def] + +/-! ### shiftLeft -/ + +@[simp] theorem toNat_shiftLeft {x : BitVec v} : + BitVec.toNat (x <<< n) = BitVec.toNat x <<< n % 2^v := + BitVec.toNat_ofNat _ _ + +@[simp] theorem toFin_shiftLeft {n : Nat} (x : BitVec w) : + BitVec.toFin (x <<< n) = Fin.ofNat' (x.toNat <<< n) (Nat.two_pow_pos w) := rfl + +@[simp] theorem getLsb_shiftLeft (x : BitVec m) (n) : + getLsb (x <<< n) i = (decide (i < m) && !decide (i < n) && getLsb x (i - n)) := by + rw [← testBit_toNat, getLsb] + simp only [toNat_shiftLeft, Nat.testBit_mod_two_pow, Nat.testBit_shiftLeft, ge_iff_le] + -- This step could be a case bashing tactic. + cases h₁ : decide (i < m) <;> cases h₂ : decide (n ≤ i) <;> cases h₃ : decide (i < n) + all_goals { simp_all <;> omega } + +theorem shiftLeftZeroExtend_eq {x : BitVec w} : + shiftLeftZeroExtend x n = zeroExtend (w+n) x <<< n := by + apply eq_of_toNat_eq + rw [shiftLeftZeroExtend, zeroExtend] + split + · simp + rw [Nat.mod_eq_of_lt] + rw [Nat.shiftLeft_eq, Nat.pow_add] + exact Nat.mul_lt_mul_of_pos_right (BitVec.toNat_lt x) (Nat.two_pow_pos _) + · omega + +@[simp] theorem getLsb_shiftLeftZeroExtend (x : BitVec m) (n : Nat) : + getLsb (shiftLeftZeroExtend x n) i = ((! decide (i < n)) && getLsb x (i - n)) := by + rw [shiftLeftZeroExtend_eq] + simp only [getLsb_shiftLeft, getLsb_zeroExtend] + cases h₁ : decide (i < n) <;> cases h₂ : decide (i - n < m + n) <;> cases h₃ : decide (i < m + n) + <;> simp_all + <;> (rw [getLsb_ge]; omega) + +/-! ### ushiftRight -/ + +@[simp] theorem toNat_ushiftRight (x : BitVec n) (i : Nat) : + (x >>> i).toNat = x.toNat >>> i := rfl + +@[simp] theorem getLsb_ushiftRight (x : BitVec n) (i j : Nat) : + getLsb (x >>> i) j = getLsb x (i+j) := by + unfold getLsb ; simp + +/-! ### append -/ + +theorem append_def (x : BitVec v) (y : BitVec w) : + x ++ y = (shiftLeftZeroExtend x w ||| zeroExtend' (Nat.le_add_left w v) y) := rfl + +@[simp] theorem toNat_append (x : BitVec m) (y : BitVec n) : + (x ++ y).toNat = x.toNat <<< n ||| y.toNat := + rfl + +@[simp] theorem getLsb_append {v : BitVec n} {w : BitVec m} : + getLsb (v ++ w) i = bif i < m then getLsb w i else getLsb v (i - m) := by + simp [append_def] + by_cases h : i < m + · simp [h] + · simp [h]; simp_all + +/-! ### rev -/ + +theorem getLsb_rev (x : BitVec w) (i : Fin w) : + x.getLsb i.rev = x.getMsb i := by + simp [getLsb, getMsb] + congr 1 + omega + +theorem getMsb_rev (x : BitVec w) (i : Fin w) : + x.getMsb i.rev = x.getLsb i := by + simp only [← getLsb_rev] + simp only [Fin.rev] + congr + omega + +/-! ### cons -/ + +@[simp] theorem toNat_cons (b : Bool) (x : BitVec w) : + (cons b x).toNat = (b.toNat <<< w) ||| x.toNat := by + let ⟨x, _⟩ := x + simp [cons, toNat_append, toNat_ofBool] + +@[simp] theorem getLsb_cons (b : Bool) {n} (x : BitVec n) (i : Nat) : + getLsb (cons b x) i = if i = n then b else getLsb x i := by + simp only [getLsb, toNat_cons, Nat.testBit_or] + rw [Nat.testBit_shiftLeft] + rcases Nat.lt_trichotomy i n with i_lt_n | i_eq_n | n_lt_i + · have p1 : ¬(n ≤ i) := by omega + have p2 : i ≠ n := by omega + simp [p1, p2] + · simp [i_eq_n, testBit_toNat] + cases b <;> trivial + · have p1 : i ≠ n := by omega + have p2 : i - n ≠ 0 := by omega + simp [p1, p2, Nat.testBit_bool_to_nat] + +theorem truncate_succ (x : BitVec w) : + truncate (i+1) x = cons (getLsb x i) (truncate i x) := by + apply eq_of_getLsb_eq + intro j + simp only [getLsb_truncate, getLsb_cons, j.isLt, decide_True, Bool.true_and] + if j_eq : j.val = i then + simp [j_eq] + else + have j_lt : j.val < i := Nat.lt_of_le_of_ne (Nat.le_of_succ_le_succ j.isLt) j_eq + simp [j_eq, j_lt] + +/-! ### add -/ + +theorem add_def {n} (x y : BitVec n) : x + y = .ofNat n (x.toNat + y.toNat) := rfl + +/-- +Definition of bitvector addition as a nat. +-/ +@[simp] theorem toNat_add (x y : BitVec w) : (x + y).toNat = (x.toNat + y.toNat) % 2^w := rfl +@[simp] theorem toFin_add (x y : BitVec w) : (x + y).toFin = toFin x + toFin y := rfl +@[simp] theorem ofFin_add (x : Fin (2^n)) (y : BitVec n) : + .ofFin x + y = .ofFin (x + y.toFin) := rfl +@[simp] theorem add_ofFin (x : BitVec n) (y : Fin (2^n)) : + x + .ofFin y = .ofFin (x.toFin + y) := rfl +@[simp] theorem ofNat_add_ofNat {n} (x y : Nat) : x#n + y#n = (x + y)#n := by + apply eq_of_toNat_eq ; simp [BitVec.ofNat] + +protected theorem add_assoc (x y z : BitVec n) : x + y + z = x + (y + z) := by + apply eq_of_toNat_eq ; simp [Nat.add_assoc] + +protected theorem add_comm (x y : BitVec n) : x + y = y + x := by + simp [add_def, Nat.add_comm] + +@[simp] protected theorem add_zero (x : BitVec n) : x + 0#n = x := by simp [add_def] + +@[simp] protected theorem zero_add (x : BitVec n) : 0#n + x = x := by simp [add_def] + + +/-! ### sub/neg -/ + +theorem sub_def {n} (x y : BitVec n) : x - y = .ofNat n (x.toNat + (2^n - y.toNat)) := by rfl + +@[simp] theorem toNat_sub {n} (x y : BitVec n) : + (x - y).toNat = ((x.toNat + (2^n - y.toNat)) % 2^n) := rfl +@[simp] theorem toFin_sub (x y : BitVec n) : (x - y).toFin = toFin x - toFin y := rfl + +@[simp] theorem ofFin_sub (x : Fin (2^n)) (y : BitVec n) : .ofFin x - y = .ofFin (x - y.toFin) := + rfl +@[simp] theorem sub_ofFin (x : BitVec n) (y : Fin (2^n)) : x - .ofFin y = .ofFin (x.toFin - y) := + rfl +-- Remark: we don't use `[simp]` here because simproc` subsumes it for literals. +-- If `x` and `n` are not literals, applying this theorem eagerly may not be a good idea. +theorem ofNat_sub_ofNat {n} (x y : Nat) : x#n - y#n = .ofNat n (x + (2^n - y % 2^n)) := by + apply eq_of_toNat_eq ; simp [BitVec.ofNat] + +@[simp] protected theorem sub_zero (x : BitVec n) : x - (0#n) = x := by apply eq_of_toNat_eq ; simp + +@[simp] protected theorem sub_self (x : BitVec n) : x - x = 0#n := by + apply eq_of_toNat_eq + simp only [toNat_sub] + rw [Nat.add_sub_of_le] + · simp + · exact Nat.le_of_lt x.isLt + +@[simp] theorem toNat_neg (x : BitVec n) : (- x).toNat = (2^n - x.toNat) % 2^n := by + simp [Neg.neg, BitVec.neg] + +theorem sub_toAdd {n} (x y : BitVec n) : x - y = x + - y := by + apply eq_of_toNat_eq + simp + +@[simp] theorem neg_zero (n:Nat) : -0#n = 0#n := by apply eq_of_toNat_eq ; simp + +theorem add_sub_cancel (x y : BitVec w) : x + y - y = x := by + apply eq_of_toNat_eq + have y_toNat_le := Nat.le_of_lt y.toNat_lt + rw [toNat_sub, toNat_add, Nat.mod_add_mod, Nat.add_assoc, ← Nat.add_sub_assoc y_toNat_le, + Nat.add_sub_cancel_left, Nat.add_mod_right, toNat_mod_cancel] + +/-! ### mul -/ + +theorem mul_def {n} {x y : BitVec n} : x * y = (ofFin <| x.toFin * y.toFin) := by rfl + +theorem toNat_mul (x y : BitVec n) : (x * y).toNat = (x.toNat * y.toNat) % 2 ^ n := rfl +@[simp] theorem toFin_mul (x y : BitVec n) : (x * y).toFin = (x.toFin * y.toFin) := rfl + +/-! ### le and lt -/ + +theorem le_def (x y : BitVec n) : + x ≤ y ↔ x.toNat ≤ y.toNat := Iff.rfl + +@[simp] theorem le_ofFin (x : BitVec n) (y : Fin (2^n)) : + x ≤ BitVec.ofFin y ↔ x.toFin ≤ y := Iff.rfl +@[simp] theorem ofFin_le (x : Fin (2^n)) (y : BitVec n) : + BitVec.ofFin x ≤ y ↔ x ≤ y.toFin := Iff.rfl +@[simp] theorem ofNat_le_ofNat {n} (x y : Nat) : (x#n) ≤ (y#n) ↔ x % 2^n ≤ y % 2^n := by + simp [le_def] + +theorem lt_def (x y : BitVec n) : + x < y ↔ x.toNat < y.toNat := Iff.rfl + +@[simp] theorem lt_ofFin (x : BitVec n) (y : Fin (2^n)) : + x < BitVec.ofFin y ↔ x.toFin < y := Iff.rfl +@[simp] theorem ofFin_lt (x : Fin (2^n)) (y : BitVec n) : + BitVec.ofFin x < y ↔ x < y.toFin := Iff.rfl +@[simp] theorem ofNat_lt_ofNat {n} (x y : Nat) : (x#n) < (y#n) ↔ x % 2^n < y % 2^n := by + simp [lt_def] + +protected theorem lt_of_le_ne (x y : BitVec n) (h1 : x <= y) (h2 : ¬ x = y) : x < y := by + revert h1 h2 + let ⟨x, lt⟩ := x + let ⟨y, lt⟩ := y + simp + exact Nat.lt_of_le_of_ne diff --git a/src/Init/Data/Fin.lean b/src/Init/Data/Fin.lean index b33d44e608..3b5a8ecc71 100644 --- a/src/Init/Data/Fin.lean +++ b/src/Init/Data/Fin.lean @@ -8,3 +8,4 @@ import Init.Data.Fin.Basic import Init.Data.Fin.Log2 import Init.Data.Fin.Iterate import Init.Data.Fin.Fold +import Init.Data.Fin.Lemmas diff --git a/src/Init/Data/Fin/Basic.lean b/src/Init/Data/Fin/Basic.lean index e5b44aab68..8ad2997ba8 100644 --- a/src/Init/Data/Fin/Basic.lean +++ b/src/Init/Data/Fin/Basic.lean @@ -5,7 +5,7 @@ Author: Leonardo de Moura, Robert Y. Lewis, Keeley Hoek, Mario Carneiro -/ prelude import Init.Data.Nat.Div -import Init.Data.Nat.Bitwise +import Init.Data.Nat.Bitwise.Basic import Init.Coe open Nat diff --git a/src/Init/Data/Fin/Lemmas.lean b/src/Init/Data/Fin/Lemmas.lean new file mode 100644 index 0000000000..63cecad719 --- /dev/null +++ b/src/Init/Data/Fin/Lemmas.lean @@ -0,0 +1,830 @@ +/- +Copyright (c) 2022 Mario Carneiro. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Mario Carneiro +-/ +prelude +import Init.Data.Fin.Basic +import Init.Data.Nat.Lemmas +import Init.Ext +import Init.ByCases +import Init.Conv +import Init.Omega + +namespace Fin + +/-- If you actually have an element of `Fin n`, then the `n` is always positive -/ +theorem size_pos (i : Fin n) : 0 < n := Nat.lt_of_le_of_lt (Nat.zero_le _) i.2 + +theorem mod_def (a m : Fin n) : a % m = Fin.mk (a % m) (Nat.lt_of_le_of_lt (Nat.mod_le _ _) a.2) := + rfl + +theorem mul_def (a b : Fin n) : a * b = Fin.mk ((a * b) % n) (Nat.mod_lt _ a.size_pos) := rfl + +theorem sub_def (a b : Fin n) : a - b = Fin.mk ((a + (n - b)) % n) (Nat.mod_lt _ a.size_pos) := rfl + +theorem size_pos' : ∀ [Nonempty (Fin n)], 0 < n | ⟨i⟩ => i.size_pos + +@[simp] theorem is_lt (a : Fin n) : (a : Nat) < n := a.2 + +theorem pos_iff_nonempty {n : Nat} : 0 < n ↔ Nonempty (Fin n) := + ⟨fun h => ⟨⟨0, h⟩⟩, fun ⟨i⟩ => i.pos⟩ + +/-! ### coercions and constructions -/ + +@[simp] protected theorem eta (a : Fin n) (h : a < n) : (⟨a, h⟩ : Fin n) = a := rfl + +@[ext] theorem ext {a b : Fin n} (h : (a : Nat) = b) : a = b := eq_of_val_eq h + +theorem val_inj {a b : Fin n} : a.1 = b.1 ↔ a = b := ⟨Fin.eq_of_val_eq, Fin.val_eq_of_eq⟩ + +theorem ext_iff {a b : Fin n} : a = b ↔ a.1 = b.1 := val_inj.symm + +theorem val_ne_iff {a b : Fin n} : a.1 ≠ b.1 ↔ a ≠ b := not_congr val_inj + +theorem exists_iff {p : Fin n → Prop} : (∃ i, p i) ↔ ∃ i h, p ⟨i, h⟩ := + ⟨fun ⟨⟨i, hi⟩, hpi⟩ => ⟨i, hi, hpi⟩, fun ⟨i, hi, hpi⟩ => ⟨⟨i, hi⟩, hpi⟩⟩ + +theorem forall_iff {p : Fin n → Prop} : (∀ i, p i) ↔ ∀ i h, p ⟨i, h⟩ := + ⟨fun h i hi => h ⟨i, hi⟩, fun h ⟨i, hi⟩ => h i hi⟩ + +protected theorem mk.inj_iff {n a b : Nat} {ha : a < n} {hb : b < n} : + (⟨a, ha⟩ : Fin n) = ⟨b, hb⟩ ↔ a = b := ext_iff + +theorem val_mk {m n : Nat} (h : m < n) : (⟨m, h⟩ : Fin n).val = m := rfl + +theorem eq_mk_iff_val_eq {a : Fin n} {k : Nat} {hk : k < n} : + a = ⟨k, hk⟩ ↔ (a : Nat) = k := ext_iff + +theorem mk_val (i : Fin n) : (⟨i, i.isLt⟩ : Fin n) = i := Fin.eta .. + +@[simp] theorem val_ofNat' (a : Nat) (is_pos : n > 0) : + (Fin.ofNat' a is_pos).val = a % n := rfl + +@[deprecated ofNat'_zero_val] theorem ofNat'_zero_val : (Fin.ofNat' 0 h).val = 0 := Nat.zero_mod _ + +@[simp] theorem mod_val (a b : Fin n) : (a % b).val = a.val % b.val := + rfl + +@[simp] theorem div_val (a b : Fin n) : (a / b).val = a.val / b.val := + rfl + +@[simp] theorem modn_val (a : Fin n) (b : Nat) : (a.modn b).val = a.val % b := + rfl + +theorem ite_val {n : Nat} {c : Prop} [Decidable c] {x : c → Fin n} (y : ¬c → Fin n) : + (if h : c then x h else y h).val = if h : c then (x h).val else (y h).val := by + by_cases c <;> simp [*] + +theorem dite_val {n : Nat} {c : Prop} [Decidable c] {x y : Fin n} : + (if c then x else y).val = if c then x.val else y.val := by + by_cases c <;> simp [*] + +/-! ### order -/ + +theorem le_def {a b : Fin n} : a ≤ b ↔ a.1 ≤ b.1 := .rfl + +theorem lt_def {a b : Fin n} : a < b ↔ a.1 < b.1 := .rfl + +theorem lt_iff_val_lt_val {a b : Fin n} : a < b ↔ a.val < b.val := Iff.rfl + +@[simp] protected theorem not_le {a b : Fin n} : ¬ a ≤ b ↔ b < a := Nat.not_le + +@[simp] protected theorem not_lt {a b : Fin n} : ¬ a < b ↔ b ≤ a := Nat.not_lt + +protected theorem ne_of_lt {a b : Fin n} (h : a < b) : a ≠ b := Fin.ne_of_val_ne (Nat.ne_of_lt h) + +protected theorem ne_of_gt {a b : Fin n} (h : a < b) : b ≠ a := Fin.ne_of_val_ne (Nat.ne_of_gt h) + +protected theorem le_of_lt {a b : Fin n} (h : a < b) : a ≤ b := Nat.le_of_lt h + +theorem is_le (i : Fin (n + 1)) : i ≤ n := Nat.le_of_lt_succ i.is_lt + +@[simp] theorem is_le' {a : Fin n} : a ≤ n := Nat.le_of_lt a.is_lt + +theorem mk_lt_of_lt_val {b : Fin n} {a : Nat} (h : a < b) : + (⟨a, Nat.lt_trans h b.is_lt⟩ : Fin n) < b := h + +theorem mk_le_of_le_val {b : Fin n} {a : Nat} (h : a ≤ b) : + (⟨a, Nat.lt_of_le_of_lt h b.is_lt⟩ : Fin n) ≤ b := h + +@[simp] theorem mk_le_mk {x y : Nat} {hx hy} : (⟨x, hx⟩ : Fin n) ≤ ⟨y, hy⟩ ↔ x ≤ y := .rfl + +@[simp] theorem mk_lt_mk {x y : Nat} {hx hy} : (⟨x, hx⟩ : Fin n) < ⟨y, hy⟩ ↔ x < y := .rfl + +@[simp] theorem val_zero (n : Nat) : (0 : Fin (n + 1)).1 = 0 := rfl + +@[simp] theorem mk_zero : (⟨0, Nat.succ_pos n⟩ : Fin (n + 1)) = 0 := rfl + +@[simp] theorem zero_le (a : Fin (n + 1)) : 0 ≤ a := Nat.zero_le a.val + +theorem zero_lt_one : (0 : Fin (n + 2)) < 1 := Nat.zero_lt_one + +@[simp] theorem not_lt_zero (a : Fin (n + 1)) : ¬a < 0 := nofun + +theorem pos_iff_ne_zero {a : Fin (n + 1)} : 0 < a ↔ a ≠ 0 := by + rw [lt_def, val_zero, Nat.pos_iff_ne_zero, ← val_ne_iff]; rfl + +theorem eq_zero_or_eq_succ {n : Nat} : ∀ i : Fin (n + 1), i = 0 ∨ ∃ j : Fin n, i = j.succ + | 0 => .inl rfl + | ⟨j + 1, h⟩ => .inr ⟨⟨j, Nat.lt_of_succ_lt_succ h⟩, rfl⟩ + +theorem eq_succ_of_ne_zero {n : Nat} {i : Fin (n + 1)} (hi : i ≠ 0) : ∃ j : Fin n, i = j.succ := + (eq_zero_or_eq_succ i).resolve_left hi + +@[simp] theorem val_rev (i : Fin n) : rev i = n - (i + 1) := rfl + +@[simp] theorem rev_rev (i : Fin n) : rev (rev i) = i := ext <| by + rw [val_rev, val_rev, ← Nat.sub_sub, Nat.sub_sub_self (by exact i.2), Nat.add_sub_cancel] + +@[simp] theorem rev_le_rev {i j : Fin n} : rev i ≤ rev j ↔ j ≤ i := by + simp only [le_def, val_rev, Nat.sub_le_sub_iff_left (Nat.succ_le.2 j.is_lt)] + exact Nat.succ_le_succ_iff + +@[simp] theorem rev_inj {i j : Fin n} : rev i = rev j ↔ i = j := + ⟨fun h => by simpa using congrArg rev h, congrArg _⟩ + +theorem rev_eq {n a : Nat} (i : Fin (n + 1)) (h : n = a + i) : + rev i = ⟨a, Nat.lt_succ_of_le (h ▸ Nat.le_add_right ..)⟩ := by + ext; dsimp + conv => lhs; congr; rw [h] + rw [Nat.add_assoc, Nat.add_sub_cancel] + +@[simp] theorem rev_lt_rev {i j : Fin n} : rev i < rev j ↔ j < i := by + rw [← Fin.not_le, ← Fin.not_le, rev_le_rev] + +@[simp] theorem val_last (n : Nat) : last n = n := rfl + +theorem le_last (i : Fin (n + 1)) : i ≤ last n := Nat.le_of_lt_succ i.is_lt + +theorem last_pos : (0 : Fin (n + 2)) < last (n + 1) := Nat.succ_pos _ + +theorem eq_last_of_not_lt {i : Fin (n + 1)} (h : ¬(i : Nat) < n) : i = last n := + ext <| Nat.le_antisymm (le_last i) (Nat.not_lt.1 h) + +theorem val_lt_last {i : Fin (n + 1)} : i ≠ last n → (i : Nat) < n := + Decidable.not_imp_comm.1 eq_last_of_not_lt + +@[simp] theorem rev_last (n : Nat) : rev (last n) = 0 := ext <| by simp + +@[simp] theorem rev_zero (n : Nat) : rev 0 = last n := by + rw [← rev_rev (last _), rev_last] + +/-! ### addition, numerals, and coercion from Nat -/ + +@[simp] theorem val_one (n : Nat) : (1 : Fin (n + 2)).val = 1 := rfl + +@[simp] theorem mk_one : (⟨1, Nat.succ_lt_succ (Nat.succ_pos n)⟩ : Fin (n + 2)) = (1 : Fin _) := rfl + +theorem subsingleton_iff_le_one : Subsingleton (Fin n) ↔ n ≤ 1 := by + (match n with | 0 | 1 | n+2 => ?_) <;> try simp + · exact ⟨nofun⟩ + · exact ⟨fun ⟨0, _⟩ ⟨0, _⟩ => rfl⟩ + · exact iff_of_false (fun h => Fin.ne_of_lt zero_lt_one (h.elim ..)) (of_decide_eq_false rfl) + +instance subsingleton_zero : Subsingleton (Fin 0) := subsingleton_iff_le_one.2 (by decide) + +instance subsingleton_one : Subsingleton (Fin 1) := subsingleton_iff_le_one.2 (by decide) + +theorem fin_one_eq_zero (a : Fin 1) : a = 0 := Subsingleton.elim a 0 + +theorem add_def (a b : Fin n) : a + b = Fin.mk ((a + b) % n) (Nat.mod_lt _ a.size_pos) := rfl + +theorem val_add (a b : Fin n) : (a + b).val = (a.val + b.val) % n := rfl + +theorem val_add_one_of_lt {n : Nat} {i : Fin n.succ} (h : i < last _) : (i + 1).1 = i + 1 := by + match n with + | 0 => cases h + | n+1 => rw [val_add, val_one, Nat.mod_eq_of_lt (by exact Nat.succ_lt_succ h)] + +@[simp] theorem last_add_one : ∀ n, last n + 1 = 0 + | 0 => rfl + | n + 1 => by ext; rw [val_add, val_zero, val_last, val_one, Nat.mod_self] + +theorem val_add_one {n : Nat} (i : Fin (n + 1)) : + ((i + 1 : Fin (n + 1)) : Nat) = if i = last _ then (0 : Nat) else i + 1 := by + match Nat.eq_or_lt_of_le (le_last i) with + | .inl h => cases Fin.eq_of_val_eq h; simp + | .inr h => simpa [Fin.ne_of_lt h] using val_add_one_of_lt h + +@[simp] theorem val_two {n : Nat} : (2 : Fin (n + 3)).val = 2 := rfl + +theorem add_one_pos (i : Fin (n + 1)) (h : i < Fin.last n) : (0 : Fin (n + 1)) < i + 1 := by + match n with + | 0 => cases h + | n+1 => + rw [Fin.lt_def, val_last, ← Nat.add_lt_add_iff_right] at h + rw [Fin.lt_def, val_add, val_zero, val_one, Nat.mod_eq_of_lt h] + exact Nat.zero_lt_succ _ + +theorem one_pos : (0 : Fin (n + 2)) < 1 := Nat.succ_pos 0 + +theorem zero_ne_one : (0 : Fin (n + 2)) ≠ 1 := Fin.ne_of_lt one_pos + +/-! ### succ and casts into larger Fin types -/ + +@[simp] theorem val_succ (j : Fin n) : (j.succ : Nat) = j + 1 := rfl + +@[simp] theorem succ_pos (a : Fin n) : (0 : Fin (n + 1)) < a.succ := by + simp [Fin.lt_def, Nat.succ_pos] + +@[simp] theorem succ_le_succ_iff {a b : Fin n} : a.succ ≤ b.succ ↔ a ≤ b := Nat.succ_le_succ_iff + +@[simp] theorem succ_lt_succ_iff {a b : Fin n} : a.succ < b.succ ↔ a < b := Nat.succ_lt_succ_iff + +@[simp] theorem succ_inj {a b : Fin n} : a.succ = b.succ ↔ a = b := by + refine ⟨fun h => ext ?_, congrArg _⟩ + apply Nat.le_antisymm <;> exact succ_le_succ_iff.1 (h ▸ Nat.le_refl _) + +theorem succ_ne_zero {n} : ∀ k : Fin n, Fin.succ k ≠ 0 + | ⟨k, _⟩, heq => Nat.succ_ne_zero k <| ext_iff.1 heq + +@[simp] theorem succ_zero_eq_one : Fin.succ (0 : Fin (n + 1)) = 1 := rfl + +/-- Version of `succ_one_eq_two` to be used by `dsimp` -/ +@[simp] theorem succ_one_eq_two : Fin.succ (1 : Fin (n + 2)) = 2 := rfl + +@[simp] theorem succ_mk (n i : Nat) (h : i < n) : + Fin.succ ⟨i, h⟩ = ⟨i + 1, Nat.succ_lt_succ h⟩ := rfl + +theorem mk_succ_pos (i : Nat) (h : i < n) : + (0 : Fin (n + 1)) < ⟨i.succ, Nat.add_lt_add_right h 1⟩ := by + rw [lt_def, val_zero]; exact Nat.succ_pos i + +theorem one_lt_succ_succ (a : Fin n) : (1 : Fin (n + 2)) < a.succ.succ := by + let n+1 := n + rw [← succ_zero_eq_one, succ_lt_succ_iff]; exact succ_pos a + +@[simp] theorem add_one_lt_iff {n : Nat} {k : Fin (n + 2)} : k + 1 < k ↔ k = last _ := by + simp only [lt_def, val_add, val_last, ext_iff] + let ⟨k, hk⟩ := k + match Nat.eq_or_lt_of_le (Nat.le_of_lt_succ hk) with + | .inl h => cases h; simp [Nat.succ_pos] + | .inr hk' => simp [Nat.ne_of_lt hk', Nat.mod_eq_of_lt (Nat.succ_lt_succ hk'), Nat.le_succ] + +@[simp] theorem add_one_le_iff {n : Nat} : ∀ {k : Fin (n + 1)}, k + 1 ≤ k ↔ k = last _ := by + match n with + | 0 => + intro (k : Fin 1) + exact iff_of_true (Subsingleton.elim (α := Fin 1) (k+1) _ ▸ Nat.le_refl _) (fin_one_eq_zero ..) + | n + 1 => + intro (k : Fin (n+2)) + rw [← add_one_lt_iff, lt_def, le_def, Nat.lt_iff_le_and_ne, and_iff_left] + rw [val_add_one] + split <;> simp [*, (Nat.succ_ne_zero _).symm, Nat.ne_of_gt (Nat.lt_succ_self _)] + +@[simp] theorem last_le_iff {n : Nat} {k : Fin (n + 1)} : last n ≤ k ↔ k = last n := by + rw [ext_iff, Nat.le_antisymm_iff, le_def, and_iff_right (by apply le_last)] + +@[simp] theorem lt_add_one_iff {n : Nat} {k : Fin (n + 1)} : k < k + 1 ↔ k < last n := by + rw [← Decidable.not_iff_not]; simp + +@[simp] theorem le_zero_iff {n : Nat} {k : Fin (n + 1)} : k ≤ 0 ↔ k = 0 := + ⟨fun h => Fin.eq_of_val_eq <| Nat.eq_zero_of_le_zero h, (· ▸ Nat.le_refl _)⟩ + +theorem succ_succ_ne_one (a : Fin n) : Fin.succ (Fin.succ a) ≠ 1 := + Fin.ne_of_gt (one_lt_succ_succ a) + +@[simp] theorem coe_castLT (i : Fin m) (h : i.1 < n) : (castLT i h : Nat) = i := rfl + +@[simp] theorem castLT_mk (i n m : Nat) (hn : i < n) (hm : i < m) : castLT ⟨i, hn⟩ hm = ⟨i, hm⟩ := + rfl + +@[simp] theorem coe_castLE (h : n ≤ m) (i : Fin n) : (castLE h i : Nat) = i := rfl + +@[simp] theorem castLE_mk (i n m : Nat) (hn : i < n) (h : n ≤ m) : + castLE h ⟨i, hn⟩ = ⟨i, Nat.lt_of_lt_of_le hn h⟩ := rfl + +@[simp] theorem castLE_zero {n m : Nat} (h : n.succ ≤ m.succ) : castLE h 0 = 0 := by simp [ext_iff] + +@[simp] theorem castLE_succ {m n : Nat} (h : m + 1 ≤ n + 1) (i : Fin m) : + castLE h i.succ = (castLE (Nat.succ_le_succ_iff.mp h) i).succ := by simp [ext_iff] + +@[simp] theorem castLE_castLE {k m n} (km : k ≤ m) (mn : m ≤ n) (i : Fin k) : + Fin.castLE mn (Fin.castLE km i) = Fin.castLE (Nat.le_trans km mn) i := + Fin.ext (by simp only [coe_castLE]) + +@[simp] theorem castLE_comp_castLE {k m n} (km : k ≤ m) (mn : m ≤ n) : + Fin.castLE mn ∘ Fin.castLE km = Fin.castLE (Nat.le_trans km mn) := + funext (castLE_castLE km mn) + +@[simp] theorem coe_cast (h : n = m) (i : Fin n) : (cast h i : Nat) = i := rfl + +@[simp] theorem cast_last {n' : Nat} {h : n + 1 = n' + 1} : cast h (last n) = last n' := + ext (by rw [coe_cast, val_last, val_last, Nat.succ.inj h]) + +@[simp] theorem cast_mk (h : n = m) (i : Nat) (hn : i < n) : cast h ⟨i, hn⟩ = ⟨i, h ▸ hn⟩ := rfl + +@[simp] theorem cast_trans {k : Nat} (h : n = m) (h' : m = k) {i : Fin n} : + cast h' (cast h i) = cast (Eq.trans h h') i := rfl + +theorem castLE_of_eq {m n : Nat} (h : m = n) {h' : m ≤ n} : castLE h' = Fin.cast h := rfl + +@[simp] theorem coe_castAdd (m : Nat) (i : Fin n) : (castAdd m i : Nat) = i := rfl + +@[simp] theorem castAdd_zero : (castAdd 0 : Fin n → Fin (n + 0)) = cast rfl := rfl + +theorem castAdd_lt {m : Nat} (n : Nat) (i : Fin m) : (castAdd n i : Nat) < m := by simp + +@[simp] theorem castAdd_mk (m : Nat) (i : Nat) (h : i < n) : + castAdd m ⟨i, h⟩ = ⟨i, Nat.lt_add_right m h⟩ := rfl + +@[simp] theorem castAdd_castLT (m : Nat) (i : Fin (n + m)) (hi : i.val < n) : + castAdd m (castLT i hi) = i := rfl + +@[simp] theorem castLT_castAdd (m : Nat) (i : Fin n) : + castLT (castAdd m i) (castAdd_lt m i) = i := rfl + +/-- For rewriting in the reverse direction, see `Fin.cast_castAdd_left`. -/ +theorem castAdd_cast {n n' : Nat} (m : Nat) (i : Fin n') (h : n' = n) : + castAdd m (Fin.cast h i) = Fin.cast (congrArg (. + m) h) (castAdd m i) := ext rfl + +theorem cast_castAdd_left {n n' m : Nat} (i : Fin n') (h : n' + m = n + m) : + cast h (castAdd m i) = castAdd m (cast (Nat.add_right_cancel h) i) := rfl + +@[simp] theorem cast_castAdd_right {n m m' : Nat} (i : Fin n) (h : n + m' = n + m) : + cast h (castAdd m' i) = castAdd m i := rfl + +theorem castAdd_castAdd {m n p : Nat} (i : Fin m) : + castAdd p (castAdd n i) = cast (Nat.add_assoc ..).symm (castAdd (n + p) i) := rfl + +/-- The cast of the successor is the successor of the cast. See `Fin.succ_cast_eq` for rewriting in +the reverse direction. -/ +@[simp] theorem cast_succ_eq {n' : Nat} (i : Fin n) (h : n.succ = n'.succ) : + cast h i.succ = (cast (Nat.succ.inj h) i).succ := rfl + +theorem succ_cast_eq {n' : Nat} (i : Fin n) (h : n = n') : + (cast h i).succ = cast (by rw [h]) i.succ := rfl + +@[simp] theorem coe_castSucc (i : Fin n) : (Fin.castSucc i : Nat) = i := rfl + +@[simp] theorem castSucc_mk (n i : Nat) (h : i < n) : castSucc ⟨i, h⟩ = ⟨i, Nat.lt.step h⟩ := rfl + +@[simp] theorem cast_castSucc {n' : Nat} {h : n + 1 = n' + 1} {i : Fin n} : + cast h (castSucc i) = castSucc (cast (Nat.succ.inj h) i) := rfl + +theorem castSucc_lt_succ (i : Fin n) : Fin.castSucc i < i.succ := + lt_def.2 <| by simp only [coe_castSucc, val_succ, Nat.lt_succ_self] + +theorem le_castSucc_iff {i : Fin (n + 1)} {j : Fin n} : i ≤ Fin.castSucc j ↔ i < j.succ := by + simpa [lt_def, le_def] using Nat.succ_le_succ_iff.symm + +theorem castSucc_lt_iff_succ_le {n : Nat} {i : Fin n} {j : Fin (n + 1)} : + Fin.castSucc i < j ↔ i.succ ≤ j := .rfl + +@[simp] theorem succ_last (n : Nat) : (last n).succ = last n.succ := rfl + +@[simp] theorem succ_eq_last_succ {n : Nat} (i : Fin n.succ) : + i.succ = last (n + 1) ↔ i = last n := by rw [← succ_last, succ_inj] + +@[simp] theorem castSucc_castLT (i : Fin (n + 1)) (h : (i : Nat) < n) : + castSucc (castLT i h) = i := rfl + +@[simp] theorem castLT_castSucc {n : Nat} (a : Fin n) (h : (a : Nat) < n) : + castLT (castSucc a) h = a := rfl + +@[simp] theorem castSucc_lt_castSucc_iff {a b : Fin n} : + Fin.castSucc a < Fin.castSucc b ↔ a < b := .rfl + +theorem castSucc_inj {a b : Fin n} : castSucc a = castSucc b ↔ a = b := by simp [ext_iff] + +theorem castSucc_lt_last (a : Fin n) : castSucc a < last n := a.is_lt + +@[simp] theorem castSucc_zero : castSucc (0 : Fin (n + 1)) = 0 := rfl + +@[simp] theorem castSucc_one {n : Nat} : castSucc (1 : Fin (n + 2)) = 1 := rfl + +/-- `castSucc i` is positive when `i` is positive -/ +theorem castSucc_pos {i : Fin (n + 1)} (h : 0 < i) : 0 < castSucc i := by + simpa [lt_def] using h + +@[simp] theorem castSucc_eq_zero_iff (a : Fin (n + 1)) : castSucc a = 0 ↔ a = 0 := by simp [ext_iff] + +theorem castSucc_ne_zero_iff (a : Fin (n + 1)) : castSucc a ≠ 0 ↔ a ≠ 0 := + not_congr <| castSucc_eq_zero_iff a + +theorem castSucc_fin_succ (n : Nat) (j : Fin n) : + castSucc (Fin.succ j) = Fin.succ (castSucc j) := by simp [Fin.ext_iff] + +@[simp] +theorem coeSucc_eq_succ {a : Fin n} : castSucc a + 1 = a.succ := by + cases n + · exact a.elim0 + · simp [ext_iff, add_def, Nat.mod_eq_of_lt (Nat.succ_lt_succ a.is_lt)] + +theorem lt_succ {a : Fin n} : castSucc a < a.succ := by + rw [castSucc, lt_def, coe_castAdd, val_succ]; exact Nat.lt_succ_self a.val + +theorem exists_castSucc_eq {n : Nat} {i : Fin (n + 1)} : (∃ j, castSucc j = i) ↔ i ≠ last n := + ⟨fun ⟨j, hj⟩ => hj ▸ Fin.ne_of_lt j.castSucc_lt_last, + fun hi => ⟨i.castLT <| Fin.val_lt_last hi, rfl⟩⟩ + +theorem succ_castSucc {n : Nat} (i : Fin n) : i.castSucc.succ = castSucc i.succ := rfl + +@[simp] theorem coe_addNat (m : Nat) (i : Fin n) : (addNat i m : Nat) = i + m := rfl + +@[simp] theorem addNat_one {i : Fin n} : addNat i 1 = i.succ := rfl + +theorem le_coe_addNat (m : Nat) (i : Fin n) : m ≤ addNat i m := + Nat.le_add_left _ _ + +@[simp] theorem addNat_mk (n i : Nat) (hi : i < m) : + addNat ⟨i, hi⟩ n = ⟨i + n, Nat.add_lt_add_right hi n⟩ := rfl + +@[simp] theorem cast_addNat_zero {n n' : Nat} (i : Fin n) (h : n + 0 = n') : + cast h (addNat i 0) = cast ((Nat.add_zero _).symm.trans h) i := rfl + +/-- For rewriting in the reverse direction, see `Fin.cast_addNat_left`. -/ +theorem addNat_cast {n n' m : Nat} (i : Fin n') (h : n' = n) : + addNat (cast h i) m = cast (congrArg (. + m) h) (addNat i m) := rfl + +theorem cast_addNat_left {n n' m : Nat} (i : Fin n') (h : n' + m = n + m) : + cast h (addNat i m) = addNat (cast (Nat.add_right_cancel h) i) m := rfl + +@[simp] theorem cast_addNat_right {n m m' : Nat} (i : Fin n) (h : n + m' = n + m) : + cast h (addNat i m') = addNat i m := + ext <| (congrArg ((· + ·) (i : Nat)) (Nat.add_left_cancel h) : _) + +@[simp] theorem coe_natAdd (n : Nat) {m : Nat} (i : Fin m) : (natAdd n i : Nat) = n + i := rfl + +@[simp] theorem natAdd_mk (n i : Nat) (hi : i < m) : + natAdd n ⟨i, hi⟩ = ⟨n + i, Nat.add_lt_add_left hi n⟩ := rfl + +theorem le_coe_natAdd (m : Nat) (i : Fin n) : m ≤ natAdd m i := Nat.le_add_right .. + +theorem natAdd_zero {n : Nat} : natAdd 0 = cast (Nat.zero_add n).symm := by ext; simp + +/-- For rewriting in the reverse direction, see `Fin.cast_natAdd_right`. -/ +theorem natAdd_cast {n n' : Nat} (m : Nat) (i : Fin n') (h : n' = n) : + natAdd m (cast h i) = cast (congrArg _ h) (natAdd m i) := rfl + +theorem cast_natAdd_right {n n' m : Nat} (i : Fin n') (h : m + n' = m + n) : + cast h (natAdd m i) = natAdd m (cast (Nat.add_left_cancel h) i) := rfl + +@[simp] theorem cast_natAdd_left {n m m' : Nat} (i : Fin n) (h : m' + n = m + n) : + cast h (natAdd m' i) = natAdd m i := + ext <| (congrArg (· + (i : Nat)) (Nat.add_right_cancel h) : _) + +theorem castAdd_natAdd (p m : Nat) {n : Nat} (i : Fin n) : + castAdd p (natAdd m i) = cast (Nat.add_assoc ..).symm (natAdd m (castAdd p i)) := rfl + +theorem natAdd_castAdd (p m : Nat) {n : Nat} (i : Fin n) : + natAdd m (castAdd p i) = cast (Nat.add_assoc ..) (castAdd p (natAdd m i)) := rfl + +theorem natAdd_natAdd (m n : Nat) {p : Nat} (i : Fin p) : + natAdd m (natAdd n i) = cast (Nat.add_assoc ..) (natAdd (m + n) i) := + ext <| (Nat.add_assoc ..).symm + +@[simp] +theorem cast_natAdd_zero {n n' : Nat} (i : Fin n) (h : 0 + n = n') : + cast h (natAdd 0 i) = cast ((Nat.zero_add _).symm.trans h) i := + ext <| Nat.zero_add _ + +@[simp] +theorem cast_natAdd (n : Nat) {m : Nat} (i : Fin m) : + cast (Nat.add_comm ..) (natAdd n i) = addNat i n := ext <| Nat.add_comm .. + +@[simp] +theorem cast_addNat {n : Nat} (m : Nat) (i : Fin n) : + cast (Nat.add_comm ..) (addNat i m) = natAdd m i := ext <| Nat.add_comm .. + +@[simp] theorem natAdd_last {m n : Nat} : natAdd n (last m) = last (n + m) := rfl + +theorem natAdd_castSucc {m n : Nat} {i : Fin m} : natAdd n (castSucc i) = castSucc (natAdd n i) := + rfl + +theorem rev_castAdd (k : Fin n) (m : Nat) : rev (castAdd m k) = addNat (rev k) m := ext <| by + rw [val_rev, coe_castAdd, coe_addNat, val_rev, Nat.sub_add_comm (Nat.succ_le_of_lt k.is_lt)] + +theorem rev_addNat (k : Fin n) (m : Nat) : rev (addNat k m) = castAdd m (rev k) := by + rw [← rev_rev (castAdd ..), rev_castAdd, rev_rev] + +theorem rev_castSucc (k : Fin n) : rev (castSucc k) = succ (rev k) := k.rev_castAdd 1 + +theorem rev_succ (k : Fin n) : rev (succ k) = castSucc (rev k) := k.rev_addNat 1 + +/-! ### pred -/ + +@[simp] theorem coe_pred (j : Fin (n + 1)) (h : j ≠ 0) : (j.pred h : Nat) = j - 1 := rfl + +@[simp] theorem succ_pred : ∀ (i : Fin (n + 1)) (h : i ≠ 0), (i.pred h).succ = i + | ⟨0, h⟩, hi => by simp only [mk_zero, ne_eq, not_true] at hi + | ⟨n + 1, h⟩, hi => rfl + +@[simp] +theorem pred_succ (i : Fin n) {h : i.succ ≠ 0} : i.succ.pred h = i := by + cases i + rfl + +theorem pred_eq_iff_eq_succ {n : Nat} (i : Fin (n + 1)) (hi : i ≠ 0) (j : Fin n) : + i.pred hi = j ↔ i = j.succ := + ⟨fun h => by simp only [← h, Fin.succ_pred], fun h => by simp only [h, Fin.pred_succ]⟩ + +theorem pred_mk_succ (i : Nat) (h : i < n + 1) : + Fin.pred ⟨i + 1, Nat.add_lt_add_right h 1⟩ (ne_of_val_ne (Nat.ne_of_gt (mk_succ_pos i h))) = + ⟨i, h⟩ := by + simp only [ext_iff, coe_pred, Nat.add_sub_cancel] + +@[simp] theorem pred_mk_succ' (i : Nat) (h₁ : i + 1 < n + 1 + 1) (h₂) : + Fin.pred ⟨i + 1, h₁⟩ h₂ = ⟨i, Nat.lt_of_succ_lt_succ h₁⟩ := pred_mk_succ i _ + +-- This is not a simp theorem by default, because `pred_mk_succ` is nicer when it applies. +theorem pred_mk {n : Nat} (i : Nat) (h : i < n + 1) (w) : Fin.pred ⟨i, h⟩ w = + ⟨i - 1, Nat.sub_lt_right_of_lt_add (Nat.pos_iff_ne_zero.2 (Fin.val_ne_of_ne w)) h⟩ := + rfl + +@[simp] theorem pred_le_pred_iff {n : Nat} {a b : Fin n.succ} {ha : a ≠ 0} {hb : b ≠ 0} : + a.pred ha ≤ b.pred hb ↔ a ≤ b := by rw [← succ_le_succ_iff, succ_pred, succ_pred] + +@[simp] theorem pred_lt_pred_iff {n : Nat} {a b : Fin n.succ} {ha : a ≠ 0} {hb : b ≠ 0} : + a.pred ha < b.pred hb ↔ a < b := by rw [← succ_lt_succ_iff, succ_pred, succ_pred] + +@[simp] theorem pred_inj : + ∀ {a b : Fin (n + 1)} {ha : a ≠ 0} {hb : b ≠ 0}, a.pred ha = b.pred hb ↔ a = b + | ⟨0, _⟩, _, ha, _ => by simp only [mk_zero, ne_eq, not_true] at ha + | ⟨i + 1, _⟩, ⟨0, _⟩, _, hb => by simp only [mk_zero, ne_eq, not_true] at hb + | ⟨i + 1, hi⟩, ⟨j + 1, hj⟩, ha, hb => by simp [ext_iff] + +@[simp] theorem pred_one {n : Nat} : + Fin.pred (1 : Fin (n + 2)) (Ne.symm (Fin.ne_of_lt one_pos)) = 0 := rfl + +theorem pred_add_one (i : Fin (n + 2)) (h : (i : Nat) < n + 1) : + pred (i + 1) (Fin.ne_of_gt (add_one_pos _ (lt_def.2 h))) = castLT i h := by + rw [ext_iff, coe_pred, coe_castLT, val_add, val_one, Nat.mod_eq_of_lt, Nat.add_sub_cancel] + exact Nat.add_lt_add_right h 1 + +@[simp] theorem coe_subNat (i : Fin (n + m)) (h : m ≤ i) : (i.subNat m h : Nat) = i - m := rfl + +@[simp] theorem subNat_mk {i : Nat} (h₁ : i < n + m) (h₂ : m ≤ i) : + subNat m ⟨i, h₁⟩ h₂ = ⟨i - m, Nat.sub_lt_right_of_lt_add h₂ h₁⟩ := rfl + +@[simp] theorem pred_castSucc_succ (i : Fin n) : + pred (castSucc i.succ) (Fin.ne_of_gt (castSucc_pos i.succ_pos)) = castSucc i := rfl + +@[simp] theorem addNat_subNat {i : Fin (n + m)} (h : m ≤ i) : addNat (subNat m i h) m = i := + ext <| Nat.sub_add_cancel h + +@[simp] theorem subNat_addNat (i : Fin n) (m : Nat) (h : m ≤ addNat i m := le_coe_addNat m i) : + subNat m (addNat i m) h = i := ext <| Nat.add_sub_cancel i m + +@[simp] theorem natAdd_subNat_cast {i : Fin (n + m)} (h : n ≤ i) : + natAdd n (subNat n (cast (Nat.add_comm ..) i) h) = i := by simp [← cast_addNat]; rfl + +/-! ### recursion and induction principles -/ + +/-- Define `motive n i` by induction on `i : Fin n` interpreted as `(0 : Fin (n - i)).succ.succ…`. +This function has two arguments: `zero n` defines `0`-th element `motive (n+1) 0` of an +`(n+1)`-tuple, and `succ n i` defines `(i+1)`-st element of `(n+1)`-tuple based on `n`, `i`, and +`i`-th element of `n`-tuple. -/ +-- FIXME: Performance review +@[elab_as_elim] def succRec {motive : ∀ n, Fin n → Sort _} + (zero : ∀ n, motive n.succ (0 : Fin (n + 1))) + (succ : ∀ n i, motive n i → motive n.succ i.succ) : ∀ {n : Nat} (i : Fin n), motive n i + | 0, i => i.elim0 + | Nat.succ n, ⟨0, _⟩ => by rw [mk_zero]; exact zero n + | Nat.succ _, ⟨Nat.succ i, h⟩ => succ _ _ (succRec zero succ ⟨i, Nat.lt_of_succ_lt_succ h⟩) + +/-- Define `motive n i` by induction on `i : Fin n` interpreted as `(0 : Fin (n - i)).succ.succ…`. +This function has two arguments: +`zero n` defines the `0`-th element `motive (n+1) 0` of an `(n+1)`-tuple, and +`succ n i` defines the `(i+1)`-st element of an `(n+1)`-tuple based on `n`, `i`, +and the `i`-th element of an `n`-tuple. + +A version of `Fin.succRec` taking `i : Fin n` as the first argument. -/ +-- FIXME: Performance review +@[elab_as_elim] def succRecOn {n : Nat} (i : Fin n) {motive : ∀ n, Fin n → Sort _} + (zero : ∀ n, motive (n + 1) 0) (succ : ∀ n i, motive n i → motive (Nat.succ n) i.succ) : + motive n i := i.succRec zero succ + +@[simp] theorem succRecOn_zero {motive : ∀ n, Fin n → Sort _} {zero succ} (n) : + @Fin.succRecOn (n + 1) 0 motive zero succ = zero n := by + cases n <;> rfl + +@[simp] theorem succRecOn_succ {motive : ∀ n, Fin n → Sort _} {zero succ} {n} (i : Fin n) : + @Fin.succRecOn (n + 1) i.succ motive zero succ = succ n i (Fin.succRecOn i zero succ) := by + cases i; rfl + +/-- Define `motive i` by induction on `i : Fin (n + 1)` via induction on the underlying `Nat` value. +This function has two arguments: `zero` handles the base case on `motive 0`, +and `succ` defines the inductive step using `motive i.castSucc`. +-/ +-- FIXME: Performance review +@[elab_as_elim] def induction {motive : Fin (n + 1) → Sort _} (zero : motive 0) + (succ : ∀ i : Fin n, motive (castSucc i) → motive i.succ) : + ∀ i : Fin (n + 1), motive i + | ⟨0, hi⟩ => by rwa [Fin.mk_zero] + | ⟨i+1, hi⟩ => succ ⟨i, Nat.lt_of_succ_lt_succ hi⟩ (induction zero succ ⟨i, Nat.lt_of_succ_lt hi⟩) + +@[simp] theorem induction_zero {motive : Fin (n + 1) → Sort _} (zero : motive 0) + (hs : ∀ i : Fin n, motive (castSucc i) → motive i.succ) : + (induction zero hs : ∀ i : Fin (n + 1), motive i) 0 = zero := rfl + +@[simp] theorem induction_succ {motive : Fin (n + 1) → Sort _} (zero : motive 0) + (succ : ∀ i : Fin n, motive (castSucc i) → motive i.succ) (i : Fin n) : + induction (motive := motive) zero succ i.succ = succ i (induction zero succ (castSucc i)) := rfl + +/-- Define `motive i` by induction on `i : Fin (n + 1)` via induction on the underlying `Nat` value. +This function has two arguments: `zero` handles the base case on `motive 0`, +and `succ` defines the inductive step using `motive i.castSucc`. + +A version of `Fin.induction` taking `i : Fin (n + 1)` as the first argument. +-/ +-- FIXME: Performance review +@[elab_as_elim] def inductionOn (i : Fin (n + 1)) {motive : Fin (n + 1) → Sort _} (zero : motive 0) + (succ : ∀ i : Fin n, motive (castSucc i) → motive i.succ) : motive i := induction zero succ i + +/-- Define `f : Π i : Fin n.succ, motive i` by separately handling the cases `i = 0` and +`i = j.succ`, `j : Fin n`. -/ +@[elab_as_elim] def cases {motive : Fin (n + 1) → Sort _} + (zero : motive 0) (succ : ∀ i : Fin n, motive i.succ) : + ∀ i : Fin (n + 1), motive i := induction zero fun i _ => succ i + +@[simp] theorem cases_zero {n} {motive : Fin (n + 1) → Sort _} {zero succ} : + @Fin.cases n motive zero succ 0 = zero := rfl + +@[simp] theorem cases_succ {n} {motive : Fin (n + 1) → Sort _} {zero succ} (i : Fin n) : + @Fin.cases n motive zero succ i.succ = succ i := rfl + +@[simp] theorem cases_succ' {n} {motive : Fin (n + 1) → Sort _} {zero succ} + {i : Nat} (h : i + 1 < n + 1) : + @Fin.cases n motive zero succ ⟨i.succ, h⟩ = succ ⟨i, Nat.lt_of_succ_lt_succ h⟩ := rfl + +theorem forall_fin_succ {P : Fin (n + 1) → Prop} : (∀ i, P i) ↔ P 0 ∧ ∀ i : Fin n, P i.succ := + ⟨fun H => ⟨H 0, fun _ => H _⟩, fun ⟨H0, H1⟩ i => Fin.cases H0 H1 i⟩ + +theorem exists_fin_succ {P : Fin (n + 1) → Prop} : (∃ i, P i) ↔ P 0 ∨ ∃ i : Fin n, P i.succ := + ⟨fun ⟨i, h⟩ => Fin.cases Or.inl (fun i hi => Or.inr ⟨i, hi⟩) i h, fun h => + (h.elim fun h => ⟨0, h⟩) fun ⟨i, hi⟩ => ⟨i.succ, hi⟩⟩ + +theorem forall_fin_one {p : Fin 1 → Prop} : (∀ i, p i) ↔ p 0 := + ⟨fun h => h _, fun h i => Subsingleton.elim i 0 ▸ h⟩ + +theorem exists_fin_one {p : Fin 1 → Prop} : (∃ i, p i) ↔ p 0 := + ⟨fun ⟨i, h⟩ => Subsingleton.elim i 0 ▸ h, fun h => ⟨_, h⟩⟩ + +theorem forall_fin_two {p : Fin 2 → Prop} : (∀ i, p i) ↔ p 0 ∧ p 1 := + forall_fin_succ.trans <| and_congr_right fun _ => forall_fin_one + +theorem exists_fin_two {p : Fin 2 → Prop} : (∃ i, p i) ↔ p 0 ∨ p 1 := + exists_fin_succ.trans <| or_congr_right exists_fin_one + +theorem fin_two_eq_of_eq_zero_iff : ∀ {a b : Fin 2}, (a = 0 ↔ b = 0) → a = b := by + simp only [forall_fin_two]; decide + +/-- +Define `motive i` by reverse induction on `i : Fin (n + 1)` via induction on the underlying `Nat` +value. This function has two arguments: `last` handles the base case on `motive (Fin.last n)`, +and `cast` defines the inductive step using `motive i.succ`, inducting downwards. +-/ +@[elab_as_elim] def reverseInduction {motive : Fin (n + 1) → Sort _} (last : motive (Fin.last n)) + (cast : ∀ i : Fin n, motive i.succ → motive (castSucc i)) (i : Fin (n + 1)) : motive i := + if hi : i = Fin.last n then _root_.cast (congrArg motive hi.symm) last + else + let j : Fin n := ⟨i, Nat.lt_of_le_of_ne (Nat.le_of_lt_succ i.2) fun h => hi (Fin.ext h)⟩ + cast _ (reverseInduction last cast j.succ) +termination_by n + 1 - i +decreasing_by decreasing_with + -- FIXME: we put the proof down here to avoid getting a dummy `have` in the definition + exact Nat.add_sub_add_right .. ▸ Nat.sub_lt_sub_left i.2 (Nat.lt_succ_self i) + +@[simp] theorem reverseInduction_last {n : Nat} {motive : Fin (n + 1) → Sort _} {zero succ} : + (reverseInduction zero succ (Fin.last n) : motive (Fin.last n)) = zero := by + rw [reverseInduction]; simp; rfl + +@[simp] theorem reverseInduction_castSucc {n : Nat} {motive : Fin (n + 1) → Sort _} {zero succ} + (i : Fin n) : reverseInduction (motive := motive) zero succ (castSucc i) = + succ i (reverseInduction zero succ i.succ) := by + rw [reverseInduction, dif_neg (Fin.ne_of_lt (Fin.castSucc_lt_last i))]; rfl + +/-- Define `f : Π i : Fin n.succ, motive i` by separately handling the cases `i = Fin.last n` and +`i = j.castSucc`, `j : Fin n`. -/ +@[elab_as_elim] def lastCases {n : Nat} {motive : Fin (n + 1) → Sort _} (last : motive (Fin.last n)) + (cast : ∀ i : Fin n, motive (castSucc i)) (i : Fin (n + 1)) : motive i := + reverseInduction last (fun i _ => cast i) i + +@[simp] theorem lastCases_last {n : Nat} {motive : Fin (n + 1) → Sort _} {last cast} : + (Fin.lastCases last cast (Fin.last n) : motive (Fin.last n)) = last := + reverseInduction_last .. + +@[simp] theorem lastCases_castSucc {n : Nat} {motive : Fin (n + 1) → Sort _} {last cast} + (i : Fin n) : (Fin.lastCases last cast (Fin.castSucc i) : motive (Fin.castSucc i)) = cast i := + reverseInduction_castSucc .. + +/-- Define `f : Π i : Fin (m + n), motive i` by separately handling the cases `i = castAdd n i`, +`j : Fin m` and `i = natAdd m j`, `j : Fin n`. -/ +@[elab_as_elim] def addCases {m n : Nat} {motive : Fin (m + n) → Sort u} + (left : ∀ i, motive (castAdd n i)) (right : ∀ i, motive (natAdd m i)) + (i : Fin (m + n)) : motive i := + if hi : (i : Nat) < m then (castAdd_castLT n i hi) ▸ (left (castLT i hi)) + else (natAdd_subNat_cast (Nat.le_of_not_lt hi)) ▸ (right _) + +@[simp] theorem addCases_left {m n : Nat} {motive : Fin (m + n) → Sort _} {left right} (i : Fin m) : + addCases (motive := motive) left right (Fin.castAdd n i) = left i := by + rw [addCases, dif_pos (castAdd_lt _ _)]; rfl + +@[simp] +theorem addCases_right {m n : Nat} {motive : Fin (m + n) → Sort _} {left right} (i : Fin n) : + addCases (motive := motive) left right (natAdd m i) = right i := by + have : ¬(natAdd m i : Nat) < m := Nat.not_lt.2 (le_coe_natAdd ..) + rw [addCases, dif_neg this]; exact eq_of_heq <| (eqRec_heq _ _).trans (by congr 1; simp) + +/-! ### add -/ + +@[simp] theorem ofNat'_add (x : Nat) (lt : 0 < n) (y : Fin n) : + Fin.ofNat' x lt + y = Fin.ofNat' (x + y.val) lt := by + apply Fin.eq_of_val_eq + simp [Fin.ofNat', Fin.add_def] + +@[simp] theorem add_ofNat' (x : Fin n) (y : Nat) (lt : 0 < n) : + x + Fin.ofNat' y lt = Fin.ofNat' (x.val + y) lt := by + apply Fin.eq_of_val_eq + simp [Fin.ofNat', Fin.add_def] + +/-! ### sub -/ + +protected theorem coe_sub (a b : Fin n) : ((a - b : Fin n) : Nat) = (a + (n - b)) % n := by + cases a; cases b; rfl + +@[simp] theorem ofNat'_sub (x : Nat) (lt : 0 < n) (y : Fin n) : + Fin.ofNat' x lt - y = Fin.ofNat' (x + (n - y.val)) lt := by + apply Fin.eq_of_val_eq + simp [Fin.ofNat', Fin.sub_def] + +@[simp] theorem sub_ofNat' (x : Fin n) (y : Nat) (lt : 0 < n) : + x - Fin.ofNat' y lt = Fin.ofNat' (x.val + (n - y % n)) lt := by + apply Fin.eq_of_val_eq + simp [Fin.ofNat', Fin.sub_def] + +private theorem _root_.Nat.mod_eq_sub_of_lt_two_mul {x n} (h₁ : n ≤ x) (h₂ : x < 2 * n) : + x % n = x - n := by + rw [Nat.mod_eq, if_pos (by omega), Nat.mod_eq_of_lt (by omega)] + +theorem coe_sub_iff_le {a b : Fin n} : (↑(a - b) : Nat) = a - b ↔ b ≤ a := by + rw [sub_def, le_def] + dsimp only + if h : n ≤ a + (n - b) then + rw [Nat.mod_eq_sub_of_lt_two_mul h] + all_goals omega + else + rw [Nat.mod_eq_of_lt] + all_goals omega + +theorem coe_sub_iff_lt {a b : Fin n} : (↑(a - b) : Nat) = n + a - b ↔ a < b := by + rw [sub_def, lt_def] + dsimp only + if h : n ≤ a + (n - b) then + rw [Nat.mod_eq_sub_of_lt_two_mul h] + all_goals omega + else + rw [Nat.mod_eq_of_lt] + all_goals omega + +/-! ### mul -/ + +theorem val_mul {n : Nat} : ∀ a b : Fin n, (a * b).val = a.val * b.val % n + | ⟨_, _⟩, ⟨_, _⟩ => rfl + +theorem coe_mul {n : Nat} : ∀ a b : Fin n, ((a * b : Fin n) : Nat) = a * b % n + | ⟨_, _⟩, ⟨_, _⟩ => rfl + +protected theorem mul_one (k : Fin (n + 1)) : k * 1 = k := by + match n with + | 0 => exact Subsingleton.elim (α := Fin 1) .. + | n+1 => simp [ext_iff, mul_def, Nat.mod_eq_of_lt (is_lt k)] + +protected theorem mul_comm (a b : Fin n) : a * b = b * a := + ext <| by rw [mul_def, mul_def, Nat.mul_comm] + +protected theorem one_mul (k : Fin (n + 1)) : (1 : Fin (n + 1)) * k = k := by + rw [Fin.mul_comm, Fin.mul_one] + +protected theorem mul_zero (k : Fin (n + 1)) : k * 0 = 0 := by simp [ext_iff, mul_def] + +protected theorem zero_mul (k : Fin (n + 1)) : (0 : Fin (n + 1)) * k = 0 := by + simp [ext_iff, mul_def] + +end Fin + +namespace USize + +@[simp] theorem lt_def {a b : USize} : a < b ↔ a.toNat < b.toNat := .rfl + +@[simp] theorem le_def {a b : USize} : a ≤ b ↔ a.toNat ≤ b.toNat := .rfl + +@[simp] theorem zero_toNat : (0 : USize).toNat = 0 := Nat.zero_mod _ + +@[simp] theorem mod_toNat (a b : USize) : (a % b).toNat = a.toNat % b.toNat := + Fin.mod_val .. + +@[simp] theorem div_toNat (a b : USize) : (a / b).toNat = a.toNat / b.toNat := + Fin.div_val .. + +@[simp] theorem modn_toNat (a : USize) (b : Nat) : (a.modn b).toNat = a.toNat % b := + Fin.modn_val .. + +theorem mod_lt (a b : USize) (h : 0 < b) : a % b < b := USize.modn_lt _ (by simp at h; exact h) + +theorem toNat.inj : ∀ {a b : USize}, a.toNat = b.toNat → a = b + | ⟨_, _⟩, ⟨_, _⟩, rfl => rfl + +end USize diff --git a/src/Init/Data/Int/Bitwise.lean b/src/Init/Data/Int/Bitwise.lean index c894daa44c..2bcce0a8e8 100644 --- a/src/Init/Data/Int/Bitwise.lean +++ b/src/Init/Data/Int/Bitwise.lean @@ -5,7 +5,7 @@ Authors: Mario Carneiro -/ prelude import Init.Data.Int.Basic -import Init.Data.Nat.Bitwise +import Init.Data.Nat.Bitwise.Basic namespace Int diff --git a/src/Init/Data/Nat/Bitwise.lean b/src/Init/Data/Nat/Bitwise.lean index eecafaeb07..5827803a53 100644 --- a/src/Init/Data/Nat/Bitwise.lean +++ b/src/Init/Data/Nat/Bitwise.lean @@ -1,54 +1,3 @@ -/- -Copyright (c) 2019 Microsoft Corporation. All rights reserved. -Released under Apache 2.0 license as described in the file LICENSE. -Authors: Leonardo de Moura --/ prelude -import Init.Data.Nat.Basic -import Init.Data.Nat.Div -import Init.Coe - -namespace Nat - -theorem bitwise_rec_lemma {n : Nat} (hNe : n ≠ 0) : n / 2 < n := - Nat.div_lt_self (Nat.zero_lt_of_ne_zero hNe) (Nat.lt_succ_self _) - -def bitwise (f : Bool → Bool → Bool) (n m : Nat) : Nat := - if n = 0 then - if f false true then m else 0 - else if m = 0 then - if f true false then n else 0 - else - let n' := n / 2 - let m' := m / 2 - let b₁ := n % 2 = 1 - let b₂ := m % 2 = 1 - let r := bitwise f n' m' - if f b₁ b₂ then - r+r+1 - else - r+r -decreasing_by apply bitwise_rec_lemma; assumption - -@[extern "lean_nat_land"] -def land : @& Nat → @& Nat → Nat := bitwise and -@[extern "lean_nat_lor"] -def lor : @& Nat → @& Nat → Nat := bitwise or -@[extern "lean_nat_lxor"] -def xor : @& Nat → @& Nat → Nat := bitwise bne -@[extern "lean_nat_shiftl"] -def shiftLeft : @& Nat → @& Nat → Nat - | n, 0 => n - | n, succ m => shiftLeft (2*n) m -@[extern "lean_nat_shiftr"] -def shiftRight : @& Nat → @& Nat → Nat - | n, 0 => n - | n, succ m => shiftRight n m / 2 - -instance : AndOp Nat := ⟨Nat.land⟩ -instance : OrOp Nat := ⟨Nat.lor⟩ -instance : Xor Nat := ⟨Nat.xor⟩ -instance : ShiftLeft Nat := ⟨Nat.shiftLeft⟩ -instance : ShiftRight Nat := ⟨Nat.shiftRight⟩ - -end Nat +import Init.Data.Nat.Bitwise.Basic +import Init.Data.Nat.Bitwise.Lemmas diff --git a/src/Init/Data/Nat/Bitwise/Basic.lean b/src/Init/Data/Nat/Bitwise/Basic.lean new file mode 100644 index 0000000000..57e6bb3515 --- /dev/null +++ b/src/Init/Data/Nat/Bitwise/Basic.lean @@ -0,0 +1,63 @@ +/- +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Data.Nat.Basic +import Init.Data.Nat.Div +import Init.Coe + +namespace Nat + +theorem bitwise_rec_lemma {n : Nat} (hNe : n ≠ 0) : n / 2 < n := + Nat.div_lt_self (Nat.zero_lt_of_ne_zero hNe) (Nat.lt_succ_self _) + +def bitwise (f : Bool → Bool → Bool) (n m : Nat) : Nat := + if n = 0 then + if f false true then m else 0 + else if m = 0 then + if f true false then n else 0 + else + let n' := n / 2 + let m' := m / 2 + let b₁ := n % 2 = 1 + let b₂ := m % 2 = 1 + let r := bitwise f n' m' + if f b₁ b₂ then + r+r+1 + else + r+r +decreasing_by apply bitwise_rec_lemma; assumption + +@[extern "lean_nat_land"] +def land : @& Nat → @& Nat → Nat := bitwise and +@[extern "lean_nat_lor"] +def lor : @& Nat → @& Nat → Nat := bitwise or +@[extern "lean_nat_lxor"] +def xor : @& Nat → @& Nat → Nat := bitwise bne +@[extern "lean_nat_shiftl"] +def shiftLeft : @& Nat → @& Nat → Nat + | n, 0 => n + | n, succ m => shiftLeft (2*n) m +@[extern "lean_nat_shiftr"] +def shiftRight : @& Nat → @& Nat → Nat + | n, 0 => n + | n, succ m => shiftRight n m / 2 + +instance : AndOp Nat := ⟨Nat.land⟩ +instance : OrOp Nat := ⟨Nat.lor⟩ +instance : Xor Nat := ⟨Nat.xor⟩ +instance : ShiftLeft Nat := ⟨Nat.shiftLeft⟩ +instance : ShiftRight Nat := ⟨Nat.shiftRight⟩ + +/-! +### testBit +We define an operation for testing individual bits in the binary representation +of a number. +-/ + +/-- `testBit m n` returns whether the `(n+1)` least significant bit is `1` or `0`-/ +def testBit (m n : Nat) : Bool := (m >>> n) &&& 1 != 0 + +end Nat diff --git a/src/Init/Data/Nat/Bitwise/Lemmas.lean b/src/Init/Data/Nat/Bitwise/Lemmas.lean new file mode 100644 index 0000000000..23e325368a --- /dev/null +++ b/src/Init/Data/Nat/Bitwise/Lemmas.lean @@ -0,0 +1,503 @@ +/- +Copyright (c) 2023 by the authors listed in the file AUTHORS and their +institutional affiliations. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Joe Hendrix +-/ + +prelude +import Init.Data.Bool +import Init.Data.Nat.Bitwise.Basic +import Init.Data.Nat.Lemmas +import Init.TacticsExtra +import Init.Omega + +/- +This module defines properties of the bitwise operations on Natural numbers. + +It is primarily intended to support the bitvector library. +-/ + +namespace Nat + +@[local simp] +private theorem one_div_two : 1/2 = 0 := by trivial + +private theorem two_pow_succ_sub_succ_div_two : (2 ^ (n+1) - (x + 1)) / 2 = 2^n - (x/2 + 1) := by + if h : x + 1 ≤ 2 ^ (n + 1) then + apply fun x => (Nat.sub_eq_of_eq_add x).symm + apply Eq.trans _ + apply Nat.add_mul_div_left _ _ Nat.zero_lt_two + rw [← Nat.sub_add_comm h] + rw [Nat.add_sub_assoc (by omega)] + rw [Nat.pow_succ'] + rw [Nat.mul_add_div Nat.zero_lt_two] + simp [show (2 * (x / 2 + 1) - (x + 1)) / 2 = 0 by omega] + else + rw [Nat.pow_succ'] at * + omega + +private theorem two_pow_succ_sub_one_div_two : (2 ^ (n+1) - 1) / 2 = 2^n - 1 := + two_pow_succ_sub_succ_div_two + +private theorem two_mul_sub_one {n : Nat} (n_pos : n > 0) : (2*n - 1) % 2 = 1 := by + match n with + | 0 => contradiction + | n + 1 => simp [Nat.mul_succ, Nat.mul_add_mod, mod_eq_of_lt] + +/-! ### Preliminaries -/ + +/-- +An induction principal that works on divison by two. +-/ +noncomputable def div2Induction {motive : Nat → Sort u} + (n : Nat) (ind : ∀(n : Nat), (n > 0 → motive (n/2)) → motive n) : motive n := by + induction n using Nat.strongInductionOn with + | ind n hyp => + apply ind + intro n_pos + if n_eq : n = 0 then + simp [n_eq] at n_pos + else + apply hyp + exact Nat.div_lt_self n_pos (Nat.le_refl _) + +@[simp] theorem zero_and (x : Nat) : 0 &&& x = 0 := by rfl + +@[simp] theorem and_zero (x : Nat) : x &&& 0 = 0 := by + simp only [HAnd.hAnd, AndOp.and, land] + unfold bitwise + simp + +@[simp] theorem and_one_is_mod (x : Nat) : x &&& 1 = x % 2 := by + if xz : x = 0 then + simp [xz, zero_and] + else + have andz := and_zero (x/2) + simp only [HAnd.hAnd, AndOp.and, land] at andz + simp only [HAnd.hAnd, AndOp.and, land] + unfold bitwise + cases mod_two_eq_zero_or_one x with | _ p => + simp [xz, p, andz, one_div_two, mod_eq_of_lt] + +/-! ### testBit -/ + +@[simp] theorem zero_testBit (i : Nat) : testBit 0 i = false := by + simp only [testBit, zero_shiftRight, zero_and, bne_self_eq_false] + +@[simp] theorem testBit_zero (x : Nat) : testBit x 0 = decide (x % 2 = 1) := by + cases mod_two_eq_zero_or_one x with | _ p => simp [testBit, p] + +@[simp] theorem testBit_succ (x i : Nat) : testBit x (succ i) = testBit (x/2) i := by + unfold testBit + simp [shiftRight_succ_inside] + +theorem testBit_to_div_mod {x : Nat} : testBit x i = decide (x / 2^i % 2 = 1) := by + induction i generalizing x with + | zero => + unfold testBit + cases mod_two_eq_zero_or_one x with | _ xz => simp [xz] + | succ i hyp => + simp [hyp, Nat.div_div_eq_div_mul, Nat.pow_succ'] + +theorem ne_zero_implies_bit_true {x : Nat} (xnz : x ≠ 0) : ∃ i, testBit x i := by + induction x using div2Induction with + | ind x hyp => + have x_pos : x > 0 := Nat.pos_of_ne_zero xnz + match mod_two_eq_zero_or_one x with + | Or.inl mod2_eq => + rw [←div_add_mod x 2] at xnz + simp only [mod2_eq, ne_eq, Nat.mul_eq_zero, Nat.add_zero, false_or] at xnz + have ⟨d, dif⟩ := hyp x_pos xnz + apply Exists.intro (d+1) + simp_all + | Or.inr mod2_eq => + apply Exists.intro 0 + simp_all + +theorem ne_implies_bit_diff {x y : Nat} (p : x ≠ y) : ∃ i, testBit x i ≠ testBit y i := by + induction y using Nat.div2Induction generalizing x with + | ind y hyp => + cases Nat.eq_zero_or_pos y with + | inl yz => + simp only [yz, Nat.zero_testBit, Bool.eq_false_iff] + simp only [yz] at p + have ⟨i,ip⟩ := ne_zero_implies_bit_true p + apply Exists.intro i + simp [ip] + | inr ypos => + if lsb_diff : x % 2 = y % 2 then + rw [←Nat.div_add_mod x 2, ←Nat.div_add_mod y 2] at p + simp only [ne_eq, lsb_diff, Nat.add_right_cancel_iff, + Nat.zero_lt_succ, Nat.mul_left_cancel_iff] at p + have ⟨i, ieq⟩ := hyp ypos p + apply Exists.intro (i+1) + simpa + else + apply Exists.intro 0 + simp only [testBit_zero] + revert lsb_diff + cases mod_two_eq_zero_or_one x with | _ p => + cases mod_two_eq_zero_or_one y with | _ q => + simp [p,q] + +/-- +`eq_of_testBit_eq` allows proving two natural numbers are equal +if their bits are all equal. +-/ +theorem eq_of_testBit_eq {x y : Nat} (pred : ∀i, testBit x i = testBit y i) : x = y := by + if h : x = y then + exact h + else + let ⟨i,eq⟩ := ne_implies_bit_diff h + have p := pred i + contradiction + +theorem ge_two_pow_implies_high_bit_true {x : Nat} (p : x ≥ 2^n) : ∃ i, i ≥ n ∧ testBit x i := by + induction x using div2Induction generalizing n with + | ind x hyp => + have x_pos : x > 0 := Nat.lt_of_lt_of_le (Nat.two_pow_pos n) p + have x_ne_zero : x ≠ 0 := Nat.ne_of_gt x_pos + match n with + | zero => + let ⟨j, jp⟩ := ne_zero_implies_bit_true x_ne_zero + exact Exists.intro j (And.intro (Nat.zero_le _) jp) + | succ n => + have x_ge_n : x / 2 ≥ 2 ^ n := by + simpa [le_div_iff_mul_le, ← Nat.pow_succ'] using p + have ⟨j, jp⟩ := @hyp x_pos n x_ge_n + apply Exists.intro (j+1) + apply And.intro + case left => + exact (Nat.succ_le_succ jp.left) + case right => + simpa using jp.right + +theorem testBit_implies_ge {x : Nat} (p : testBit x i = true) : x ≥ 2^i := by + simp only [testBit_to_div_mod] at p + apply Decidable.by_contra + intro not_ge + have x_lt : x < 2^i := Nat.lt_of_not_le not_ge + simp [div_eq_of_lt x_lt] at p + +theorem testBit_lt_two_pow {x i : Nat} (lt : x < 2^i) : x.testBit i = false := by + match p : x.testBit i with + | false => trivial + | true => + exfalso + exact Nat.not_le_of_gt lt (testBit_implies_ge p) + +theorem lt_pow_two_of_testBit (x : Nat) (p : ∀i, i ≥ n → testBit x i = false) : x < 2^n := by + apply Decidable.by_contra + intro not_lt + have x_ge_n := Nat.ge_of_not_lt not_lt + have ⟨i, ⟨i_ge_n, test_true⟩⟩ := ge_two_pow_implies_high_bit_true x_ge_n + have test_false := p _ i_ge_n + simp only [test_true] at test_false + +/-! ### testBit -/ + +private theorem succ_mod_two : succ x % 2 = 1 - x % 2 := by + induction x with + | zero => + trivial + | succ x hyp => + have p : 2 ≤ x + 2 := Nat.le_add_left _ _ + simp [Nat.mod_eq (x+2) 2, p, hyp] + cases Nat.mod_two_eq_zero_or_one x with | _ p => simp [p] + +private theorem testBit_succ_zero : testBit (x + 1) 0 = not (testBit x 0) := by + simp [testBit_to_div_mod, succ_mod_two] + cases Nat.mod_two_eq_zero_or_one x with | _ p => + simp [p] + +theorem testBit_two_pow_add_eq (x i : Nat) : testBit (2^i + x) i = not (testBit x i) := by + simp [testBit_to_div_mod, add_div_left, Nat.two_pow_pos, succ_mod_two] + cases mod_two_eq_zero_or_one (x / 2 ^ i) with + | _ p => simp [p] + +theorem testBit_mul_two_pow_add_eq (a b i : Nat) : + testBit (2^i*a + b) i = Bool.xor (a%2 = 1) (testBit b i) := by + match a with + | 0 => simp + | a+1 => + simp [Nat.mul_succ, Nat.add_assoc, + testBit_mul_two_pow_add_eq a, + testBit_two_pow_add_eq, + Nat.succ_mod_two] + cases mod_two_eq_zero_or_one a with + | _ p => simp [p] + +theorem testBit_two_pow_add_gt {i j : Nat} (j_lt_i : j < i) (x : Nat) : + testBit (2^i + x) j = testBit x j := by + have i_def : i = j + (i-j) := (Nat.add_sub_cancel' (Nat.le_of_lt j_lt_i)).symm + rw [i_def] + simp only [testBit_to_div_mod, Nat.pow_add, + Nat.add_comm x, Nat.mul_add_div (Nat.two_pow_pos _)] + match i_sub_j_eq : i - j with + | 0 => + exfalso + rw [Nat.sub_eq_zero_iff_le] at i_sub_j_eq + exact Nat.not_le_of_gt j_lt_i i_sub_j_eq + | d+1 => + simp [pow_succ, Nat.mul_comm _ 2, Nat.mul_add_mod] + +@[simp] theorem testBit_mod_two_pow (x j i : Nat) : + testBit (x % 2^j) i = (decide (i < j) && testBit x i) := by + induction x using Nat.strongInductionOn generalizing j i with + | ind x hyp => + rw [mod_eq] + rcases Nat.lt_or_ge x (2^j) with x_lt_j | x_ge_j + · have not_j_le_x := Nat.not_le_of_gt x_lt_j + simp [not_j_le_x] + rcases Nat.lt_or_ge i j with i_lt_j | i_ge_j + · simp [i_lt_j] + · have x_lt : x < 2^i := + calc x < 2^j := x_lt_j + _ ≤ 2^i := Nat.pow_le_pow_of_le_right Nat.zero_lt_two i_ge_j + simp [Nat.testBit_lt_two_pow x_lt] + · generalize y_eq : x - 2^j = y + have x_eq : x = y + 2^j := Nat.eq_add_of_sub_eq x_ge_j y_eq + simp only [Nat.two_pow_pos, x_eq, Nat.le_add_left, true_and, ite_true] + have y_lt_x : y < x := by + simp [x_eq] + exact Nat.lt_add_of_pos_right (Nat.two_pow_pos j) + simp only [hyp y y_lt_x] + if i_lt_j : i < j then + rw [ Nat.add_comm _ (2^_), testBit_two_pow_add_gt i_lt_j] + else + simp [i_lt_j] + +theorem testBit_one_zero : testBit 1 0 = true := by trivial + +theorem testBit_two_pow_sub_succ (h₂ : x < 2 ^ n) (i : Nat) : + testBit (2^n - (x + 1)) i = (decide (i < n) && ! testBit x i) := by + induction i generalizing n x with + | zero => + simp only [testBit_zero, zero_eq, Bool.and_eq_true, decide_eq_true_eq, + Bool.not_eq_true'] + match n with + | 0 => simp + | n+1 => + -- just logic + omega: + simp only [zero_lt_succ, decide_True, Bool.true_and] + rw [Nat.pow_succ', ← decide_not, decide_eq_decide] + rw [Nat.pow_succ'] at h₂ + omega + | succ i ih => + simp only [testBit_succ] + match n with + | 0 => + simp only [pow_zero, succ_sub_succ_eq_sub, Nat.zero_sub, Nat.zero_div, zero_testBit] + rw [decide_eq_false] <;> simp + | n+1 => + rw [Nat.two_pow_succ_sub_succ_div_two, ih] + · simp [Nat.succ_lt_succ_iff] + · rw [Nat.pow_succ'] at h₂ + omega + +@[simp] theorem testBit_two_pow_sub_one (n i : Nat) : testBit (2^n-1) i = decide (i < n) := by + rw [testBit_two_pow_sub_succ] + · simp + · exact Nat.two_pow_pos _ + +theorem testBit_bool_to_nat (b : Bool) (i : Nat) : + testBit (Bool.toNat b) i = (decide (i = 0) && b) := by + cases b <;> cases i <;> + simp [testBit_to_div_mod, Nat.pow_succ, Nat.mul_comm _ 2, + ←Nat.div_div_eq_div_mul _ 2, one_div_two, + Nat.mod_eq_of_lt] + +/-! ### bitwise -/ + +theorem testBit_bitwise + (false_false_axiom : f false false = false) (x y i : Nat) +: (bitwise f x y).testBit i = f (x.testBit i) (y.testBit i) := by + induction i using Nat.strongInductionOn generalizing x y with + | ind i hyp => + unfold bitwise + if x_zero : x = 0 then + cases p : f false true <;> + cases yi : testBit y i <;> + simp [x_zero, p, yi, false_false_axiom] + else if y_zero : y = 0 then + simp [x_zero, y_zero] + cases p : f true false <;> + cases xi : testBit x i <;> + simp [p, xi, false_false_axiom] + else + simp only [x_zero, y_zero, ←Nat.two_mul] + cases i with + | zero => + cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) <;> + simp [p, Nat.mul_add_mod, mod_eq_of_lt] + | succ i => + have hyp_i := hyp i (Nat.le_refl (i+1)) + cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) <;> + simp [p, one_div_two, hyp_i, Nat.mul_add_div] + +/-! ### bitwise -/ + +@[local simp] +private theorem eq_0_of_lt_one (x : Nat) : x < 1 ↔ x = 0 := + Iff.intro + (fun p => + match x with + | 0 => Eq.refl 0 + | _+1 => False.elim (not_lt_zero _ (Nat.lt_of_succ_lt_succ p))) + (fun p => by simp [p, Nat.zero_lt_succ]) + +private theorem eq_0_of_lt (x : Nat) : x < 2^ 0 ↔ x = 0 := eq_0_of_lt_one x + +@[local simp] +private theorem zero_lt_pow (n : Nat) : 0 < 2^n := by + induction n + case zero => simp [eq_0_of_lt] + case succ n hyp => simpa [pow_succ] + +private theorem div_two_le_of_lt_two {m n : Nat} (p : m < 2 ^ succ n) : m / 2 < 2^n := by + simp [div_lt_iff_lt_mul Nat.zero_lt_two] + exact p + +/-- This provides a bound on bitwise operations. -/ +theorem bitwise_lt_two_pow (left : x < 2^n) (right : y < 2^n) : (Nat.bitwise f x y) < 2^n := by + induction n generalizing x y with + | zero => + simp only [eq_0_of_lt] at left right + unfold bitwise + simp [left, right] + | succ n hyp => + unfold bitwise + if x_zero : x = 0 then + simp only [x_zero, if_pos] + by_cases p : f false true = true <;> simp [p, right] + else if y_zero : y = 0 then + simp only [x_zero, y_zero, if_neg, if_pos] + by_cases p : f true false = true <;> simp [p, left] + else + simp only [x_zero, y_zero, if_neg] + have hyp1 := hyp (div_two_le_of_lt_two left) (div_two_le_of_lt_two right) + by_cases p : f (decide (x % 2 = 1)) (decide (y % 2 = 1)) = true <;> + simp [p, pow_succ, mul_succ, Nat.add_assoc] + case pos => + apply lt_of_succ_le + simp only [← Nat.succ_add] + apply Nat.add_le_add <;> exact hyp1 + case neg => + apply Nat.add_lt_add <;> exact hyp1 + +/-! ### and -/ + +@[simp] theorem testBit_and (x y i : Nat) : (x &&& y).testBit i = (x.testBit i && y.testBit i) := by + simp [HAnd.hAnd, AndOp.and, land, testBit_bitwise ] + +theorem and_lt_two_pow (x : Nat) {y n : Nat} (right : y < 2^n) : (x &&& y) < 2^n := by + apply lt_pow_two_of_testBit + intro i i_ge_n + have yf : testBit y i = false := by + apply Nat.testBit_lt_two_pow + apply Nat.lt_of_lt_of_le right + exact pow_le_pow_of_le_right Nat.zero_lt_two i_ge_n + simp [testBit_and, yf] + +@[simp] theorem and_pow_two_is_mod (x n : Nat) : x &&& (2^n-1) = x % 2^n := by + apply eq_of_testBit_eq + intro i + simp only [testBit_and, testBit_mod_two_pow] + cases testBit x i <;> simp + +theorem and_pow_two_identity {x : Nat} (lt : x < 2^n) : x &&& 2^n-1 = x := by + rw [and_pow_two_is_mod] + apply Nat.mod_eq_of_lt lt + +/-! ### lor -/ + +@[simp] theorem or_zero (x : Nat) : 0 ||| x = x := by + simp only [HOr.hOr, OrOp.or, lor] + unfold bitwise + simp [@eq_comm _ 0] + +@[simp] theorem zero_or (x : Nat) : x ||| 0 = x := by + simp only [HOr.hOr, OrOp.or, lor] + unfold bitwise + simp [@eq_comm _ 0] + +@[simp] theorem testBit_or (x y i : Nat) : (x ||| y).testBit i = (x.testBit i || y.testBit i) := by + simp [HOr.hOr, OrOp.or, lor, testBit_bitwise ] + +theorem or_lt_two_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : x ||| y < 2^n := + bitwise_lt_two_pow left right + +/-! ### xor -/ + +@[simp] theorem testBit_xor (x y i : Nat) : + (x ^^^ y).testBit i = Bool.xor (x.testBit i) (y.testBit i) := by + simp [HXor.hXor, Xor.xor, xor, testBit_bitwise ] + +theorem xor_lt_two_pow {x y n : Nat} (left : x < 2^n) (right : y < 2^n) : x ^^^ y < 2^n := + bitwise_lt_two_pow left right + +/-! ### Arithmetic -/ + +theorem testBit_mul_pow_two_add (a : Nat) {b i : Nat} (b_lt : b < 2^i) (j : Nat) : + testBit (2 ^ i * a + b) j = + if j < i then + testBit b j + else + testBit a (j - i) := by + cases Nat.lt_or_ge j i with + | inl j_lt => + simp only [j_lt] + have i_ge := Nat.le_of_lt j_lt + have i_sub_j_nez : i-j ≠ 0 := Nat.sub_ne_zero_of_lt j_lt + have i_def : i = j + succ (pred (i-j)) := + calc i = j + (i-j) := (Nat.add_sub_cancel' i_ge).symm + _ = j + succ (pred (i-j)) := by + rw [← congrArg (j+·) (Nat.succ_pred i_sub_j_nez)] + rw [i_def] + simp only [testBit_to_div_mod, Nat.pow_add, Nat.mul_assoc] + simp only [Nat.mul_add_div (Nat.two_pow_pos _), Nat.mul_add_mod] + simp [Nat.pow_succ, Nat.mul_comm _ 2, Nat.mul_assoc, Nat.mul_add_mod] + | inr j_ge => + have j_def : j = i + (j-i) := (Nat.add_sub_cancel' j_ge).symm + simp only [ + testBit_to_div_mod, + Nat.not_lt_of_le, + j_ge, + ite_false] + simp [congrArg (2^·) j_def, Nat.pow_add, + ←Nat.div_div_eq_div_mul, + Nat.mul_add_div, + Nat.div_eq_of_lt b_lt, + Nat.two_pow_pos i] + +theorem testBit_mul_pow_two : + testBit (2 ^ i * a) j = (decide (j ≥ i) && testBit a (j-i)) := by + have gen := testBit_mul_pow_two_add a (Nat.two_pow_pos i) j + simp at gen + rw [gen] + cases Nat.lt_or_ge j i with + | _ p => simp [p, Nat.not_le_of_lt, Nat.not_lt_of_le] + +theorem mul_add_lt_is_or {b : Nat} (b_lt : b < 2^i) (a : Nat) : 2^i * a + b = 2^i * a ||| b := by + apply eq_of_testBit_eq + intro j + simp only [testBit_mul_pow_two_add _ b_lt, + testBit_or, testBit_mul_pow_two] + if j_lt : j < i then + simp [Nat.not_le_of_lt, j_lt] + else + have i_le : i ≤ j := Nat.le_of_not_lt j_lt + have b_lt_j := + calc b < 2 ^ i := b_lt + _ ≤ 2 ^ j := Nat.pow_le_pow_of_le_right Nat.zero_lt_two i_le + simp [i_le, j_lt, testBit_lt_two_pow, b_lt_j] + +/-! ### shiftLeft and shiftRight -/ + +@[simp] theorem testBit_shiftLeft (x : Nat) : testBit (x <<< i) j = + (decide (j ≥ i) && testBit x (j-i)) := by + simp [shiftLeft_eq, Nat.mul_comm _ (2^_), testBit_mul_pow_two] + +@[simp] theorem testBit_shiftRight (x : Nat) : testBit (x >>> i) j = testBit x (i+j) := by + simp [testBit, ←shiftRight_add] diff --git a/tests/lean/interactive/inWordCompletion.lean.expected.out b/tests/lean/interactive/inWordCompletion.lean.expected.out index 08fdc09894..86375d7238 100644 --- a/tests/lean/interactive/inWordCompletion.lean.expected.out +++ b/tests/lean/interactive/inWordCompletion.lean.expected.out @@ -3,7 +3,11 @@ {"items": [{"label": "gfabc", "kind": 3, "detail": "Nat → Nat"}, {"label": "gfacc", "kind": 3, "detail": "Nat → Nat"}, - {"label": "gfadc", "kind": 3, "detail": "Nat → Nat"}], + {"label": "gfadc", "kind": 3, "detail": "Nat → Nat"}, + {"label": "Std.BitVec.getLsb_ofNat", + "kind": 3, + "detail": + "∀ (n x i : Nat), Std.BitVec.getLsb (x#n) i = (decide (i < n) && Nat.testBit x i)"}], "isIncomplete": true} {"textDocument": {"uri": "file:///inWordCompletion.lean"}, "position": {"line": 13, "character": 14}} diff --git a/tests/lean/run/etaStructProofIrrelIssue.lean b/tests/lean/run/etaStructProofIrrelIssue.lean index 94e8d091b1..312b4bb796 100644 --- a/tests/lean/run/etaStructProofIrrelIssue.lean +++ b/tests/lean/run/etaStructProofIrrelIssue.lean @@ -1,6 +1,3 @@ -theorem Fin.ext_iff : (Fin.mk m h₁ : Fin k) = Fin.mk n h₂ ↔ m = n := - Fin.mk.injEq _ _ _ _ ▸ Iff.rfl - example (h : m = n) : (Fin.mk m h₁ : Fin k) = Fin.mk n h₂ := by apply Fin.ext_iff.2 exact h diff --git a/tests/lean/run/ext1.lean b/tests/lean/run/ext1.lean index 08b4c2be72..97dccbe9ca 100644 --- a/tests/lean/run/ext1.lean +++ b/tests/lean/run/ext1.lean @@ -44,7 +44,6 @@ example (f g : Nat × Nat → Nat) : f = g := by -- exact h ▸ rfl -- allow more specific ext theorems -declare_ext_theorems_for Fin @[ext high] theorem Fin.zero_ext (a b : Fin 0) : True → a = b := by cases a.isLt example (a b : Fin 0) : a = b := by ext; exact True.intro