feat(library/aux_definition): add closure_helper

This commit is contained in:
Leonardo de Moura 2017-10-27 11:18:50 -07:00
parent 7999200676
commit 7b683427da
2 changed files with 219 additions and 124 deletions

View file

@ -16,127 +16,145 @@ Author: Leonardo de Moura
#include "library/replace_visitor_with_tc.h"
namespace lean {
struct mk_aux_definition_fn {
type_context & m_ctx;
name m_prefix;
unsigned m_next_idx;
name_set m_found_univ_params;
name_map<level> m_univ_meta_to_param;
name_map<level> m_univ_meta_to_param_inv;
name_set m_found_local;
name_map<expr> m_meta_to_param;
name_map<expr> m_meta_to_param_inv;
buffer<name> m_level_params;
buffer<expr> m_params;
mk_aux_definition_fn(type_context & ctx):
m_ctx(ctx),
m_prefix("_aux_param"),
m_next_idx(0) {}
level collect(level const & l) {
return replace(l, [&](level const & l) {
if (is_meta(l)) {
name const & id = meta_id(l);
if (auto r = m_univ_meta_to_param.find(id)) {
return some_level(*r);
} else {
name n = m_prefix.append_after(m_next_idx);
m_next_idx++;
level new_r = mk_param_univ(n);
m_univ_meta_to_param.insert(id, new_r);
m_univ_meta_to_param_inv.insert(n, l);
m_level_params.push_back(n);
return some_level(new_r);
}
} else if (is_param(l)) {
lean_assert(!is_placeholder(l));
name const & id = param_id(l);
if (!m_found_univ_params.contains(id)) {
m_found_univ_params.insert(id);
m_level_params.push_back(id);
}
level closure_helper::collect(level const & l) {
lean_assert(!m_finalized_collection);
return replace(l, [&](level const & l) {
if (is_meta(l)) {
name const & id = meta_id(l);
if (auto r = m_univ_meta_to_param.find(id)) {
return some_level(*r);
} else {
name n = m_prefix.append_after(m_next_idx);
m_next_idx++;
level new_r = mk_param_univ(n);
m_univ_meta_to_param.insert(id, new_r);
m_univ_meta_to_param_inv.insert(n, l);
m_level_params.push_back(n);
return some_level(new_r);
}
return none_level();
});
}
} else if (is_param(l)) {
lean_assert(!is_placeholder(l));
name const & id = param_id(l);
if (!m_found_univ_params.contains(id)) {
m_found_univ_params.insert(id);
m_level_params.push_back(id);
}
}
return none_level();
});
}
levels collect(levels const & ls) {
bool modified = false;
buffer<level> r;
for (level const & l : ls) {
level new_l = collect(l);
if (new_l != l) modified = true;
r.push_back(new_l);
}
if (!modified)
return ls;
levels closure_helper::collect(levels const & ls) {
lean_assert(!m_finalized_collection);
bool modified = false;
buffer<level> r;
for (level const & l : ls) {
level new_l = collect(l);
if (new_l != l) modified = true;
r.push_back(new_l);
}
if (!modified)
return ls;
else
return to_list(r);
}
expr closure_helper::collect(expr const & e) {
lean_assert(!m_finalized_collection);
return replace(e, [&](expr const & e, unsigned) {
if (is_metavar(e)) {
name const & id = mlocal_name(e);
if (auto r = m_meta_to_param.find(id)) {
return some_expr(*r);
} else {
expr type = m_ctx.infer(e);
expr x = m_ctx.push_local("_x", type);
m_meta_to_param.insert(id, x);
m_meta_to_param_inv.insert(mlocal_name(x), e);
m_params.push_back(x);
return some_expr(x);
}
} else if (is_local(e)) {
name const & id = mlocal_name(e);
if (!m_found_local.contains(id)) {
m_found_local.insert(id);
m_params.push_back(e);
}
} else if (is_sort(e)) {
return some_expr(update_sort(e, collect(sort_level(e))));
} else if (is_constant(e)) {
return some_expr(update_constant(e, collect(const_levels(e))));
}
return none_expr();
});
}
void closure_helper::finalize_collection() {
lean_assert(!m_finalized_collection);
name_map<expr> new_types;
for (unsigned i = 0; i < m_params.size(); i++) {
expr x = m_params[i];
expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x))));
new_types.insert(mlocal_name(x), new_type);
}
local_context const & lctx = m_ctx.lctx();
std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) {
return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx();
});
for (unsigned i = 0; i < m_params.size(); i++) {
expr x = m_params[i];
expr type = *new_types.find(mlocal_name(x));
expr new_type = replace_locals(type, i, m_params.data(), m_norm_params.data());
expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x));
m_norm_params.push_back(new_param);
}
m_finalized_collection = true;
}
expr closure_helper::mk_pi_closure(expr const & e) {
lean_assert(m_finalized_collection);
expr new_e = replace_locals(e, m_params, m_norm_params);
return m_ctx.mk_pi(m_norm_params, new_e);
}
expr closure_helper::mk_lambda_closure(expr const & e) {
lean_assert(m_finalized_collection);
expr new_e = replace_locals(e, m_params, m_norm_params);
return m_ctx.mk_lambda(m_norm_params, new_e);
}
void closure_helper::get_level_closure(buffer<level> & ls) {
lean_assert(m_finalized_collection);
for (name const & n : m_level_params) {
if (level const * l = m_univ_meta_to_param_inv.find(n))
ls.push_back(*l);
else
return to_list(r);
ls.push_back(mk_param_univ(n));
}
}
expr collect(expr const & e) {
return replace(e, [&](expr const & e, unsigned) {
if (is_metavar(e)) {
name const & id = mlocal_name(e);
if (auto r = m_meta_to_param.find(id)) {
return some_expr(*r);
} else {
expr type = m_ctx.infer(e);
expr x = m_ctx.push_local("_x", type);
m_meta_to_param.insert(id, x);
m_meta_to_param_inv.insert(mlocal_name(x), e);
m_params.push_back(x);
return some_expr(x);
}
} else if (is_local(e)) {
name const & id = mlocal_name(e);
if (!m_found_local.contains(id)) {
m_found_local.insert(id);
m_params.push_back(e);
}
} else if (is_sort(e)) {
return some_expr(update_sort(e, collect(sort_level(e))));
} else if (is_constant(e)) {
return some_expr(update_constant(e, collect(const_levels(e))));
}
return none_expr();
});
void closure_helper::get_expr_closure(buffer<expr> & ps) {
lean_assert(m_finalized_collection);
for (expr const & x : m_params) {
if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x)))
ps.push_back(*m);
else
ps.push_back(x);
}
}
/* Collect (and sort) dependencies of collected parameters */
void collect_and_normalize_dependencies(buffer<expr> & norm_params) {
name_map<expr> new_types;
for (unsigned i = 0; i < m_params.size(); i++) {
expr x = m_params[i];
expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x))));
new_types.insert(mlocal_name(x), new_type);
}
local_context const & lctx = m_ctx.lctx();
std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) {
return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx();
});
for (unsigned i = 0; i < m_params.size(); i++) {
expr x = m_params[i];
expr type = *new_types.find(mlocal_name(x));
expr new_type = replace_locals(type, i, m_params.data(), norm_params.data());
expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x));
norm_params.push_back(new_param);
}
}
struct mk_aux_definition_fn : public closure_helper {
mk_aux_definition_fn(type_context & ctx):closure_helper(ctx) {}
pair<environment, expr> operator()(name const & c, expr const & type, expr const & value, bool is_lemma, optional<bool> const & is_meta) {
lean_assert(!is_lemma || is_meta);
lean_assert(!is_lemma || *is_meta == false);
expr new_type = collect(m_ctx.instantiate_mvars(type));
expr new_value = collect(m_ctx.instantiate_mvars(value));
environment env = m_ctx.env();
buffer<expr> norm_params;
collect_and_normalize_dependencies(norm_params);
new_type = replace_locals(new_type, m_params, norm_params);
new_value = replace_locals(new_value, m_params, norm_params);
expr def_type = m_ctx.mk_pi(norm_params, new_type);
expr def_value = m_ctx.mk_lambda(norm_params, new_value);
expr new_type = collect(ctx().instantiate_mvars(type));
expr new_value = collect(ctx().instantiate_mvars(value));
environment env = ctx().env();
finalize_collection();
expr def_type = mk_pi_closure(new_type);
expr def_value = mk_lambda_closure(new_value);
bool untrusted = false;
if (is_meta)
untrusted = *is_meta;
@ -148,26 +166,16 @@ struct mk_aux_definition_fn {
}
declaration d;
if (is_lemma) {
d = mk_theorem(c, to_list(m_level_params), def_type, def_value);
d = mk_theorem(c, get_norm_level_names(), def_type, def_value);
} else {
bool use_self_opt = true;
d = mk_definition(env, c, to_list(m_level_params), def_type, def_value, use_self_opt, !untrusted);
d = mk_definition(env, c, get_norm_level_names(), def_type, def_value, use_self_opt, !untrusted);
}
environment new_env = module::add(env, check(env, d, true));
buffer<level> ls;
for (name const & n : m_level_params) {
if (level const * l = m_univ_meta_to_param_inv.find(n))
ls.push_back(*l);
else
ls.push_back(mk_param_univ(n));
}
get_level_closure(ls);
buffer<expr> ps;
for (expr const & x : m_params) {
if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x)))
ps.push_back(*m);
else
ps.push_back(x);
}
get_expr_closure(ps);
expr r = mk_app(mk_constant(c, to_list(ls)), ps);
return mk_pair(new_env, r);
}

