diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index bf129ad90d..18fd9fbc38 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -187,6 +187,7 @@ void type_context::init_core(transparency_mode m) { /* default type class resolution mode */ m_cache->m_local_instances_initialized = false; } + m_unfold_pred = nullptr; } type_context::type_context(metavar_context & mctx, local_context const & lctx, type_context_cache & cache, @@ -398,7 +399,7 @@ optional type_context::unfold_definition(expr const & e) { if (auto f = unfold_definition_core(f0)) { buffer args; get_app_rev_args(e, args); - return some_expr(mk_rev_app(*f, args)); + return some_expr(apply_beta(*f, args.size(), args.data())); } else { return none_expr(); } @@ -407,6 +408,12 @@ optional type_context::unfold_definition(expr const & e) { } } +optional type_context::try_unfold_definition(expr const & e) { + if (m_unfold_pred && !(*m_unfold_pred)(e)) + return none_expr(); + return unfold_definition(e); +} + optional type_context::reduce_projection(expr const & e) { expr const & f = get_app_fn(e); if (!is_constant(f)) @@ -439,7 +446,7 @@ optional type_context::reduce_aux_recursor(expr const & e) { if (!is_constant(f)) return none_expr(); if (is_aux_recursor(env(), const_name(f))) - return unfold_definition(e); + return try_unfold_definition(e); else return none_expr(); } @@ -530,7 +537,7 @@ expr type_context::whnf_core(expr const & e) { return whnf_core(mk_rev_app(::lean::instantiate(binding_body(f), m, args.data() + (num_args - m)), num_args - m, args.data())); } else if (f == f0) { - if (auto r = env().norm_ext()(e, *this)) { + if (auto r = norm_ext(e)) { /* mainly iota-reduction, it also applies HIT and quotient reduction rules */ return whnf_core(*r); } else if (auto r = reduce_projection(e)) { @@ -547,13 +554,10 @@ expr type_context::whnf_core(expr const & e) { lean_unreachable(); } -template -expr type_context::whnf_loop(expr const & e, F const & pred) { +expr type_context::whnf(expr const & e) { expr t = e; while (true) { expr t1 = whnf_core(t); - if (!pred(t1)) - return t1; if (auto next_t = unfold_definition(t1)) { t = *next_t; } else { @@ -562,12 +566,17 @@ expr type_context::whnf_loop(expr const & e, F const & pred) { } } -expr type_context::whnf(expr const & e) { - return whnf_loop(e, [](expr const &) { return true; }); -} - expr type_context::whnf_pred(expr const & e, std::function const & pred) { - return whnf_loop(e, pred); + flet const *>set_unfold_pred(m_unfold_pred, &pred); + expr t = e; + while (true) { + expr t1 = whnf_core(t); + if (auto next_t = try_unfold_definition(t1)) { + t = *next_t; + } else { + return t1; + } + } } expr type_context::relaxed_whnf(expr const & e) { diff --git a/src/library/type_context.h b/src/library/type_context.h index df176d0133..8cd9a58307 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -136,7 +136,6 @@ class type_context : public abstract_type_context { m_mctx(mctx), m_tmp_uassignment_sz(usz), m_tmp_eassignment_sz(esz), m_tmp_trail_sz(tsz) {} }; typedef buffer scopes; - template expr whnf_loop(expr const & e, F const & pred); metavar_context & m_mctx; local_context m_lctx; @@ -166,6 +165,8 @@ class type_context : public abstract_type_context { /* Stack of backtracking point (aka scope) */ scopes m_scopes; + std::function const * m_unfold_pred; + public: type_context(metavar_context & mctx, local_context const & lctx, type_context_cache & cache, transparency_mode m = transparency_mode::Reducible); @@ -198,7 +199,9 @@ public: If pred(e') is false, then the method will not unfold definition in the head of e', and will return e'. This method is useful when we want to normalize the expression until we get a particular symbol as the head symbol. */ expr whnf_pred(expr const & e, std::function const & pred); + optional reduce_aux_recursor(expr const & e); optional reduce_projection(expr const & e); + optional norm_ext(expr const & e) { return env().norm_ext()(e, *this); } /** Given a metavariable \c mvar, and local constants in \c locals, return (mvar' C) where C is a superset of \c locals and includes all local constants that depend on \c locals. @@ -253,7 +256,7 @@ private: void init_core(transparency_mode m); optional unfold_definition_core(expr const & e); optional unfold_definition(expr const & e); - optional reduce_aux_recursor(expr const & e); + optional try_unfold_definition(expr const & e); bool should_unfold_macro(expr const & e); optional expand_macro(expr const & e); expr whnf_core(expr const & e);