diff --git a/src/tests/util/CMakeLists.txt b/src/tests/util/CMakeLists.txt index c47a4b1b17..de168fdf7b 100644 --- a/src/tests/util/CMakeLists.txt +++ b/src/tests/util/CMakeLists.txt @@ -70,3 +70,6 @@ add_test(stackinfo ${CMAKE_CURRENT_BINARY_DIR}/stackinfo) add_executable(serializer serializer.cpp) target_link_libraries(serializer ${EXTRA_LIBS}) add_test(serializer ${CMAKE_CURRENT_BINARY_DIR}/serializer) +add_executable(trie trie.cpp) +target_link_libraries(trie ${EXTRA_LIBS}) +add_test(trie ${CMAKE_CURRENT_BINARY_DIR}/trie) diff --git a/src/tests/util/trie.cpp b/src/tests/util/trie.cpp new file mode 100644 index 0000000000..be4a90e921 --- /dev/null +++ b/src/tests/util/trie.cpp @@ -0,0 +1,38 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "util/test.h" +#include "util/trie.h" +using namespace lean; + +static void tst1() { + ctrie t; + lean_assert(!find(t, "hello")); + t = insert(t, "hello", 3); + lean_assert(*find(t, "hello") == 3); + lean_assert(!find(t, "hell")); + lean_assert(!find(t, "hellow")); + t = insert(t, "hallo", 2); + t = insert(t, "hell", 5); + lean_assert(*find(t, "hallo") == 2); + lean_assert(*find(t, "hell") == 5); + lean_assert(*find(t, "hello") == 3); + lean_assert(!find(t, "hel")); + ctrie t2 = t; + t2 = insert(t2, "abc", 10); + t2 = insert(t2, "abd", 11); + t2 = insert(t2, "help", 12); + lean_assert(*find(t2, "abd") == 11); + lean_assert(!find(t, "abd")); + ctrie t3 = *t2.find('a'); + lean_assert(*find(t3, "bc") == 10); + lean_assert(*find(t3, "bd") == 11); +} + +int main() { + tst1(); + return has_violations() ? 1 : 0; +} diff --git a/src/util/trie.h b/src/util/trie.h new file mode 100644 index 0000000000..006dab3ecc --- /dev/null +++ b/src/util/trie.h @@ -0,0 +1,129 @@ +/* +Copyright (c) 2014 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include +#include +#include "util/rb_map.h" +#include "util/optional.h" + +namespace lean { +template +class trie : public KeyCMP { + struct cell; + struct node { + cell * m_ptr; + node():m_ptr(new cell()) { m_ptr->inc_ref(); } + node(cell * ptr):m_ptr(ptr) { if (m_ptr) ptr->inc_ref(); } + node(node const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } + node(node && s):m_ptr(s.m_ptr) { s.m_ptr = nullptr; } + ~node() { if (m_ptr) m_ptr->dec_ref(); } + node & operator=(node const & n) { LEAN_COPY_REF(n); } + node & operator=(node&& n) { LEAN_MOVE_REF(n); } + cell * operator->() const { lean_assert(m_ptr); return m_ptr; } + bool is_shared() const { return m_ptr && m_ptr->get_rc() > 1; } + friend void swap(node & n1, node & n2) { std::swap(n1.m_ptr, n2.m_ptr); } + node steal() { node r; swap(r, *this); return r; } + }; + + struct cell { + rb_map m_children; + optional m_value; + MK_LEAN_RC(); + void dealloc() { delete this; } + cell():m_rc(0) {} + cell(Val const & v):m_value(v), m_rc(0) {} + cell(cell const & s):m_children(s.m_children), m_value(s.m_value), m_rc(0) {} + }; + + static node ensure_unshared(node && n) { + if (n.is_shared()) + return node(new cell(*n.m_ptr)); + else + return n; + } + + template + static node insert(node && n, It const & begin, It const & end, Val const & v) { + node h = ensure_unshared(n.steal()); + if (begin == end) { + h->m_value = v; + return h; + } else { + Key k = *begin; + node const * c = h->m_children.find(k); + It it(begin); it++; + if (c == nullptr) { + h->m_children.insert(k, insert(node(), it, end, v)); + } else { + node n(*c); + h->m_children.erase(k); + h->m_children.insert(k, insert(n.steal(), it, end, v)); + } + return h; + } + } + + node m_root; + trie(node const & n):m_root(n) {} +public: + trie() {} + trie(trie const & s):m_root(s.m_root) {} + trie(trie && s):m_root(s.m_root) {} + + trie & operator=(trie const & s) { m_root = s.m_root; return *this; } + trie & operator=(trie && s) { m_root = s.m_root; return *this; } + + template + optional find(It const & begin, It const & end) const { + node const * n = &m_root; + for (It it = begin; it != end; ++it) { + n = (*n)->m_children.find(*it); + if (!n) + return optional(); + } + return (*n)->m_value; + } + + template + void insert(It const & begin, It const & end, Val const & v) { + m_root = insert(m_root.steal(), begin, end, v); + } + + optional find(Key const & k) const { + node const * c = m_root->m_children.find(k); + if (c) + return optional(trie(*c)); + else + return optional(); + } +}; + +struct char_cmp { int operator()(char c1, char c2) const { return c1 < c2 ? -1 : (c1 == c2 ? 0 : 1); } }; + +template +using ctrie = trie; + +template +inline ctrie insert(ctrie const & t, std::string const & str, Val const & v) { + ctrie r(t); + r.insert(str.begin(), str.end(), v); + return r; +} + +template +inline ctrie insert(ctrie const & t, char const * str, Val const & v) { + ctrie r(t); + r.insert(str, str+strlen(str), v); + return r; +} + +template +optional find(ctrie const & t, std::string const & str) { return t.find(str.begin(), str.end()); } + +template +optional find(ctrie const & t, char const * str) { return t.find(str, str + strlen(str)); } +}