From ee050431e0b357101f62869b2620cdecdfffa2ff Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 3 Apr 2019 04:01:36 -0700 Subject: [PATCH] feat(runtime): add primitive hash functions --- library/init/data/hashable.lean | 12 +++---- src/runtime/mpz.h | 2 ++ src/runtime/object.cpp | 58 +++++++++++++++++++++++++++++++-- src/runtime/object.h | 47 +++++++------------------- tests/playground/hash.lean | 7 ++++ 5 files changed, 81 insertions(+), 45 deletions(-) create mode 100644 tests/playground/hash.lean diff --git a/library/init/data/hashable.lean b/library/init/data/hashable.lean index bbc0d9f3a4..85f69aa511 100644 --- a/library/init/data/hashable.lean +++ b/library/init/data/hashable.lean @@ -12,17 +12,13 @@ class Hashable (α : Type u) := export Hashable (hash) --- TODO: mark as builtin and opaque -def mixHash (u₁ u₂ : USize) : USize := -default USize - --- TODO: mark as builtin -protected def String.hash (s : String) : USize := -default USize +@[extern cpp "lean::usize_mix_hash"] +constant mixHash (u₁ u₂ : USize) : USize := default _ +@[extern cpp "lean::string_hash"] +protected constant String.hash (s : String) : USize := default _ instance : Hashable String := ⟨String.hash⟩ --- TODO: add builtin protected def Nat.hash (n : Nat) : USize := USize.ofNat n diff --git a/src/runtime/mpz.h b/src/runtime/mpz.h index e3b1d16f4b..24e4cf5502 100644 --- a/src/runtime/mpz.h +++ b/src/runtime/mpz.h @@ -209,6 +209,8 @@ public: friend void mul2k(mpz & a, mpz const & b, unsigned k) { mpz_mul_2exp(a.m_val, b.m_val, k); } // a <- b / 2^k friend void div2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_q_2exp(a.m_val, b.m_val, k); } + // a <- b % 2^k + friend void mod2k(mpz & a, mpz const & b, unsigned k) { mpz_tdiv_r_2exp(a.m_val, b.m_val, k); } /** \brief Return the position of the most significant bit. diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 7d4343a61b..18b27cb10c 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1383,6 +1383,54 @@ bool int_big_lt(object * a1, object * a2) { } } +// ======================================= +// UInt + +uint8 uint8_of_big_nat(b_obj_arg a) { + mpz r; + mod2k(r, mpz_value(a), 8); + return static_cast(r.get_unsigned_int()); +} +uint16 uint16_of_big_nat(b_obj_arg a) { + mpz r; + mod2k(r, mpz_value(a), 16); + return static_cast(r.get_unsigned_int()); +} +uint32 uint32_of_big_nat(b_obj_arg a) { + mpz r; + mod2k(r, mpz_value(a), 32); + return static_cast(r.get_unsigned_int()); +} +uint64 uint64_of_big_nat(b_obj_arg a) { + mpz r; + mod2k(r, mpz_value(a), 64); + if (sizeof(void*) == 8) { + // 64 bit + return static_cast(r.get_unsigned_long_int()); + } else { + // 32 bit + mpz l; + mod2k(l, r, 32); + mpz h; + div2k(h, r, 32); + return (static_cast(h.get_unsigned_int()) << 32) + static_cast(l.get_unsigned_int()); + } +} + +usize usize_of_big_nat(b_obj_arg a) { + if (sizeof(void*) == 8) + return uint64_of_big_nat(a); + else + return uint32_of_big_nat(a); +} + +usize usize_mix_hash(usize a1, usize a2) { + if (sizeof(void*) == 8) + return hash(static_cast(a1), static_cast(a2)); + else + return hash(static_cast(a1), static_cast(a2)); +} + // ======================================= // Strings @@ -1591,8 +1639,8 @@ uint32 string_utf8_get(b_obj_arg s, b_obj_arg i0) { /* The reference implementation is: ``` - def utf8_next (s : @& String) (p : @& Pos) : Ppos := - let c := utf8_get s p in + def next (s : @& String) (p : @& Pos) : Ppos := + let c := get s p in p + csize c ``` */ @@ -1697,6 +1745,12 @@ obj_res string_utf8_set(obj_arg s, b_obj_arg i0, uint32 c) { return mk_string(new_s); } +usize string_hash(b_obj_arg s) { + usize sz = string_size(s) - 1; + char const * str = string_cstr(s); + return hash_str(sz, str, 11); +} + // ======================================= // array functions for generated code diff --git a/src/runtime/object.h b/src/runtime/object.h index cb89a54b95..1a58a23574 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -1209,10 +1209,12 @@ inline bool string_ne(b_obj_arg s1, b_obj_arg s2) { return !string_eq(s1, s2); } bool string_lt(b_obj_arg s1, b_obj_arg s2); inline uint8 string_dec_eq(b_obj_arg s1, b_obj_arg s2) { return string_eq(s1, s2); } inline uint8 string_dec_lt(b_obj_arg s1, b_obj_arg s2) { return string_lt(s1, s2); } +usize string_hash(b_obj_arg); // ======================================= // uint8 -inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : 0; } +uint8 uint8_of_big_nat(b_obj_arg a); +inline uint8 uint8_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : uint8_of_big_nat(a); } inline obj_res uint8_to_nat(uint8 a) { return mk_nat_obj(static_cast(a)); } inline uint8 uint8_add(uint8 a1, uint8 a2) { return a1+a2; } inline uint8 uint8_sub(uint8 a1, uint8 a2) { return a1-a2; } @@ -1233,7 +1235,8 @@ inline uint8 uint8_dec_le(uint8 a1, uint8 a2) { return a1 <= a2; } // ======================================= // uint16 -inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : 0; } +uint16 uint16_of_big_nat(b_obj_arg a); +inline uint16 uint16_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : uint16_of_big_nat(a); } inline obj_res uint16_to_nat(uint16 a) { return mk_nat_obj(static_cast(a)); } inline uint16 uint16_add(uint16 a1, uint16 a2) { return a1+a2; } inline uint16 uint16_sub(uint16 a1, uint16 a2) { return a1-a2; } @@ -1254,22 +1257,8 @@ inline uint8 uint16_dec_le(uint16 a1, uint16 a2) { return a1 <= a2; } // ======================================= // uint32 -inline uint32 uint32_of_nat(b_obj_arg a) { - if (is_scalar(a)) { - usize v = unbox(a); - if (v < std::numeric_limits::max()) - return v; - else - return 0; - } else if (sizeof(void*) == 4) { - // 32-bit - mpz const & m = mpz_value(a); - return m.is_unsigned_int() ? mpz_value(a).get_unsigned_int() : 0; - } else { - // 64-bit - return 0; - } -} +uint32 uint32_of_big_nat(b_obj_arg a); +inline uint32 uint32_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : uint32_of_big_nat(a); } inline obj_res uint32_to_nat(uint32 a) { return mk_nat_obj(static_cast(a)); } inline uint32 uint32_add(uint32 a1, uint32 a2) { return a1+a2; } inline uint32 uint32_sub(uint32 a1, uint32 a2) { return a1-a2; } @@ -1295,14 +1284,8 @@ inline uint8 uint32_dec_le(uint32 a1, uint32 a2) { return a1 <= a2; } // ======================================= // uint64 -inline uint64 uint64_of_nat(b_obj_arg a) { - if (is_scalar(a)) { - return unbox(a); - } else { - // TODO(Leo): - return 0; - } -} +uint64 uint64_of_big_nat(b_obj_arg a); +inline uint64 uint64_of_nat(b_obj_arg a) { return is_scalar(a) ? static_cast(unbox(a)) : uint64_of_big_nat(a); } inline obj_res uint64_to_nat(uint64 a) { return mk_nat_obj(a); } inline uint64 uint64_add(uint64 a1, uint64 a2) { return a1+a2; } inline uint64 uint64_sub(uint64 a1, uint64 a2) { return a1-a2; } @@ -1324,14 +1307,8 @@ inline uint8 uint64_dec_le(uint64 a1, uint64 a2) { return a1 <= a2; } // ======================================= // usize -inline usize usize_of_nat(b_obj_arg a) { - if (is_scalar(a)) { - return unbox(a); - } else { - // TODO(Leo): - return 0; - } -} +usize usize_of_big_nat(b_obj_arg a); +inline usize usize_of_nat(b_obj_arg a) { return is_scalar(a) ? unbox(a) : usize_of_big_nat(a); } inline obj_res usize_to_nat(usize a) { return mk_nat_obj(a); } @@ -1352,7 +1329,7 @@ inline usize usize_modn(usize a1, b_obj_arg a2) { inline uint8 usize_dec_eq(usize a1, usize a2) { return a1 == a2; } inline uint8 usize_dec_lt(usize a1, usize a2) { return a1 < a2; } inline uint8 usize_dec_le(usize a1, usize a2) { return a1 <= a2; } - +usize usize_mix_hash(usize a1, usize a2); // ======================================= // array functions for generated code static_assert(sizeof(unsigned long) == sizeof(size_t), "we assume that `unsigned long` and `size_t` have the same size"); diff --git a/tests/playground/hash.lean b/tests/playground/hash.lean new file mode 100644 index 0000000000..c93fbd58d5 --- /dev/null +++ b/tests/playground/hash.lean @@ -0,0 +1,7 @@ +import init.data.hashable + +def main (xs : List String): IO Unit := +do IO.println $ hash xs.head, + IO.println $ hash xs.head.toNat, + IO.println $ mixHash 1 2, + pure ()