feat: Nat/Fin/UInt instances of bitwise classes

This commit is contained in:
Joe Hendrix 2021-03-01 20:01:33 -08:00 committed by Leonardo de Moura
parent 2831dd6872
commit 9bd60c7519
10 changed files with 249 additions and 35 deletions

View file

@ -62,6 +62,15 @@ def land : Fin n → Fin n → Fin n
def lor : Fin n → Fin n → Fin n
| ⟨a, h⟩, ⟨b, _⟩ => ⟨(Nat.lor a b) % n, mlt h⟩
def xor : Fin n → Fin n → Fin n
| ⟨a, h⟩, ⟨b, _⟩ => ⟨(Nat.xor a b) % n, mlt h⟩
def shiftLeft : Fin n → Fin n → Fin n
| ⟨a, h⟩, ⟨b, _⟩ => ⟨(a <<< b) % n, mlt h⟩
def shiftRight : Fin n → Fin n → Fin n
| ⟨a, h⟩, ⟨b, _⟩ => ⟨(a >>> b) % n, mlt h⟩
instance : Add (Fin n) where
add := Fin.add
@ -77,6 +86,17 @@ instance : Mod (Fin n) where
instance : Div (Fin n) where
div := Fin.div
instance : AndOp (Fin n) where
and := Fin.land
instance : OrOp (Fin n) where
or := Fin.lor
instance : Xor (Fin n) where
xor := Fin.xor
instance : ShiftLeft (Fin n) where
shiftLeft := Fin.shiftLeft
instance : ShiftRight (Fin n) where
shiftRight := Fin.shiftRight
instance : HMod (Fin n) Nat (Fin n) where
hMod := Fin.modn

View file

@ -30,5 +30,21 @@ partial def bitwise (f : Bool → Bool → Bool) (n m : Nat) : Nat :=
def land : Nat → Nat → Nat := bitwise and
@[extern "lean_nat_lor"]
def lor : Nat → Nat → Nat := bitwise or
@[extern "lean_nat_lxor"]
def xor : Nat → Nat → Nat := bitwise bne
@[extern "lean_nat_shiftl"]
def shiftLeft : Nat → Nat → Nat
| n, 0 => n
| n, succ m => shiftLeft (2*n) m
@[extern "lean_nat_shiftr"]
def shiftRight : Nat → Nat → Nat
| n, 0 => n
| n, succ m => shiftRight n m / 2
instance : AndOp Nat := ⟨Nat.land⟩
instance : OrOp Nat := ⟨Nat.lor⟩
instance : Xor Nat := ⟨Nat.xor⟩
instance : ShiftLeft Nat := ⟨Nat.shiftLeft⟩
instance : ShiftRight Nat := ⟨Nat.shiftRight⟩
end Nat

View file

