perf: use memcmp for ByteArray equality (#13235)

This PR uses `std::memcmp` for `ByteArray` `BEq` and `DecidableEq`.

Implementation is done in the same way as `String` but adapted to scalar
arrays.
This commit is contained in:
Henrik Böving 2026-04-01 17:30:03 +02:00 committed by GitHub
parent 861f722844
commit 097f3ebdbc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 82 additions and 3 deletions

View file

@ -20,12 +20,20 @@ universe u
namespace ByteArray
deriving instance BEq for ByteArray
@[extern "lean_sarray_dec_eq"]
def beq (lhs rhs : @& ByteArray) : Bool :=
lhs.data == rhs.data
instance : BEq ByteArray where
beq := beq
attribute [ext] ByteArray
instance : DecidableEq ByteArray :=
fun _ _ => decidable_of_decidable_of_iff ByteArray.ext_iff.symm
@[extern "lean_sarray_dec_eq"]
def decEq (lhs rhs : @& ByteArray) : Decidable (lhs = rhs) :=
decidable_of_decidable_of_iff ByteArray.ext_iff.symm
instance : DecidableEq ByteArray := decEq
instance : Inhabited ByteArray where
default := empty

View file

@ -976,6 +976,13 @@ static inline void lean_sarray_set_size(u_lean_obj_arg o, size_t sz) {
}
static inline uint8_t* lean_sarray_cptr(lean_object * o) { return lean_to_sarray(o)->m_data; }
LEAN_EXPORT bool lean_sarray_eq_cold(b_lean_obj_arg a1, b_lean_obj_arg a2);
static inline bool lean_sarray_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) {
assert(lean_sarray_elem_size(a1) == lean_sarray_elem_size(a2));
return a1 == a2 || (lean_sarray_size(a1) == lean_sarray_size(a2) && lean_sarray_eq_cold(a1, a2));
}
static inline uint8_t lean_sarray_dec_eq(b_lean_obj_arg a1, b_lean_obj_arg a2) { return lean_sarray_eq(a1, a2); }
/* Remark: expand sarray API after we add better support in the compiler */
/* ByteArray (special case of Array of Scalars) */

View file

@ -2078,6 +2078,11 @@ extern "C" LEAN_EXPORT bool lean_string_eq_cold(b_lean_obj_arg s1, b_lean_obj_ar
return std::memcmp(lean_string_cstr(s1), lean_string_cstr(s2), lean_string_size(s1)) == 0;
}
extern "C" LEAN_EXPORT bool lean_sarray_eq_cold(b_lean_obj_arg a1, b_lean_obj_arg a2) {
size_t len = lean_sarray_elem_size(a1) * lean_sarray_size(a1);
return std::memcmp(lean_sarray_cptr(a1), lean_sarray_cptr(a2), len) == 0;
}
bool string_eq(object * s1, char const * s2) {
if (lean_string_size(s1) != strlen(s2) + 1)
return false;

View file

@ -0,0 +1,59 @@
module
/-! test native equality on ByteArray -/
def mk (xs : Array UInt8) : ByteArray := ⟨xs⟩
/-- info: true -/
#guard_msgs in
#eval mk #[] == mk #[]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] == mk #[1, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1] != mk #[]
/-- info: true -/
#guard_msgs in
#eval mk #[] != mk #[1]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] != mk #[0, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[0, 2, 3, 4, 5] != mk #[1, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] != mk #[1, 2, 3, 0, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 0, 5] != mk #[1, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[] = mk #[]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] = mk #[1, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1] ≠ mk #[]
/-- info: true -/
#guard_msgs in
#eval mk #[] ≠ mk #[1]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] ≠ mk #[0, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[0, 2, 3, 4, 5] ≠ mk #[1, 2, 3, 4, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 4, 5] ≠ mk #[1, 2, 3, 0, 5]
/-- info: true -/
#guard_msgs in
#eval mk #[1, 2, 3, 0, 5] ≠ mk #[1, 2, 3, 4, 5]