diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 774b91d2f8..e5637e1163 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -395,7 +395,6 @@ optional type_context::unfold_definition(expr const & e) { } } - optional type_context::reduce_projection(expr const & e) { expr const & f = get_app_fn(e); if (!is_constant(f)) @@ -536,10 +535,13 @@ expr type_context::whnf_core(expr const & e) { lean_unreachable(); } -expr type_context::whnf(expr const & e) { +template +expr type_context::whnf_loop(expr const & e, F const & pred) { expr t = e; while (true) { expr t1 = whnf_core(t); + if (!pred(t1)) + return t1; if (auto next_t = unfold_definition(t1)) { t = *next_t; } else { @@ -548,6 +550,14 @@ expr type_context::whnf(expr const & e) { } } +expr type_context::whnf(expr const & e) { + return whnf_loop(e, [](expr const &) { return true; }); +} + +expr type_context::whnf_pred(expr const & e, std::function const & pred) { + return whnf_loop(e, pred); +} + expr type_context::relaxed_whnf(expr const & e) { flet set(m_transparency_mode, transparency_mode::All); return whnf(e); diff --git a/src/library/type_context.h b/src/library/type_context.h index 4e33dc0c79..e58a784815 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -136,6 +136,7 @@ class type_context : public abstract_type_context { m_mctx(mctx), m_tmp_uassignment_sz(usz), m_tmp_eassignment_sz(esz), m_tmp_trail_sz(tsz) {} }; typedef buffer scopes; + template expr whnf_loop(expr const & e, F const & pred); metavar_context & m_mctx; local_context m_lctx; @@ -191,6 +192,11 @@ public: virtual void pop_local() override; virtual expr abstract_locals(expr const & e, unsigned num_locals, expr const * locals) override; + /** Similar to whnf, but invoked the given predicate before unfolding constants in the head. + If pred(e') is false, then the method will not unfold definition in the head of e', and will return e'. + This method is useful when we want to normalize the expression until we get a particular symbol as the head symbol. */ + expr whnf_pred(expr const & e, std::function const & pred); + /** Given a metavariable \c mvar, and local constants in \c locals, return (mvar' C) where C is a superset of \c locals and includes all local constants that depend on \c locals. \pre all local constants in \c locals are in metavariable context. */