fix(library/type_context): make sure aux_recursors are not unfolded in whnf_pred IF the given predicate returns false for them.

This commit is contained in:
Leonardo de Moura 2016-05-03 17:26:53 -07:00
parent 8c878e8196
commit fd83b711b6
2 changed files with 26 additions and 14 deletions

View file

@ -187,6 +187,7 @@ void type_context::init_core(transparency_mode m) {
/* default type class resolution mode */
m_cache->m_local_instances_initialized = false;
}
m_unfold_pred = nullptr;
}
type_context::type_context(metavar_context & mctx, local_context const & lctx, type_context_cache & cache,
@ -398,7 +399,7 @@ optional<expr> type_context::unfold_definition(expr const & e) {
if (auto f = unfold_definition_core(f0)) {
buffer<expr> args;
get_app_rev_args(e, args);
return some_expr(mk_rev_app(*f, args));
return some_expr(apply_beta(*f, args.size(), args.data()));
} else {
return none_expr();
}
@ -407,6 +408,12 @@ optional<expr> type_context::unfold_definition(expr const & e) {
}
}
optional<expr> type_context::try_unfold_definition(expr const & e) {
if (m_unfold_pred && !(*m_unfold_pred)(e))
return none_expr();
return unfold_definition(e);
}
optional<expr> type_context::reduce_projection(expr const & e) {
expr const & f = get_app_fn(e);
if (!is_constant(f))
@ -439,7 +446,7 @@ optional<expr> type_context::reduce_aux_recursor(expr const & e) {
if (!is_constant(f))
return none_expr();
if (is_aux_recursor(env(), const_name(f)))
return unfold_definition(e);
return try_unfold_definition(e);
else
return none_expr();
}
@ -530,7 +537,7 @@ expr type_context::whnf_core(expr const & e) {
return whnf_core(mk_rev_app(::lean::instantiate(binding_body(f), m, args.data() + (num_args - m)),
num_args - m, args.data()));
} else if (f == f0) {
if (auto r = env().norm_ext()(e, *this)) {
if (auto r = norm_ext(e)) {
/* mainly iota-reduction, it also applies HIT and quotient reduction rules */
return whnf_core(*r);
} else if (auto r = reduce_projection(e)) {
@ -547,13 +554,10 @@ expr type_context::whnf_core(expr const & e) {
lean_unreachable();
}
template<typename F>
expr type_context::whnf_loop(expr const & e, F const & pred) {
expr type_context::whnf(expr const & e) {
expr t = e;
while (true) {
expr t1 = whnf_core(t);
if (!pred(t1))
return t1;
if (auto next_t = unfold_definition(t1)) {
t = *next_t;
} else {
@ -562,12 +566,17 @@ expr type_context::whnf_loop(expr const & e, F const & pred) {
}
}
expr type_context::whnf(expr const & e) {
return whnf_loop(e, [](expr const &) { return true; });
}
expr type_context::whnf_pred(expr const & e, std::function<bool(expr const &)> const & pred) {
return whnf_loop(e, pred);
flet<std::function<bool(expr const &)> const *>set_unfold_pred(m_unfold_pred, &pred);
expr t = e;
while (true) {
expr t1 = whnf_core(t);
if (auto next_t = try_unfold_definition(t1)) {
t = *next_t;
} else {
return t1;
}
}
}
expr type_context::relaxed_whnf(expr const & e) {

View file

@ -136,7 +136,6 @@ class type_context : public abstract_type_context {
m_mctx(mctx), m_tmp_uassignment_sz(usz), m_tmp_eassignment_sz(esz), m_tmp_trail_sz(tsz) {}
};
typedef buffer<scope_data> scopes;
template<typename F> expr whnf_loop(expr const & e, F const & pred);
metavar_context & m_mctx;
local_context m_lctx;
@ -166,6 +165,8 @@ class type_context : public abstract_type_context {
/* Stack of backtracking point (aka scope) */
scopes m_scopes;
std::function<bool(expr const & e)> const * m_unfold_pred;
public:
type_context(metavar_context & mctx, local_context const & lctx, type_context_cache & cache,
transparency_mode m = transparency_mode::Reducible);
@ -198,7 +199,9 @@ public:
If pred(e') is false, then the method will not unfold definition in the head of e', and will return e'.
This method is useful when we want to normalize the expression until we get a particular symbol as the head symbol. */
expr whnf_pred(expr const & e, std::function<bool(expr const &)> const & pred);
optional<expr> reduce_aux_recursor(expr const & e);
optional<expr> reduce_projection(expr const & e);
optional<expr> norm_ext(expr const & e) { return env().norm_ext()(e, *this); }
/** Given a metavariable \c mvar, and local constants in \c locals, return (mvar' C) where
C is a superset of \c locals and includes all local constants that depend on \c locals.
@ -253,7 +256,7 @@ private:
void init_core(transparency_mode m);
optional<expr> unfold_definition_core(expr const & e);
optional<expr> unfold_definition(expr const & e);
optional<expr> reduce_aux_recursor(expr const & e);
optional<expr> try_unfold_definition(expr const & e);
bool should_unfold_macro(expr const & e);
optional<expr> expand_macro(expr const & e);
expr whnf_core(expr const & e);