@ -30,6 +30,12 @@ def UInt8.modn (a : UInt8) (n : @& Nat) : UInt8 := ⟨a.val % n⟩
def UInt8.land (a b : UInt8) : UInt8 := ⟨Fin.land a.val b.val⟩
@[extern c inline "#1 | #2"]
def UInt8.lor (a b : UInt8) : UInt8 := ⟨Fin.lor a.val b.val⟩
@[extern c inline "#1 ^ #2"]
def UInt8.xor (a b : UInt8) : UInt8 := ⟨Fin.xor a.val b.val⟩
@[extern c inline "#1 << #2"]
def UInt8.shiftLeft (a b : UInt8) : UInt8 := ⟨a.val <<< b.val⟩
@[extern c inline "#1 >> #2"]
def UInt8.shiftRight (a b : UInt8) : UInt8 := ⟨a.val >>> b.val⟩
def UInt8.lt (a b : UInt8) : Prop := a.val < b.val
def UInt8.le (a b : UInt8) : Prop := a.val ≤ b.val
@ -43,10 +49,15 @@ instance : Div UInt8 := ⟨UInt8.div⟩
instance : HasLess UInt8 := ⟨UInt8.lt⟩
instance : HasLessEq UInt8 := ⟨UInt8.le⟩
@[extern c inline "#1 << #2"]
constant UInt8.shiftLeft (a b : UInt8) : UInt8
@[extern c inline "#1 >> #2"]
constant UInt8.shiftRight (a b : UInt8) : UInt8
@[extern c inline "~ #1"]
def UInt8.complement (a:UInt8) : UInt8 := 0-(a+1)
instance : Complement UInt8 := ⟨UInt8.complement⟩
instance : AndOp UInt8 := ⟨UInt8.land⟩
instance : OrOp UInt8 := ⟨UInt8.lor⟩
instance : Xor UInt8 := ⟨UInt8.xor⟩
instance : ShiftLeft UInt8 := ⟨UInt8.shiftLeft⟩
instance : ShiftRight UInt8 := ⟨UInt8.shiftRight⟩
set_option bootstrap.genMatcherCode false in
@[extern c inline "#1 < #2"]
@ -84,6 +95,12 @@ def UInt16.modn (a : UInt16) (n : @& Nat) : UInt16 := ⟨a.val % n⟩
def UInt16.land (a b : UInt16) : UInt16 := ⟨Fin.land a.val b.val⟩
@[extern c inline "#1 | #2"]
def UInt16.lor (a b : UInt16) : UInt16 := ⟨Fin.lor a.val b.val⟩
@[extern c inline "#1 ^ #2"]
def UInt16.xor (a b : UInt16) : UInt16 := ⟨Fin.xor a.val b.val⟩
@[extern c inline "#1 << #2"]
def UInt16.shiftLeft (a b : UInt16) : UInt16 := ⟨a.val <<< b.val⟩
@[extern c inline "#1 >> #2"]
def UInt16.shiftRight (a b : UInt16) : UInt16 := ⟨a.val >>> b.val⟩
def UInt16.lt (a b : UInt16) : Prop := a.val < b.val
def UInt16.le (a b : UInt16) : Prop := a.val ≤ b.val
@ -98,10 +115,15 @@ instance : Div UInt16 := ⟨UInt16.div⟩
instance : HasLess UInt16 := ⟨UInt16.lt⟩
instance : HasLessEq UInt16 := ⟨UInt16.le⟩
@[extern c inline "#1 << #2"]
constant UInt16.shiftLeft (a b : UInt16) : UInt16
@[extern c inline "#1 >> #2"]
constant UInt16.shiftRight (a b : UInt16) : UInt16
@[extern c inline "~ #1"]
def UInt16.complement (a:UInt16) : UInt16 := 0-(a+1)
instance : Complement UInt16 := ⟨UInt16.complement⟩
instance : AndOp UInt16 := ⟨UInt16.land⟩
instance : OrOp UInt16 := ⟨UInt16.lor⟩
instance : Xor UInt16 := ⟨UInt16.xor⟩
instance : ShiftLeft UInt16 := ⟨UInt16.shiftLeft⟩
instance : ShiftRight UInt16 := ⟨UInt16.shiftRight⟩
set_option bootstrap.genMatcherCode false in
@[extern c inline "#1 < #2"]
@ -139,6 +161,12 @@ def UInt32.modn (a : UInt32) (n : @& Nat) : UInt32 := ⟨a.val % n⟩
def UInt32.land (a b : UInt32) : UInt32 := ⟨Fin.land a.val b.val⟩
@[extern c inline "#1 | #2"]
def UInt32.lor (a b : UInt32) : UInt32 := ⟨Fin.lor a.val b.val⟩
@[extern c inline "#1 ^ #2"]
def UInt32.xor (a b : UInt32) : UInt32 := ⟨Fin.xor a.val b.val⟩
@[extern c inline "#1 << #2"]
def UInt32.shiftLeft (a b : UInt32) : UInt32 := ⟨a.val <<< b.val⟩
@[extern c inline "#1 >> #2"]
def UInt32.shiftRight (a b : UInt32) : UInt32 := ⟨a.val >>> b.val⟩
@[extern c inline "((uint8_t)#1)"]
def UInt32.toUInt8 (a : UInt32) : UInt8 := a.toNat.toUInt8
@[extern c inline "((uint16_t)#1)"]
@ -154,10 +182,15 @@ instance : Mod UInt32 := ⟨UInt32.mod⟩
instance : HMod UInt32 Nat UInt32 := ⟨UInt32.modn⟩
instance : Div UInt32 := ⟨UInt32.div⟩
@[extern c inline "#1 << #2"]
constant UInt32.shiftLeft (a b : UInt32) : UInt32
@[extern c inline "#1 >> #2"]
constant UInt32.shiftRight (a b : UInt32) : UInt32
@[extern c inline "~ #1"]
def UInt32.complement (a:UInt32) : UInt32 := 0-(a+1)
instance : Complement UInt32 := ⟨UInt32.complement⟩
instance : AndOp UInt32 := ⟨UInt32.land⟩
instance : OrOp UInt32 := ⟨UInt32.lor⟩
instance : Xor UInt32 := ⟨UInt32.xor⟩
instance : ShiftLeft UInt32 := ⟨UInt32.shiftLeft⟩
instance : ShiftRight UInt32 := ⟨UInt32.shiftRight⟩
@[extern "lean_uint64_of_nat"]
def UInt64.ofNat (n : @& Nat) : UInt64 := ⟨Fin.ofNat n⟩
@ -180,6 +213,12 @@ def UInt64.modn (a : UInt64) (n : @& Nat) : UInt64 := ⟨a.val % n⟩
def UInt64.land (a b : UInt64) : UInt64 := ⟨Fin.land a.val b.val⟩
@[extern c inline "#1 | #2"]
def UInt64.lor (a b : UInt64) : UInt64 := ⟨Fin.lor a.val b.val⟩
@[extern c inline "#1 ^ #2"]
def UInt64.xor (a b : UInt64) : UInt64 := ⟨Fin.xor a.val b.val⟩
@[extern c inline "#1 << #2"]
def UInt64.shiftLeft (a b : UInt64) : UInt64 := ⟨a.val <<< b.val⟩
@[extern c inline "#1 >> #2"]
def UInt64.shiftRight (a b : UInt64) : UInt64 := ⟨a.val >>> b.val⟩
def UInt64.lt (a b : UInt64) : Prop := a.val < b.val
def UInt64.le (a b : UInt64) : Prop := a.val ≤ b.val
@[extern c inline "((uint8_t)#1)"]
@ -191,12 +230,6 @@ def UInt64.toUInt32 (a : UInt64) : UInt32 := a.toNat.toUInt32
@[extern c inline "((uint64_t)#1)"]
def UInt32.toUInt64 (a : UInt32) : UInt64 := a.toNat.toUInt64
-- TODO(Leo): give reference implementation for shiftLeft and shiftRight, and define them for other UInt types
@[extern c inline "#1 << #2"]
constant UInt64.shiftLeft (a b : UInt64) : UInt64
@[extern c inline "#1 >> #2"]
constant UInt64.shiftRight (a b : UInt64) : UInt64
instance : OfNat UInt64 n := ⟨UInt64.ofNat n⟩
instance : Add UInt64 := ⟨UInt64.add⟩
instance : Sub UInt64 := ⟨UInt64.sub⟩
@ -207,6 +240,16 @@ instance : Div UInt64 := ⟨UInt64.div⟩
instance : HasLess UInt64 := ⟨UInt64.lt⟩
instance : HasLessEq UInt64 := ⟨UInt64.le⟩
@[extern c inline "~ #1"]
def UInt64.complement (a:UInt64) : UInt64 := 0-(a+1)
instance : Complement UInt64 := ⟨UInt64.complement⟩
instance : AndOp UInt64 := ⟨UInt64.land⟩
instance : OrOp UInt64 := ⟨UInt64.lor⟩
instance : Xor UInt64 := ⟨UInt64.xor⟩
instance : ShiftLeft UInt64 := ⟨UInt64.shiftLeft⟩
instance : ShiftRight UInt64 := ⟨UInt64.shiftRight⟩
@[extern c inline "(uint64_t)#1"]
def Bool.toUInt64 (b : Bool) : UInt64 := if b then 1 else 0
@ -249,6 +292,12 @@ def USize.modn (a : USize) (n : @& Nat) : USize := ⟨a.val % n⟩
def USize.land (a b : USize) : USize := ⟨Fin.land a.val b.val⟩
@[extern c inline "#1 | #2"]
def USize.lor (a b : USize) : USize := ⟨Fin.lor a.val b.val⟩
@[extern c inline "#1 ^ #2"]
def USize.xor (a b : USize) : USize := ⟨Fin.xor a.val b.val⟩
@[extern c inline "#1 << #2"]
def USize.shiftLeft (a b : USize) : USize := ⟨a.val <<< b.val⟩
@[extern c inline "#1 >> #2"]
def USize.shiftRight (a b : USize) : USize := ⟨a.val >>> b.val⟩
@[extern c inline "#1"]
def UInt32.toUSize (a : UInt32) : USize := a.toNat.toUSize
@[extern c inline "((size_t)#1)"]
@ -256,11 +305,6 @@ def UInt64.toUSize (a : UInt64) : USize := a.toNat.toUSize
@[extern c inline "(uint32_t)#1"]
def USize.toUInt32 (a : USize) : UInt32 := a.toNat.toUInt32
-- TODO(Leo): give reference implementation for shiftLeft and shiftRight, and define them for other UInt types
@[extern c inline "#1 << #2"]
constant USize.shiftLeft (a b : USize) : USize
@[extern c inline "#1 >> #2"]
constant USize.shiftRight (a b : USize) : USize
def USize.lt (a b : USize) : Prop := a.val < b.val
def USize.le (a b : USize) : Prop := a.val ≤ b.val
@ -274,6 +318,16 @@ instance : Div USize := ⟨USize.div⟩
instance : HasLess USize := ⟨USize.lt⟩
instance : HasLessEq USize := ⟨USize.le⟩
@[extern c inline "~ #1"]
def USize.complement (a:USize) : USize := 0-(a+1)
instance : Complement USize := ⟨USize.complement⟩
instance : AndOp USize := ⟨USize.land⟩
instance : OrOp USize := ⟨USize.lor⟩
instance : Xor USize := ⟨USize.xor⟩
instance : ShiftLeft USize := ⟨USize.shiftLeft⟩
instance : ShiftRight USize := ⟨USize.shiftRight⟩
set_option bootstrap.genMatcherCode false in
@[extern c inline "#1 < #2"]
def USize.decLt (a b : USize) : Decidable (a < b) :=

