feat(library/aux_definition): add closure_helper
This commit is contained in:
parent
7999200676
commit
7b683427da
2 changed files with 219 additions and 124 deletions
|
|
@ -16,127 +16,145 @@ Author: Leonardo de Moura
|
|||
#include "library/replace_visitor_with_tc.h"
|
||||
|
||||
namespace lean {
|
||||
struct mk_aux_definition_fn {
|
||||
type_context & m_ctx;
|
||||
name m_prefix;
|
||||
unsigned m_next_idx;
|
||||
name_set m_found_univ_params;
|
||||
name_map<level> m_univ_meta_to_param;
|
||||
name_map<level> m_univ_meta_to_param_inv;
|
||||
name_set m_found_local;
|
||||
name_map<expr> m_meta_to_param;
|
||||
name_map<expr> m_meta_to_param_inv;
|
||||
buffer<name> m_level_params;
|
||||
buffer<expr> m_params;
|
||||
|
||||
mk_aux_definition_fn(type_context & ctx):
|
||||
m_ctx(ctx),
|
||||
m_prefix("_aux_param"),
|
||||
m_next_idx(0) {}
|
||||
|
||||
level collect(level const & l) {
|
||||
return replace(l, [&](level const & l) {
|
||||
if (is_meta(l)) {
|
||||
name const & id = meta_id(l);
|
||||
if (auto r = m_univ_meta_to_param.find(id)) {
|
||||
return some_level(*r);
|
||||
} else {
|
||||
name n = m_prefix.append_after(m_next_idx);
|
||||
m_next_idx++;
|
||||
level new_r = mk_param_univ(n);
|
||||
m_univ_meta_to_param.insert(id, new_r);
|
||||
m_univ_meta_to_param_inv.insert(n, l);
|
||||
m_level_params.push_back(n);
|
||||
return some_level(new_r);
|
||||
}
|
||||
} else if (is_param(l)) {
|
||||
lean_assert(!is_placeholder(l));
|
||||
name const & id = param_id(l);
|
||||
if (!m_found_univ_params.contains(id)) {
|
||||
m_found_univ_params.insert(id);
|
||||
m_level_params.push_back(id);
|
||||
}
|
||||
level closure_helper::collect(level const & l) {
|
||||
lean_assert(!m_finalized_collection);
|
||||
return replace(l, [&](level const & l) {
|
||||
if (is_meta(l)) {
|
||||
name const & id = meta_id(l);
|
||||
if (auto r = m_univ_meta_to_param.find(id)) {
|
||||
return some_level(*r);
|
||||
} else {
|
||||
name n = m_prefix.append_after(m_next_idx);
|
||||
m_next_idx++;
|
||||
level new_r = mk_param_univ(n);
|
||||
m_univ_meta_to_param.insert(id, new_r);
|
||||
m_univ_meta_to_param_inv.insert(n, l);
|
||||
m_level_params.push_back(n);
|
||||
return some_level(new_r);
|
||||
}
|
||||
return none_level();
|
||||
});
|
||||
}
|
||||
} else if (is_param(l)) {
|
||||
lean_assert(!is_placeholder(l));
|
||||
name const & id = param_id(l);
|
||||
if (!m_found_univ_params.contains(id)) {
|
||||
m_found_univ_params.insert(id);
|
||||
m_level_params.push_back(id);
|
||||
}
|
||||
}
|
||||
return none_level();
|
||||
});
|
||||
}
|
||||
|
||||
levels collect(levels const & ls) {
|
||||
bool modified = false;
|
||||
buffer<level> r;
|
||||
for (level const & l : ls) {
|
||||
level new_l = collect(l);
|
||||
if (new_l != l) modified = true;
|
||||
r.push_back(new_l);
|
||||
}
|
||||
if (!modified)
|
||||
return ls;
|
||||
levels closure_helper::collect(levels const & ls) {
|
||||
lean_assert(!m_finalized_collection);
|
||||
bool modified = false;
|
||||
buffer<level> r;
|
||||
for (level const & l : ls) {
|
||||
level new_l = collect(l);
|
||||
if (new_l != l) modified = true;
|
||||
r.push_back(new_l);
|
||||
}
|
||||
if (!modified)
|
||||
return ls;
|
||||
else
|
||||
return to_list(r);
|
||||
}
|
||||
|
||||
expr closure_helper::collect(expr const & e) {
|
||||
lean_assert(!m_finalized_collection);
|
||||
return replace(e, [&](expr const & e, unsigned) {
|
||||
if (is_metavar(e)) {
|
||||
name const & id = mlocal_name(e);
|
||||
if (auto r = m_meta_to_param.find(id)) {
|
||||
return some_expr(*r);
|
||||
} else {
|
||||
expr type = m_ctx.infer(e);
|
||||
expr x = m_ctx.push_local("_x", type);
|
||||
m_meta_to_param.insert(id, x);
|
||||
m_meta_to_param_inv.insert(mlocal_name(x), e);
|
||||
m_params.push_back(x);
|
||||
return some_expr(x);
|
||||
}
|
||||
} else if (is_local(e)) {
|
||||
name const & id = mlocal_name(e);
|
||||
if (!m_found_local.contains(id)) {
|
||||
m_found_local.insert(id);
|
||||
m_params.push_back(e);
|
||||
}
|
||||
} else if (is_sort(e)) {
|
||||
return some_expr(update_sort(e, collect(sort_level(e))));
|
||||
} else if (is_constant(e)) {
|
||||
return some_expr(update_constant(e, collect(const_levels(e))));
|
||||
}
|
||||
return none_expr();
|
||||
});
|
||||
}
|
||||
|
||||
void closure_helper::finalize_collection() {
|
||||
lean_assert(!m_finalized_collection);
|
||||
name_map<expr> new_types;
|
||||
for (unsigned i = 0; i < m_params.size(); i++) {
|
||||
expr x = m_params[i];
|
||||
expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x))));
|
||||
new_types.insert(mlocal_name(x), new_type);
|
||||
}
|
||||
local_context const & lctx = m_ctx.lctx();
|
||||
std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) {
|
||||
return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx();
|
||||
});
|
||||
for (unsigned i = 0; i < m_params.size(); i++) {
|
||||
expr x = m_params[i];
|
||||
expr type = *new_types.find(mlocal_name(x));
|
||||
expr new_type = replace_locals(type, i, m_params.data(), m_norm_params.data());
|
||||
expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x));
|
||||
m_norm_params.push_back(new_param);
|
||||
}
|
||||
m_finalized_collection = true;
|
||||
}
|
||||
|
||||
expr closure_helper::mk_pi_closure(expr const & e) {
|
||||
lean_assert(m_finalized_collection);
|
||||
expr new_e = replace_locals(e, m_params, m_norm_params);
|
||||
return m_ctx.mk_pi(m_norm_params, new_e);
|
||||
}
|
||||
|
||||
expr closure_helper::mk_lambda_closure(expr const & e) {
|
||||
lean_assert(m_finalized_collection);
|
||||
expr new_e = replace_locals(e, m_params, m_norm_params);
|
||||
return m_ctx.mk_lambda(m_norm_params, new_e);
|
||||
}
|
||||
|
||||
void closure_helper::get_level_closure(buffer<level> & ls) {
|
||||
lean_assert(m_finalized_collection);
|
||||
for (name const & n : m_level_params) {
|
||||
if (level const * l = m_univ_meta_to_param_inv.find(n))
|
||||
ls.push_back(*l);
|
||||
else
|
||||
return to_list(r);
|
||||
ls.push_back(mk_param_univ(n));
|
||||
}
|
||||
}
|
||||
|
||||
expr collect(expr const & e) {
|
||||
return replace(e, [&](expr const & e, unsigned) {
|
||||
if (is_metavar(e)) {
|
||||
name const & id = mlocal_name(e);
|
||||
if (auto r = m_meta_to_param.find(id)) {
|
||||
return some_expr(*r);
|
||||
} else {
|
||||
expr type = m_ctx.infer(e);
|
||||
expr x = m_ctx.push_local("_x", type);
|
||||
m_meta_to_param.insert(id, x);
|
||||
m_meta_to_param_inv.insert(mlocal_name(x), e);
|
||||
m_params.push_back(x);
|
||||
return some_expr(x);
|
||||
}
|
||||
} else if (is_local(e)) {
|
||||
name const & id = mlocal_name(e);
|
||||
if (!m_found_local.contains(id)) {
|
||||
m_found_local.insert(id);
|
||||
m_params.push_back(e);
|
||||
}
|
||||
} else if (is_sort(e)) {
|
||||
return some_expr(update_sort(e, collect(sort_level(e))));
|
||||
} else if (is_constant(e)) {
|
||||
return some_expr(update_constant(e, collect(const_levels(e))));
|
||||
}
|
||||
return none_expr();
|
||||
});
|
||||
void closure_helper::get_expr_closure(buffer<expr> & ps) {
|
||||
lean_assert(m_finalized_collection);
|
||||
for (expr const & x : m_params) {
|
||||
if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x)))
|
||||
ps.push_back(*m);
|
||||
else
|
||||
ps.push_back(x);
|
||||
}
|
||||
}
|
||||
|
||||
/* Collect (and sort) dependencies of collected parameters */
|
||||
void collect_and_normalize_dependencies(buffer<expr> & norm_params) {
|
||||
name_map<expr> new_types;
|
||||
for (unsigned i = 0; i < m_params.size(); i++) {
|
||||
expr x = m_params[i];
|
||||
expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x))));
|
||||
new_types.insert(mlocal_name(x), new_type);
|
||||
}
|
||||
local_context const & lctx = m_ctx.lctx();
|
||||
std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) {
|
||||
return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx();
|
||||
});
|
||||
for (unsigned i = 0; i < m_params.size(); i++) {
|
||||
expr x = m_params[i];
|
||||
expr type = *new_types.find(mlocal_name(x));
|
||||
expr new_type = replace_locals(type, i, m_params.data(), norm_params.data());
|
||||
expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x));
|
||||
norm_params.push_back(new_param);
|
||||
}
|
||||
}
|
||||
struct mk_aux_definition_fn : public closure_helper {
|
||||
mk_aux_definition_fn(type_context & ctx):closure_helper(ctx) {}
|
||||
|
||||
pair<environment, expr> operator()(name const & c, expr const & type, expr const & value, bool is_lemma, optional<bool> const & is_meta) {
|
||||
lean_assert(!is_lemma || is_meta);
|
||||
lean_assert(!is_lemma || *is_meta == false);
|
||||
expr new_type = collect(m_ctx.instantiate_mvars(type));
|
||||
expr new_value = collect(m_ctx.instantiate_mvars(value));
|
||||
environment env = m_ctx.env();
|
||||
buffer<expr> norm_params;
|
||||
collect_and_normalize_dependencies(norm_params);
|
||||
new_type = replace_locals(new_type, m_params, norm_params);
|
||||
new_value = replace_locals(new_value, m_params, norm_params);
|
||||
expr def_type = m_ctx.mk_pi(norm_params, new_type);
|
||||
expr def_value = m_ctx.mk_lambda(norm_params, new_value);
|
||||
expr new_type = collect(ctx().instantiate_mvars(type));
|
||||
expr new_value = collect(ctx().instantiate_mvars(value));
|
||||
environment env = ctx().env();
|
||||
finalize_collection();
|
||||
expr def_type = mk_pi_closure(new_type);
|
||||
expr def_value = mk_lambda_closure(new_value);
|
||||
bool untrusted = false;
|
||||
if (is_meta)
|
||||
untrusted = *is_meta;
|
||||
|
|
@ -148,26 +166,16 @@ struct mk_aux_definition_fn {
|
|||
}
|
||||
declaration d;
|
||||
if (is_lemma) {
|
||||
d = mk_theorem(c, to_list(m_level_params), def_type, def_value);
|
||||
d = mk_theorem(c, get_norm_level_names(), def_type, def_value);
|
||||
} else {
|
||||
bool use_self_opt = true;
|
||||
d = mk_definition(env, c, to_list(m_level_params), def_type, def_value, use_self_opt, !untrusted);
|
||||
d = mk_definition(env, c, get_norm_level_names(), def_type, def_value, use_self_opt, !untrusted);
|
||||
}
|
||||
environment new_env = module::add(env, check(env, d, true));
|
||||
buffer<level> ls;
|
||||
for (name const & n : m_level_params) {
|
||||
if (level const * l = m_univ_meta_to_param_inv.find(n))
|
||||
ls.push_back(*l);
|
||||
else
|
||||
ls.push_back(mk_param_univ(n));
|
||||
}
|
||||
get_level_closure(ls);
|
||||
buffer<expr> ps;
|
||||
for (expr const & x : m_params) {
|
||||
if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x)))
|
||||
ps.push_back(*m);
|
||||
else
|
||||
ps.push_back(x);
|
||||
}
|
||||
get_expr_closure(ps);
|
||||
expr r = mk_app(mk_constant(c, to_list(ls)), ps);
|
||||
return mk_pair(new_env, r);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,6 +7,93 @@ Author: Leonardo de Moura
|
|||
#pragma once
|
||||
#include "library/type_context.h"
|
||||
namespace lean {
|
||||
/** Helper class for creating closures for nested terms.
|
||||
|
||||
There are two phases:
|
||||
1- Parameter and metavariable collection.
|
||||
2- Closure creation.
|
||||
|
||||
The methods \c collect are used in the first phase.
|
||||
The method \c finalize_collection moves object to the second phase.
|
||||
|
||||
The collection phase collects parameters and metavariables.
|
||||
A new parameter is created for each metavariable found in a subterm.
|
||||
|
||||
Parameter and metavariables occurring in the types of collected
|
||||
parameters and metavariables are also collected.
|
||||
|
||||
The method \c finalize_collection moves the state to phase 2.
|
||||
It also creates a new (normalized) parameter for each collected parameter and metavariable.
|
||||
The type of the new parameter is the type of the source after normalization.
|
||||
All new parameters are sorted based on dependencies.
|
||||
*/
|
||||
class closure_helper {
|
||||
type_context & m_ctx;
|
||||
name m_prefix;
|
||||
unsigned m_next_idx;
|
||||
name_set m_found_univ_params;
|
||||
name_map<level> m_univ_meta_to_param;
|
||||
name_map<level> m_univ_meta_to_param_inv;
|
||||
name_set m_found_local;
|
||||
name_map<expr> m_meta_to_param;
|
||||
name_map<expr> m_meta_to_param_inv;
|
||||
buffer<name> m_level_params;
|
||||
buffer<expr> m_params;
|
||||
bool m_finalized_collection{false};
|
||||
buffer<expr> m_norm_params;
|
||||
|
||||
public:
|
||||
closure_helper(type_context & ctx):
|
||||
m_ctx(ctx),
|
||||
m_prefix("_aux_param"),
|
||||
m_next_idx(0) {}
|
||||
|
||||
type_context & ctx() { return m_ctx; }
|
||||
|
||||
/* \pre finalize_collection has not been invoked */
|
||||
level collect(level const & l);
|
||||
/* \pre finalize_collection has not been invoked */
|
||||
levels collect(levels const & ls);
|
||||
/* \pre finalize_collection has not been invoked */
|
||||
expr collect(expr const & e);
|
||||
|
||||
/* \pre finalize_collection has not been invoked */
|
||||
void finalize_collection();
|
||||
|
||||
/* Replace parameters in \c e with corresponding normalized parameters, obtain e', and
|
||||
then return (Pi N, e') where N are the normalized parameters.
|
||||
|
||||
Remark \c e must not contain meta-variables. We can ensure this constraint by using the
|
||||
collect method
|
||||
|
||||
\pre finalize_collection has been invoked */
|
||||
expr mk_pi_closure(expr const & e);
|
||||
|
||||
/* Replace parameters in \c e with corresponding normalized parameters, obtain e', and
|
||||
then return (fun N, e') where N are the normalized parameters.
|
||||
|
||||
Remark \c e must not contain meta-variables. We can ensure this constraint by using the
|
||||
collect method
|
||||
|
||||
\pre finalize_collection has been invoked */
|
||||
expr mk_lambda_closure(expr const & e);
|
||||
|
||||
/* Return the level parameters and meta-variables collected by collect methods.
|
||||
|
||||
\pre finalize_collection has been invoked */
|
||||
void get_level_closure(buffer<level> & ls);
|
||||
/* Return the parameters and meta-variables collected by collect methods.
|
||||
|
||||
\pre finalize_collection has been invoked */
|
||||
void get_expr_closure(buffer<expr> & ps);
|
||||
|
||||
/* Return the name of normalized parameters. That is, it includes the collected
|
||||
level parameters and new parameters created for universe meta-variables.
|
||||
|
||||
\pre finalize_collection has been invoked */
|
||||
list<name> get_norm_level_names() const { return to_list(m_level_params); }
|
||||
};
|
||||
|
||||
/** \brief Create an auxiliary definition with name `c` where `type` and `value` may contain local constants and
|
||||
meta-variables. This function collects all dependencies (universe parameters, universe metavariables,
|
||||
local constants and metavariables). The result is the updated environment and an expression of the form
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue