From 970e11bf5e7613c07d697a6be38be4ccebfa011f Mon Sep 17 00:00:00 2001 From: Sebastian Ullrich Date: Tue, 7 Mar 2017 18:59:19 +0100 Subject: [PATCH] feat(frontends/lean/{elaborator,structure_cmd}): allow overriding field defaults --- src/frontends/lean/elaborator.cpp | 68 ++++++-- src/frontends/lean/structure_cmd.cpp | 245 +++++++++++++++------------ 2 files changed, 189 insertions(+), 124 deletions(-) diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index f0be1da1ff..e1c28c2520 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -2477,10 +2477,11 @@ expr elaborator::visit_structure_instance(expr const & e, optional const & get_intro_rule_names(m_env, S_name, c_names); lean_assert(c_names.size() == 1); expr c = copy_tag(e, mk_constant(c_names[0])); - expr c_type = m_env.get(c_names[0]).get_type(); - unsigned i = 0; name_map field2value; - while (is_pi(c_type)) { + + // first check for explicit, implicit and parent fields + expr c_type = m_env.get(c_names[0]).get_type(); + for (unsigned i = 0; is_pi(c_type); c_type = binding_body(c_type), i++) { if (i < nparams) { if (is_explicit(binding_info(c_type))) { throw elaborator_exception(e, sstream() << "invalid structure value {...}, structure parameter '" << @@ -2495,7 +2496,6 @@ expr elaborator::visit_structure_instance(expr const & e, optional const & if (S_fname == fnames[j]) { used[j] = true; field2value.insert(S_fname, fvalues[j]); - c = copy_tag(e, mk_app(c, fvalues[j])); break; } } @@ -2517,19 +2517,9 @@ expr elaborator::visit_structure_instance(expr const & e, optional const & } f = copy_tag(e, mk_as_is(f)); field2value.insert(S_fname, f); - c = copy_tag(e, mk_app(c, f)); } else { name full_S_fname = S_name + S_fname; - if (optional default_value_fn = has_default_value(m_env, full_S_fname)) { - expr value = mk_field_default_value(m_env, full_S_fname, [&](name const & fname) { - if (auto v = field2value.find(fname)) - return some_expr(*v); - else - return none_expr(); - }); - field2value.insert(S_fname, value); - c = copy_tag(e, mk_app(c, value)); - } else { + if (!has_default_value(m_env, full_S_fname)) { throw elaborator_exception(e, sstream() << "invalid structure value { ... }, field '" << S_fname << "' was not provided"); } @@ -2543,15 +2533,59 @@ expr elaborator::visit_structure_instance(expr const & e, optional const & } } } - c_type = binding_body(c_type); - i++; } + for (unsigned i = 0; i < fnames.size(); i++) { if (!used[i]) { throw elaborator_exception(e, sstream() << "invalid structure value { ... }, '" << fnames[i] << "'" << " is not a field of structure '" << S_name << "'"); } } + + // now repeatedly try to insert defaulted fields + bool last_progress = true; + bool done = false; + while (!done) { + done = true; + bool progress = false; + c_type = m_env.get(c_names[0]).get_type(); + for (unsigned i = 0; is_pi(c_type); c_type = binding_body(c_type), i++) { + if (is_explicit(binding_info(c_type)) && !src) { + name S_fname = binding_name(c_type); + if (!field2value.find(S_fname)) { + name full_S_fname = S_name + S_fname; + if (optional default_value_fn = has_default_value(m_env, full_S_fname)) { + try { + expr value = mk_field_default_value(m_env, full_S_fname, [&](name const &fname) { + if (auto v = field2value.find(fname)) + return some_expr(*v); + else + return none_expr(); + }); + field2value.insert(S_fname, value); + progress = true; + } catch (exception &) { + done = false; + if (!last_progress) + throw; + } + } + } + } + } + last_progress = progress; + } + + // finally apply fields to the constructor + c_type = m_env.get(c_names[0]).get_type(); + for (unsigned i = 0; is_pi(c_type); c_type = binding_body(c_type), i++) { + if (is_explicit(binding_info(c_type))) { + name S_fname = binding_name(c_type); + lean_assert(field2value.find(S_fname)); + c = copy_tag(e, mk_app(c, field2value[S_fname])); + } + } + return visit(c, expected_type); } diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index de29d0ecb0..8601a709f0 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -109,7 +109,12 @@ struct structure_cmd_fn { typedef std::vector> rename_vector; // field_map[i] contains the position of the \c i-th field of a parent structure into this one. typedef std::vector field_map; - typedef pair> field_decl; + struct field_decl { + expr local; // name, type, and pos as an expr::local + optional default_val; + bool from_parent; + bool explicit_type; + }; parser & m_p; decl_modifiers m_modifiers; @@ -333,19 +338,30 @@ struct structure_cmd_fn { return new_tmp; } - expr update_fields(expr new_tmp, buffer & decls) { - for (unsigned i = 0; i < decls.size(); i++) { - if (decls[i].second) { + expr update_default_values(expr new_tmp, buffer & decls) { + for (auto & decl : decls) { + if (decl.default_val && decl.explicit_type) { lean_assert(is_let(new_tmp)); - expr new_local = mk_local(mlocal_name(decls[i].first), let_name(new_tmp), let_type(new_tmp), binder_info()); - decls[i].first = new_local; - decls[i].second = let_value(new_tmp); - new_tmp = instantiate(let_body(new_tmp), new_local); + decl.default_val = let_value(new_tmp); + new_tmp = let_body(new_tmp); + } + } + return new_tmp; + } + + expr update_fields(expr new_tmp, buffer & decls) { + for (auto & decl : decls) { + if (decl.default_val && !decl.explicit_type) { + lean_assert(is_let(new_tmp)); + expr new_local = mk_local(mlocal_name(decl.local), let_name(new_tmp), let_type(new_tmp), {}); + decl.local = new_local; + decl.default_val = let_value(new_tmp); + new_tmp = instantiate(let_body(new_tmp), new_local); } else { lean_assert(is_pi(new_tmp)); - expr new_local = mk_local(mlocal_name(decls[i].first), binding_name(new_tmp), binding_domain(new_tmp), + expr new_local = mk_local(mlocal_name(decl.local), binding_name(new_tmp), binding_domain(new_tmp), binding_info(new_tmp)); - decls[i].first = new_local; + decl.local = new_local; new_tmp = instantiate(binding_body(new_tmp), new_local); } } @@ -422,29 +438,13 @@ struct structure_cmd_fn { /** \brief If \c fname matches the name of an existing field, then check if the types are definitionally equal (store any generated unification constraints in cseq), and return the index of the existing field. */ - optional merge(expr const & parent, name const & fname, expr const & ftype, optional const & fdefault) { + optional merge(expr const & parent, name const & fname, expr const & ftype) { for (unsigned i = 0; i < m_fields.size(); i++) { - if (local_pp_name(m_fields[i].first) == fname) { - if (m_ctx.is_def_eq(mlocal_type(m_fields[i].first), ftype)) { - if (!m_fields[i].second && fdefault) { - m_fields[i].second = fdefault; - } else if (m_fields[i].second && fdefault && !m_ctx.is_def_eq(*m_fields[i].second, *fdefault)) { - expr prev_default = *m_fields[i].second; - throw generic_exception(parent, [=](formatter const & fmt) { - format r = format("invalid 'structure' header, field '"); - r += format(fname); - r += format("' from '"); - r += format(const_name(get_app_fn(parent))); - r += format("' has already been declared with a different default value"); - r += pp_indent_expr(fmt, prev_default); - r += compose(line(), format("and")); - r += pp_indent_expr(fmt, *fdefault); - return r; - }); - } + if (local_pp_name(m_fields[i].local) == fname) { + if (m_ctx.is_def_eq(mlocal_type(m_fields[i].local), ftype)) { return optional(i); } else { - expr prev_ftype = mlocal_type(m_fields[i].first); + expr prev_ftype = mlocal_type(m_fields[i].local); throw generic_exception(parent, [=](formatter const & fmt) { format r = format("invalid 'structure' header, field '"); r += format(fname); @@ -465,8 +465,8 @@ struct structure_cmd_fn { expr mk_field_default_value(name const & full_field_name) { return ::lean::mk_field_default_value(m_env, full_field_name, [&](name const & fname) { for (field_decl const & d : m_fields) { - if (local_pp_name(d.first) == fname) - return some_expr(mk_explicit(d.first)); + if (local_pp_name(d.local) == fname) + return some_expr(mk_explicit(d.local)); } return none_expr(); }); @@ -499,19 +499,16 @@ struct structure_cmd_fn { throw_ill_formed_parent(parent_name); intro_type = instantiate(binding_body(intro_type), arg); } + size_t fmap_start = fmap.size(); while (is_pi(intro_type)) { name fname = binding_name(intro_type); fname = rename(renames, fname); expr const & ftype = binding_domain(intro_type); - optional fdefault; name full_fname = parent_name + fname; - if (optional fdefault_name = has_default_value(m_env, full_fname)) { - fdefault = mk_field_default_value(full_fname); - } expr field; - if (auto fidx = merge(parent, fname, ftype, fdefault)) { + if (auto fidx = merge(parent, fname, ftype)) { fmap.push_back(*fidx); - field = m_fields[*fidx].first; + field = m_fields[*fidx].local; if (local_info(field) != binding_info(intro_type)) { throw elaborator_exception(parent, sstream() << "invalid 'structure' header, field '" << fname << @@ -520,10 +517,36 @@ struct structure_cmd_fn { } else { field = mk_local(fname, ftype, binding_info(intro_type)); fmap.push_back(m_fields.size()); - m_fields.push_back(field_decl(field, fdefault)); + m_fields.push_back({field, none_expr(), /* from_parent */ true, /* explicit_type */ true}); } intro_type = instantiate(binding_body(intro_type), field); } + // construct and add default values now that all fields have been defined + for (size_t fmap_idx = fmap_start; fmap_idx < fmap.size(); fmap_idx++) { + field_decl & field = m_fields[fmap[fmap_start]]; + name fname = local_pp_name(field.local); + name full_fname = parent_name + fname; + if (optional fdefault_name = has_default_value(m_env, full_fname)) { + expr fdefault = mk_field_default_value(full_fname); + if (!field.default_val) { + field.default_val = fdefault; + } else if (field.default_val && !m_ctx.is_def_eq(*field.default_val, fdefault)) { + expr prev_default = *field.default_val; + throw generic_exception(parent, [=](formatter const &fmt) { + format r = format("invalid 'structure' header, field '"); + r += format(fname); + r += format("' from '"); + r += format(const_name(get_app_fn(parent))); + r += format("' has already been declared with a different default value"); + r += pp_indent_expr(fmt, prev_default); + r += compose(line(), format("and")); + r += pp_indent_expr(fmt, fdefault); + return r; + }); + } + } + fmap_start++; + } } lean_assert(m_parents.size() == m_field_maps.size()); } @@ -534,9 +557,9 @@ struct structure_cmd_fn { for (expr & param : m_params) param = m_ctx.instantiate_mvars(param); for (field_decl & decl : m_fields) { - decl.first = m_ctx.instantiate_mvars(decl.first); - if (decl.second) - decl.second = m_ctx.instantiate_mvars(*decl.second); + decl.local = m_ctx.instantiate_mvars(decl.local); + if (decl.default_val) + decl.default_val = m_ctx.instantiate_mvars(*decl.default_val); } } @@ -560,7 +583,7 @@ struct structure_cmd_fn { name const & parent_intro_name = inductive::intro_rule_name(std::get<2>(parent_info)); expr parent_intro = mk_app(mk_constant(parent_intro_name, parent_ls), parent_params); for (unsigned idx : fmap) { - expr const & field = m_fields[idx].first; + expr const & field = m_fields[idx].local; parent_intro = mk_app(parent_intro, field); } return parent_intro; @@ -575,29 +598,21 @@ struct structure_cmd_fn { for (expr const & param : m_params) m_p.add_local(param); for (field_decl const & decl : m_fields) - m_p.add_local(decl.first); + m_p.add_local(decl.local); for (unsigned i = 0; i < m_parents.size(); i++) { if (auto n = m_parent_refs[i]) m_p.add_local_expr(*n, mk_as_is(mk_parent_expr(i))); } } - /** \brief Check if new field names collide with fields inherited from parent datastructures */ - void check_new_field_names(buffer const & new_fields) { - for (field_decl const & new_field : new_fields) { - /* TODO(Leo): allow new field to set default value for parent field */ - if (std::find_if(m_fields.begin(), m_fields.end(), - [&](field_decl const & inherited_field) { - return local_pp_name(inherited_field.first) == local_pp_name(new_field.first); - }) != m_fields.end()) { - throw elaborator_exception(new_field.first, - sstream() << "field '" << local_pp_name(new_field.first) << - "' has been declared in parent structure"); - } - } + field_decl * get_field_by_name(name const & name) { + auto it = std::find_if(m_fields.begin(), m_fields.end(), [&](field_decl const & inherited_field) { + return local_pp_name(inherited_field.local) == name; + }); + return it != m_fields.end() ? it : nullptr; } - void parse_field_block(buffer & new_fields, binder_info const & bi) { + void parse_field_block(binder_info const & bi) { buffer> names; auto start_pos = m_p.pos(); while (m_p.curr_is_identifier()) { @@ -625,57 +640,76 @@ struct structure_cmd_fn { throw parser_error("invalid field, it is not explicit, but it has a default value", start_pos); } for (auto p : names) { - expr local = m_p.save_pos(mk_local(p.second, type, bi), p.first); - m_p.add_local(local); - new_fields.push_back(field_decl(local, default_value)); + if (auto old_field = get_field_by_name(p.second)) { + if (is_placeholder(type)) { + old_field->default_val = default_value; + } else { + sstream msg; + msg << "field '" << p.second; + if (old_field->from_parent) + msg << "' has been declared in parent structure"; + else + msg <<"' has already been declared"; + if (default_value) + msg << " (omit its type to set a new default value)"; + throw parser_error(msg, start_pos); + } + } else { + expr local = m_p.save_pos(mk_local(p.second, type, bi), p.first); + m_p.add_local(local); + m_fields.push_back({local, default_value, /* from_parent */ false, /* explicit_type */ !is_placeholder(type)}); + } } } /** \brief Parse new fields declared in this structure */ - void parse_new_fields(buffer & new_fields) { + void parse_new_fields() { parser::local_scope scope(m_p); add_locals(); while (!m_p.curr_is_command_like()) { if (m_p.curr_is_identifier()) { - parse_field_block(new_fields, binder_info()); + parse_field_block(binder_info()); } else { binder_info bi = m_p.parse_binder_info(); - parse_field_block(new_fields, bi); + parse_field_block(bi); m_p.parse_close_binder_info(bi); } } - check_new_field_names(new_fields); } - expr mk_field_binder(buffer const & decls, expr const & type, bool as_is = false) { + expr mk_field_binder(buffer const & decls, expr const & type, + bool typed_defaults_only = false) { expr r = type; unsigned i = decls.size(); while (i > 0) { --i; field_decl const & decl = decls[i]; - if (decl.second) { - expr type = mlocal_type(decl.first); - expr value = *decl.second; - if (as_is) { + if (decl.default_val && (typed_defaults_only == decl.explicit_type)) { + expr type = mlocal_type(decl.local); + expr value = *decl.default_val; + if (decl.from_parent) { type = mk_as_is(type); } - r = mk_let(local_pp_name(decl.first), type, value, abstract_local(r, decl.first)); - } else if (as_is) { - r = Pi_as_is(decl.first, r); - } else { - r = Pi(decl.first, r); + if (typed_defaults_only) + r = mk_let(mk_fresh_name(), type, value, r); + else + r = mk_let(local_pp_name(decl.local), type, value, abstract_local(r, decl.local)); + } else if (!typed_defaults_only) { + if (decl.from_parent) { + r = Pi_as_is(decl.local, r); + } else { + r = Pi(decl.local, r); + } } } return r; } - expr mk_field_binder_as_is(buffer const & decls, expr const & type) { - return mk_field_binder(decls, type, true); - } - - /** \brief Elaborate new fields, and copy them to m_fields */ - void elaborate_new_fields(buffer & new_fields) { - expr tmp = mk_field_binder_as_is(m_fields, mk_field_binder(new_fields, mk_Prop())); + /** \brief Elaborate new fields */ + void elaborate_new_fields() { + // start with typed default values so that they can depend on any field + expr tmp = mk_field_binder(m_fields, mk_Prop(), true); + tmp = mk_field_binder(m_fields, tmp, false); unsigned j = m_parents.size(); while (j > 0) { --j; @@ -692,28 +726,25 @@ struct structure_cmd_fn { new_tmp = update_locals(new_tmp, m_params); new_tmp = update_parents(new_tmp, true); new_tmp = update_fields(new_tmp, m_fields); - new_tmp = update_fields(new_tmp, new_fields); + new_tmp = update_default_values(new_tmp, m_fields); lean_assert(new_tmp == mk_Prop()); - m_fields.append(new_fields); } /** \brief Parse new fields declared by this structure, and elaborate them. */ void process_new_fields() { - buffer new_fields; - parse_new_fields(new_fields); - elaborate_new_fields(new_fields); + parse_new_fields(); + elaborate_new_fields(); } void process_empty_new_fields() { - buffer new_fields; - elaborate_new_fields(new_fields); + elaborate_new_fields(); } /** \brief Traverse fields and collect the universes they reside in \c r_lvls. This information is used to compute the resultant universe level for the inductive datatype declaration. */ void accumulate_levels(buffer & r_lvls) { for (field_decl const & decl : m_fields) { - level l = get_level(m_ctx, mlocal_type(decl.first)); + level l = get_level(m_ctx, mlocal_type(decl.local)); if (std::find(r_lvls.begin(), r_lvls.end(), l) == r_lvls.end()) { r_lvls.push_back(l); } @@ -734,9 +765,9 @@ struct structure_cmd_fn { /** \brief Display m_fields (for debugging purposes) */ void display_fields(std::ostream & out) { for (field_decl const & decl : m_fields) { - out << ">> " << mlocal_name(decl.first) << " : " << mlocal_type(decl.first); - if (decl.second) - out << " := " << *decl.second; + out << ">> " << mlocal_name(decl.local) << " : " << mlocal_type(decl.local); + if (decl.default_val) + out << " := " << *decl.default_val; out << "\n"; } } @@ -780,9 +811,9 @@ struct structure_cmd_fn { for (expr const & p : m_params) all_lvl_params = collect_univ_params(mlocal_type(p), all_lvl_params); for (field_decl const & f : m_fields) { - all_lvl_params = collect_univ_params(mlocal_type(f.first), all_lvl_params); - if (f.second) - all_lvl_params = collect_univ_params(*f.second, all_lvl_params); + all_lvl_params = collect_univ_params(mlocal_type(f.local), all_lvl_params); + if (f.default_val) + all_lvl_params = collect_univ_params(*f.default_val, all_lvl_params); } buffer section_lvls; all_lvl_params.for_each([&](name const & l) { @@ -811,7 +842,7 @@ struct structure_cmd_fn { levels ls = param_names_to_levels(to_list(m_level_names.begin(), m_level_names.end())); expr r = mk_app(mk_constant(m_name, ls), m_params); buffer field_wo_defaults; - for (field_decl const & decl : m_fields) field_wo_defaults.push_back(decl.first); + for (field_decl const & decl : m_fields) field_wo_defaults.push_back(decl.local); r = Pi(m_params, Pi(field_wo_defaults, r, m_p), m_p); return infer_implicit_params(r, m_params.size(), m_mk_infer); } @@ -820,7 +851,7 @@ struct structure_cmd_fn { levels ls = param_names_to_levels(to_list(m_level_names.begin(), m_level_names.end())); expr r = mk_app(mk_constant(m_name, ls), m_params); buffer field_wo_defaults; - for (field_decl const & decl : m_fields) field_wo_defaults.push_back(decl.first); + for (field_decl const & decl : m_fields) field_wo_defaults.push_back(decl.local); r = Pi(field_wo_defaults, r, m_p); return r; } @@ -871,14 +902,14 @@ struct structure_cmd_fn { void declare_projections() { m_env = mk_projections(m_env, m_name, m_mk_infer, m_attrs.has_class()); for (field_decl const & field : m_fields) { - name field_name = m_name + mlocal_name(field.first); + name field_name = m_name + mlocal_name(field.local); add_alias(field_name); } } bool is_field(expr const & local) { for (field_decl const & field : m_fields) { - if (mlocal_name(field.first) == mlocal_name(local)) + if (mlocal_name(field.local) == mlocal_name(local)) return true; } return false; @@ -886,10 +917,10 @@ struct structure_cmd_fn { void declare_defaults() { for (field_decl const & field : m_fields) { - if (field.second) { + if (field.default_val) { collected_locals used_locals; - collect_locals(mlocal_type(field.first), used_locals); - collect_locals(*field.second, used_locals); + collect_locals(mlocal_type(field.local), used_locals); + collect_locals(*field.default_val, used_locals); buffer args; /* Copy params first */ for (expr const & local : used_locals.get_collected()) { @@ -905,11 +936,11 @@ struct structure_cmd_fn { if (is_field(local)) args.push_back(local); } - name decl_name = name(m_name + local_pp_name(field.first), "_default"); + name decl_name = name(m_name + local_pp_name(field.local), "_default"); /* TODO(Leo): add helper function for adding definition. It should unfold_untrusted_macros */ - expr decl_type = unfold_untrusted_macros(m_env, Pi(args, mlocal_type(field.first))); - expr val = mk_app(m_ctx, get_id_name(), *field.second); + expr decl_type = unfold_untrusted_macros(m_env, Pi(args, mlocal_type(field.local))); + expr val = mk_app(m_ctx, get_id_name(), *field.default_val); expr decl_value = unfold_untrusted_macros(m_env, Fun(args, val)); name_set used_univs; used_univs = collect_univ_params(decl_value, used_univs); @@ -1007,7 +1038,7 @@ struct structure_cmd_fn { expr coercion_type = infer_implicit(Pi(m_params, Pi(st, parent, m_p), m_p), m_params.size(), true);; expr coercion_value = parent_intro; for (unsigned idx : fmap) { - expr const & field = m_fields[idx].first; + expr const & field = m_fields[idx].local; name proj_name = m_name + mlocal_name(field); expr proj = mk_app(mk_app(mk_constant(proj_name, st_ls), m_params), st); coercion_value = mk_app(coercion_value, proj);