feat(runtime): add primitive hash functions

This commit is contained in:
Leonardo de Moura 2019-04-03 04:01:36 -07:00
parent a46e27a3d7
commit ee050431e0
5 changed files with 81 additions and 45 deletions

View file

@ -12,17 +12,13 @@ class Hashable (α : Type u) :=
export Hashable (hash)
-- TODO: mark as builtin and opaque
def mixHash (u₁ u₂ : USize) : USize :=
default USize
-- TODO: mark as builtin
protected def String.hash (s : String) : USize :=
default USize
@[extern cpp "lean::usize_mix_hash"]
constant mixHash (u₁ u₂ : USize) : USize := default _
@[extern cpp "lean::string_hash"]
protected constant String.hash (s : String) : USize := default _
instance : Hashable String := ⟨String.hash⟩
-- TODO: add builtin
protected def Nat.hash (n : Nat) : USize :=
USize.ofNat n

View file

@ -209,6 +209,8 @@ public:
friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); }
// a <- b / 2^k
friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); }
// a <- b % 2^k
friend void mod2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_r_2exp(a.m_val, b.m_val, k); }
/**
\brief Return the position of the most significant bit.

View file

@ -1383,6 +1383,54 @@ bool int_big_lt(object * a1, object * a2) {
}
}
// =======================================
// UInt
uint8 uint8_of_big_nat(b_obj_arg a) {
mpz r;
mod2k(r, mpz_value(a), 8);
return static_cast<uint8>(r.get_unsigned_int());
}
uint16 uint16_of_big_nat(b_obj_arg a) {
mpz r;
mod2k(r, mpz_value(a), 16);
return static_cast<uint16>(r.get_unsigned_int());
}
uint32 uint32_of_big_nat(b_obj_arg a) {
mpz r;
mod2k(r, mpz_value(a), 32);
return static_cast<uint32>(r.get_unsigned_int());
}
uint64 uint64_of_big_nat(b_obj_arg a) {
mpz r;
mod2k(r, mpz_value(a), 64);
if (sizeof(void*) == 8) {
// 64 bit
return static_cast<uint64>(r.get_unsigned_long_int());
} else {
// 32 bit
mpz l;
mod2k(l, r, 32);
mpz h;
div2k(h, r, 32);
return (static_cast<uint64>(h.get_unsigned_int()) << 32) + static_cast<uint64>(l.get_unsigned_int());
}
}
usize usize_of_big_nat(b_obj_arg a) {
if (sizeof(void*) == 8)
return uint64_of_big_nat(a);
else
return uint32_of_big_nat(a);
}
usize usize_mix_hash(usize a1, usize a2) {
if (sizeof(void*) == 8)
return hash(static_cast<uint64>(a1), static_cast<uint64>(a2));
else
return hash(static_cast<uint32>(a1), static_cast<uint32>(a2));
}
// =======================================
// Strings
@ -1591,8 +1639,8 @@ uint32 string_utf8_get(b_obj_arg s, b_obj_arg i0) {
/* The reference implementation is:
```
def utf8_next (s : @& String) (p : @& Pos) : Ppos :=
let c := utf8_get s p in
def next (s : @& String) (p : @& Pos) : Ppos :=
let c := get s p in
p + csize c
```
*/
@ -1697,6 +1745,12 @@ obj_res string_utf8_set(obj_arg s, b_obj_arg i0, uint32 c) {
return mk_string(new_s);
}
usize string_hash(b_obj_arg s) {
usize sz = string_size(s) - 1;
char const * str = string_cstr(s);
return hash_str(sz, str, 11);
}
// =======================================
// array functions for generated code

View file