View file

@ -7,6 +7,93 @@ Author: Leonardo de Moura
#pragma once
#include "library/type_context.h"
namespace lean {
/** Helper class for creating closures for nested terms.
There are two phases:
1- Parameter and metavariable collection.
2- Closure creation.
The methods \c collect are used in the first phase.
The method \c finalize_collection moves object to the second phase.
The collection phase collects parameters and metavariables.
A new parameter is created for each metavariable found in a subterm.
Parameter and metavariables occurring in the types of collected
parameters and metavariables are also collected.
The method \c finalize_collection moves the state to phase 2.
It also creates a new (normalized) parameter for each collected parameter and metavariable.
The type of the new parameter is the type of the source after normalization.
All new parameters are sorted based on dependencies.
*/
class closure_helper {
type_context & m_ctx;
name m_prefix;
unsigned m_next_idx;
name_set m_found_univ_params;
name_map<level> m_univ_meta_to_param;
name_map<level> m_univ_meta_to_param_inv;
name_set m_found_local;
name_map<expr> m_meta_to_param;
name_map<expr> m_meta_to_param_inv;
buffer<name> m_level_params;
buffer<expr> m_params;
bool m_finalized_collection{false};
buffer<expr> m_norm_params;
public:
closure_helper(type_context & ctx):
m_ctx(ctx),
m_prefix("_aux_param"),
m_next_idx(0) {}
type_context & ctx() { return m_ctx; }
/* \pre finalize_collection has not been invoked */
level collect(level const & l);
/* \pre finalize_collection has not been invoked */
levels collect(levels const & ls);
/* \pre finalize_collection has not been invoked */
expr collect(expr const & e);
/* \pre finalize_collection has not been invoked */
void finalize_collection();
/* Replace parameters in \c e with corresponding normalized parameters, obtain e', and
then return (Pi N, e') where N are the normalized parameters.
Remark \c e must not contain meta-variables. We can ensure this constraint by using the
collect method
\pre finalize_collection has been invoked */
expr mk_pi_closure(expr const & e);
/* Replace parameters in \c e with corresponding normalized parameters, obtain e', and
then return (fun N, e') where N are the normalized parameters.
Remark \c e must not contain meta-variables. We can ensure this constraint by using the
collect method
\pre finalize_collection has been invoked */
expr mk_lambda_closure(expr const & e);
/* Return the level parameters and meta-variables collected by collect methods.
\pre finalize_collection has been invoked */
void get_level_closure(buffer<level> & ls);
/* Return the parameters and meta-variables collected by collect methods.
\pre finalize_collection has been invoked */
void get_expr_closure(buffer<expr> & ps);
/* Return the name of normalized parameters. That is, it includes the collected
level parameters and new parameters created for universe meta-variables.
\pre finalize_collection has been invoked */
list<name> get_norm_level_names() const { return to_list(m_level_params); }
};
/** \brief Create an auxiliary definition with name `c` where `type` and `value` may contain local constants and
meta-variables. This function collects all dependencies (universe parameters, universe metavariables,
local constants and metavariables). The result is the updated environment and an expression of the form