View file

@ -1430,6 +1430,8 @@ static inline lean_obj_res lean_nat_lxor(b_lean_obj_arg a1, b_lean_obj_arg a2) {
}
}
lean_obj_res lean_nat_shiftl(b_lean_obj_arg a1, b_lean_obj_arg a2);
lean_obj_res lean_nat_shiftr(b_lean_obj_arg a1, b_lean_obj_arg a2);
lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2);
/* Integers */

View file

@ -224,18 +224,18 @@ public:
\brief Return the position of the most significant bit.
Return 0 if the number is negative
*/
unsigned log2() const;
size_t log2() const;
/**
\brief log2(-n)
Return 0 if the number is nonegative
*/
unsigned mlog2() const;
size_t mlog2() const;
bool perfect_square() const { return mpz_perfect_square_p(m_val); }
bool is_power_of_two() const { return is_pos() && mpz_popcount(m_val) == 1; }
bool is_power_of_two(unsigned & shift) const;
bool is_power_of_two(size_t& shift) const;
/**
\brief Return largest k s.t. n is a multiple of 2^k
*/

View file

@ -38,27 +38,27 @@ size_t mpz::get_size_t() const {
return static_cast<size_t>(mpz_getlimbn(m_val, 0));
}
unsigned mpz::log2() const {
size_t mpz::log2() const {
if (is_nonpos())
return 0;
unsigned r = mpz_sizeinbase(m_val, 2);
size_t r = mpz_sizeinbase(m_val, 2);
lean_assert(r > 0);
return r - 1;
}
unsigned mpz::mlog2() const {
size_t mpz::mlog2() const {
if (is_nonneg())
return 0;
mpz * _this = const_cast<mpz*>(this);
_this->neg();
lean_assert(is_pos());
unsigned r = mpz_sizeinbase(m_val, 2);
size_t r = mpz_sizeinbase(m_val, 2);
_this->neg();
lean_assert(is_neg());
return r - 1;
}
bool mpz::is_power_of_two(unsigned & shift) const {
bool mpz::is_power_of_two(size_t& shift) const {
if (is_nonpos())
return false;
if (mpz_popcount(m_val) == 1) {

View file

@ -1257,7 +1257,7 @@ extern "C" object * lean_nat_big_lor(object * a1, object * a2) {
return mpz_to_nat(mpz_value(a1) | mpz_value(a2));
}
extern "C" object * lean_nat_big_lxor(object * a1, object * a2) {
extern "C" object * lean_nat_big_xor(object * a1, object * a2) {
lean_assert(!lean_is_scalar(a1) || !lean_is_scalar(a2));
if (lean_is_scalar(a1))
return mpz_to_nat(mpz::of_size_t(lean_unbox(a1)) ^ mpz_value(a2));
@ -1267,8 +1267,46 @@ extern "C" object * lean_nat_big_lxor(object * a1, object * a2) {
return mpz_to_nat(mpz_value(a1) ^ mpz_value(a2));
}
extern "C" lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2) {
extern "C" lean_obj_res lean_nat_shiftl(b_lean_obj_arg a1, b_lean_obj_arg a2) {
// Special case for shifted value is 0.
if (lean_is_scalar(a1) && lean_unbox(a1) == 0) {
return lean_box(0);
}
auto a = lean_is_scalar(a1)
? mpz::of_size_t(lean_unbox(a1))
: mpz_value(a1);
if (!lean_is_scalar(a2) || lean_unbox(a2) > UINT_MAX) {
lean_panic("Nat.shiftl exponent is too big");
}
mpz r;
mul2k(r, a, lean_unbox(a2));
return mpz_to_nat(r);
}
extern "C" lean_obj_res lean_nat_shiftr(b_lean_obj_arg a1, b_lean_obj_arg a2) {
if (!lean_is_scalar(a2)) {
return lean_box(0); // This large of an exponent must be 0.
}
auto a = lean_is_scalar(a1)
? mpz::of_size_t(lean_unbox(a1))
: mpz_value(a1);
size_t s = lean_unbox(a2);
// If the shift amount is large, then we fail if it is not large
// enough to zero out all the bits.
if (s > UINT_MAX) {
if (a.log2() >= s) {
lean_panic("Nat.shiftr exponent is too big");
} else {
return lean_box(0);
}
}
mpz r;
div2k(r, a, s);
return mpz_to_nat(r);
}
extern "C" lean_obj_res lean_nat_pow(b_lean_obj_arg a1, b_lean_obj_arg a2) {
if (!lean_is_scalar(a2) || lean_unbox(a2) > UINT_MAX) {
lean_internal_panic("Nat.pow exponent is too big");
}
if (lean_is_scalar(a1))

View file

@ -1,3 +1,3 @@
infixl:67 " <<< " => nonexistant
infixl:67 " <>< " => nonexistant
#eval (1 <<< 11 : UInt64)
#eval (1 <>< 11 : UInt64)

49
tests/lean/bitwise.lean Normal file
View file

@ -0,0 +1,49 @@
#eval "Nat"
#eval 0x17 &&& 0xf == 0x7
#eval 0x17 ||| 0xf == 0x1f
#eval 0x17 ^^^ 0xf == 0x18
#eval 0x12 <<< 4 == 0x120
#eval 0x12 >>> 4 == 0x1
-- Expected failure
-- #eval 1 ^ (2 ^ 32)
-- Edge case testing
#eval 0 <<< 2^32 == 0
-- Expected failures
--#eval 1 <<< 2^32
--#eval (1 <<< 2^31 <<< 2^31 >>> 2^32)
#eval "UInt8"
#eval 0x117 &&& (0x1ff : UInt8) == 0x17
#eval 0x17 ||| (0x10f : UInt8) == 0x1f
#eval 0x17 ^^^ (0x10f : UInt8) == 0x18
#eval (0x12 : UInt8) <<< 4 == 0x120
#eval (0x12 : UInt8) >>> 4 == 0x1
#eval ~~~(0x12 : UInt8) == 0xed
#eval "UInt16"
#eval 0x117 &&& (0x101ff : UInt16) == 0x117
#eval 0x17 ||| (0x1010f : UInt16) == 0x011f
#eval 0x17 ^^^ (0x1010f : UInt16) == 0x0118
#eval (0x12 : UInt16) <<< 4 == 0x120
#eval (0x12 : UInt16) >>> 4 == 0x1
#eval ~~~(0x12 : UInt16) == 0xffed
#eval "UInt32"
#eval 0x117 &&& (0x101ff : UInt32) == 0x117
#eval 0x17 ||| (0x1010f : UInt32) == 0x1011f
#eval 0x17 ^^^ (0x1010f : UInt32) == 0x10118
#eval (0x12 : UInt32) <<< 4 == 0x120
#eval (0x12 : UInt32) >>> 4 == 0x1
#eval ~~~(0x12 : UInt32) == 0xffffffed
#eval "UInt64"
#eval 0x117 &&& (0x101ff : UInt64) == 0x117
#eval 0x17 ||| (0x1010f : UInt64) == 0x1011f
#eval 0x17 ^^^ (0x1010f : UInt64) == 0x10118
#eval (0x12 : UInt64) <<< 4 == 0x120
#eval (0x12 : UInt64) >>> 4 == 0x1
#eval ~~~(0x12 : UInt64) == 0xffffffffffffffed

View file

@ -0,0 +1,35 @@
"Nat"
true
true
true
true
true
true
"UInt8"
true
true
true
true
true
true
"UInt16"
true
true
true
true
true
true
"UInt32"
true
true
true
true
true
true
"UInt64"
true
true
true
true
true
true