diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index ac590c2ff6..f0b19f070d 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -623,21 +623,7 @@ struct structure_cmd_fn { field_map & fmap = m_field_maps.back(); buffer args; expr const & parent_fn = get_app_args(parent, args); - level_param_names lparams; unsigned nparams; inductive::intro_rule intro; name const & parent_name = const_name(parent_fn); - std::tie(lparams, nparams, intro) = get_parent_info(parent_name); - expr intro_type = inductive::intro_rule_type(intro); - intro_type = instantiate_univ_params(intro_type, lparams, const_levels(parent_fn)); - if (nparams != args.size()) { - throw elaborator_exception(parent, - sstream() << "invalid 'structure' header, number of argument " - "mismatch for parent structure '" << parent_name << "'"); - } - for (expr const & arg : args) { - if (!is_pi(intro_type)) - throw_ill_formed_parent(parent_name); - intro_type = instantiate(binding_body(intro_type), arg); - } if (m_subobjects) { name fname; if (auto const & ref = m_parent_refs[i]) @@ -666,8 +652,9 @@ struct structure_cmd_fn { // by projecting our new subobject field and then obtain `A` as // `(fun {ps : Ps} (x : base_S_name ps), A) x`. expr base_obj = *mk_base_projections(m_env, parent_name, base_S_name, field); - auto nparams = std::get<1>(get_structure_info(m_env, base_S_name)); - expr type = m_p.env().get(full_fname).get_type(); + level_param_names lparams; unsigned nparams; inductive::intro_rule intro; + std::tie(lparams, nparams, intro) = get_parent_info(base_S_name); + expr type = instantiate_univ_params(m_p.env().get(full_fname).get_type(), lparams, const_levels(parent_fn)); std::function pi_to_lam = [&](expr const & e, unsigned i) { if (i == nparams + 1) return mk_as_is(e); @@ -682,6 +669,21 @@ struct structure_cmd_fn { subfields.emplace_back(subfield, some_expr(proj), field_kind::from_parent); } } else { + level_param_names lparams; unsigned nparams; inductive::intro_rule intro; + std::tie(lparams, nparams, intro) = get_parent_info(parent_name); + expr intro_type = inductive::intro_rule_type(intro); + intro_type = instantiate_univ_params(intro_type, lparams, const_levels(parent_fn)); + if (nparams != args.size()) { + throw elaborator_exception(parent, + sstream() << "invalid 'structure' header, number of argument " + "mismatch for parent structure '" << parent_name << "'"); + } + for (expr const & arg : args) { + if (!is_pi(intro_type)) + throw_ill_formed_parent(parent_name); + intro_type = instantiate(binding_body(intro_type), arg); + } + size_t fmap_start = fmap.size(); while (is_pi(intro_type)) { name fname = binding_name(intro_type); diff --git a/tests/lean/run/struct_extend_univ.lean b/tests/lean/run/struct_extend_univ.lean new file mode 100644 index 0000000000..36a86aa7e1 --- /dev/null +++ b/tests/lean/run/struct_extend_univ.lean @@ -0,0 +1,3 @@ +universe u +class foo (α : Sort u) := (a : α) +class bar (α : Type u) extends foo α := (b : α)