feat(runtime): add primitive hash functions
This commit is contained in:
parent
a46e27a3d7
commit
ee050431e0
5 changed files with 81 additions and 45 deletions
|
|
@ -12,17 +12,13 @@ class Hashable (α : Type u) :=
|
|||
|
||||
export Hashable (hash)
|
||||
|
||||
-- TODO: mark as builtin and opaque
|
||||
def mixHash (u₁ u₂ : USize) : USize :=
|
||||
default USize
|
||||
|
||||
-- TODO: mark as builtin
|
||||
protected def String.hash (s : String) : USize :=
|
||||
default USize
|
||||
@[extern cpp "lean::usize_mix_hash"]
|
||||
constant mixHash (u₁ u₂ : USize) : USize := default _
|
||||
|
||||
@[extern cpp "lean::string_hash"]
|
||||
protected constant String.hash (s : String) : USize := default _
|
||||
instance : Hashable String := ⟨String.hash⟩
|
||||
|
||||
-- TODO: add builtin
|
||||
protected def Nat.hash (n : Nat) : USize :=
|
||||
USize.ofNat n
|
||||
|
||||
|
|
|
|||
|
|
@ -209,6 +209,8 @@ public:
|
|||
friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); }
|
||||
// a <- b / 2^k
|
||||
friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); }
|
||||
// a <- b % 2^k
|
||||
friend void mod2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_r_2exp(a.m_val, b.m_val, k); }
|
||||
|
||||
/**
|
||||
\brief Return the position of the most significant bit.
|
||||
|
|
|
|||
|
|
@ -1383,6 +1383,54 @@ bool int_big_lt(object * a1, object * a2) {
|
|||
}
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// UInt
|
||||
|
||||
uint8 uint8_of_big_nat(b_obj_arg a) {
|
||||
mpz r;
|
||||
mod2k(r, mpz_value(a), 8);
|
||||
return static_cast<uint8>(r.get_unsigned_int());
|
||||
}
|
||||
uint16 uint16_of_big_nat(b_obj_arg a) {
|
||||
mpz r;
|
||||
mod2k(r, mpz_value(a), 16);
|
||||
return static_cast<uint16>(r.get_unsigned_int());
|
||||
}
|
||||
uint32 uint32_of_big_nat(b_obj_arg a) {
|
||||
mpz r;
|
||||
mod2k(r, mpz_value(a), 32);
|
||||
return static_cast<uint32>(r.get_unsigned_int());
|
||||
}
|
||||
uint64 uint64_of_big_nat(b_obj_arg a) {
|
||||
mpz r;
|
||||
mod2k(r, mpz_value(a), 64);
|
||||
if (sizeof(void*) == 8) {
|
||||
// 64 bit
|
||||
return static_cast<uint64>(r.get_unsigned_long_int());
|
||||
} else {
|
||||
// 32 bit
|
||||
mpz l;
|
||||
mod2k(l, r, 32);
|
||||
mpz h;
|
||||
div2k(h, r, 32);
|
||||
return (static_cast<uint64>(h.get_unsigned_int()) << 32) + static_cast<uint64>(l.get_unsigned_int());
|
||||
}
|
||||
}
|
||||
|
||||
usize usize_of_big_nat(b_obj_arg a) {
|
||||
if (sizeof(void*) == 8)
|
||||
return uint64_of_big_nat(a);
|
||||
else
|
||||
return uint32_of_big_nat(a);
|
||||
}
|
||||
|
||||
usize usize_mix_hash(usize a1, usize a2) {
|
||||
if (sizeof(void*) == 8)
|
||||
return hash(static_cast<uint64>(a1), static_cast<uint64>(a2));
|
||||
else
|
||||
return hash(static_cast<uint32>(a1), static_cast<uint32>(a2));
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// Strings
|
||||
|
||||
|
|
@ -1591,8 +1639,8 @@ uint32 string_utf8_get(b_obj_arg s, b_obj_arg i0) {
|
|||
|
||||
/* The reference implementation is:
|
||||
```
|
||||
def utf8_next (s : @& String) (p : @& Pos) : Ppos :=
|
||||
let c := utf8_get s p in
|
||||
def next (s : @& String) (p : @& Pos) : Ppos :=
|
||||
let c := get s p in
|
||||
p + csize c
|
||||
```
|
||||
*/
|
||||
|
|
@ -1697,6 +1745,12 @@ obj_res string_utf8_set(obj_arg s, b_obj_arg i0, uint32 c) {
|
|||
return mk_string(new_s);
|
||||
}
|
||||
|
||||
usize string_hash(b_obj_arg s) {
|
||||
usize sz = string_size(s) - 1;
|
||||
char const * str = string_cstr(s);
|
||||
return hash_str(sz, str, 11);
|
||||
}
|
||||
|
||||
// =======================================
|
||||
// array functions for generated code
|
||||
|
||||
|
|
|
|||
|
|
@ -1209,10 +1209,12 @@ inline bool string_ne(b_obj_arg s1, b_obj_arg s2) { return !string_eq(s1, s2); }
|
|||
bool string_lt(b_obj_arg s1, b_obj_arg s2);
|
||||
inline uint8 string_dec_eq(b_obj_arg s1, b_obj_arg s2) { return string_eq(s1, s2); }
|
||||
inline uint8 string_dec_lt(b_obj_arg s1, b_obj_arg s2) { return string_lt(s1, s2); }
|
||||
usize string_hash(b_obj_arg);
|
||||
|
||||
// =======================================
|
||||
// uint8
|
||||
inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint8>(unbox(a)) : 0; }
|
||||
uint8 uint8_of_big_nat(b_obj_arg a);
|
||||
inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint8>(unbox(a)) : uint8_of_big_nat(a); }
|
||||
inline obj_res uint8_to_nat(uint8 a) { return mk_nat_obj(static_cast<usize>(a)); }
|
||||
inline uint8 uint8_add(uint8 a1, uint8 a2) { return a1+a2; }
|
||||
inline uint8 uint8_sub(uint8 a1, uint8 a2) { return a1-a2; }
|
||||
|
|
@ -1233,7 +1235,8 @@ inline uint8 uint8_dec_le(uint8 a1, uint8 a2) { return a1 <= a2; }
|
|||
|
||||
// =======================================
|
||||
// uint16
|
||||
inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint16>(unbox(a)) : 0; }
|
||||
uint16 uint16_of_big_nat(b_obj_arg a);
|
||||
inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint16>(unbox(a)) : uint16_of_big_nat(a); }
|
||||
inline obj_res uint16_to_nat(uint16 a) { return mk_nat_obj(static_cast<usize>(a)); }
|
||||
inline uint16 uint16_add(uint16 a1, uint16 a2) { return a1+a2; }
|
||||
inline uint16 uint16_sub(uint16 a1, uint16 a2) { return a1-a2; }
|
||||
|
|
@ -1254,22 +1257,8 @@ inline uint8 uint16_dec_le(uint16 a1, uint16 a2) { return a1 <= a2; }
|
|||
|
||||
// =======================================
|
||||
// uint32
|
||||
inline uint32 uint32_of_nat(b_obj_arg a) {
|
||||
if (is_scalar(a)) {
|
||||
usize v = unbox(a);
|
||||
if (v < std::numeric_limits<uint32>::max())
|
||||
return v;
|
||||
else
|
||||
return 0;
|
||||
} else if (sizeof(void*) == 4) {
|
||||
// 32-bit
|
||||
mpz const & m = mpz_value(a);
|
||||
return m.is_unsigned_int() ? mpz_value(a).get_unsigned_int() : 0;
|
||||
} else {
|
||||
// 64-bit
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
uint32 uint32_of_big_nat(b_obj_arg a);
|
||||
inline uint32 uint32_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint32>(unbox(a)) : uint32_of_big_nat(a); }
|
||||
inline obj_res uint32_to_nat(uint32 a) { return mk_nat_obj(static_cast<usize>(a)); }
|
||||
inline uint32 uint32_add(uint32 a1, uint32 a2) { return a1+a2; }
|
||||
inline uint32 uint32_sub(uint32 a1, uint32 a2) { return a1-a2; }
|
||||
|
|
@ -1295,14 +1284,8 @@ inline uint8 uint32_dec_le(uint32 a1, uint32 a2) { return a1 <= a2; }
|
|||
|
||||
// =======================================
|
||||
// uint64
|
||||
inline uint64 uint64_of_nat(b_obj_arg a) {
|
||||
if (is_scalar(a)) {
|
||||
return unbox(a);
|
||||
} else {
|
||||
// TODO(Leo):
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
uint64 uint64_of_big_nat(b_obj_arg a);
|
||||
inline uint64 uint64_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint64>(unbox(a)) : uint64_of_big_nat(a); }
|
||||
inline obj_res uint64_to_nat(uint64 a) { return mk_nat_obj(a); }
|
||||
inline uint64 uint64_add(uint64 a1, uint64 a2) { return a1+a2; }
|
||||
inline uint64 uint64_sub(uint64 a1, uint64 a2) { return a1-a2; }
|
||||
|
|
@ -1324,14 +1307,8 @@ inline uint8 uint64_dec_le(uint64 a1, uint64 a2) { return a1 <= a2; }
|
|||
|
||||
// =======================================
|
||||
// usize
|
||||
inline usize usize_of_nat(b_obj_arg a) {
|
||||
if (is_scalar(a)) {
|
||||
return unbox(a);
|
||||
} else {
|
||||
// TODO(Leo):
|
||||
return 0;
|
||||
}
|
||||
}
|
||||
usize usize_of_big_nat(b_obj_arg a);
|
||||
inline usize usize_of_nat(b_obj_arg a) { return is_scalar(a) ? unbox(a) : usize_of_big_nat(a); }
|
||||
inline obj_res usize_to_nat(usize a) {
|
||||
return mk_nat_obj(a);
|
||||
}
|
||||
|
|
@ -1352,7 +1329,7 @@ inline usize usize_modn(usize a1, b_obj_arg a2) {
|
|||
inline uint8 usize_dec_eq(usize a1, usize a2) { return a1 == a2; }
|
||||
inline uint8 usize_dec_lt(usize a1, usize a2) { return a1 < a2; }
|
||||
inline uint8 usize_dec_le(usize a1, usize a2) { return a1 <= a2; }
|
||||
|
||||
usize usize_mix_hash(usize a1, usize a2);
|
||||
// =======================================
|
||||
// array functions for generated code
|
||||
static_assert(sizeof(unsigned long) == sizeof(size_t), "we assume that `unsigned long` and `size_t` have the same size");
|
||||
|
|
|
|||
7
tests/playground/hash.lean
Normal file
7
tests/playground/hash.lean
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
import init.data.hashable
|
||||
|
||||
def main (xs : List String): IO Unit :=
|
||||
do IO.println $ hash xs.head,
|
||||
IO.println $ hash xs.head.toNat,
|
||||
IO.println $ mixHash 1 2,
|
||||
pure ()
|
||||
Loading…
Add table
Reference in a new issue