From 193ce3541970bab264184f35a5da7ec135d3df70 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sat, 28 Jun 2014 15:33:56 -0700 Subject: [PATCH] refactor(frontends/lean/inductive_cmd): redesign inductive datatype elaboration, use the new elaborator, and use simpler algorithm to infer the resulting universe Signed-off-by: Leonardo de Moura --- library/standard/logic.lean | 2 +- src/frontends/lean/inductive_cmd.cpp | 188 ++++++++++++++++----------- src/frontends/lean/parser.cpp | 4 + src/frontends/lean/parser.h | 1 + src/kernel/type_checker.h | 2 + src/library/unifier.cpp | 6 +- tests/lean/run/ind3.lean | 4 +- tests/lean/run/ind6.lean | 2 +- 8 files changed, 122 insertions(+), 87 deletions(-) diff --git a/library/standard/logic.lean b/library/standard/logic.lean index 977e8c1262..526a107281 100644 --- a/library/standard/logic.lean +++ b/library/standard/logic.lean @@ -54,7 +54,7 @@ theorem or_elim (a b c : Bool) (H1 : a ∨ b) (H2 : a → c) (H3 : b → c) : c := or_rec H2 H3 H1 inductive eq {A : Type} (a : A) : A → Bool := -| refl : eq A a a -- TODO: use elaborator in inductive_cmd module, we should not need to type A here +| refl : eq a a infix `=` 50 := eq diff --git a/src/frontends/lean/inductive_cmd.cpp b/src/frontends/lean/inductive_cmd.cpp index dd9491ee0a..4ba8b60ac4 100644 --- a/src/frontends/lean/inductive_cmd.cpp +++ b/src/frontends/lean/inductive_cmd.cpp @@ -70,13 +70,17 @@ static level mk_result_level(bool impredicative, buffer const & ls) { level r = ls[0]; for (unsigned i = 1; i < ls.size(); i++) r = mk_max(r, ls[i]); - return impredicative ? mk_max(r, mk_level_one()) : r; + if (is_not_zero(r)) + return r; + else + return impredicative ? mk_max(r, mk_level_one()) : r; } } -static expr update_result_sort(expr const & t, level const & l) { +static expr update_result_sort(type_checker & tc, expr t, level const & l) { + t = tc.whnf(t); if (is_pi(t)) { - return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l)); + return update_binding(t, binding_domain(t), update_result_sort(tc, binding_body(t), l)); } else if (is_sort(t)) { return update_sort(t, l); } else { @@ -84,75 +88,110 @@ static expr update_result_sort(expr const & t, level const & l) { } } +/** \brief Return the universe level of the given inductive datatype declaration. */ +level get_datatype_result_level(type_checker & tc, inductive_decl const & d) { + expr d_t = tc.whnf(inductive_decl_type(d)); + while (is_pi(d_t)) { + d_t = tc.whnf(binding_body(d_t)); + } + if (!is_sort(d_t)) { + std::cout << "ERROR: " << inductive_decl_type(d) << "\n"; + throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', " + "resultant type is not a sort"); + } + return sort_level(d_t); +} + +/** \brief Return true if \c u occurs in \c l */ +bool occurs(level const & u, level const & l) { + bool found = false; + for_each(l, [&](level const & l) { + if (found) return false; + if (l == u) { found = true; return false; } + return true; + }); + return found; +} + static name g_tmp_prefix = name::mk_internal_unique_name(); -static void set_result_universes(buffer & decls, level_param_names const & lvls, unsigned num_params, parser & p) { - if (std::all_of(decls.begin(), decls.end(), [](inductive_decl const & d) { - return !has_placeholder(inductive_decl_type(d)); - })) - return; // nothing to be done - // We can't infer the type of intro rule arguments because we did declare the inductive datatypes. - // So, we use the following trick, we create a "draft" environment where the inductive datatypes - // are asserted as variable declarations, and keep doing that until we reach a "fix" point. - unsigned num_rounds = 0; - while (true) { - if (num_rounds > 2*decls.size() + 1) { - // TODO(Leo): this is test is a hack to avoid non-termination. - // We should use a better termination condition - throw exception("failed to compute resultant universe level for inductive datatypes, " - "provide explicit universe levels"); - } - num_rounds++; - bool progress = false; - environment env = p.env(); - bool impredicative = env.impredicative(); - // first assert inductive types that do not have placeholders - for (auto const & d : decls) { - expr type = inductive_decl_type(d); - if (!has_placeholder(type)) - env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, inductive_decl_type(d)))); - } - type_checker tc(env); - name_generator ngen(g_tmp_prefix); - // try to update resultant universe levels - for (auto & d : decls) { - expr d_t = inductive_decl_type(d); - while (is_pi(d_t)) { - d_t = binding_body(d_t); - } - if (!is_sort(d_t)) - throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', " - "resultant type is not a sort"); - level r_lvl = sort_level(d_t); - if (impredicative && is_zero(r_lvl)) - continue; - buffer lvls; - for (intro_rule const & ir : inductive_decl_intros(d)) { - expr t = intro_rule_type(ir); - unsigned i = 0; - while (is_pi(t)) { - if (i >= num_params) { - try { - expr s = tc.ensure_type(binding_domain(t)); - level lvl = sort_level(s); - if (std::find(lvls.begin(), lvls.end(), lvl) == lvls.end()) - lvls.push_back(lvl); - } catch (...) { - } - } - t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t))); - i++; - } - } - level m_lvl = normalize(mk_result_level(impredicative, lvls)); - if (is_placeholder(r_lvl) || !(is_geq(r_lvl, m_lvl))) { - progress = true; - // update result level - expr new_type = update_result_sort(inductive_decl_type(d), m_lvl); - d = inductive_decl(inductive_decl_name(d), new_type, inductive_decl_intros(d)); +/** + \brief Given a type \c t for an introduction rule, store the universe of the types of non-parameters in \c ls. + + \remark aux_u is an temporary universe used for inductive decls. It should be ignored. +*/ +static void accumulate_levels(type_checker & tc, expr t, unsigned num_params, level const & aux_u, buffer & ls) { + name_generator ngen(g_tmp_prefix); + unsigned i = 0; + while (is_pi(t)) { + if (i >= num_params) { + expr s = tc.ensure_type(binding_domain(t)); + level l = sort_level(s); + if (l == aux_u) { + // ignore, this is the auxiliary level + } else if (occurs(aux_u, l)) { + throw exception("failed to infer inductive datatype resultant universe, provide the universe levels explicitly"); + } else if (std::find(ls.begin(), ls.end(), l) == ls.end()) { + ls.push_back(l); } } - if (!progress) - break; + t = instantiate(binding_body(t), mk_local(ngen.next(), binding_name(t), binding_domain(t))); + i++; + } +} + +void throw_all_or_nothing() { + throw exception("invalid mutually recursive datatype declaration, " + "if the universe level of one type is provided, then all of them should be"); +} + +static void elaborate_inductive(buffer & decls, level_param_names const & lvls, unsigned num_params, parser & p) { + // temporary environment used during elaboration + environment env = p.env(); + // add fake universe level + name u_name(g_tmp_prefix, "u"); + env = env.add_universe(u_name); + level u = mk_global_univ(u_name); + std::unique_ptr tc(new type_checker(env)); + bool infer_result_universe = false; + unsigned first = true; + // elaborate inductive datatype types, and declare them in temporary environment. + for (inductive_decl & d : decls) { + level l = get_datatype_result_level(*tc, d); + expr t = inductive_decl_type(d); + if (is_placeholder(l)) { + if (first) + infer_result_universe = true; + else if (!infer_result_universe) + throw_all_or_nothing(); + t = update_result_sort(*tc, t, u); + } else if (!first && infer_result_universe) { + throw_all_or_nothing(); + } + t = p.elaborate(env, t); + env = env.add(check(env, mk_var_decl(inductive_decl_name(d), lvls, t))); + d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d)); + first = false; + } + tc.reset(new type_checker(env)); + buffer r_lvls; // used for inferring the universe level of resultant datatypes. + // elaborate introduction rules using temporary environment + for (inductive_decl & d : decls) { + buffer intros; + for (intro_rule const & ir : inductive_decl_intros(d)) { + expr t = p.elaborate(env, intro_rule_type(ir)); + if (infer_result_universe) + accumulate_levels(*tc, t, num_params, u, r_lvls); + intros.push_back(intro_rule(intro_rule_name(ir), t)); + } + d = inductive_decl(inductive_decl_name(d), inductive_decl_type(d), to_list(intros.begin(), intros.end())); + } + if (infer_result_universe) { + level r_lvl = normalize(mk_result_level(env.impredicative(), r_lvls)); + for (inductive_decl & d : decls) { + expr t = inductive_decl_type(d); + t = update_result_sort(*tc, t, r_lvl); + d = inductive_decl(inductive_decl_name(d), t, inductive_decl_intros(d)); + } } } @@ -319,16 +358,7 @@ environment inductive_cmd(parser & p) { num_params += section_params.size(); level_param_names ls = to_list(ls_buffer.begin(), ls_buffer.end()); - // Check if introduction rules do not have placeholders - for (inductive_decl const & d : decls) { - for (auto const & ir : inductive_decl_intros(d)) { - if (has_placeholder(intro_rule_type(ir))) - throw exception(sstream() << "invalid inductive datatype '" << inductive_decl_name(d) << "', " - << "introduction rule '" << intro_rule_name(ir) << "' has placeholders"); - } - } - // "Fix" the inductive type resultant type universe level, if it was not explicitly provided. - set_result_universes(decls, ls, num_params, p); + elaborate_inductive(decls, ls, num_params, p); env = module::add_inductive(env, ls, num_params, to_list(decls.begin(), decls.end())); // Create aliases/local refs levels section_levels = collect_section_levels(ls, p); diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index 7d7ca6c88c..5956ac8533 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -549,6 +549,10 @@ expr parser::elaborate(expr const & e) { return ::lean::elaborate(m_env, m_ios, e); } +expr parser::elaborate(environment const & env, expr const & e) { + return ::lean::elaborate(env, m_ios, e); +} + std::pair parser::elaborate(name const & n, expr const & t, expr const & v) { return ::lean::elaborate(m_env, m_ios, n, t, v); } diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index 8e88f2899e..34febfbdbe 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -263,6 +263,7 @@ public: struct no_undef_id_error_scope { parser & m_p; bool m_old; no_undef_id_error_scope(parser &); ~no_undef_id_error_scope(); }; expr elaborate(expr const & e); + expr elaborate(environment const & env, expr const & e); std::pair elaborate(name const & n, expr const & t, expr const & v); /** parse all commands in the input stream */ diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index cc90a1b6c2..42785478d0 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -81,6 +81,8 @@ public: type_checker(environment const & env); ~type_checker(); + environment const & env() const { return m_env; } + /** \brief Return the type of \c t. diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index f781898e4b..607aac7116 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -139,7 +139,7 @@ std::pair unify_simple(substitution const & s, expr } // Return true if m occurs in e -bool occurs(level const & m, level const & e) { +bool occurs_meta(level const & m, level const & e) { lean_assert(is_meta(m)); bool contains = false; for_each(e, [&](level const & l) { @@ -156,7 +156,7 @@ bool occurs(level const & m, level const & e) { std::pair unify_simple_core(substitution const & s, level const & lhs, level const & rhs, justification const & j) { lean_assert(is_meta(lhs)); - bool contains = occurs(lhs, rhs); + bool contains = occurs_meta(lhs, rhs); if (contains) { if (is_succ(rhs)) return mk_pair(unify_status::Failed, s); @@ -620,7 +620,7 @@ struct unifier_fn { status process_metavar_eq(level const & lhs, level const & rhs, justification const & j) { if (!is_meta(lhs)) return Continue; - bool contains = occurs(lhs, rhs); + bool contains = occurs_meta(lhs, rhs); if (contains) { if (is_succ(rhs)) return Failed; diff --git a/tests/lean/run/ind3.lean b/tests/lean/run/ind3.lean index e90350996f..27cc3c794e 100644 --- a/tests/lean/run/ind3.lean +++ b/tests/lean/run/ind3.lean @@ -1,6 +1,6 @@ inductive tree (A : Type) : Type := | node : A → forest A → tree A -with forest {A : Type} : Type := +with forest (A : Type) : Type := | nil : forest A | cons : tree A → forest A → forest A @@ -17,5 +17,3 @@ inductive group : Type := check group.{1} check group.{2} check group_rec.{1 1} - - diff --git a/tests/lean/run/ind6.lean b/tests/lean/run/ind6.lean index b091208dab..611dfc7cbe 100644 --- a/tests/lean/run/ind6.lean +++ b/tests/lean/run/ind6.lean @@ -1,6 +1,6 @@ inductive tree.{u} (A : Type.{u}) : Type.{max u 1} := | node : A → forest.{u} A → tree.{u} A -with forest.{u} {A : Type.{u}} : Type.{max u 1} := +with forest.{u} (A : Type.{u}) : Type.{max u 1} := | nil : forest.{u} A | cons : tree.{u} A → forest.{u} A → forest.{u} A