perf: use memcmp for ByteArray equality (#13235)
This PR uses `std::memcmp` for `ByteArray` `BEq` and `DecidableEq`. Implementation is done in the same way as `String` but adapted to scalar arrays.
This commit is contained in:
parent
861f722844
commit
097f3ebdbc
4 changed files with 82 additions and 3 deletions
|
|
@ -20,12 +20,20 @@ universe u
|
|||
|
||||
namespace ByteArray
|
||||
|
||||
deriving instance BEq for ByteArray
|
||||
@[extern "lean_sarray_dec_eq"]
|
||||
def beq (lhs rhs : @& ByteArray) : Bool :=
|
||||
lhs.data == rhs.data
|
||||
|
||||
instance : BEq ByteArray where
|
||||
beq := beq
|
||||
|
||||
attribute [ext] ByteArray
|
||||
|
||||
instance : DecidableEq ByteArray :=
|
||||
fun _ _ => decidable_of_decidable_of_iff ByteArray.ext_iff.symm
|
||||
@[extern "lean_sarray_dec_eq"]
|
||||
def decEq (lhs rhs : @& ByteArray) : Decidable (lhs = rhs) :=
|
||||
decidable_of_decidable_of_iff ByteArray.ext_iff.symm
|
||||
|
||||
instance : DecidableEq ByteArray := decEq
|
||||
|
||||
instance : Inhabited ByteArray where
|
||||
default := empty
|
||||
|
|
|
|||
|
|
@ -976,6 +976,13 @@ static inline void lean_sarray_set_size(u_lean_obj_arg o, size_t sz) {
|
|||
}
|
||||
static inline uint8_t* lean_sarray_cptr(lean_object * o) { return lean_to_sarray(o)->m_data; }
|
||||
|
||||
LEAN_EXPORT bool lean_sarray_eq_cold(b_lean_obj_arg a1, b_lean_obj_arg a2);
|
||||
static inline bool lean_sarray_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
||||
assert(lean_sarray_elem_size(a1) == lean_sarray_elem_size(a2));
|
||||
return a1 == a2 || (lean_sarray_size(a1) == lean_sarray_size(a2) && lean_sarray_eq_cold(a1, a2));
|
||||
}
|
||||
static inline uint8_t lean_sarray_dec_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) { return lean_sarray_eq(a1, a2); }
|
||||
|
||||
/* Remark: expand sarray API after we add better support in the compiler */
|
||||
|
||||
/* ByteArray (special case of Array of Scalars) */
|
||||
|
|
|
|||
|
|
@ -2078,6 +2078,11 @@ extern "C" LEAN_EXPORT bool lean_string_eq_cold(b_lean_obj_arg s1, b_lean_obj_ar
|
|||
return std::memcmp(lean_string_cstr(s1), lean_string_cstr(s2), lean_string_size(s1)) == 0;
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT bool lean_sarray_eq_cold(b_lean_obj_arg a1, b_lean_obj_arg a2) {
|
||||
size_t len = lean_sarray_elem_size(a1) * lean_sarray_size(a1);
|
||||
return std::memcmp(lean_sarray_cptr(a1), lean_sarray_cptr(a2), len) == 0;
|
||||
}
|
||||
|
||||
bool string_eq(object * s1, char const * s2) {
|
||||
if (lean_string_size(s1) != strlen(s2) + 1)
|
||||
return false;
|
||||
|
|
|
|||
59
tests/elab/bytearray_eq.lean
Normal file
59
tests/elab/bytearray_eq.lean
Normal file
|
|
@ -0,0 +1,59 @@
|
|||
module
|
||||
|
||||
/-! test native equality on ByteArray -/
|
||||
|
||||
def mk (xs : Array UInt8) : ByteArray := ⟨xs⟩
|
||||
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[] == mk #[]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] == mk #[1, 2, 3, 4, 5]
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1] != mk #[]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[] != mk #[1]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] != mk #[0, 2, 3, 4, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[0, 2, 3, 4, 5] != mk #[1, 2, 3, 4, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] != mk #[1, 2, 3, 0, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 0, 5] != mk #[1, 2, 3, 4, 5]
|
||||
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[] = mk #[]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] = mk #[1, 2, 3, 4, 5]
|
||||
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1] ≠ mk #[]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[] ≠ mk #[1]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] ≠ mk #[0, 2, 3, 4, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[0, 2, 3, 4, 5] ≠ mk #[1, 2, 3, 4, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 4, 5] ≠ mk #[1, 2, 3, 0, 5]
|
||||
/-- info: true -/
|
||||
#guard_msgs in
|
||||
#eval mk #[1, 2, 3, 0, 5] ≠ mk #[1, 2, 3, 4, 5]
|
||||
Loading…
Add table
Reference in a new issue