From 82fb38b440dde9c8bc5f5692a7821f2004e37381 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Tue, 1 Mar 2016 15:40:22 -0800 Subject: [PATCH] feat(util/rb_tree): add for_each_greater --- src/tests/util/rb_tree.cpp | 17 ++++++++++++++++- src/util/rb_map.h | 6 ++++++ src/util/rb_tree.h | 24 ++++++++++++++++++++++++ 3 files changed, 46 insertions(+), 1 deletion(-) diff --git a/src/tests/util/rb_tree.cpp b/src/tests/util/rb_tree.cpp index 2dd2f8dec5..e5924a6a15 100644 --- a/src/tests/util/rb_tree.cpp +++ b/src/tests/util/rb_tree.cpp @@ -235,6 +235,21 @@ static void tst6() { #endif } +static void tst7() { + int_rb_tree t; + for (int i = 0; i < 1000; i++) { + t.insert(i); + } + for (int i = 0; i < 1000; i++) { + int c = 0; + t.for_each_greater(i, [&](int v) { + lean_assert(v > i); + c++; + }); + lean_assert(c == 1000 - i - 1, c, i); + } +} + int main() { tst1(); tst2(); @@ -242,6 +257,6 @@ int main() { tst4(); tst5(); tst6(); + tst7(); return has_violations() ? 1 : 0; } - diff --git a/src/util/rb_map.h b/src/util/rb_map.h index 126a9a5248..0c26c38303 100644 --- a/src/util/rb_map.h +++ b/src/util/rb_map.h @@ -82,6 +82,12 @@ public: return optional(); } + template + void for_each_greater(K const & k, F && f) const { + auto f_prime = [&](entry const & e) { f(e.first, e.second); }; + m_map.for_each_greater(mk_pair(k, T()), f_prime); + } + /** \brief (For debugging) Display the content of this splay map. */ friend std::ostream & operator<<(std::ostream & out, rb_map const & m) { out << "{"; diff --git a/src/util/rb_tree.h b/src/util/rb_tree.h index 2ca759cb4f..59baba6995 100644 --- a/src/util/rb_tree.h +++ b/src/util/rb_tree.h @@ -248,6 +248,25 @@ class rb_tree : public CMP { return optional(); } + template + void for_each_greater(T const & v, F && f, node_cell const * n) const { + if (n) { + int c = cmp(v, n->m_value); + if (c == 0) { + for_each(f, n->m_right.m_ptr); + } else if (c > 0) { + // v > n->m_value + for_each_greater(v, f, n->m_right.m_ptr); + } else { + // v < n->m_value + lean_assert(c < 0); + for_each_greater(v, f, n->m_left.m_ptr); + f(n->m_value); + for_each(f, n->m_right.m_ptr); + } + } + } + static void display(std::ostream & out, node_cell const * n) { if (n) { out << "("; @@ -371,6 +390,11 @@ public: template optional find_if(F && f) const { return find_if(f, m_root.m_ptr); } + template + void for_each_greater(T const & v, F && f) const { + for_each_greater(v, f, m_root.m_ptr); + } + // For debugging purposes void display(std::ostream & out) const { display(out, m_root.m_ptr); }