@ -1209,10 +1209,12 @@ inline bool string_ne(b_obj_arg s1, b_obj_arg s2) { return !string_eq(s1, s2); }
bool string_lt(b_obj_arg s1, b_obj_arg s2);
inline uint8 string_dec_eq(b_obj_arg s1, b_obj_arg s2) { return string_eq(s1, s2); }
inline uint8 string_dec_lt(b_obj_arg s1, b_obj_arg s2) { return string_lt(s1, s2); }
usize string_hash(b_obj_arg);
// =======================================
// uint8
inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint8>(unbox(a)) : 0; }
uint8 uint8_of_big_nat(b_obj_arg a);
inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint8>(unbox(a)) : uint8_of_big_nat(a); }
inline obj_res uint8_to_nat(uint8 a) { return mk_nat_obj(static_cast<usize>(a)); }
inline uint8 uint8_add(uint8 a1, uint8 a2) { return a1+a2; }
inline uint8 uint8_sub(uint8 a1, uint8 a2) { return a1-a2; }
@ -1233,7 +1235,8 @@ inline uint8 uint8_dec_le(uint8 a1, uint8 a2) { return a1 <= a2; }
// =======================================
// uint16
inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint16>(unbox(a)) : 0; }
uint16 uint16_of_big_nat(b_obj_arg a);
inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint16>(unbox(a)) : uint16_of_big_nat(a); }
inline obj_res uint16_to_nat(uint16 a) { return mk_nat_obj(static_cast<usize>(a)); }
inline uint16 uint16_add(uint16 a1, uint16 a2) { return a1+a2; }
inline uint16 uint16_sub(uint16 a1, uint16 a2) { return a1-a2; }
@ -1254,22 +1257,8 @@ inline uint8 uint16_dec_le(uint16 a1, uint16 a2) { return a1 <= a2; }
// =======================================
// uint32
inline uint32 uint32_of_nat(b_obj_arg a) {
if (is_scalar(a)) {
usize v = unbox(a);
if (v < std::numeric_limits<uint32>::max())
return v;
else
return 0;
} else if (sizeof(void*) == 4) {
// 32-bit
mpz const & m = mpz_value(a);
return m.is_unsigned_int() ? mpz_value(a).get_unsigned_int() : 0;
} else {
// 64-bit
return 0;
}
}
uint32 uint32_of_big_nat(b_obj_arg a);
inline uint32 uint32_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint32>(unbox(a)) : uint32_of_big_nat(a); }
inline obj_res uint32_to_nat(uint32 a) { return mk_nat_obj(static_cast<usize>(a)); }
inline uint32 uint32_add(uint32 a1, uint32 a2) { return a1+a2; }
inline uint32 uint32_sub(uint32 a1, uint32 a2) { return a1-a2; }
@ -1295,14 +1284,8 @@ inline uint8 uint32_dec_le(uint32 a1, uint32 a2) { return a1 <= a2; }
// =======================================
// uint64
inline uint64 uint64_of_nat(b_obj_arg a) {
if (is_scalar(a)) {
return unbox(a);
} else {
// TODO(Leo):
return 0;
}
}
uint64 uint64_of_big_nat(b_obj_arg a);
inline uint64 uint64_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast<uint64>(unbox(a)) : uint64_of_big_nat(a); }
inline obj_res uint64_to_nat(uint64 a) { return mk_nat_obj(a); }
inline uint64 uint64_add(uint64 a1, uint64 a2) { return a1+a2; }
inline uint64 uint64_sub(uint64 a1, uint64 a2) { return a1-a2; }
@ -1324,14 +1307,8 @@ inline uint8 uint64_dec_le(uint64 a1, uint64 a2) { return a1 <= a2; }
// =======================================
// usize
inline usize usize_of_nat(b_obj_arg a) {
if (is_scalar(a)) {
return unbox(a);
} else {
// TODO(Leo):
return 0;
}
}
usize usize_of_big_nat(b_obj_arg a);
inline usize usize_of_nat(b_obj_arg a) { return is_scalar(a) ? unbox(a) : usize_of_big_nat(a); }
inline obj_res usize_to_nat(usize a) {
return mk_nat_obj(a);
}
@ -1352,7 +1329,7 @@ inline usize usize_modn(usize a1, b_obj_arg a2) {
inline uint8 usize_dec_eq(usize a1, usize a2) { return a1 == a2; }
inline uint8 usize_dec_lt(usize a1, usize a2) { return a1 < a2; }
inline uint8 usize_dec_le(usize a1, usize a2) { return a1 <= a2; }
usize usize_mix_hash(usize a1, usize a2);
// =======================================
// array functions for generated code
static_assert(sizeof(unsigned long) == sizeof(size_t), "we assume that `unsigned long` and `size_t` have the same size");

View file

@ -0,0 +1,7 @@
import init.data.hashable
def main (xs : List String): IO Unit :=
do IO.println $ hash xs.head,
IO.println $ hash xs.head.toNat,
IO.println $ mixHash 1 2,
pure ()