diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index d6d4602eab..93210f6806 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -187,7 +187,7 @@ void del(object * o) { #if 0 static object * sarray_ensure_capacity(object * o, size_t extra) { - lean_assert(!is_heap_obj(o) || !is_shared(o)); + lean_assert(!is_exclusive(o)); size_t sz = sarray_size(o); size_t cap = sarray_capacity(o); if (sz + extra > cap) { @@ -381,109 +381,6 @@ obj_res parray_copy(b_obj_arg o) { return r; } -// ======================================= -// Strings - -static inline char * w_string_data(object * o) { lean_assert(is_string(o)); return reinterpret_cast(o) + sizeof(string_object); } - -static object * string_ensure_capacity(object * o, size_t extra) { - lean_assert(!is_heap_obj(o) || !is_shared(o)); - size_t sz = string_size(o); - size_t cap = string_capacity(o); - if (sz + extra > cap) { - object * new_o = alloc_string(sz, cap + sz + extra, string_len(o)); - lean_assert(string_capacity(new_o) >= sz + extra); - memcpy(w_string_data(new_o), string_data(o), sz); - free_heap_obj(o); - return new_o; - } else { - return o; - } -} - -object * mk_string(char const * s) { - size_t sz = strlen(s); - size_t len = utf8_strlen(s); - size_t rsz = sz + 1; - object * r = alloc_string(rsz, rsz, len); - memcpy(w_string_data(r), s, sz+1); - return r; -} - -object * mk_string(std::string const & s) { - size_t sz = s.size(); - size_t len = utf8_strlen(s); - size_t rsz = sz + 1; - object * r = alloc_string(rsz, rsz, len); - memcpy(w_string_data(r), s.data(), sz); - w_string_data(r)[sz] = 0; - return r; -} - -static size_t mk_capacity(size_t sz) { - return sz*2; -} - -object * string_push(object * s, unsigned c) { - size_t sz = string_size(s); - size_t len = string_len(s); - object * r; - if (!is_heap_obj(s) || is_shared(s)) { - r = alloc_string(sz, mk_capacity(sz+5), len); - memcpy(w_string_data(r), string_data(s), sz - 1); - dec_ref(s); - } else { - r = string_ensure_capacity(s, 5); - } - unsigned consumed = push_unicode_scalar(w_string_data(r) + sz - 1, c); - to_string(r)->m_size = sz + consumed; - to_string(r)->m_length++; - w_string_data(r)[sz + consumed - 1] = 0; - return r; -} - -object * string_append(object * s1, object * s2) { - lean_assert(s1 != s2); - size_t sz1 = string_size(s1); - size_t sz2 = string_size(s2); - size_t len1 = string_len(s1); - size_t len2 = string_len(s2); - size_t new_len = len1 + len2; - unsigned new_sz = sz1 + sz2 - 1; - object * r; - if (!is_heap_obj(s1) || is_shared(s1)) { - r = alloc_string(new_sz, mk_capacity(new_sz), new_len); - memcpy(w_string_data(r), string_data(s1), sz1 - 1); - dec_ref(s1); - } else { - r = string_ensure_capacity(s1, sz2-1); - } - memcpy(w_string_data(r) + sz1 - 1, string_data(s2), sz2 - 1); - to_string(r)->m_size = new_sz; - to_string(r)->m_length = new_len; - w_string_data(r)[new_sz - 1] = 0; - return r; -} - -bool string_eq(object * s1, object * s2) { - if (string_size(s1) != string_size(s2)) - return false; - return std::memcmp(string_data(s1), string_data(s2), string_size(s1)) == 0; -} - -bool string_eq(object * s1, char const * s2) { - if (string_size(s1) != strlen(s2) + 1) - return false; - return std::memcmp(string_data(s1), s2, string_size(s1)) == 0; -} - -bool string_lt(object * s1, object * s2) { - size_t sz1 = string_size(s1) - 1; // ignore null char in the end - size_t sz2 = string_size(s2) - 1; // ignore null char in the end - int r = std::memcmp(string_data(s1), string_data(s2), std::min(sz1, sz2)); - return r < 0 || (r == 0 && sz1 < sz2); -} - // ======================================= // Closures @@ -966,6 +863,16 @@ b_obj_res io_wait_any_core(b_obj_arg task_list) { // ======================================= // Natural numbers +size_t size_t_of_nat(b_obj_arg n) { + if (is_scalar(n)) { + return unbox(n); + } else if (mpz_value(n).is_unsigned_long_int()) { + return static_cast(mpz_value(n).get_unsigned_long_int()); + } else { + return std::numeric_limits::max(); + } +} + object * nat_big_add(object * a1, object * a2) { lean_assert(!is_scalar(a1) || !is_scalar(a2)); if (is_scalar(a1)) @@ -1177,12 +1084,404 @@ bool int_big_lt(object * a1, object * a2) { } } +// ======================================= +// Strings + +static inline char * w_string_cstr(object * o) { lean_assert(is_string(o)); return reinterpret_cast(o) + sizeof(string_object); } + +static object * string_ensure_capacity(object * o, size_t extra) { + lean_assert(is_exclusive(o)); + size_t sz = string_size(o); + size_t cap = string_capacity(o); + if (sz + extra > cap) { + object * new_o = alloc_string(sz, cap + sz + extra, string_len(o)); + lean_assert(string_capacity(new_o) >= sz + extra); + memcpy(w_string_cstr(new_o), string_cstr(o), sz); + free_heap_obj(o); + return new_o; + } else { + return o; + } +} + +object * mk_string(char const * s) { + size_t sz = strlen(s); + size_t len = utf8_strlen(s); + size_t rsz = sz + 1; + object * r = alloc_string(rsz, rsz, len); + memcpy(w_string_cstr(r), s, sz+1); + return r; +} + +object * mk_string(std::string const & s) { + size_t sz = s.size(); + size_t len = utf8_strlen(s); + size_t rsz = sz + 1; + object * r = alloc_string(rsz, rsz, len); + memcpy(w_string_cstr(r), s.data(), sz); + w_string_cstr(r)[sz] = 0; + return r; +} + +static size_t mk_capacity(size_t sz) { + return sz*2; +} + +object * string_push(object * s, unsigned c) { + size_t sz = string_size(s); + size_t len = string_len(s); + object * r; + if (!is_exclusive(s)) { + r = alloc_string(sz, mk_capacity(sz+5), len); + memcpy(w_string_cstr(r), string_cstr(s), sz - 1); + dec_ref(s); + } else { + r = string_ensure_capacity(s, 5); + } + unsigned consumed = push_unicode_scalar(w_string_cstr(r) + sz - 1, c); + to_string(r)->m_size = sz + consumed; + to_string(r)->m_length++; + w_string_cstr(r)[sz + consumed - 1] = 0; + return r; +} + +object * string_append(object * s1, object * s2) { + lean_assert(s1 != s2); + size_t sz1 = string_size(s1); + size_t sz2 = string_size(s2); + size_t len1 = string_len(s1); + size_t len2 = string_len(s2); + size_t new_len = len1 + len2; + unsigned new_sz = sz1 + sz2 - 1; + object * r; + if (!is_exclusive(s1)) { + r = alloc_string(new_sz, mk_capacity(new_sz), new_len); + memcpy(w_string_cstr(r), string_cstr(s1), sz1 - 1); + dec_ref(s1); + } else { + r = string_ensure_capacity(s1, sz2-1); + } + memcpy(w_string_cstr(r) + sz1 - 1, string_cstr(s2), sz2 - 1); + to_string(r)->m_size = new_sz; + to_string(r)->m_length = new_len; + w_string_cstr(r)[new_sz - 1] = 0; + return r; +} + +bool string_eq(object * s1, object * s2) { + if (string_size(s1) != string_size(s2)) + return false; + return std::memcmp(string_cstr(s1), string_cstr(s2), string_size(s1)) == 0; +} + +bool string_eq(object * s1, char const * s2) { + if (string_size(s1) != strlen(s2) + 1) + return false; + return std::memcmp(string_cstr(s1), s2, string_size(s1)) == 0; +} + +bool string_lt(object * s1, object * s2) { + size_t sz1 = string_size(s1) - 1; // ignore null char in the end + size_t sz2 = string_size(s2) - 1; // ignore null char in the end + int r = std::memcmp(string_cstr(s1), string_cstr(s2), std::min(sz1, sz2)); + return r < 0 || (r == 0 && sz1 < sz2); +} + +static std::string list_as_string(b_obj_arg lst) { + std::string s; + b_obj_arg o = lst; + while (!is_scalar(o)) { + push_unicode_scalar(s, unbox(cnstr_get(o, 0))); + o = cnstr_get(o, 1); + } + return s; +} + +static obj_res string_to_list_core(std::string const & s, bool reverse = false) { + buffer tmp; + utf8_decode(s, tmp); + if (reverse) + std::reverse(tmp.begin(), tmp.end()); + obj_res r = box(0); + unsigned i = tmp.size(); + while (i > 0) { + --i; + obj_res new_r = alloc_cnstr(1, 2, 0); + cnstr_set(new_r, 0, box(tmp[i])); + cnstr_set(new_r, 1, r); + r = new_r; + } + return r; +} + +obj_res string_mk(obj_arg cs) { + std::string s = list_as_string(cs); + dec(cs); + return mk_string(s); +} + +obj_res string_data(obj_arg s) { + std::string tmp(string_cstr(s), string_size(s)); + dec_ref(s); + return string_to_list_core(tmp); +} + +/* `pos` is in bytes, and `remaining` is in characters */ +static obj_res mk_iterator(obj_arg s, size_t pos, size_t remaining) { + obj_res r = alloc_cnstr(0, 1, sizeof(size_t)*2); + cnstr_set(r, 0, s); + cnstr_set_scalar(r, sizeof(object*), pos); + cnstr_set_scalar(r, sizeof(object*)+sizeof(size_t), remaining); + return r; +} + +static b_obj_res it_string(b_obj_arg it) { return cnstr_get(it, 0); } +static size_t it_pos(b_obj_arg it) { return cnstr_get_scalar(it, sizeof(object*)); } +static size_t it_remaining(b_obj_arg it) { return cnstr_get_scalar(it, sizeof(object*)+sizeof(size_t)); } +static void it_set_string(u_obj_arg it, obj_arg s) { cnstr_set(it, 0, s); } +static void it_set_pos(u_obj_arg it, size_t pos) { cnstr_set_scalar(it, sizeof(object*), pos); } +static void it_set_remaining(u_obj_arg it, size_t r) { cnstr_set_scalar(it, sizeof(object*)+sizeof(size_t), r); } +/* instance : inhabited char := ⟨'A'⟩ */ +static uint32 mk_default_char() { return 65; } +static bool is_unshared_it_string(b_obj_arg it) { return is_exclusive(it) && !is_shared(cnstr_get(it, 0)); } + +static unsigned get_utf8_char_size_at(std::string const & s, unsigned i) { + if (auto sz = is_utf8_first_byte(s[i])) { + return *sz; + } else { + return 1; + } +} + +obj_res string_mk_iterator(obj_arg s) { + return mk_iterator(s, 0, string_len(s)); +} + +uint32 string_iterator_curr(b_obj_arg it) { + object * s = it_string(it); + size_t i = it_pos(it); + if (i < string_size(s)) { + return next_utf8(string_cstr(s), i); + } else { + return mk_default_char(); + } +} + +/* def set_curr : iterator → char → iterator */ +obj_res string_iterator_set_curr(obj_arg it, uint32 c) { + object * s = it_string(it); + size_t i = it_pos(it); + if (i >= string_size(s)) { + /* at end */ + return it; + } + if (is_unshared_it_string(it)) { + if (static_cast(string_cstr(s)[i]) < 128 && c < 128) { + /* easy case, old and new characters are encoded using 1 byte */ + w_string_cstr(s)[i] = c; + return it; + } + } + /* TODO(Leo): improve performance of the special cases. + Example: `it` is not shared, but string is; new and old characters have the same size; etc. */ + std::string tmp; + push_unicode_scalar(tmp, c); + std::string new_s(string_cstr(s), string_size(s)); + new_s.replace(i, get_utf8_char_size_at(new_s, i), tmp); + size_t rem = it_remaining(it); + dec_ref(it); + return mk_iterator(mk_string(new_s), i, rem); +} + +/* def next : iterator → iterator */ +obj_res string_iterator_next(obj_arg it) { + object * s = it_string(it); + size_t i = it_pos(it); + size_t r = it_remaining(it); + if (i < string_size(s)) { + next_utf8(string_cstr(s), i); + if (is_exclusive(it)) { + it_set_pos(it, i); + it_set_remaining(it, r-1); + return it; + } else { + inc_ref(s); + obj_res new_it = mk_iterator(s, i, r-1); + dec_ref(it); + return new_it; + } + } else { + return it; + } +} + +/* def prev : iterator → iterator */ +obj_res string_iterator_prev(obj_arg it) { + object * s = it_string(it); + size_t i = it_pos(it); + size_t r = it_remaining(it); + if (i > 0) { + size_t new_i = i; + /* we have to walk at most 4 steps backwards */ + for (unsigned j = 0; j < 4; j++) { + --new_i; + if (is_utf8_first_byte(string_cstr(s)[new_i])) { + if (is_exclusive(it)) { + it_set_pos(it, new_i); + it_set_remaining(it, r+1); + return it; + } else { + inc_ref(s); + obj_res new_it = mk_iterator(s, new_i, r+1); + dec_ref(it); + return new_it; + } + } + } + /* incorrectly encoded utf-8 string */ + return it; + } else { + return it; + } +} + +/* def has_next : iterator → bool */ +uint8 string_iterator_has_next(b_obj_arg it) { + return it_pos(it) < string_size(it_string(it)); +} + +/* def has_prev : iterator → bool */ +uint8 string_iterator_has_prev(b_obj_arg it) { + return it_pos(it) > 0; +} + +obj_res string_iterator_remaining(b_obj_arg it) { + return nat_of_size_t(it_remaining(it)); +} + +obj_res string_iterator_offset(b_obj_arg it) { + size_t len = string_len(it_string(it)); + size_t rem = it_remaining(it); + return nat_of_size_t(len - rem); +} + +/* def to_string : iterator → string */ +obj_res string_iterator_to_string(b_obj_arg it) { + object * s = it_string(it); + inc_ref(s); + return s; +} + +/* def to_end : iterator → iterator */ +obj_res string_iterator_to_end(obj_arg it) { + object * s = it_string(it); + if (is_exclusive(it)) { + it_set_pos(it, string_size(s)); + it_set_remaining(it, 0); + return it; + } else { + inc_ref(s); + obj_res new_it = mk_iterator(s, string_size(s), 0); + dec_ref(it); + return new_it; + } +} + +/* def remaining_to_string : iterator → string */ +obj_res string_iterator_remaining_to_string(b_obj_arg it) { + object * s = it_string(it); + size_t i = it_pos(it); + std::string r; + for (; i < string_size(s); i++) { + r += string_cstr(s)[i]; + } + return mk_string(r); +} + +/* def prev_to_string : iterator → string */ +obj_res string_iterator_prev_to_string(b_obj_arg it) { + object * s = it_string(it); + size_t pos = it_pos(it); + std::string r; + for (size_t i = 0; i < pos; i++) { + r += string_cstr(s)[i]; + } + return mk_string(r); +} + +/* def insert : iterator → string → iterator */ +obj_res string_iterator_insert(obj_arg it, b_obj_arg s) { + object * s_0 = it_string(it); + object * s_1 = s; + size_t i = it_pos(it); + size_t r = it_remaining(it); + if (i >= string_size(s_0)) { + /* insert s in the end */ + if (is_unshared_it_string(it)) { + object * new_s = string_append(s_0, s_1); + it_set_string(it, new_s); + it_set_remaining(it, r + string_len(s_1)); + return it; + } else { + inc_ref(s_0); + object * new_s = string_append(s_0, s_1); + dec_ref(it); + return mk_iterator(new_s, i, r + string_len(s_1)); + } + } else { + /* insert in the middle */ + /* TODO(Leo): optimize is_unshared_it_string(it) case */ + std::string new_s(string_cstr(s_0), string_size(s_0)); + new_s.insert(i, std::string(string_cstr(s_1), string_size(s_1))); + dec_ref(it); + return mk_iterator(mk_string(new_s), i, r + string_len(s_1)); + } +} + +/* def remove : iterator → nat → iterator */ +obj_res string_iterator_remove(obj_arg it, b_obj_arg n0) { + object * s = it_string(it); + size_t sz = string_size(s); + size_t i = it_pos(it); + size_t j = i; + size_t n = size_t_of_nat(n0); + size_t new_len = string_len(s); + size_t r = it_remaining(it); + for (size_t k = 0; k < n && j < sz; k++) { + next_utf8(string_cstr(s), j); + new_len--; + r--; + } + size_t count = j - i; + /* TODO(Leo): optimize case wher is_unshared_it_string(it) */ + std::string new_s(string_cstr(s), sz); + new_s.erase(i, count); + dec_ref(it); + return mk_iterator(mk_string(new_s), i, r); +} + +/* def extract : iterator → iterator → option string */ +obj_res string_iterator_extract(b_obj_arg it1, b_obj_arg it2) { + object * s1 = it_string(it1); + object * s2 = it_string(it2); + if (&s1 != &s2 && string_ne(s1, s2)) + return mk_option_none(); + size_t pos1 = it_pos(it1); + size_t pos2 = it_pos(it2); + if (pos2 < pos1) + return mk_option_none(); + size_t new_sz = pos2 - pos1; + object * r = alloc_cnstr(new_sz, new_sz, it_remaining(it1) - it_remaining(it2)); + memcpy(w_string_cstr(r), string_cstr(s1) + pos1, new_sz); + return mk_option_some(r); +} + // ======================================= // Debugging helper functions void dbg_print_str(object * o) { lean_assert(is_string(o)); - std::cout << string_data(o) << "\n"; + std::cout << string_cstr(o) << "\n"; } void dbg_print_num(object * o) { diff --git a/src/runtime/object.h b/src/runtime/object.h index 40e271f217..0f1677f471 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -17,6 +17,11 @@ Author: Leonardo de Moura #include "runtime/thread.h" namespace lean { +typedef unsigned char uint8; +typedef unsigned short uint16; +typedef unsigned uint32; +typedef unsigned long long uint64; + /* The primitives implemented in the runtime do not modify the RC of its arguments. Callers are responsible for increasing/decreasing the RCs using the `inc`/`dec` operations. @@ -241,6 +246,7 @@ inline rc_type get_rc(object * o) { } inline bool is_shared(object * o) { return get_rc(o) > 1; } +inline bool is_exclusive(object * o) { return is_heap_obj(o) && !is_shared(o); } inline void inc_ref(object * o) { if (is_mt_heap_obj(o)) { @@ -694,6 +700,10 @@ inline obj_res mk_nat_obj(uint64 n) { } } +inline obj_res nat_of_size_t(size_t n) { + return (sizeof(size_t) == sizeof(unsigned)) ? mk_nat_obj(static_cast(n)) : mk_nat_obj(static_cast(n)); +} + inline uint64 nat2uint64(b_obj_arg a) { lean_assert(is_scalar(a)); return unbox(a); @@ -971,6 +981,12 @@ inline bool int_lt(b_obj_arg a1, b_obj_arg a2) { } } +// ======================================= +// Option + +inline obj_res mk_option_none() { return box(0); } +inline obj_res mk_option_some(obj_arg v) { obj_res r = alloc_cnstr(1, 1, 0); cnstr_set(r, 0, v); return v; } + // ======================================= // String @@ -979,14 +995,31 @@ inline obj_res alloc_string(size_t size, size_t capacity, size_t len) { } obj_res mk_string(char const * s); obj_res mk_string(std::string const & s); -inline char const * string_data(b_obj_arg o) { lean_assert(is_string(o)); return reinterpret_cast(o) + sizeof(string_object); } +inline char const * string_cstr(b_obj_arg o) { lean_assert(is_string(o)); return reinterpret_cast(o) + sizeof(string_object); } inline size_t string_size(b_obj_arg o) { return to_string(o)->m_size; } inline size_t string_len(b_obj_arg o) { return to_string(o)->m_length; } -obj_res string_push(obj_arg s, unsigned c); +obj_res string_push(obj_arg s, uint32 c); obj_res string_append(obj_arg s1, b_obj_arg s2); -inline obj_res string_length(b_obj_arg s) { - return (sizeof(size_t) == sizeof(unsigned)) ? mk_nat_obj(static_cast(string_len(s))) : mk_nat_obj(static_cast(string_len(s))); -} +inline obj_res string_length(b_obj_arg s) { return nat_of_size_t(string_len(s)); } +obj_res string_mk(obj_arg cs); +obj_res string_data(obj_arg s); +obj_res string_mk_iterator(obj_arg s); +uint32 string_iterator_curr(b_obj_arg it); +obj_res string_iterator_set_curr(obj_arg it, uint32 c); +obj_res string_iterator_next(obj_arg it); +obj_res string_iterator_prev(obj_arg it); +uint8 string_iterator_has_next(b_obj_arg it); +uint8 string_iterator_has_prev(b_obj_arg it); +obj_res string_iterator_insert(obj_arg it, b_obj_arg s); +obj_res string_iterator_remove(obj_arg it, b_obj_arg n); +obj_res string_iterator_remaining(b_obj_arg it); +obj_res string_iterator_offset(b_obj_arg it); +obj_res string_iterator_remaining_to_string(b_obj_arg it); +obj_res string_iterator_prev_to_string(b_obj_arg it); +obj_res string_iterator_to_string(b_obj_arg it); +obj_res string_iterator_to_end(obj_arg it); +obj_res string_iterator_extract(b_obj_arg it1, b_obj_arg it2); + bool string_eq(b_obj_arg s1, b_obj_arg s2); inline bool string_ne(b_obj_arg s1, b_obj_arg s2) { return !string_eq(s1, s2); } bool string_eq(b_obj_arg s1, char const * s2); diff --git a/src/runtime/serializer.cpp b/src/runtime/serializer.cpp index a2f2952743..3ec554e0a5 100644 --- a/src/runtime/serializer.cpp +++ b/src/runtime/serializer.cpp @@ -138,7 +138,7 @@ void serializer::write_string_object(object * o) { size_t len = string_len(o); write_size_t(sz); write_size_t(len); - char const * it = string_data(o); + char const * it = string_cstr(o); char const * end = it + sz; for (; it != end; ++it) m_out.put(*it); @@ -332,7 +332,7 @@ object * deserializer::read_string_object() { size_t sz = read_size_t(); size_t len = read_size_t(); object * r = alloc_string(sz, sz, len); - unsigned char * it = const_cast(reinterpret_cast(string_data(r))); + unsigned char * it = const_cast(reinterpret_cast(string_cstr(r))); unsigned char * end = it + sz; for (; it != end; ++it) *it = m_in.get(); diff --git a/src/tests/util/object.cpp b/src/tests/util/object.cpp index 0a8c1308bb..f80a670d74 100644 --- a/src/tests/util/object.cpp +++ b/src/tests/util/object.cpp @@ -107,7 +107,7 @@ void tst5() { lean_assert(r1 == r2); lean_assert(is_thunk(r1)); object * str = thunk_get(r1); - lean_assert(strcmp(string_data(str), "hello world") == 0); + lean_assert(strcmp(string_cstr(str), "hello world") == 0); USED(r2); USED(str); } diff --git a/src/util/string_ref.h b/src/util/string_ref.h index fcbdb43a1e..33c7b25267 100644 --- a/src/util/string_ref.h +++ b/src/util/string_ref.h @@ -23,7 +23,7 @@ public: size_t num_bytes() const { return string_size(raw()) - 1; } /* The length is the number of unicode scalars. It is <= num_bytes. */ size_t length() const { return string_len(raw()); } - char const * data() const { return string_data(raw()); } + char const * data() const { return string_cstr(raw()); } std::string to_std_string() const { return std::string(data(), num_bytes()); } friend bool operator==(string_ref const & s1, string_ref const & s2) { return string_eq(s1.raw(), s2.raw()); } friend bool operator!=(string_ref const & s1, string_ref const & s2) { return string_ne(s1.raw(), s2.raw()); }