feat(runtime): implement string.iterator primitives in the new runtime

Some of the primitives do not have optimal implementation.

@Kha Could you please check if everything we use in the parser has a
reasonable implementation?
This commit is contained in:
Leonardo de Moura 2018-11-15 10:42:23 -08:00
parent ed4eeddf0a
commit efa703d2b5
5 changed files with 446 additions and 114 deletions

View file

@ -187,7 +187,7 @@ void del(object * o) {
#if 0
static object * sarray_ensure_capacity(object * o, size_t extra) {
lean_assert(!is_heap_obj(o) || !is_shared(o));
lean_assert(!is_exclusive(o));
size_t sz = sarray_size(o);
size_t cap = sarray_capacity(o);
if (sz + extra > cap) {
@ -381,109 +381,6 @@ obj_res parray_copy(b_obj_arg o) {
return r;
}
// =======================================
// Strings
static inline char * w_string_data(object * o) { lean_assert(is_string(o)); return reinterpret_cast<char *>(o) + sizeof(string_object); }
static object * string_ensure_capacity(object * o, size_t extra) {
lean_assert(!is_heap_obj(o) || !is_shared(o));
size_t sz = string_size(o);
size_t cap = string_capacity(o);
if (sz + extra > cap) {
object * new_o = alloc_string(sz, cap + sz + extra, string_len(o));
lean_assert(string_capacity(new_o) >= sz + extra);
memcpy(w_string_data(new_o), string_data(o), sz);
free_heap_obj(o);
return new_o;
} else {
return o;
}
}
object * mk_string(char const * s) {
size_t sz = strlen(s);
size_t len = utf8_strlen(s);
size_t rsz = sz + 1;
object * r = alloc_string(rsz, rsz, len);
memcpy(w_string_data(r), s, sz+1);
return r;
}
object * mk_string(std::string const & s) {
size_t sz = s.size();
size_t len = utf8_strlen(s);
size_t rsz = sz + 1;
object * r = alloc_string(rsz, rsz, len);
memcpy(w_string_data(r), s.data(), sz);
w_string_data(r)[sz] = 0;
return r;
}
static size_t mk_capacity(size_t sz) {
return sz*2;
}
object * string_push(object * s, unsigned c) {
size_t sz = string_size(s);
size_t len = string_len(s);
object * r;
if (!is_heap_obj(s) || is_shared(s)) {
r = alloc_string(sz, mk_capacity(sz+5), len);
memcpy(w_string_data(r), string_data(s), sz - 1);
dec_ref(s);
} else {
r = string_ensure_capacity(s, 5);
}
unsigned consumed = push_unicode_scalar(w_string_data(r) + sz - 1, c);
to_string(r)->m_size = sz + consumed;
to_string(r)->m_length++;
w_string_data(r)[sz + consumed - 1] = 0;
return r;
}
object * string_append(object * s1, object * s2) {
lean_assert(s1 != s2);
size_t sz1 = string_size(s1);
size_t sz2 = string_size(s2);
size_t len1 = string_len(s1);
size_t len2 = string_len(s2);
size_t new_len = len1 + len2;
unsigned new_sz = sz1 + sz2 - 1;
object * r;
if (!is_heap_obj(s1) || is_shared(s1)) {
r = alloc_string(new_sz, mk_capacity(new_sz), new_len);
memcpy(w_string_data(r), string_data(s1), sz1 - 1);
dec_ref(s1);
} else {
r = string_ensure_capacity(s1, sz2-1);
}
memcpy(w_string_data(r) + sz1 - 1, string_data(s2), sz2 - 1);
to_string(r)->m_size = new_sz;
to_string(r)->m_length = new_len;
w_string_data(r)[new_sz - 1] = 0;
return r;
}
bool string_eq(object * s1, object * s2) {
if (string_size(s1) != string_size(s2))
return false;
return std::memcmp(string_data(s1), string_data(s2), string_size(s1)) == 0;
}
bool string_eq(object * s1, char const * s2) {
if (string_size(s1) != strlen(s2) + 1)
return false;
return std::memcmp(string_data(s1), s2, string_size(s1)) == 0;
}
bool string_lt(object * s1, object * s2) {
size_t sz1 = string_size(s1) - 1; // ignore null char in the end
size_t sz2 = string_size(s2) - 1; // ignore null char in the end
int r = std::memcmp(string_data(s1), string_data(s2), std::min(sz1, sz2));
return r < 0 || (r == 0 && sz1 < sz2);
}
// =======================================
// Closures
@ -966,6 +863,16 @@ b_obj_res io_wait_any_core(b_obj_arg task_list) {
// =======================================
// Natural numbers
size_t size_t_of_nat(b_obj_arg n) {
if (is_scalar(n)) {
return unbox(n);
} else if (mpz_value(n).is_unsigned_long_int()) {
return static_cast<size_t>(mpz_value(n).get_unsigned_long_int());
} else {
return std::numeric_limits<size_t>::max();
}
}
object * nat_big_add(object * a1, object * a2) {
lean_assert(!is_scalar(a1) || !is_scalar(a2));
if (is_scalar(a1))
@ -1177,12 +1084,404 @@ bool int_big_lt(object * a1, object * a2) {
}
}
// =======================================
// Strings
static inline char * w_string_cstr(object * o) { lean_assert(is_string(o)); return reinterpret_cast<char *>(o) + sizeof(string_object); }
static object * string_ensure_capacity(object * o, size_t extra) {
lean_assert(is_exclusive(o));
size_t sz = string_size(o);
size_t cap = string_capacity(o);
if (sz + extra > cap) {
object * new_o = alloc_string(sz, cap + sz + extra, string_len(o));
lean_assert(string_capacity(new_o) >= sz + extra);
memcpy(w_string_cstr(new_o), string_cstr(o), sz);
free_heap_obj(o);
return new_o;
} else {
return o;
}
}
object * mk_string(char const * s) {
size_t sz = strlen(s);
size_t len = utf8_strlen(s);
size_t rsz = sz + 1;
object * r = alloc_string(rsz, rsz, len);
memcpy(w_string_cstr(r), s, sz+1);
return r;
}
object * mk_string(std::string const & s) {
size_t sz = s.size();
size_t len = utf8_strlen(s);
size_t rsz = sz + 1;
object * r = alloc_string(rsz, rsz, len);
memcpy(w_string_cstr(r), s.data(), sz);
w_string_cstr(r)[sz] = 0;
return r;
}
static size_t mk_capacity(size_t sz) {
return sz*2;
}
object * string_push(object * s, unsigned c) {
size_t sz = string_size(s);
size_t len = string_len(s);
object * r;
if (!is_exclusive(s)) {
r = alloc_string(sz, mk_capacity(sz+5), len);
memcpy(w_string_cstr(r), string_cstr(s), sz - 1);
dec_ref(s);
} else {
r = string_ensure_capacity(s, 5);
}
unsigned consumed = push_unicode_scalar(w_string_cstr(r) + sz - 1, c);
to_string(r)->m_size = sz + consumed;
to_string(r)->m_length++;
w_string_cstr(r)[sz + consumed - 1] = 0;
return r;
}
object * string_append(object * s1, object * s2) {
lean_assert(s1 != s2);
size_t sz1 = string_size(s1);
size_t sz2 = string_size(s2);
size_t len1 = string_len(s1);
size_t len2 = string_len(s2);
size_t new_len = len1 + len2;
unsigned new_sz = sz1 + sz2 - 1;
object * r;
if (!is_exclusive(s1)) {
r = alloc_string(new_sz, mk_capacity(new_sz), new_len);
memcpy(w_string_cstr(r), string_cstr(s1), sz1 - 1);
dec_ref(s1);
} else {
r = string_ensure_capacity(s1, sz2-1);
}
memcpy(w_string_cstr(r) + sz1 - 1, string_cstr(s2), sz2 - 1);
to_string(r)->m_size = new_sz;
to_string(r)->m_length = new_len;
w_string_cstr(r)[new_sz - 1] = 0;
return r;
}
bool string_eq(object * s1, object * s2) {
if (string_size(s1) != string_size(s2))
return false;
return std::memcmp(string_cstr(s1), string_cstr(s2), string_size(s1)) == 0;
}
bool string_eq(object * s1, char const * s2) {
if (string_size(s1) != strlen(s2) + 1)
return false;
return std::memcmp(string_cstr(s1), s2, string_size(s1)) == 0;
}
bool string_lt(object * s1, object * s2) {
size_t sz1 = string_size(s1) - 1; // ignore null char in the end
size_t sz2 = string_size(s2) - 1; // ignore null char in the end
int r = std::memcmp(string_cstr(s1), string_cstr(s2), std::min(sz1, sz2));
return r < 0 || (r == 0 && sz1 < sz2);
}
static std::string list_as_string(b_obj_arg lst) {
std::string s;
b_obj_arg o = lst;
while (!is_scalar(o)) {
push_unicode_scalar(s, unbox(cnstr_get(o, 0)));
o = cnstr_get(o, 1);
}
return s;
}
static obj_res string_to_list_core(std::string const & s, bool reverse = false) {
buffer<unsigned> tmp;
utf8_decode(s, tmp);
if (reverse)
std::reverse(tmp.begin(), tmp.end());
obj_res r = box(0);
unsigned i = tmp.size();
while (i > 0) {
--i;
obj_res new_r = alloc_cnstr(1, 2, 0);
cnstr_set(new_r, 0, box(tmp[i]));
cnstr_set(new_r, 1, r);
r = new_r;
}
return r;
}
obj_res string_mk(obj_arg cs) {
std::string s = list_as_string(cs);
dec(cs);
return mk_string(s);
}
obj_res string_data(obj_arg s) {
std::string tmp(string_cstr(s), string_size(s));
dec_ref(s);
return string_to_list_core(tmp);
}
/* `pos` is in bytes, and `remaining` is in characters */
static obj_res mk_iterator(obj_arg s, size_t pos, size_t remaining) {
obj_res r = alloc_cnstr(0, 1, sizeof(size_t)*2);
cnstr_set(r, 0, s);
cnstr_set_scalar<size_t>(r, sizeof(object*), pos);
cnstr_set_scalar<size_t>(r, sizeof(object*)+sizeof(size_t), remaining);
return r;
}
static b_obj_res it_string(b_obj_arg it) { return cnstr_get(it, 0); }
static size_t it_pos(b_obj_arg it) { return cnstr_get_scalar<size_t>(it, sizeof(object*)); }
static size_t it_remaining(b_obj_arg it) { return cnstr_get_scalar<size_t>(it, sizeof(object*)+sizeof(size_t)); }
static void it_set_string(u_obj_arg it, obj_arg s) { cnstr_set(it, 0, s); }
static void it_set_pos(u_obj_arg it, size_t pos) { cnstr_set_scalar<size_t>(it, sizeof(object*), pos); }
static void it_set_remaining(u_obj_arg it, size_t r) { cnstr_set_scalar<size_t>(it, sizeof(object*)+sizeof(size_t), r); }
/* instance : inhabited char := ⟨'A'⟩ */
static uint32 mk_default_char() { return 65; }
static bool is_unshared_it_string(b_obj_arg it) { return is_exclusive(it) && !is_shared(cnstr_get(it, 0)); }
static unsigned get_utf8_char_size_at(std::string const & s, unsigned i) {
if (auto sz = is_utf8_first_byte(s[i])) {
return *sz;
} else {
return 1;
}
}
obj_res string_mk_iterator(obj_arg s) {
return mk_iterator(s, 0, string_len(s));
}
uint32 string_iterator_curr(b_obj_arg it) {
object * s = it_string(it);
size_t i = it_pos(it);
if (i < string_size(s)) {
return next_utf8(string_cstr(s), i);
} else {
return mk_default_char();
}
}
/* def set_curr : iterator → char → iterator */
obj_res string_iterator_set_curr(obj_arg it, uint32 c) {
object * s = it_string(it);
size_t i = it_pos(it);
if (i >= string_size(s)) {
/* at end */
return it;
}
if (is_unshared_it_string(it)) {
if (static_cast<unsigned char>(string_cstr(s)[i]) < 128 && c < 128) {
/* easy case, old and new characters are encoded using 1 byte */
w_string_cstr(s)[i] = c;
return it;
}
}
/* TODO(Leo): improve performance of the special cases.
Example: `it` is not shared, but string is; new and old characters have the same size; etc. */
std::string tmp;
push_unicode_scalar(tmp, c);
std::string new_s(string_cstr(s), string_size(s));
new_s.replace(i, get_utf8_char_size_at(new_s, i), tmp);
size_t rem = it_remaining(it);
dec_ref(it);
return mk_iterator(mk_string(new_s), i, rem);
}
/* def next : iterator → iterator */
obj_res string_iterator_next(obj_arg it) {
object * s = it_string(it);
size_t i = it_pos(it);
size_t r = it_remaining(it);
if (i < string_size(s)) {
next_utf8(string_cstr(s), i);
if (is_exclusive(it)) {
it_set_pos(it, i);
it_set_remaining(it, r-1);
return it;
} else {
inc_ref(s);
obj_res new_it = mk_iterator(s, i, r-1);
dec_ref(it);
return new_it;
}
} else {
return it;
}
}
/* def prev : iterator → iterator */
obj_res string_iterator_prev(obj_arg it) {
object * s = it_string(it);
size_t i = it_pos(it);
size_t r = it_remaining(it);
if (i > 0) {
size_t new_i = i;
/* we have to walk at most 4 steps backwards */
for (unsigned j = 0; j < 4; j++) {
--new_i;
if (is_utf8_first_byte(string_cstr(s)[new_i])) {
if (is_exclusive(it)) {
it_set_pos(it, new_i);
it_set_remaining(it, r+1);
return it;
} else {
inc_ref(s);
obj_res new_it = mk_iterator(s, new_i, r+1);
dec_ref(it);
return new_it;
}
}
}
/* incorrectly encoded utf-8 string */
return it;
} else {
return it;
}
}
/* def has_next : iterator → bool */
uint8 string_iterator_has_next(b_obj_arg it) {
return it_pos(it) < string_size(it_string(it));
}
/* def has_prev : iterator → bool */
uint8 string_iterator_has_prev(b_obj_arg it) {
return it_pos(it) > 0;
}
obj_res string_iterator_remaining(b_obj_arg it) {
return nat_of_size_t(it_remaining(it));
}
obj_res string_iterator_offset(b_obj_arg it) {
size_t len = string_len(it_string(it));
size_t rem = it_remaining(it);
return nat_of_size_t(len - rem);
}
/* def to_string : iterator → string */
obj_res string_iterator_to_string(b_obj_arg it) {
object * s = it_string(it);
inc_ref(s);
return s;
}
/* def to_end : iterator → iterator */
obj_res string_iterator_to_end(obj_arg it) {
object * s = it_string(it);
if (is_exclusive(it)) {
it_set_pos(it, string_size(s));
it_set_remaining(it, 0);
return it;
} else {
inc_ref(s);
obj_res new_it = mk_iterator(s, string_size(s), 0);
dec_ref(it);
return new_it;
}
}
/* def remaining_to_string : iterator → string */
obj_res string_iterator_remaining_to_string(b_obj_arg it) {
object * s = it_string(it);
size_t i = it_pos(it);
std::string r;
for (; i < string_size(s); i++) {
r += string_cstr(s)[i];
}
return mk_string(r);
}
/* def prev_to_string : iterator → string */
obj_res string_iterator_prev_to_string(b_obj_arg it) {
object * s = it_string(it);
size_t pos = it_pos(it);
std::string r;
for (size_t i = 0; i < pos; i++) {
r += string_cstr(s)[i];
}
return mk_string(r);
}
/* def insert : iterator → string → iterator */
obj_res string_iterator_insert(obj_arg it, b_obj_arg s) {
object * s_0 = it_string(it);
object * s_1 = s;
size_t i = it_pos(it);
size_t r = it_remaining(it);
if (i >= string_size(s_0)) {
/* insert s in the end */
if (is_unshared_it_string(it)) {
object * new_s = string_append(s_0, s_1);
it_set_string(it, new_s);
it_set_remaining(it, r + string_len(s_1));
return it;
} else {
inc_ref(s_0);
object * new_s = string_append(s_0, s_1);
dec_ref(it);
return mk_iterator(new_s, i, r + string_len(s_1));
}
} else {
/* insert in the middle */
/* TODO(Leo): optimize is_unshared_it_string(it) case */
std::string new_s(string_cstr(s_0), string_size(s_0));
new_s.insert(i, std::string(string_cstr(s_1), string_size(s_1)));
dec_ref(it);
return mk_iterator(mk_string(new_s), i, r + string_len(s_1));
}
}
/* def remove : iterator → nat → iterator */
obj_res string_iterator_remove(obj_arg it, b_obj_arg n0) {
object * s = it_string(it);
size_t sz = string_size(s);
size_t i = it_pos(it);
size_t j = i;
size_t n = size_t_of_nat(n0);
size_t new_len = string_len(s);
size_t r = it_remaining(it);
for (size_t k = 0; k < n && j < sz; k++) {
next_utf8(string_cstr(s), j);
new_len--;
r--;
}
size_t count = j - i;
/* TODO(Leo): optimize case wher is_unshared_it_string(it) */
std::string new_s(string_cstr(s), sz);
new_s.erase(i, count);
dec_ref(it);
return mk_iterator(mk_string(new_s), i, r);
}
/* def extract : iterator → iterator → option string */
obj_res string_iterator_extract(b_obj_arg it1, b_obj_arg it2) {
object * s1 = it_string(it1);
object * s2 = it_string(it2);
if (&s1 != &s2 && string_ne(s1, s2))
return mk_option_none();
size_t pos1 = it_pos(it1);
size_t pos2 = it_pos(it2);
if (pos2 < pos1)
return mk_option_none();
size_t new_sz = pos2 - pos1;
object * r = alloc_cnstr(new_sz, new_sz, it_remaining(it1) - it_remaining(it2));
memcpy(w_string_cstr(r), string_cstr(s1) + pos1, new_sz);
return mk_option_some(r);
}
// =======================================
// Debugging helper functions
void dbg_print_str(object * o) {
lean_assert(is_string(o));
std::cout << string_data(o) << "\n";
std::cout << string_cstr(o) << "\n";
}
void dbg_print_num(object * o) {

View file

@ -17,6 +17,11 @@ Author: Leonardo de Moura
#include "runtime/thread.h"
namespace lean {
typedef unsigned char uint8;
typedef unsigned short uint16;
typedef unsigned uint32;
typedef unsigned long long uint64;
/*
The primitives implemented in the runtime do not modify the RC of its arguments.
Callers are responsible for increasing/decreasing the RCs using the `inc`/`dec` operations.
@ -241,6 +246,7 @@ inline rc_type get_rc(object * o) {
}
inline bool is_shared(object * o) { return get_rc(o) > 1; }
inline bool is_exclusive(object * o) { return is_heap_obj(o) && !is_shared(o); }
inline void inc_ref(object * o) {
if (is_mt_heap_obj(o)) {
@ -694,6 +700,10 @@ inline obj_res mk_nat_obj(uint64 n) {
}
}
inline obj_res nat_of_size_t(size_t n) {
return (sizeof(size_t) == sizeof(unsigned)) ? mk_nat_obj(static_cast<unsigned>(n)) : mk_nat_obj(static_cast<uint64>(n));
}
inline uint64 nat2uint64(b_obj_arg a) {
lean_assert(is_scalar(a));
return unbox(a);
@ -971,6 +981,12 @@ inline bool int_lt(b_obj_arg a1, b_obj_arg a2) {
}
}
// =======================================
// Option
inline obj_res mk_option_none() { return box(0); }
inline obj_res mk_option_some(obj_arg v) { obj_res r = alloc_cnstr(1, 1, 0); cnstr_set(r, 0, v); return v; }
// =======================================
// String
@ -979,14 +995,31 @@ inline obj_res alloc_string(size_t size, size_t capacity, size_t len) {
}
obj_res mk_string(char const * s);
obj_res mk_string(std::string const & s);
inline char const * string_data(b_obj_arg o) { lean_assert(is_string(o)); return reinterpret_cast<char*>(o) + sizeof(string_object); }
inline char const * string_cstr(b_obj_arg o) { lean_assert(is_string(o)); return reinterpret_cast<char*>(o) + sizeof(string_object); }
inline size_t string_size(b_obj_arg o) { return to_string(o)->m_size; }
inline size_t string_len(b_obj_arg o) { return to_string(o)->m_length; }
obj_res string_push(obj_arg s, unsigned c);
obj_res string_push(obj_arg s, uint32 c);
obj_res string_append(obj_arg s1, b_obj_arg s2);
inline obj_res string_length(b_obj_arg s) {
return (sizeof(size_t) == sizeof(unsigned)) ? mk_nat_obj(static_cast<unsigned>(string_len(s))) : mk_nat_obj(static_cast<uint64>(string_len(s)));
}
inline obj_res string_length(b_obj_arg s) { return nat_of_size_t(string_len(s)); }
obj_res string_mk(obj_arg cs);
obj_res string_data(obj_arg s);
obj_res string_mk_iterator(obj_arg s);
uint32 string_iterator_curr(b_obj_arg it);
obj_res string_iterator_set_curr(obj_arg it, uint32 c);
obj_res string_iterator_next(obj_arg it);
obj_res string_iterator_prev(obj_arg it);
uint8 string_iterator_has_next(b_obj_arg it);
uint8 string_iterator_has_prev(b_obj_arg it);
obj_res string_iterator_insert(obj_arg it, b_obj_arg s);
obj_res string_iterator_remove(obj_arg it, b_obj_arg n);
obj_res string_iterator_remaining(b_obj_arg it);
obj_res string_iterator_offset(b_obj_arg it);
obj_res string_iterator_remaining_to_string(b_obj_arg it);
obj_res string_iterator_prev_to_string(b_obj_arg it);
obj_res string_iterator_to_string(b_obj_arg it);
obj_res string_iterator_to_end(obj_arg it);
obj_res string_iterator_extract(b_obj_arg it1, b_obj_arg it2);
bool string_eq(b_obj_arg s1, b_obj_arg s2);
inline bool string_ne(b_obj_arg s1, b_obj_arg s2) { return !string_eq(s1, s2); }
bool string_eq(b_obj_arg s1, char const * s2);

View file

@ -138,7 +138,7 @@ void serializer::write_string_object(object * o) {
size_t len = string_len(o);
write_size_t(sz);
write_size_t(len);
char const * it = string_data(o);
char const * it = string_cstr(o);
char const * end = it + sz;
for (; it != end; ++it)
m_out.put(*it);
@ -332,7 +332,7 @@ object * deserializer::read_string_object() {
size_t sz = read_size_t();
size_t len = read_size_t();
object * r = alloc_string(sz, sz, len);
unsigned char * it = const_cast<unsigned char*>(reinterpret_cast<unsigned char const *>(string_data(r)));
unsigned char * it = const_cast<unsigned char*>(reinterpret_cast<unsigned char const *>(string_cstr(r)));
unsigned char * end = it + sz;
for (; it != end; ++it)
*it = m_in.get();

View file

@ -107,7 +107,7 @@ void tst5() {
lean_assert(r1 == r2);
lean_assert(is_thunk(r1));
object * str = thunk_get(r1);
lean_assert(strcmp(string_data(str), "hello world") == 0);
lean_assert(strcmp(string_cstr(str), "hello world") == 0);
USED(r2); USED(str);
}

View file

@ -23,7 +23,7 @@ public:
size_t num_bytes() const { return string_size(raw()) - 1; }
/* The length is the number of unicode scalars. It is <= num_bytes. */
size_t length() const { return string_len(raw()); }
char const * data() const { return string_data(raw()); }
char const * data() const { return string_cstr(raw()); }
std::string to_std_string() const { return std::string(data(), num_bytes()); }
friend bool operator==(string_ref const & s1, string_ref const & s2) { return string_eq(s1.raw(), s2.raw()); }
friend bool operator!=(string_ref const & s1, string_ref const & s2) { return string_ne(s1.raw(), s2.raw()); }