diff --git a/src/kernel/environment.cpp b/src/kernel/environment.cpp index 5d6e113ee9..ef6978d3ab 100644 --- a/src/kernel/environment.cpp +++ b/src/kernel/environment.cpp @@ -341,6 +341,26 @@ struct environment::imp { } } + unsigned get_num_objects(bool local) const { + if (local || !has_parent()) { + return m_objects.size(); + } else { + return m_objects.size() + m_parent->get_num_objects(false); + } + } + + object const & get_object(unsigned i, bool local) const { + if (local || !has_parent()) { + return *(m_objects[i]); + } else { + unsigned num_parent_objects = m_parent->get_num_objects(false); + if (i >= num_parent_objects) + return *(m_objects[i - num_parent_objects]); + else + return m_parent->get_object(i, false); + } + } + /** \brief Display universal variable constraints and objects stored in this environment and its parents. */ void display(std::ostream & out, environment const & env) const { if (has_parent()) @@ -468,12 +488,12 @@ named_object const * environment::get_object_ptr(name const & n) const { return m_imp->get_object_ptr(n); } -unsigned environment::get_num_objects() const { - return m_imp->m_objects.size(); +unsigned environment::get_num_objects(bool local) const { + return m_imp->get_num_objects(local); } -object const & environment::get_object(unsigned i) const { - return *(m_imp->m_objects[i]); +object const & environment::get_object(unsigned i, bool local) const { + return m_imp->get_object(i, local); } void environment::display(std::ostream & out) const { diff --git a/src/kernel/environment.h b/src/kernel/environment.h index c9d6e59e0e..d989ca1faa 100644 --- a/src/kernel/environment.h +++ b/src/kernel/environment.h @@ -24,8 +24,8 @@ private: void check_type(name const & n, expr const & t, expr const & v); explicit environment(std::shared_ptr const & ptr); explicit environment(imp * new_ptr); - unsigned get_num_objects() const; - object const & get_object(unsigned i) const; + unsigned get_num_objects(bool local) const; + object const & get_object(unsigned i, bool local) const; public: environment(); ~environment(); @@ -140,22 +140,49 @@ public: class object_iterator { environment const & m_env; unsigned m_idx; + bool m_local; friend class environment; - object_iterator(environment const & env, unsigned idx):m_env(env), m_idx(idx) {} + object_iterator(environment const & env, unsigned idx, bool local):m_env(env), m_idx(idx), m_local(local) {} public: - object_iterator(object_iterator const & s):m_env(s.m_env), m_idx(s.m_idx) {} + object_iterator(object_iterator const & s):m_env(s.m_env), m_idx(s.m_idx), m_local(s.m_local) {} object_iterator & operator++() { ++m_idx; return *this; } object_iterator operator++(int) { object_iterator tmp(*this); operator++(); return tmp; } bool operator==(object_iterator const & s) const { lean_assert(&m_env == &(s.m_env)); return m_idx == s.m_idx; } bool operator!=(object_iterator const & s) const { return !operator==(s); } - object const & operator*() { return m_env.get_object(m_idx); } + object const & operator*() { return m_env.get_object(m_idx, m_local); } }; - /** \brief Return an iterator to the beginning of the sequence of objects stored in this environment */ - object_iterator begin_objects() const { return object_iterator(*this, 0); } + /** + \brief Return an iterator to the beginning of the sequence of + objects stored in this environment. - /** \brief Return an iterator to the end of the sequence of objects stored in this environment */ - object_iterator end_objects() const { return object_iterator(*this, get_num_objects()); } + \remark The objects in this environment and ancestor + environments are considered + */ + object_iterator begin_objects() const { return object_iterator(*this, 0, false); } + + /** + \brief Return an iterator to the end of the sequence of + objects stored in this environment. + + \remark The objects in this environment and ancestor + environments are considered + */ + object_iterator end_objects() const { return object_iterator(*this, get_num_objects(false), false); } + + /** + \brief Return an iterator to the beginning of the sequence of + objects stored in this environment (objects in ancestor + environments are ingored). + */ + object_iterator begin_local_objects() const { return object_iterator(*this, 0, true); } + + /** + \brief Return an iterator to the end of the sequence of + objects stored in this environment (objects in ancestor + environments are ingored). + */ + object_iterator end_local_objects() const { return object_iterator(*this, get_num_objects(true), true); } // ======================================= /** \brief Display universal variable constraints and objects stored in this environment and its parents. */ diff --git a/src/tests/kernel/environment.cpp b/src/tests/kernel/environment.cpp index af8b3a649d..74b1799325 100644 --- a/src/tests/kernel/environment.cpp +++ b/src/tests/kernel/environment.cpp @@ -168,6 +168,30 @@ static void tst7() { std::cout << "Environment\n" << env; } +static void tst8() { + environment env; + std::cout << "=======================\n"; + env.add_var("a", Type()); + env.add_var("b", Type()); + environment env2 = env.mk_child(); + env2.add_var("c", Type()); + env2.add_var("d", Type()); + env2.add_var("e", Type()); + unsigned counter = 0; + std::for_each(env2.begin_local_objects(), env2.end_local_objects(), [&](object const & obj) { std::cout << obj.pp(env2) << "\n"; counter++; }); + lean_assert(counter == 3); + std::cout << "=======================\n"; + counter = 0; + std::for_each(env2.begin_objects(), env2.end_objects(), [&](object const & obj) { std::cout << obj.pp(env2) << "\n"; counter++; }); + lean_assert(counter == 5); + environment env3 = env2.mk_child(); + env3.add_var("f", Type() >> Type()); + std::cout << "=======================\n"; + counter = 0; + std::for_each(env3.begin_objects(), env3.end_objects(), [&](object const & obj) { std::cout << obj.pp(env3) << "\n"; counter++; }); + lean_assert(counter == 6); +} + int main() { enable_trace("is_convertible"); tst1(); @@ -177,5 +201,6 @@ int main() { tst5(); tst6(); tst7(); + tst8(); return has_violations() ? 1 : 0; }