From 7b683427da3d6f4ba24216709ce358fd8904cfc5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 27 Oct 2017 11:18:50 -0700 Subject: [PATCH] feat(library/aux_definition): add closure_helper --- src/library/aux_definition.cpp | 256 +++++++++++++++++---------------- src/library/aux_definition.h | 87 +++++++++++ 2 files changed, 219 insertions(+), 124 deletions(-) diff --git a/src/library/aux_definition.cpp b/src/library/aux_definition.cpp index a6fffe1d5f..cd0ffa1afc 100644 --- a/src/library/aux_definition.cpp +++ b/src/library/aux_definition.cpp @@ -16,127 +16,145 @@ Author: Leonardo de Moura #include "library/replace_visitor_with_tc.h" namespace lean { -struct mk_aux_definition_fn { - type_context & m_ctx; - name m_prefix; - unsigned m_next_idx; - name_set m_found_univ_params; - name_map m_univ_meta_to_param; - name_map m_univ_meta_to_param_inv; - name_set m_found_local; - name_map m_meta_to_param; - name_map m_meta_to_param_inv; - buffer m_level_params; - buffer m_params; - - mk_aux_definition_fn(type_context & ctx): - m_ctx(ctx), - m_prefix("_aux_param"), - m_next_idx(0) {} - - level collect(level const & l) { - return replace(l, [&](level const & l) { - if (is_meta(l)) { - name const & id = meta_id(l); - if (auto r = m_univ_meta_to_param.find(id)) { - return some_level(*r); - } else { - name n = m_prefix.append_after(m_next_idx); - m_next_idx++; - level new_r = mk_param_univ(n); - m_univ_meta_to_param.insert(id, new_r); - m_univ_meta_to_param_inv.insert(n, l); - m_level_params.push_back(n); - return some_level(new_r); - } - } else if (is_param(l)) { - lean_assert(!is_placeholder(l)); - name const & id = param_id(l); - if (!m_found_univ_params.contains(id)) { - m_found_univ_params.insert(id); - m_level_params.push_back(id); - } +level closure_helper::collect(level const & l) { + lean_assert(!m_finalized_collection); + return replace(l, [&](level const & l) { + if (is_meta(l)) { + name const & id = meta_id(l); + if (auto r = m_univ_meta_to_param.find(id)) { + return some_level(*r); + } else { + name n = m_prefix.append_after(m_next_idx); + m_next_idx++; + level new_r = mk_param_univ(n); + m_univ_meta_to_param.insert(id, new_r); + m_univ_meta_to_param_inv.insert(n, l); + m_level_params.push_back(n); + return some_level(new_r); } - return none_level(); - }); - } + } else if (is_param(l)) { + lean_assert(!is_placeholder(l)); + name const & id = param_id(l); + if (!m_found_univ_params.contains(id)) { + m_found_univ_params.insert(id); + m_level_params.push_back(id); + } + } + return none_level(); + }); +} - levels collect(levels const & ls) { - bool modified = false; - buffer r; - for (level const & l : ls) { - level new_l = collect(l); - if (new_l != l) modified = true; - r.push_back(new_l); - } - if (!modified) - return ls; +levels closure_helper::collect(levels const & ls) { + lean_assert(!m_finalized_collection); + bool modified = false; + buffer r; + for (level const & l : ls) { + level new_l = collect(l); + if (new_l != l) modified = true; + r.push_back(new_l); + } + if (!modified) + return ls; + else + return to_list(r); +} + +expr closure_helper::collect(expr const & e) { + lean_assert(!m_finalized_collection); + return replace(e, [&](expr const & e, unsigned) { + if (is_metavar(e)) { + name const & id = mlocal_name(e); + if (auto r = m_meta_to_param.find(id)) { + return some_expr(*r); + } else { + expr type = m_ctx.infer(e); + expr x = m_ctx.push_local("_x", type); + m_meta_to_param.insert(id, x); + m_meta_to_param_inv.insert(mlocal_name(x), e); + m_params.push_back(x); + return some_expr(x); + } + } else if (is_local(e)) { + name const & id = mlocal_name(e); + if (!m_found_local.contains(id)) { + m_found_local.insert(id); + m_params.push_back(e); + } + } else if (is_sort(e)) { + return some_expr(update_sort(e, collect(sort_level(e)))); + } else if (is_constant(e)) { + return some_expr(update_constant(e, collect(const_levels(e)))); + } + return none_expr(); + }); +} + +void closure_helper::finalize_collection() { + lean_assert(!m_finalized_collection); + name_map new_types; + for (unsigned i = 0; i < m_params.size(); i++) { + expr x = m_params[i]; + expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x)))); + new_types.insert(mlocal_name(x), new_type); + } + local_context const & lctx = m_ctx.lctx(); + std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) { + return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx(); + }); + for (unsigned i = 0; i < m_params.size(); i++) { + expr x = m_params[i]; + expr type = *new_types.find(mlocal_name(x)); + expr new_type = replace_locals(type, i, m_params.data(), m_norm_params.data()); + expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x)); + m_norm_params.push_back(new_param); + } + m_finalized_collection = true; +} + +expr closure_helper::mk_pi_closure(expr const & e) { + lean_assert(m_finalized_collection); + expr new_e = replace_locals(e, m_params, m_norm_params); + return m_ctx.mk_pi(m_norm_params, new_e); +} + +expr closure_helper::mk_lambda_closure(expr const & e) { + lean_assert(m_finalized_collection); + expr new_e = replace_locals(e, m_params, m_norm_params); + return m_ctx.mk_lambda(m_norm_params, new_e); +} + +void closure_helper::get_level_closure(buffer & ls) { + lean_assert(m_finalized_collection); + for (name const & n : m_level_params) { + if (level const * l = m_univ_meta_to_param_inv.find(n)) + ls.push_back(*l); else - return to_list(r); + ls.push_back(mk_param_univ(n)); } +} - expr collect(expr const & e) { - return replace(e, [&](expr const & e, unsigned) { - if (is_metavar(e)) { - name const & id = mlocal_name(e); - if (auto r = m_meta_to_param.find(id)) { - return some_expr(*r); - } else { - expr type = m_ctx.infer(e); - expr x = m_ctx.push_local("_x", type); - m_meta_to_param.insert(id, x); - m_meta_to_param_inv.insert(mlocal_name(x), e); - m_params.push_back(x); - return some_expr(x); - } - } else if (is_local(e)) { - name const & id = mlocal_name(e); - if (!m_found_local.contains(id)) { - m_found_local.insert(id); - m_params.push_back(e); - } - } else if (is_sort(e)) { - return some_expr(update_sort(e, collect(sort_level(e)))); - } else if (is_constant(e)) { - return some_expr(update_constant(e, collect(const_levels(e)))); - } - return none_expr(); - }); +void closure_helper::get_expr_closure(buffer & ps) { + lean_assert(m_finalized_collection); + for (expr const & x : m_params) { + if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x))) + ps.push_back(*m); + else + ps.push_back(x); } +} - /* Collect (and sort) dependencies of collected parameters */ - void collect_and_normalize_dependencies(buffer & norm_params) { - name_map new_types; - for (unsigned i = 0; i < m_params.size(); i++) { - expr x = m_params[i]; - expr new_type = collect(zeta_expand(m_ctx.lctx(), m_ctx.instantiate_mvars(m_ctx.infer(x)))); - new_types.insert(mlocal_name(x), new_type); - } - local_context const & lctx = m_ctx.lctx(); - std::sort(m_params.begin(), m_params.end(), [&](expr const & l1, expr const & l2) { - return lctx.get_local_decl(l1).get_idx() < lctx.get_local_decl(l2).get_idx(); - }); - for (unsigned i = 0; i < m_params.size(); i++) { - expr x = m_params[i]; - expr type = *new_types.find(mlocal_name(x)); - expr new_type = replace_locals(type, i, m_params.data(), norm_params.data()); - expr new_param = m_ctx.push_local(mlocal_pp_name(x), new_type, local_info(x)); - norm_params.push_back(new_param); - } - } +struct mk_aux_definition_fn : public closure_helper { + mk_aux_definition_fn(type_context & ctx):closure_helper(ctx) {} pair operator()(name const & c, expr const & type, expr const & value, bool is_lemma, optional const & is_meta) { lean_assert(!is_lemma || is_meta); lean_assert(!is_lemma || *is_meta == false); - expr new_type = collect(m_ctx.instantiate_mvars(type)); - expr new_value = collect(m_ctx.instantiate_mvars(value)); - environment env = m_ctx.env(); - buffer norm_params; - collect_and_normalize_dependencies(norm_params); - new_type = replace_locals(new_type, m_params, norm_params); - new_value = replace_locals(new_value, m_params, norm_params); - expr def_type = m_ctx.mk_pi(norm_params, new_type); - expr def_value = m_ctx.mk_lambda(norm_params, new_value); + expr new_type = collect(ctx().instantiate_mvars(type)); + expr new_value = collect(ctx().instantiate_mvars(value)); + environment env = ctx().env(); + finalize_collection(); + expr def_type = mk_pi_closure(new_type); + expr def_value = mk_lambda_closure(new_value); bool untrusted = false; if (is_meta) untrusted = *is_meta; @@ -148,26 +166,16 @@ struct mk_aux_definition_fn { } declaration d; if (is_lemma) { - d = mk_theorem(c, to_list(m_level_params), def_type, def_value); + d = mk_theorem(c, get_norm_level_names(), def_type, def_value); } else { bool use_self_opt = true; - d = mk_definition(env, c, to_list(m_level_params), def_type, def_value, use_self_opt, !untrusted); + d = mk_definition(env, c, get_norm_level_names(), def_type, def_value, use_self_opt, !untrusted); } environment new_env = module::add(env, check(env, d, true)); buffer ls; - for (name const & n : m_level_params) { - if (level const * l = m_univ_meta_to_param_inv.find(n)) - ls.push_back(*l); - else - ls.push_back(mk_param_univ(n)); - } + get_level_closure(ls); buffer ps; - for (expr const & x : m_params) { - if (expr const * m = m_meta_to_param_inv.find(mlocal_name(x))) - ps.push_back(*m); - else - ps.push_back(x); - } + get_expr_closure(ps); expr r = mk_app(mk_constant(c, to_list(ls)), ps); return mk_pair(new_env, r); } diff --git a/src/library/aux_definition.h b/src/library/aux_definition.h index a0785b29b5..a0a44edd08 100644 --- a/src/library/aux_definition.h +++ b/src/library/aux_definition.h @@ -7,6 +7,93 @@ Author: Leonardo de Moura #pragma once #include "library/type_context.h" namespace lean { +/** Helper class for creating closures for nested terms. + + There are two phases: + 1- Parameter and metavariable collection. + 2- Closure creation. + + The methods \c collect are used in the first phase. + The method \c finalize_collection moves object to the second phase. + + The collection phase collects parameters and metavariables. + A new parameter is created for each metavariable found in a subterm. + + Parameter and metavariables occurring in the types of collected + parameters and metavariables are also collected. + + The method \c finalize_collection moves the state to phase 2. + It also creates a new (normalized) parameter for each collected parameter and metavariable. + The type of the new parameter is the type of the source after normalization. + All new parameters are sorted based on dependencies. +*/ +class closure_helper { + type_context & m_ctx; + name m_prefix; + unsigned m_next_idx; + name_set m_found_univ_params; + name_map m_univ_meta_to_param; + name_map m_univ_meta_to_param_inv; + name_set m_found_local; + name_map m_meta_to_param; + name_map m_meta_to_param_inv; + buffer m_level_params; + buffer m_params; + bool m_finalized_collection{false}; + buffer m_norm_params; + +public: + closure_helper(type_context & ctx): + m_ctx(ctx), + m_prefix("_aux_param"), + m_next_idx(0) {} + + type_context & ctx() { return m_ctx; } + + /* \pre finalize_collection has not been invoked */ + level collect(level const & l); + /* \pre finalize_collection has not been invoked */ + levels collect(levels const & ls); + /* \pre finalize_collection has not been invoked */ + expr collect(expr const & e); + + /* \pre finalize_collection has not been invoked */ + void finalize_collection(); + + /* Replace parameters in \c e with corresponding normalized parameters, obtain e', and + then return (Pi N, e') where N are the normalized parameters. + + Remark \c e must not contain meta-variables. We can ensure this constraint by using the + collect method + + \pre finalize_collection has been invoked */ + expr mk_pi_closure(expr const & e); + + /* Replace parameters in \c e with corresponding normalized parameters, obtain e', and + then return (fun N, e') where N are the normalized parameters. + + Remark \c e must not contain meta-variables. We can ensure this constraint by using the + collect method + + \pre finalize_collection has been invoked */ + expr mk_lambda_closure(expr const & e); + + /* Return the level parameters and meta-variables collected by collect methods. + + \pre finalize_collection has been invoked */ + void get_level_closure(buffer & ls); + /* Return the parameters and meta-variables collected by collect methods. + + \pre finalize_collection has been invoked */ + void get_expr_closure(buffer & ps); + + /* Return the name of normalized parameters. That is, it includes the collected + level parameters and new parameters created for universe meta-variables. + + \pre finalize_collection has been invoked */ + list get_norm_level_names() const { return to_list(m_level_params); } +}; + /** \brief Create an auxiliary definition with name `c` where `type` and `value` may contain local constants and meta-variables. This function collects all dependencies (universe parameters, universe metavariables, local constants and metavariables). The result is the updated environment and an expression of the form