feat(library/init/data): add ByteArray

This commit is contained in:
Leonardo de Moura 2019-05-08 16:43:00 -07:00
parent fd487d8db7
commit 18aa7de408
7 changed files with 195 additions and 22 deletions

View file

@ -0,0 +1,63 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import init.data.array.basic init.data.uint init.data.option.basic
universes u
structure ByteArray :=
(data : Array UInt8)
attribute [extern cpp "lean::byte_array_mk"] ByteArray.mk
attribute [extern cpp "lean::byte_array_data"] ByteArray.data
namespace ByteArray
@[extern cpp inline "lean::mk_empty_byte_array(#1)"]
def mkEmpty (c : @& Nat) : ByteArray :=
{ data := Array.empty }
def empty : ByteArray :=
mkEmpty 0
@[extern cpp "lean::byte_array_push"]
def push : ByteArray → UInt8 → ByteArray
| ⟨bs⟩ b := ⟨bs.push b⟩
@[extern cpp "lean::byte_array_size"]
def size : (@& ByteArray) → Nat
| ⟨bs⟩ := bs.size
@[extern cpp "lean::byte_array_get"]
def get : (@& ByteArray) → (@& Nat) → UInt8
| ⟨bs⟩ i := bs.get i
@[extern cpp "lean::byte_array_set"]
def set : ByteArray → (@& Nat) → UInt8 → ByteArray
| ⟨bs⟩ i b := ⟨bs.set i b⟩
def isEmpty (s : ByteArray) : Bool :=
s.size == 0
partial def toListAux (bs : ByteArray) : Nat → List UInt8 → List UInt8
| i r :=
if i < bs.size then
toListAux (i+1) (bs.get i :: r)
else
r.reverse
def toList (bs : ByteArray) : List UInt8 :=
toListAux bs 0 []
end ByteArray
def List.toByteArrayAux : List UInt8 → ByteArray → ByteArray
| [] r := r
| (b::bs) r := List.toByteArrayAux bs (r.push b)
def List.toByteArray (bs : List UInt8) : ByteArray :=
bs.toByteArrayAux ByteArray.empty
instance : HasToString ByteArray :=
⟨λ bs, bs.toList.toString⟩

View file

@ -0,0 +1,7 @@
/-
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
-/
prelude
import init.data.bytearray.basic

View file

@ -5,7 +5,7 @@ Authors: Leonardo de Moura
-/
prelude
import init.data.basic init.data.nat init.data.char init.data.string
import init.data.list init.data.int init.data.array
import init.data.list init.data.int init.data.array init.data.bytearray
import init.data.fin init.data.uint init.data.ordering
import init.data.rbtree init.data.rbmap init.data.option.basic init.data.option.instances
import init.data.hashmap init.data.random

View file

@ -128,6 +128,11 @@ def getOpt : Nat → List α → Option α
| (n+1) (a::as) := getOpt n as
| _ _ := none
def set : List α → Nat → α → List α
| (a::as) 0 b := b::as
| (a::as) (n+1) b := a::(set as n b)
| [] _ _ := []
def head [Inhabited α] : List αα
| [] := default α
| (a::_) := a

View file

@ -322,27 +322,6 @@ void * alloc_heap_object(size_t sz) {
return static_cast<char *>(r) + sizeof(rc_type);
}
// =======================================
// Scalar arrays
#if 0
static object * sarray_ensure_capacity(object * o, size_t extra) {
lean_assert(!is_exclusive(o));
size_t sz = sarray_size(o);
size_t cap = sarray_capacity(o);
if (sz + extra > cap) {
unsigned esize = sarray_elem_size(o);
object * new_o = alloc_sarray(esize, sz, cap + sz + extra);
lean_assert(sarray_capacity(new_o) >= sz + extra);
memcpy(sarray_cptr<char>(new_o), sarray_cptr<char>(o), esize * sz);
free_heap_obj(o);
return new_o;
} else {
return o;
}
}
#endif
// =======================================
// Arrays
static object * g_array_empty = nullptr;
@ -1748,6 +1727,72 @@ usize string_hash(b_obj_arg s) {
return hash_str(sz, str, 11);
}
// =======================================
// ByteArray
obj_res copy_sarray(obj_arg a, bool expand) {
unsigned esz = sarray_elem_size(a);
size_t sz = sarray_size(a);
size_t cap = sarray_capacity(a);
lean_assert(cap >= sz);
if (expand) cap = (cap + 1) * 2;
lean_assert(!expand || cap > sz);
object * r = alloc_sarray(esz, sz, cap);
uint8 * it = sarray_cptr<uint8>(a);
uint8 * dest = sarray_cptr<uint8>(r);
memcpy(dest, it, esz*sz);
dec(a);
return r;
}
obj_res copy_byte_array(obj_arg a) {
return copy_sarray(a, false);
}
obj_res byte_array_mk(obj_arg a) {
usize sz = array_size(a);
obj_res r = alloc_sarray(1, sz, sz);
object ** it = array_cptr(a);
object ** end = it + sz;
uint8 * dest = sarray_cptr<uint8>(r);
for (; it != end; ++it, ++dest) {
*dest = unbox(*it);
}
dec(a);
return r;
}
obj_res byte_array_data(obj_arg a) {
usize sz = sarray_size(a);
obj_res r = alloc_array(sz, sz);
uint8 * it = sarray_cptr<uint8>(a);
uint8 * end = it+sz;
object ** dest = array_cptr(r);
for (; it != end; ++it, ++dest) {
*dest = box(*it);
}
dec(a);
return r;
}
obj_res byte_array_push(obj_arg a, uint8 b) {
object * r;
if (is_exclusive(a)) {
if (sarray_capacity(a) > sarray_size(a))
r = a;
else
r = copy_sarray(a, true);
} else {
r = copy_sarray(a, sarray_capacity(a) < 2*sarray_size(a) + 1);
}
lean_assert(sarray_capacity(r) > sarray_size(r));
size_t & sz = to_sarray(r)->m_size;
uint8 * it = sarray_cptr<uint8>(r) + sz;
*it = b;
sz++;
return r;
}
// =======================================
// array functions for generated code

View file

@ -1213,6 +1213,47 @@ inline uint8 string_dec_eq(b_obj_arg s1, b_obj_arg s2) { return string_eq(s1, s2
inline uint8 string_dec_lt(b_obj_arg s1, b_obj_arg s2) { return string_lt(s1, s2); }
usize string_hash(b_obj_arg);
// =======================================
// ByteArray
obj_res byte_array_mk(obj_arg a);
obj_res byte_array_data(obj_arg a);
obj_res copy_byte_array(obj_arg a);
inline obj_res mk_empty_byte_array(b_obj_arg capacity) {
if (!is_scalar(capacity)) throw std::bad_alloc(); // we will run out of memory
usize cap = unbox(capacity);
return alloc_sarray(1, 0, cap);
}
inline obj_res byte_array_size(b_obj_arg a) {
return box(sarray_size(a));
}
inline uint8 byte_array_get(b_obj_arg a, b_obj_arg i) {
if (is_scalar(i)) {
usize idx = unbox(i);
return idx < sarray_size(a) ? sarray_get<uint8>(a, idx) : 0;
} else {
/* The index must be out of bounds. Otherwise we would be out of memory. */
return 0;
}
}
obj_res byte_array_push(obj_arg a, uint8 b);
inline obj_res byte_array_set(obj_arg a, b_obj_arg i, uint8 b) {
if (!is_scalar(i)) return a;
usize idx = unbox(i);
if (idx >= sarray_size(a)) return a;
obj_res r;
if (is_exclusive(a)) r = a;
else r = copy_byte_array(a);
uint8 * it = sarray_cptr<uint8>(r) + idx;
*it = b;
return r;
}
// =======================================
// uint8
uint8 uint8_of_big_nat(b_obj_arg a);

View file

@ -0,0 +1,12 @@
def main (xs : List String) : IO Unit :=
do
let bs := [1, 2, 3].toByteArray,
IO.println bs,
let bs := bs.push 4,
let bs := bs.set 1 20,
IO.println bs,
let bs₁ := bs.set 2 30,
IO.println bs₁,
IO.println bs,
IO.println bs.size,
pure ()