From 9bd60c75195562d3aeb28574a8b189e160ef6da1 Mon Sep 17 00:00:00 2001 From: Joe Hendrix Date: Mon, 1 Mar 2021 20:01:33 -0800 Subject: [PATCH] feat: Nat/Fin/UInt instances of bitwise classes --- src/Init/Data/Fin/Basic.lean | 20 ++++++ src/Init/Data/Nat/Bitwise.lean | 16 +++++ src/Init/Data/UInt.lean | 100 +++++++++++++++++++++------ src/include/lean/lean.h | 2 + src/include/lean/mpz.h | 6 +- src/runtime/mpz.cpp | 10 +-- src/runtime/object.cpp | 42 ++++++++++- tests/lean/277a.lean | 4 +- tests/lean/bitwise.lean | 49 +++++++++++++ tests/lean/bitwise.lean.expected.out | 35 ++++++++++ 10 files changed, 249 insertions(+), 35 deletions(-) create mode 100644 tests/lean/bitwise.lean create mode 100644 tests/lean/bitwise.lean.expected.out diff --git a/src/Init/Data/Fin/Basic.lean b/src/Init/Data/Fin/Basic.lean index 8bc19067ff..611bb8ed08 100644 --- a/src/Init/Data/Fin/Basic.lean +++ b/src/Init/Data/Fin/Basic.lean @@ -62,6 +62,15 @@ def land : Fin n → Fin n → Fin n def lor : Fin n → Fin n → Fin n | ⟨a, h⟩, ⟨b, _⟩ => ⟨(Nat.lor a b) % n, mlt h⟩ +def xor : Fin n → Fin n → Fin n + | ⟨a, h⟩, ⟨b, _⟩ => ⟨(Nat.xor a b) % n, mlt h⟩ + +def shiftLeft : Fin n → Fin n → Fin n + | ⟨a, h⟩, ⟨b, _⟩ => ⟨(a <<< b) % n, mlt h⟩ + +def shiftRight : Fin n → Fin n → Fin n + | ⟨a, h⟩, ⟨b, _⟩ => ⟨(a >>> b) % n, mlt h⟩ + instance : Add (Fin n) where add := Fin.add @@ -77,6 +86,17 @@ instance : Mod (Fin n) where instance : Div (Fin n) where div := Fin.div +instance : AndOp (Fin n) where + and := Fin.land +instance : OrOp (Fin n) where + or := Fin.lor +instance : Xor (Fin n) where + xor := Fin.xor +instance : ShiftLeft (Fin n) where + shiftLeft := Fin.shiftLeft +instance : ShiftRight (Fin n) where + shiftRight := Fin.shiftRight + instance : HMod (Fin n) Nat (Fin n) where hMod := Fin.modn diff --git a/src/Init/Data/Nat/Bitwise.lean b/src/Init/Data/Nat/Bitwise.lean index 3ffe5e8be8..4d014d5c09 100644 --- a/src/Init/Data/Nat/Bitwise.lean +++ b/src/Init/Data/Nat/Bitwise.lean @@ -30,5 +30,21 @@ partial def bitwise (f : Bool → Bool → Bool) (n m : Nat) : Nat := def land : Nat → Nat → Nat := bitwise and @[extern "lean_nat_lor"] def lor : Nat → Nat → Nat := bitwise or +@[extern "lean_nat_lxor"] +def xor : Nat → Nat → Nat := bitwise bne +@[extern "lean_nat_shiftl"] +def shiftLeft : Nat → Nat → Nat + | n, 0 => n + | n, succ m => shiftLeft (2*n) m +@[extern "lean_nat_shiftr"] +def shiftRight : Nat → Nat → Nat + | n, 0 => n + | n, succ m => shiftRight n m / 2 + +instance : AndOp Nat := ⟨Nat.land⟩ +instance : OrOp Nat := ⟨Nat.lor⟩ +instance : Xor Nat := ⟨Nat.xor⟩ +instance : ShiftLeft Nat := ⟨Nat.shiftLeft⟩ +instance : ShiftRight Nat := ⟨Nat.shiftRight⟩ end Nat diff --git a/src/Init/Data/UInt.lean b/src/Init/Data/UInt.lean index 44e15a4299..ccc9e5056d 100644 --- a/src/Init/Data/UInt.lean +++ b/src/Init/Data/UInt.lean @@ -30,6 +30,12 @@ def UInt8.modn (a : UInt8) (n : @& Nat) : UInt8 := ⟨a.val % n⟩ def UInt8.land (a b : UInt8) : UInt8 := ⟨Fin.land a.val b.val⟩ @[extern c inline "#1 | #2"] def UInt8.lor (a b : UInt8) : UInt8 := ⟨Fin.lor a.val b.val⟩ +@[extern c inline "#1 ^ #2"] +def UInt8.xor (a b : UInt8) : UInt8 := ⟨Fin.xor a.val b.val⟩ +@[extern c inline "#1 << #2"] +def UInt8.shiftLeft (a b : UInt8) : UInt8 := ⟨a.val <<< b.val⟩ +@[extern c inline "#1 >> #2"] +def UInt8.shiftRight (a b : UInt8) : UInt8 := ⟨a.val >>> b.val⟩ def UInt8.lt (a b : UInt8) : Prop := a.val < b.val def UInt8.le (a b : UInt8) : Prop := a.val ≤ b.val @@ -43,10 +49,15 @@ instance : Div UInt8 := ⟨UInt8.div⟩ instance : HasLess UInt8 := ⟨UInt8.lt⟩ instance : HasLessEq UInt8 := ⟨UInt8.le⟩ -@[extern c inline "#1 << #2"] -constant UInt8.shiftLeft (a b : UInt8) : UInt8 -@[extern c inline "#1 >> #2"] -constant UInt8.shiftRight (a b : UInt8) : UInt8 +@[extern c inline "~ #1"] +def UInt8.complement (a:UInt8) : UInt8 := 0-(a+1) + +instance : Complement UInt8 := ⟨UInt8.complement⟩ +instance : AndOp UInt8 := ⟨UInt8.land⟩ +instance : OrOp UInt8 := ⟨UInt8.lor⟩ +instance : Xor UInt8 := ⟨UInt8.xor⟩ +instance : ShiftLeft UInt8 := ⟨UInt8.shiftLeft⟩ +instance : ShiftRight UInt8 := ⟨UInt8.shiftRight⟩ set_option bootstrap.genMatcherCode false in @[extern c inline "#1 < #2"] @@ -84,6 +95,12 @@ def UInt16.modn (a : UInt16) (n : @& Nat) : UInt16 := ⟨a.val % n⟩ def UInt16.land (a b : UInt16) : UInt16 := ⟨Fin.land a.val b.val⟩ @[extern c inline "#1 | #2"] def UInt16.lor (a b : UInt16) : UInt16 := ⟨Fin.lor a.val b.val⟩ +@[extern c inline "#1 ^ #2"] +def UInt16.xor (a b : UInt16) : UInt16 := ⟨Fin.xor a.val b.val⟩ +@[extern c inline "#1 << #2"] +def UInt16.shiftLeft (a b : UInt16) : UInt16 := ⟨a.val <<< b.val⟩ +@[extern c inline "#1 >> #2"] +def UInt16.shiftRight (a b : UInt16) : UInt16 := ⟨a.val >>> b.val⟩ def UInt16.lt (a b : UInt16) : Prop := a.val < b.val def UInt16.le (a b : UInt16) : Prop := a.val ≤ b.val @@ -98,10 +115,15 @@ instance : Div UInt16 := ⟨UInt16.div⟩ instance : HasLess UInt16 := ⟨UInt16.lt⟩ instance : HasLessEq UInt16 := ⟨UInt16.le⟩ -@[extern c inline "#1 << #2"] -constant UInt16.shiftLeft (a b : UInt16) : UInt16 -@[extern c inline "#1 >> #2"] -constant UInt16.shiftRight (a b : UInt16) : UInt16 +@[extern c inline "~ #1"] +def UInt16.complement (a:UInt16) : UInt16 := 0-(a+1) + +instance : Complement UInt16 := ⟨UInt16.complement⟩ +instance : AndOp UInt16 := ⟨UInt16.land⟩ +instance : OrOp UInt16 := ⟨UInt16.lor⟩ +instance : Xor UInt16 := ⟨UInt16.xor⟩ +instance : ShiftLeft UInt16 := ⟨UInt16.shiftLeft⟩ +instance : ShiftRight UInt16 := ⟨UInt16.shiftRight⟩ set_option bootstrap.genMatcherCode false in @[extern c inline "#1 < #2"] @@ -139,6 +161,12 @@ def UInt32.modn (a : UInt32) (n : @& Nat) : UInt32 := ⟨a.val % n⟩ def UInt32.land (a b : UInt32) : UInt32 := ⟨Fin.land a.val b.val⟩ @[extern c inline "#1 | #2"] def UInt32.lor (a b : UInt32) : UInt32 := ⟨Fin.lor a.val b.val⟩ +@[extern c inline "#1 ^ #2"] +def UInt32.xor (a b : UInt32) : UInt32 := ⟨Fin.xor a.val b.val⟩ +@[extern c inline "#1 << #2"] +def UInt32.shiftLeft (a b : UInt32) : UInt32 := ⟨a.val <<< b.val⟩ +@[extern c inline "#1 >> #2"] +def UInt32.shiftRight (a b : UInt32) : UInt32 := ⟨a.val >>> b.val⟩ @[extern c inline "((uint8_t)#1)"] def UInt32.toUInt8 (a : UInt32) : UInt8 := a.toNat.toUInt8 @[extern c inline "((uint16_t)#1)"] @@ -154,10 +182,15 @@ instance : Mod UInt32 := ⟨UInt32.mod⟩ instance : HMod UInt32 Nat UInt32 := ⟨UInt32.modn⟩ instance : Div UInt32 := ⟨UInt32.div⟩ -@[extern c inline "#1 << #2"] -constant UInt32.shiftLeft (a b : UInt32) : UInt32 -@[extern c inline "#1 >> #2"] -constant UInt32.shiftRight (a b : UInt32) : UInt32 +@[extern c inline "~ #1"] +def UInt32.complement (a:UInt32) : UInt32 := 0-(a+1) + +instance : Complement UInt32 := ⟨UInt32.complement⟩ +instance : AndOp UInt32 := ⟨UInt32.land⟩ +instance : OrOp UInt32 := ⟨UInt32.lor⟩ +instance : Xor UInt32 := ⟨UInt32.xor⟩ +instance : ShiftLeft UInt32 := ⟨UInt32.shiftLeft⟩ +instance : ShiftRight UInt32 := ⟨UInt32.shiftRight⟩ @[extern "lean_uint64_of_nat"] def UInt64.ofNat (n : @& Nat) : UInt64 := ⟨Fin.ofNat n⟩ @@ -180,6 +213,12 @@ def UInt64.modn (a : UInt64) (n : @& Nat) : UInt64 := ⟨a.val % n⟩ def UInt64.land (a b : UInt64) : UInt64 := ⟨Fin.land a.val b.val⟩ @[extern c inline "#1 | #2"] def UInt64.lor (a b : UInt64) : UInt64 := ⟨Fin.lor a.val b.val⟩ +@[extern c inline "#1 ^ #2"] +def UInt64.xor (a b : UInt64) : UInt64 := ⟨Fin.xor a.val b.val⟩ +@[extern c inline "#1 << #2"] +def UInt64.shiftLeft (a b : UInt64) : UInt64 := ⟨a.val <<< b.val⟩ +@[extern c inline "#1 >> #2"] +def UInt64.shiftRight (a b : UInt64) : UInt64 := ⟨a.val >>> b.val⟩ def UInt64.lt (a b : UInt64) : Prop := a.val < b.val def UInt64.le (a b : UInt64) : Prop := a.val ≤ b.val @[extern c inline "((uint8_t)#1)"] @@ -191,12 +230,6 @@ def UInt64.toUInt32 (a : UInt64) : UInt32 := a.toNat.toUInt32 @[extern c inline "((uint64_t)#1)"] def UInt32.toUInt64 (a : UInt32) : UInt64 := a.toNat.toUInt64 --- TODO(Leo): give reference implementation for shiftLeft and shiftRight, and define them for other UInt types -@[extern c inline "#1 << #2"] -constant UInt64.shiftLeft (a b : UInt64) : UInt64 -@[extern c inline "#1 >> #2"] -constant UInt64.shiftRight (a b : UInt64) : UInt64 - instance : OfNat UInt64 n := ⟨UInt64.ofNat n⟩ instance : Add UInt64 := ⟨UInt64.add⟩ instance : Sub UInt64 := ⟨UInt64.sub⟩ @@ -207,6 +240,16 @@ instance : Div UInt64 := ⟨UInt64.div⟩ instance : HasLess UInt64 := ⟨UInt64.lt⟩ instance : HasLessEq UInt64 := ⟨UInt64.le⟩ +@[extern c inline "~ #1"] +def UInt64.complement (a:UInt64) : UInt64 := 0-(a+1) + +instance : Complement UInt64 := ⟨UInt64.complement⟩ +instance : AndOp UInt64 := ⟨UInt64.land⟩ +instance : OrOp UInt64 := ⟨UInt64.lor⟩ +instance : Xor UInt64 := ⟨UInt64.xor⟩ +instance : ShiftLeft UInt64 := ⟨UInt64.shiftLeft⟩ +instance : ShiftRight UInt64 := ⟨UInt64.shiftRight⟩ + @[extern c inline "(uint64_t)#1"] def Bool.toUInt64 (b : Bool) : UInt64 := if b then 1 else 0 @@ -249,6 +292,12 @@ def USize.modn (a : USize) (n : @& Nat) : USize := ⟨a.val % n⟩ def USize.land (a b : USize) : USize := ⟨Fin.land a.val b.val⟩ @[extern c inline "#1 | #2"] def USize.lor (a b : USize) : USize := ⟨Fin.lor a.val b.val⟩ +@[extern c inline "#1 ^ #2"] +def USize.xor (a b : USize) : USize := ⟨Fin.xor a.val b.val⟩ +@[extern c inline "#1 << #2"] +def USize.shiftLeft (a b : USize) : USize := ⟨a.val <<< b.val⟩ +@[extern c inline "#1 >> #2"] +def USize.shiftRight (a b : USize) : USize := ⟨a.val >>> b.val⟩ @[extern c inline "#1"] def UInt32.toUSize (a : UInt32) : USize := a.toNat.toUSize @[extern c inline "((size_t)#1)"] @@ -256,11 +305,6 @@ def UInt64.toUSize (a : UInt64) : USize := a.toNat.toUSize @[extern c inline "(uint32_t)#1"] def USize.toUInt32 (a : USize) : UInt32 := a.toNat.toUInt32 --- TODO(Leo): give reference implementation for shiftLeft and shiftRight, and define them for other UInt types -@[extern c inline "#1 << #2"] -constant USize.shiftLeft (a b : USize) : USize -@[extern c inline "#1 >> #2"] -constant USize.shiftRight (a b : USize) : USize def USize.lt (a b : USize) : Prop := a.val < b.val def USize.le (a b : USize) : Prop := a.val ≤ b.val @@ -274,6 +318,16 @@ instance : Div USize := ⟨USize.div⟩ instance : HasLess USize := ⟨USize.lt⟩ instance : HasLessEq USize := ⟨USize.le⟩ +@[extern c inline "~ #1"] +def USize.complement (a:USize) : USize := 0-(a+1) + +instance : Complement USize := ⟨USize.complement⟩ +instance : AndOp USize := ⟨USize.land⟩ +instance : OrOp USize := ⟨USize.lor⟩ +instance : Xor USize := ⟨USize.xor⟩ +instance : ShiftLeft USize := ⟨USize.shiftLeft⟩ +instance : ShiftRight USize := ⟨USize.shiftRight⟩ + set_option bootstrap.genMatcherCode false in @[extern c inline "#1 < #2"] def USize.decLt (a b : USize) : Decidable (a < b) := diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index 6af5a9749b..8ed6c87039 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -1430,6 +1430,8 @@ static inline lean_obj_res lean_nat_lxor(b_lean_obj_arg a1, b_lean_obj_arg a2) { } } +lean_obj_res lean_nat_shiftl(b_lean_obj_arg a1, b_lean_obj_arg a2); +lean_obj_res lean_nat_shiftr(b_lean_obj_arg a1, b_lean_obj_arg a2); lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2); /* Integers */ diff --git a/src/include/lean/mpz.h b/src/include/lean/mpz.h index 3a534afb6e..fa9d779611 100644 --- a/src/include/lean/mpz.h +++ b/src/include/lean/mpz.h @@ -224,18 +224,18 @@ public: \brief Return the position of the most significant bit. Return 0 if the number is negative */ - unsigned log2() const; + size_t log2() const; /** \brief log2(-n) Return 0 if the number is nonegative */ - unsigned mlog2() const; + size_t mlog2() const; bool perfect_square() const { return mpz_perfect_square_p(m_val); } bool is_power_of_two() const { return is_pos() && mpz_popcount(m_val) == 1; } - bool is_power_of_two(unsigned & shift) const; + bool is_power_of_two(size_t& shift) const; /** \brief Return largest k s.t. n is a multiple of 2^k */ diff --git a/src/runtime/mpz.cpp b/src/runtime/mpz.cpp index 7d2986cce1..35efa4b7a6 100644 --- a/src/runtime/mpz.cpp +++ b/src/runtime/mpz.cpp @@ -38,27 +38,27 @@ size_t mpz::get_size_t() const { return static_cast(mpz_getlimbn(m_val, 0)); } -unsigned mpz::log2() const { +size_t mpz::log2() const { if (is_nonpos()) return 0; - unsigned r = mpz_sizeinbase(m_val, 2); + size_t r = mpz_sizeinbase(m_val, 2); lean_assert(r > 0); return r - 1; } -unsigned mpz::mlog2() const { +size_t mpz::mlog2() const { if (is_nonneg()) return 0; mpz * _this = const_cast(this); _this->neg(); lean_assert(is_pos()); - unsigned r = mpz_sizeinbase(m_val, 2); + size_t r = mpz_sizeinbase(m_val, 2); _this->neg(); lean_assert(is_neg()); return r - 1; } -bool mpz::is_power_of_two(unsigned & shift) const { +bool mpz::is_power_of_two(size_t& shift) const { if (is_nonpos()) return false; if (mpz_popcount(m_val) == 1) { diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 283736ae18..047f38186c 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1257,7 +1257,7 @@ extern "C" object * lean_nat_big_lor(object * a1, object * a2) { return mpz_to_nat(mpz_value(a1) | mpz_value(a2)); } -extern "C" object * lean_nat_big_lxor(object * a1, object * a2) { +extern "C" object * lean_nat_big_xor(object * a1, object * a2) { lean_assert(!lean_is_scalar(a1) || !lean_is_scalar(a2)); if (lean_is_scalar(a1)) return mpz_to_nat(mpz::of_size_t(lean_unbox(a1)) ^ mpz_value(a2)); @@ -1267,8 +1267,46 @@ extern "C" object * lean_nat_big_lxor(object * a1, object * a2) { return mpz_to_nat(mpz_value(a1) ^ mpz_value(a2)); } -extern "C" lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2) { +extern "C" lean_obj_res lean_nat_shiftl(b_lean_obj_arg a1, b_lean_obj_arg a2) { + // Special case for shifted value is 0. + if (lean_is_scalar(a1) && lean_unbox(a1) == 0) { + return lean_box(0); + } + auto a = lean_is_scalar(a1) + ? mpz::of_size_t(lean_unbox(a1)) + : mpz_value(a1); + if (!lean_is_scalar(a2) || lean_unbox(a2) > UINT_MAX) { + lean_panic("Nat.shiftl exponent is too big"); + } + mpz r; + mul2k(r, a, lean_unbox(a2)); + return mpz_to_nat(r); +} + +extern "C" lean_obj_res lean_nat_shiftr(b_lean_obj_arg a1, b_lean_obj_arg a2) { if (!lean_is_scalar(a2)) { + return lean_box(0); // This large of an exponent must be 0. + } + auto a = lean_is_scalar(a1) + ? mpz::of_size_t(lean_unbox(a1)) + : mpz_value(a1); + size_t s = lean_unbox(a2); + // If the shift amount is large, then we fail if it is not large + // enough to zero out all the bits. + if (s > UINT_MAX) { + if (a.log2() >= s) { + lean_panic("Nat.shiftr exponent is too big"); + } else { + return lean_box(0); + } + } + mpz r; + div2k(r, a, s); + return mpz_to_nat(r); +} + +extern "C" lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2) { + if (!lean_is_scalar(a2) || lean_unbox(a2) > UINT_MAX) { lean_internal_panic("Nat.pow exponent is too big"); } if (lean_is_scalar(a1)) diff --git a/tests/lean/277a.lean b/tests/lean/277a.lean index 3a6d7948b6..b6b0a01e40 100644 --- a/tests/lean/277a.lean +++ b/tests/lean/277a.lean @@ -1,3 +1,3 @@ -infixl:67 " <<< " => nonexistant +infixl:67 " <>< " => nonexistant -#eval (1 <<< 11 : UInt64) +#eval (1 <>< 11 : UInt64) diff --git a/tests/lean/bitwise.lean b/tests/lean/bitwise.lean new file mode 100644 index 0000000000..6a19ed0db5 --- /dev/null +++ b/tests/lean/bitwise.lean @@ -0,0 +1,49 @@ +#eval "Nat" +#eval 0x17 &&& 0xf == 0x7 +#eval 0x17 ||| 0xf == 0x1f +#eval 0x17 ^^^ 0xf == 0x18 +#eval 0x12 <<< 4 == 0x120 +#eval 0x12 >>> 4 == 0x1 + +-- Expected failure +-- #eval 1 ^ (2 ^ 32) + +-- Edge case testing +#eval 0 <<< 2^32 == 0 + +-- Expected failures +--#eval 1 <<< 2^32 +--#eval (1 <<< 2^31 <<< 2^31 >>> 2^32) + +#eval "UInt8" +#eval 0x117 &&& (0x1ff : UInt8) == 0x17 +#eval 0x17 ||| (0x10f : UInt8) == 0x1f +#eval 0x17 ^^^ (0x10f : UInt8) == 0x18 +#eval (0x12 : UInt8) <<< 4 == 0x120 +#eval (0x12 : UInt8) >>> 4 == 0x1 +#eval ~~~(0x12 : UInt8) == 0xed + +#eval "UInt16" +#eval 0x117 &&& (0x101ff : UInt16) == 0x117 +#eval 0x17 ||| (0x1010f : UInt16) == 0x011f +#eval 0x17 ^^^ (0x1010f : UInt16) == 0x0118 +#eval (0x12 : UInt16) <<< 4 == 0x120 +#eval (0x12 : UInt16) >>> 4 == 0x1 +#eval ~~~(0x12 : UInt16) == 0xffed + +#eval "UInt32" +#eval 0x117 &&& (0x101ff : UInt32) == 0x117 +#eval 0x17 ||| (0x1010f : UInt32) == 0x1011f +#eval 0x17 ^^^ (0x1010f : UInt32) == 0x10118 +#eval (0x12 : UInt32) <<< 4 == 0x120 +#eval (0x12 : UInt32) >>> 4 == 0x1 +#eval ~~~(0x12 : UInt32) == 0xffffffed + +#eval "UInt64" +#eval 0x117 &&& (0x101ff : UInt64) == 0x117 +#eval 0x17 ||| (0x1010f : UInt64) == 0x1011f +#eval 0x17 ^^^ (0x1010f : UInt64) == 0x10118 +#eval (0x12 : UInt64) <<< 4 == 0x120 +#eval (0x12 : UInt64) >>> 4 == 0x1 +#eval ~~~(0x12 : UInt64) == 0xffffffffffffffed + diff --git a/tests/lean/bitwise.lean.expected.out b/tests/lean/bitwise.lean.expected.out new file mode 100644 index 0000000000..62cf3d64cc --- /dev/null +++ b/tests/lean/bitwise.lean.expected.out @@ -0,0 +1,35 @@ +"Nat" +true +true +true +true +true +true +"UInt8" +true +true +true +true +true +true +"UInt16" +true +true +true +true +true +true +"UInt32" +true +true +true +true +true +true +"UInt64" +true +true +true +true +true +true