From fa70930ef483fa2bcf09bf3e9b8f24e9c8b99618 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 11 May 2015 16:19:51 -0700 Subject: [PATCH] feat(library/blast): add union-find datastructure --- src/CMakeLists.txt | 1 + src/library/blast/union_find.h | 232 +++++++++++++++++++++++++ src/tests/library/blast/CMakeLists.txt | 3 + src/tests/library/blast/union_find.cpp | 60 +++++++ 4 files changed, 296 insertions(+) create mode 100644 src/library/blast/union_find.h create mode 100644 src/tests/library/blast/CMakeLists.txt create mode 100644 src/tests/library/blast/union_find.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2ed3a04d2c..2c09cf16b5 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -340,6 +340,7 @@ add_subdirectory(tests/util/numerics) add_subdirectory(tests/util/interval) add_subdirectory(tests/kernel) add_subdirectory(tests/library) +add_subdirectory(tests/library/blast) add_subdirectory(tests/frontends/lean) # Include style check diff --git a/src/library/blast/union_find.h b/src/library/blast/union_find.h new file mode 100644 index 0000000000..a63ce6aab9 --- /dev/null +++ b/src/library/blast/union_find.h @@ -0,0 +1,232 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "util/rb_map.h" +#include "util/optional.h" + +namespace lean { +/** \brief (template for) Union-find datastructure that "explains" implied equalities. + We use functional datastructures to be able to have a O(1) copy operation. + + Each join/union is decorated with a justification. + + \c cmp implements a total order on \c node. That is, it provides the method: + int operator()(node const & n1, node const & n2) const + s.t. the result is negative when n1 < n2, 0 if n1 == n2, and positive if n1 > n2. + + The implementation also provides a method to traverse the elements of an equivalence + class. The implementation is based on a datastructure used in the Simplify theorem prover. + + Since it provides extra functionality, it does not implement the O(n*alpha(n)) amortized time + per operation algorithm. +*/ +template +class union_find : private cmp { + rb_map m_root; + rb_map m_next; + rb_map m_rank; + rb_map, cmp> m_jst; + + bool is_equal(node const & n1, node const & n2) const { + return cmp::operator()(n1, n2) == 0; + } + + unsigned rank(node const & n) const { + if (auto r = m_rank.find(n)) + return *r; + else + return 0; + } + void set_rank(node const & n, unsigned r) { m_rank.insert(n, r); } + + node const & root(node const & n) const { + if (auto r = m_root.find(n)) + return *r; + else + return n; + } + void set_root(node const & n, node const & r) { m_root.insert(n, r); } + + node const & next(node const & n) const { + if (auto r = m_next.find(n)) + return *r; + else + return n; + } + void set_next(node const & n, node const & nx) { m_next.insert(n, nx); } + void set_justification(node const & n, node const & t, jst const & j) { m_jst.insert(n, mk_pair(t, j)); } + + // for debugging purposes only + bool check_inv(node const & n) const { + node r = root(n); + unsigned sz = size(r); + node it = n; + do { + lean_assert_eq(root(it), r); + lean_assert(reaches(it, r)); + lean_assert(size(it), sz); + it = next(it); + } while (!is_equal(it, n)); + return true; + } + + void join_core(node const & n1, node r1, node const & n2, node r2, jst const & j) { + // r1 will be the root of the resulting equivalence class. + DEBUG_CODE(unsigned sz1 = size(n1); unsigned sz2 = size(n2);); + // Step 1) update m_jst + // + // Given justification paths + // n1 -> ... -> r1 + // n2 -> ... -> r2 + // we generate the path + // r2 -> ... -> n2 -> n1 -> ... -> r1 + buffer> trace; + node it2 = n2; + while (pair const * p = m_jst.find(it2)) { + trace.push_back(*p); + it2 = p->first; + } + lean_assert(is_equal(it2, r2)); + unsigned i = trace.size(); + while (i > 1) { + --i; + set_justification(trace[i].first, trace[i-1].first, trace[i].second); + } + if (i > 0) { + set_justification(trace[0].first, n2, trace[0].second); + } + set_justification(n2, n1, j); + + // Step 2) update m_root of nodes in n2 equivalence class to r1 + it2 = n2; + do { + set_root(it2, r1); + it2 = next(it2); + } while (!is_equal(it2, n2)); + + // Step 3) update m_next of r1 and r2 + node next1 = next(r1); + node next2 = next(r2); + set_next(r1, next2); + set_next(r2, next1); + + lean_assert(check_inv(r1)); + lean_assert_eq(size(n1), sz1 + sz2); + } + + /** \brief Return true if \c s reaches \c r by following m_jst edges */ + bool reaches(node const & s, node const & r) const { + node it = s; + while (true) { + if (is_equal(it, r)) + return true; + pair const * p = m_jst.find(it); + if (p) { + it = p->first; + } else { + return false; + } + } + } + + void explain_core(node const & n1, node const & n2, node const & r, buffer & js) const { + lean_assert(is_equal(root(n1), r)); + lean_assert(is_equal(root(n2), r)); + node it1 = n1; + while (true) { + if (reaches(n2, it1)) { + // it is the common in the paths n1 -> r and n2 -> r + node it2 = n2; + unsigned sz1 = js.size(); + while (true) { + if (is_equal(it2, it1)) { + std::reverse(js.begin() + sz1, js.end()); + return; + } + pair const * p = m_jst.find(it2); + lean_assert(p); + js.push_back(p->second); + it2 = p->first; + } + } else { + pair const * p = m_jst.find(it1); + lean_assert(p); + js.push_back(p->second); + it1 = p->first; + } + } + } + +public: + union_find(cmp const & c = cmp()):cmp(c) {} + + /** \brief Merge the equivalence class of \c n1 with \c n2 using justification \c j. */ + void join(node const & n1, node const & n2, jst const & j) { + node const & r1 = root(n1); + node const & r2 = root(n2); + if (is_equal(r1, r2)) + return; + unsigned k1 = rank(n1); + unsigned k2 = rank(n2); + if (k1 > k2) { + join_core(n1, r1, n2, r2, j); + } else if (k1 == k2) { + join_core(n1, r1, n2, r2, j); + set_rank(n1, k1+1); + } else { + join_core(n2, r2, n1, r1, j); + } + } + + /** \brief Return the size of the equivalence class containing \c n */ + unsigned size(node const & n) const { + unsigned r = 0; + node it = n; + do { + lean_assert(is_eq(it, n)); + r++; + it = next(it); + } while (!is_equal(it, n)); + return r; + } + + /** \brief Return the representative for the equivalence class containing \c n. */ + node rep(node const & n) const { return root(n); } + + /** \brief Return true iff \c n1 and \c n2 are in the same equivalence class. */ + bool is_eq(node const & n1, node const & n2) const { return is_equal(rep(n1), rep(n2)); } + + /** \brief For each node \c m in the equivalence class of \c n, execute f(m) */ + template + void for_each(node const & n, F f) const { + node it = n; + do { + lean_assert(is_eq(it, n)); + f(it); + it = next(it); + } while (!is_equal(it, n)); + } + + /** \brief If is_eq(n1, n2), then return true and store the justifications that can be used to produce + a transitivity+symmetry proof for n1 = n2 */ + bool explain(node const & n1, node const & n2, buffer & js) const { + node r1 = root(n1); + node r2 = root(n2); + if (is_equal(r1, r2)) { + if (rank(r1) >= rank(r2)) { + explain_core(n1, n2, r1, js); + } else { + explain_core(n2, n1, r1, js); + std::reverse(js.begin(), js.end()); + } + return true; + } else { + return false; + } + } +}; +} diff --git a/src/tests/library/blast/CMakeLists.txt b/src/tests/library/blast/CMakeLists.txt new file mode 100644 index 0000000000..973c7d3505 --- /dev/null +++ b/src/tests/library/blast/CMakeLists.txt @@ -0,0 +1,3 @@ +add_executable(union_find union_find.cpp) +target_link_libraries(union_find "util" ${EXTRA_LIBS}) +add_test(union_find "${CMAKE_CURRENT_BINARY_DIR}/union_find") diff --git a/src/tests/library/blast/union_find.cpp b/src/tests/library/blast/union_find.cpp new file mode 100644 index 0000000000..fb8cffeb49 --- /dev/null +++ b/src/tests/library/blast/union_find.cpp @@ -0,0 +1,60 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "util/test.h" +#include "util/init_module.h" +#include "library/blast/union_find.h" +using namespace lean; + +typedef union_find uf; + +static void check_explain(uf const & m, unsigned n1, unsigned n2, std::initializer_list const & expected_js) { + buffer js1; + bool r = m.explain(n1, n2, js1); + lean_assert(r); + lean_assert(m.rep(n1) == m.rep(n2)); + std::sort(js1.begin(), js1.end()); + buffer js2; + js2.append(expected_js.size(), expected_js.begin()); + std::sort(js2.begin(), js2.end()); + lean_assert(js1.size() == js2.size()); + for (unsigned i = 0; i < js1.size(); i++) { + lean_assert(js1[i] == js2[i]); + } +} + +static void tst1() { + uf m; + m.join(1, 2, 0); + lean_assert(m.is_eq(1, 1)); + lean_assert(m.is_eq(1, 2)); + m.join(1, 3, 1); + lean_assert(m.is_eq(2, 3)); + check_explain(m, 2, 3, {0, 1}); + check_explain(m, 2, 1, {0}); + check_explain(m, 1, 3, {1}); + m.join(3, 4, 2); + m.join(5, 1, 3); + m.join(6, 2, 4); + lean_assert(m.rep(6) == m.rep(4)); + check_explain(m, 2, 3, {0, 1}); + check_explain(m, 6, 4, {0, 1, 2, 4}); + check_explain(m, 5, 6, {0, 3, 4}); + lean_assert_eq(m.size(1), 6); + + for (unsigned i = 10; i < 30; i++) + m.join(i, i+1, i); + check_explain(m, 10, 15, {10, 11, 12, 13, 14}); + lean_assert_eq(m.size(10), 21); +} + +int main() { + save_stack_info(); + initialize_util_module(); + tst1(); + finalize_util_module(); + return has_violations() ? 1 : 0; +}