diff --git a/library/standard/data/list/basic.lean b/library/standard/data/list/basic.lean index ed72043ac3..f4b93ebf4d 100644 --- a/library/standard/data/list/basic.lean +++ b/library/standard/data/list/basic.lean @@ -90,9 +90,9 @@ list_induction_on s assume H : length (concat s t) = length s + length t, calc length (concat (cons x s) t ) = succ (length (concat s t)) : refl _ - ... = succ (length s + length t) : { H } - ... = succ (length s) + length t : {symm (add_succ_left _ _)} - ... = length (cons x s) + length t : refl _) + ... = succ (length s + length t) : { H } + ... = succ (length s) + length t : {symm (add_succ_left _ _)} + ... = length (cons x s) + length t : refl _) -- add_rewrite length_nil length_cons @@ -138,7 +138,7 @@ list_induction_on l (refl _) assume H: reverse (reverse l') = l', show reverse (reverse (x :: l')) = x :: l', from calc - reverse (reverse (x :: l')) = reverse (reverse l' ++ [x]) : refl _ + reverse (reverse (x :: l')) = reverse (reverse l' ++ [x]) : refl _ ... = reverse [x] ++ reverse (reverse l') : reverse_concat _ _ ... = [x] ++ l' : { H } ... = x :: l' : refl _) @@ -221,6 +221,9 @@ list_induction_on s theorem mem_concat (x : T) (s t : list T) : x ∈ s ++ t ↔ x ∈ s ∨ x ∈ t := iff_intro (mem_concat_imp_or _ _ _) (mem_or_imp_concat _ _ _) +section +set_option unifier.expensive true -- TODO(Leo): remove after we add delta-split step +#erase_cache mem_split theorem mem_split (x : T) (l : list T) : x ∈ l → ∃s t : list T, l = s ++ (x :: t) := list_induction_on l (take H : x ∈ nil, false_elim _ (iff_elim_left (mem_nil x) H)) @@ -235,8 +238,9 @@ list_induction_on l obtain s (H2 : ∃t : list T, l = s ++ (x :: t)), from IH H1, obtain t (H3 : l = s ++ (x :: t)), from H2, have H4 : y :: l = (y :: s) ++ (x :: t), - from trans (subst H3 (refl (y :: l))) (cons_concat _ _ _), + from subst H3 (refl (y :: l)), exists_intro _ (exists_intro _ H4))) +end -- Find -- ---- diff --git a/library/standard/logic/axioms/examples/diaconescu.lean b/library/standard/logic/axioms/examples/diaconescu.lean index c0081207ce..223d3014ce 100644 --- a/library/standard/logic/axioms/examples/diaconescu.lean +++ b/library/standard/logic/axioms/examples/diaconescu.lean @@ -32,6 +32,7 @@ or_elim u_def (assume Hp : p, or_inr Hp)) (assume Hp : p, or_inr Hp) +set_option unifier.expensive true lemma p_implies_uv [private] : p → u = v := assume Hp : p, have Hpred : (λ x, x = true ∨ p) = (λ x, x = false ∨ p), from diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 783bbadcaf..b91d8ea7cc 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -248,10 +248,6 @@ add_subdirectory(kernel/record) set(LEAN_LIBS ${LEAN_LIBS} record) add_subdirectory(library) set(LEAN_LIBS ${LEAN_LIBS} library) -# add_subdirectory(library/rewriter) -# set(LEAN_LIBS ${LEAN_LIBS} rewriter) -# add_subdirectory(library/simplifier) -# set(LEAN_LIBS ${LEAN_LIBS} simplifier) add_subdirectory(library/tactic) set(LEAN_LIBS ${LEAN_LIBS} tactic) add_subdirectory(library/error_handling) diff --git a/src/frontends/lean/builtin_cmds.cpp b/src/frontends/lean/builtin_cmds.cpp index fbb2fe4089..7a90cccb19 100644 --- a/src/frontends/lean/builtin_cmds.cpp +++ b/src/frontends/lean/builtin_cmds.cpp @@ -96,7 +96,7 @@ environment check_cmd(parser & p) { level_param_names new_ls; std::tie(e, new_ls) = p.elaborate_relaxed(e, ctx); auto tc = mk_type_checker_with_hints(p.env(), p.mk_ngen(), true); - expr type = tc->check(e, append(ls, new_ls)); + expr type = tc->check(e, append(ls, new_ls)).first; auto reg = p.regular_stream(); formatter const & fmt = reg.get_formatter(); options opts = p.ios().get_options(); diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index cb1dd07bb8..1384aba6e1 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -301,7 +301,6 @@ class elaborator { context m_context; // current local context: a list of local constants context m_full_context; // superset of m_context, it also contains non-contextual locals. - constraint_vect m_constraints; // constraints that must be solved for the elaborated term to be type correct. local_tactic_hints m_local_tactic_hints; // mapping from metavariable name ?m to tactic expression that should be used to solve it. // this mapping is populated by the 'by tactic-expr' expression. name_set m_displayed_errors; // set of metavariables that we already reported unsolved/unassigned @@ -315,35 +314,17 @@ class elaborator { scope_ctx(elaborator & e):m_scope1(e.m_context), m_scope2(e.m_full_context) {} }; - /** \brief Auxiliary object for creating backtracking points, and replacing the local scopes. - - \remark A new scope can only be created when m_constraints is empty. - */ + /** \brief Auxiliary object for creating backtracking points, and replacing the local scopes. */ struct new_scope { elaborator & m_main; context::scope_replace m_context_scope; context::scope_replace m_full_context_scope; new_scope(elaborator & e, list const & ctx, list const & full_ctx): m_main(e), m_context_scope(e.m_context, ctx), m_full_context_scope(e.m_full_context, full_ctx) { - lean_assert(m_main.m_constraints.empty()); - m_main.m_tc[0]->push(); - m_main.m_tc[1]->push(); - } - ~new_scope() { - m_main.m_tc[0]->pop(); - m_main.m_tc[1]->pop(); - m_main.m_constraints.clear(); - lean_assert(m_main.m_constraints.empty()); } + ~new_scope() {} }; - /* \brief Move all constraints generated by the type checker to the buffer m_constraints. */ - void consume_tc_cnstrs() { - for (unsigned i = 0; i < 2; i++) - while (auto c = m_tc[i]->next_cnstr()) - m_constraints.push_back(*c); - } - struct choice_elaborator { bool m_ignore_failure; choice_elaborator(bool ignore_failure = false):m_ignore_failure(ignore_failure) {} @@ -379,11 +360,10 @@ class elaborator { m_elab.save_identifier_info(f); try { new_scope s(m_elab, m_ctx, m_full_ctx); - expr r = m_elab.visit(c); - m_elab.consume_tc_cnstrs(); - list cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end()); - cs = cons(mk_eq_cnstr(m_mvar, r, justification(), m_relax_main_opaque), cs); - return optional(cs); + pair rcs = m_elab.visit(c); + expr r = rcs.first; + constraint_seq cs = mk_eq_cnstr(m_mvar, r, justification(), m_relax_main_opaque) + rcs.second; + return optional(cs.to_list()); } catch (exception &) {} } return optional(); @@ -444,13 +424,11 @@ class elaborator { } try { new_scope s(m_elab, m_ctx, m_full_ctx); - expr r = m_elab.visit(pre); // use elaborator to create metavariables, levels, etc. - m_elab.consume_tc_cnstrs(); - for (auto & c : m_elab.m_constraints) - c = update_justification(c, mk_composite1(m_jst, c.get_justification())); - list cs = to_list(m_elab.m_constraints.begin(), m_elab.m_constraints.end()); - cs = cons(mk_eq_cnstr(m_meta, r, m_jst, m_relax_main_opaque), cs); - return optional(cs); + pair rcs = m_elab.visit(pre); // use elaborator to create metavariables, levels, etc. + expr r = rcs.first; + buffer cs; + to_buffer(rcs.second, m_jst, cs); + return optional(cons(mk_eq_cnstr(m_meta, r, m_jst, m_relax_main_opaque), to_list(cs.begin(), cs.end()))); } catch (exception &) { return optional(); } @@ -539,31 +517,53 @@ public: return ::lean::mk_local(m_ngen.next(), n, t, bi); } - expr infer_type(expr const & e) { - lean_assert(closed(e)); - return m_tc[m_relax_main_opaque]->infer(e); + pair infer_type(expr const & e) { return m_tc[m_relax_main_opaque]->infer(e); } + pair whnf(expr const & e) { return m_tc[m_relax_main_opaque]->whnf(e); } + expr infer_type(expr const & e, constraint_seq & s) { return m_tc[m_relax_main_opaque]->infer(e, s); } + expr whnf(expr const & e, constraint_seq & s) { return m_tc[m_relax_main_opaque]->whnf(e, s); } + + static expr save_tag(expr && e, tag g) { e.set_tag(g); return e; } + expr mk_app(expr const & f, expr const & a, tag g) { return save_tag(::lean::mk_app(f, a), g); } + + /** \brief Store the pair (pos(e), type(r)) in the info_data if m_info_manager is available. */ + void save_info_data(expr const & e, expr const & r) { + if (!m_noinfo && infom() && pip() && (is_constant(e) || is_local(e) || is_placeholder(e))) { + if (auto p = pip()->get_pos_info(e)) { + expr t = m_tc[m_relax_main_opaque]->infer(r).first; + m_pre_info_data.add_type_info(p->first, p->second, t); + } + } } - expr whnf(expr const & e) { - return m_tc[m_relax_main_opaque]->whnf(e); + /** \brief Auxiliary function for saving information about which overloaded identifier was used by the elaborator. */ + void save_identifier_info(expr const & f) { + if (!m_noinfo && infom() && pip() && is_constant(f)) { + if (auto p = pip()->get_pos_info(f)) + m_pre_info_data.add_identifier_info(p->first, p->second, const_name(f)); + } } - /** \brief Clear constraint buffer \c m_constraints */ - void clear_constraints() { - m_constraints.clear(); + /** \brief Store actual term that was synthesized for an explicit placeholders */ + void save_synth_data(expr const & e, expr const & r) { + if (!m_noinfo && infom() && pip() && is_placeholder(e)) { + if (auto p = pip()->get_pos_info(e)) + m_pre_info_data.add_synth_info(p->first, p->second, r); + } } - void add_cnstr(constraint const & c) { - m_constraints.push_back(c); + void save_placeholder_info(expr const & e, expr const & r) { + if (is_explicit_placeholder(e)) { + save_info_data(e, r); + save_synth_data(e, r); + } } - static expr save_tag(expr && e, tag g) { - e.set_tag(g); - return e; - } - - expr mk_app(expr const & f, expr const & a, tag g) { - return save_tag(::lean::mk_app(f, a), g); + void copy_info_to_manager(substitution s) { + if (!infom()) + return; + m_pre_info_data.instantiate(s); + infom()->merge(m_pre_info_data); + m_pre_info_data.clear(); } list get_class_instances(expr const & type) { @@ -598,7 +598,7 @@ public: return mk_justification(m, [=](formatter const & fmt, substitution const & subst) { substitution tmp(subst); expr new_m = instantiate_meta(m, tmp); - expr new_type = type_checker(_env).infer(new_m); + expr new_type = type_checker(_env).infer(new_m).first; proof_state ps(goals(goal(new_m, new_type)), substitution(), name_generator("dontcare")); return format({format("failed to synthesize placeholder"), line(), ps.pp(fmt)}); }); @@ -607,7 +607,7 @@ public: /** \brief Create a metavariable, and attach choice constraint for generating solutions using class-instances and tactic-hints. */ - expr mk_placeholder_meta(optional const & type, tag g, bool is_strict = false) { + expr mk_placeholder_meta(optional const & type, tag g, bool is_strict, constraint_seq & cs) { expr m = m_context.mk_meta(type, g); list ctx = m_context.get_data(); list full_ctx = m_full_context.get_data(); @@ -639,7 +639,7 @@ public: j, ignore_failure, m_relax_main_opaque)); } }; - add_cnstr(mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::ClassInstance), false, j, m_relax_main_opaque)); + cs += mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::ClassInstance), false, j, m_relax_main_opaque); return m; } @@ -653,38 +653,31 @@ public: return none_expr(); } - void save_placeholder_info(expr const & e, expr const & r) { - if (is_explicit_placeholder(e)) { - save_info_data(e, r); - save_synth_data(e, r); - } - } - - expr visit_expecting_type(expr const & e) { + expr visit_expecting_type(expr const & e, constraint_seq & cs) { if (is_placeholder(e) && !placeholder_type(e)) { expr r = m_context.mk_type_meta(e.get_tag()); save_placeholder_info(e, r); return r; } else { - return visit(e); + return visit(e, cs); } } - expr visit_expecting_type_of(expr const & e, expr const & t) { + expr visit_expecting_type_of(expr const & e, expr const & t, constraint_seq & cs) { if (is_placeholder(e) && !placeholder_type(e)) { - expr r = mk_placeholder_meta(some_expr(t), e.get_tag(), is_strict_placeholder(e)); + expr r = mk_placeholder_meta(some_expr(t), e.get_tag(), is_strict_placeholder(e), cs); save_placeholder_info(e, r); return r; } else if (is_choice(e)) { - return visit_choice(e, some_expr(t)); + return visit_choice(e, some_expr(t), cs); } else if (is_by(e)) { - return visit_by(e, some_expr(t)); + return visit_by(e, some_expr(t), cs); } else { - return visit(e); + return visit(e, cs); } } - expr visit_choice(expr const & e, optional const & t) { + expr visit_choice(expr const & e, optional const & t, constraint_seq & cs) { lean_assert(is_choice(e)); // Possible optimization: try to lookahead and discard some of the alternatives. expr m = m_full_context.mk_meta(t, e.get_tag()); @@ -695,13 +688,13 @@ public: return choose(std::make_shared(*this, mvar, e, ctx, full_ctx, relax)); }; justification j = mk_justification("none of the overloads is applicable", some_expr(e)); - add_cnstr(mk_choice_cnstr(m, fn, to_delay_factor(cnstr_group::Basic), true, j, m_relax_main_opaque)); + cs += mk_choice_cnstr(m, fn, to_delay_factor(cnstr_group::Basic), true, j, m_relax_main_opaque); return m; } - expr visit_by(expr const & e, optional const & t) { + expr visit_by(expr const & e, optional const & t, constraint_seq & cs) { lean_assert(is_by(e)); - expr tac = visit(get_by_arg(e)); + expr tac = visit(get_by_arg(e), cs); expr m = m_context.mk_meta(t, e.get_tag()); m_local_tactic_hints.insert(mlocal_name(get_app_fn(m)), tac); return m; @@ -711,15 +704,17 @@ public: The result is a pair new_f, f_type, where new_f is the new value for \c f, and \c f_type is its type (and a Pi-expression) */ - pair ensure_fun(expr f) { - expr f_type = infer_type(f); + pair ensure_fun(expr f, constraint_seq & cs) { + expr f_type = infer_type(f, cs); if (!is_pi(f_type)) - f_type = whnf(f_type); + f_type = whnf(f_type, cs); if (!is_pi(f_type) && has_metavar(f_type)) { - f_type = whnf(f_type); + constraint_seq saved_cs = cs; + f_type = whnf(f_type, cs); if (!is_pi(f_type) && is_meta(f_type)) { + cs = saved_cs; // let type checker add constraint - f_type = m_tc[m_relax_main_opaque]->ensure_pi(f_type, f); + f_type = m_tc[m_relax_main_opaque]->ensure_pi(f_type, f, cs); } } if (!is_pi(f_type)) { @@ -727,7 +722,7 @@ public: optional c = get_coercion_to_fun(env(), f_type); if (c) { f = mk_app(*c, f, f.get_tag()); - f_type = infer_type(f); + f_type = infer_type(f, cs); lean_assert(is_pi(f_type)); } else { throw_kernel_exception(env(), f, [=](formatter const & fmt) { return pp_function_expected(fmt, f); }); @@ -738,18 +733,18 @@ public: } bool has_coercions_from(expr const & a_type) { - expr const & a_cls = get_app_fn(whnf(a_type)); + expr const & a_cls = get_app_fn(whnf(a_type).first); return is_constant(a_cls) && ::lean::has_coercions_from(env(), const_name(a_cls)); } bool has_coercions_to(expr const & d_type) { - expr const & d_cls = get_app_fn(whnf(d_type)); + expr const & d_cls = get_app_fn(whnf(d_type).first); return is_constant(d_cls) && ::lean::has_coercions_to(env(), const_name(d_cls)); } expr apply_coercion(expr const & a, expr a_type, expr d_type) { - a_type = whnf(a_type); - d_type = whnf(d_type); + a_type = whnf(a_type).first; + d_type = whnf(d_type).first; expr const & d_cls = get_app_fn(d_type); if (is_constant(d_cls)) { if (auto c = get_coercion(env(), a_type, const_name(d_cls))) @@ -781,7 +776,7 @@ public: return lazy_list(constraints(mk_eq_cnstr(mvar, a, justification(), relax))); } } - buffer cs; + constraint_seq cs; new_a_type = tc.whnf(new_a_type, cs); if (is_meta(d_type)) { // case-split @@ -789,17 +784,15 @@ public: get_user_coercions(env(), new_a_type, alts); buffer r; // first alternative: no coercion - cs.push_back(mk_eq_cnstr(mvar, a, justification(), relax)); - r.push_back(to_list(cs.begin(), cs.end())); - cs.pop_back(); + constraint_seq cs1 = cs + mk_eq_cnstr(mvar, a, justification(), relax); + r.push_back(cs1.to_list()); unsigned i = alts.size(); while (i > 0) { --i; auto const & t = alts[i]; expr new_a = mk_app(std::get<1>(t), a, a.get_tag()); - cs.push_back(mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax)); - r.push_back(to_list(cs.begin(), cs.end())); - cs.pop_back(); + constraint_seq csi = cs + mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); + r.push_back(csi.to_list()); } return to_lazy(to_list(r.begin(), r.end())); } else { @@ -810,46 +803,49 @@ public: if (auto c = get_coercion(env(), new_a_type, const_name(d_cls))) new_a = mk_app(*c, a, a.get_tag()); } - cs.push_back(mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax)); - return lazy_list(to_list(cs.begin(), cs.end())); + cs += mk_eq_cnstr(mvar, new_a, new_a_type_jst, relax); + return lazy_list(cs.to_list()); } }; return mk_choice_cnstr(m, choice_fn, delay_factor, true, j, m_relax_main_opaque); } /** \brief Given a term a : a_type, and an expected type generate a metavariable with a delayed coercion. */ - expr mk_delayed_coercion(expr const & a, expr const & a_type, expr const & expected_type, justification const & j) { + pair mk_delayed_coercion(expr const & a, expr const & a_type, expr const & expected_type, justification const & j) { expr m = m_full_context.mk_meta(some_expr(expected_type), a.get_tag()); - add_cnstr(mk_delayed_coercion_cnstr(m, a, a_type, j, to_delay_factor(cnstr_group::Basic))); - return m; + return to_ecs(m, mk_delayed_coercion_cnstr(m, a, a_type, j, to_delay_factor(cnstr_group::Basic))); } /** \brief Given a term a : a_type, ensure it has type \c expected_type. Apply coercions if needed \remark relax == true affects how opaque definitions in the main module are treated. */ - expr ensure_type(expr const & a, expr const & a_type, expr const & expected_type, justification const & j, bool relax) { + pair ensure_has_type(expr const & a, expr const & a_type, expr const & expected_type, + justification const & j, bool relax) { if (is_meta(expected_type) && has_coercions_from(a_type)) { return mk_delayed_coercion(a, a_type, expected_type, j); } else if (is_meta(a_type) && has_coercions_to(expected_type)) { return mk_delayed_coercion(a, a_type, expected_type, j); - } else if (m_tc[relax]->is_def_eq(a_type, expected_type, j)) { - return a; } else { - expr new_a = apply_coercion(a, a_type, expected_type); - bool coercion_worked = false; - if (!is_eqp(a, new_a)) { - expr new_a_type = infer_type(new_a); - coercion_worked = m_tc[relax]->is_def_eq(new_a_type, expected_type, j); - } - if (coercion_worked) { - return new_a; - } else if (has_metavar(a_type) || has_metavar(expected_type)) { - // rely on unification hints to solve this constraint - add_cnstr(mk_eq_cnstr(a_type, expected_type, j, relax)); - return a; + auto dcs = m_tc[relax]->is_def_eq(a_type, expected_type, j); + if (dcs.first) { + return to_ecs(a, dcs.second); } else { - throw unifier_exception(j, substitution()); + expr new_a = apply_coercion(a, a_type, expected_type); + bool coercion_worked = false; + constraint_seq cs; + if (!is_eqp(a, new_a)) { + expr new_a_type = infer_type(new_a, cs); + coercion_worked = m_tc[relax]->is_def_eq(new_a_type, expected_type, j, cs); + } + if (coercion_worked) { + return to_ecs(new_a, cs); + } else if (has_metavar(a_type) || has_metavar(expected_type)) { + // rely on unification hints to solve this constraint + return to_ecs(a, mk_eq_cnstr(a_type, expected_type, j, relax)); + } else { + throw unifier_exception(j, substitution()); + } } } } @@ -862,7 +858,7 @@ public: /** \brief Process ((choice f_1 ... f_n) a_1 ... a_k) as (choice (f_1 a_1 ... a_k) ... (f_n a_1 ... a_k)) */ - expr visit_choice_app(expr const & e) { + expr visit_choice_app(expr const & e, constraint_seq & cs) { buffer args; expr f = get_app_rev_args(e, args); bool expl = is_explicit(f); @@ -877,46 +873,51 @@ public: f_i = copy_tag(f_i, mk_explicit(f_i)); new_choices.push_back(mk_rev_app(f_i, args)); } - return visit_choice(copy_tag(e, mk_choice(new_choices.size(), new_choices.data())), none_expr()); + return visit_choice(copy_tag(e, mk_choice(new_choices.size(), new_choices.data())), none_expr(), cs); } - expr visit_app(expr const & e) { + expr visit_app(expr const & e, constraint_seq & cs) { if (is_choice_app(e)) - return visit_choice_app(e); + return visit_choice_app(e, cs); + constraint_seq f_cs; bool expl = is_explicit(get_app_fn(e)); - expr f = visit(app_fn(e)); - auto f_t = ensure_fun(f); + expr f = visit(app_fn(e), f_cs); + auto f_t = ensure_fun(f, f_cs); f = f_t.first; expr f_type = f_t.second; lean_assert(is_pi(f_type)); if (!expl) { bool first = true; while (binding_info(f_type).is_strict_implicit() || (!first && binding_info(f_type).is_implicit())) { - tag g = f.get_tag(); - expr imp_arg = mk_placeholder_meta(some_expr(binding_domain(f_type)), g); - f = mk_app(f, imp_arg, g); - auto f_t = ensure_fun(f); - f = f_t.first; - f_type = f_t.second; - first = false; + tag g = f.get_tag(); + bool is_strict = false; + expr imp_arg = mk_placeholder_meta(some_expr(binding_domain(f_type)), g, is_strict, f_cs); + f = mk_app(f, imp_arg, g); + auto f_t = ensure_fun(f, f_cs); + f = f_t.first; + f_type = f_t.second; + first = false; } if (!first) { // we save the info data again for application of functions with strict implicit arguments save_info_data(get_app_fn(e), f); } } + constraint_seq a_cs; expr d_type = binding_domain(f_type); - expr a = visit_expecting_type_of(app_arg(e), d_type); - expr a_type = infer_type(a); + expr a = visit_expecting_type_of(app_arg(e), d_type, a_cs); + expr a_type = infer_type(a, a_cs); expr r = mk_app(f, a, e.get_tag()); justification j = mk_app_justification(r, a, d_type, a_type); - expr new_a = ensure_type(a, a_type, d_type, j, m_relax_main_opaque); + auto new_a_cs = ensure_has_type(a, a_type, d_type, j, m_relax_main_opaque); + expr new_a = new_a_cs.first; + cs += f_cs + new_a_cs.second + a_cs; return update_app(r, app_fn(r), new_a); } - expr visit_placeholder(expr const & e) { - expr r = mk_placeholder_meta(placeholder_type(e), e.get_tag(), is_strict_placeholder(e)); + expr visit_placeholder(expr const & e, constraint_seq & cs) { + expr r = mk_placeholder_meta(placeholder_type(e), e.get_tag(), is_strict_placeholder(e), cs); save_placeholder_info(e, r); return r; } @@ -934,7 +935,7 @@ public: return update_sort(e, replace_univ_placeholder(sort_level(e))); } - expr visit_macro(expr const & e) { + expr visit_macro(expr const & e, constraint_seq & cs) { if (is_as_is(e)) { return get_as_is_arg(e); } else { @@ -942,37 +943,11 @@ public: // Perhaps, we should throw error. buffer args; for (unsigned i = 0; i < macro_num_args(e); i++) - args.push_back(visit(macro_arg(e, i))); + args.push_back(visit(macro_arg(e, i), cs)); return update_macro(e, args.size(), args.data()); } } - /** \brief Store the pair (pos(e), type(r)) in the info_data if m_info_manager is available. */ - void save_info_data(expr const & e, expr const & r) { - if (!m_noinfo && infom() && pip() && (is_constant(e) || is_local(e) || is_placeholder(e))) { - if (auto p = pip()->get_pos_info(e)) { - type_checker::scope scope(*m_tc[m_relax_main_opaque]); - expr t = m_tc[m_relax_main_opaque]->infer(r); - m_pre_info_data.add_type_info(p->first, p->second, t); - } - } - } - - void save_identifier_info(expr const & f) { - if (!m_noinfo && infom() && pip() && is_constant(f)) { - if (auto p = pip()->get_pos_info(f)) - m_pre_info_data.add_identifier_info(p->first, p->second, const_name(f)); - } - } - - void save_synth_data(expr const & e, expr const & r) { - if (!m_noinfo && infom() && pip() && is_placeholder(e)) { - if (auto p = pip()->get_pos_info(e)) { - m_pre_info_data.add_synth_info(p->first, p->second, r); - } - } - } - expr visit_constant(expr const & e) { declaration d = env().get(const_name(e)); buffer ls; @@ -991,20 +966,20 @@ public: } /** \brief Make sure \c e is a type. If it is not, then try to apply coercions. */ - expr ensure_type(expr const & e) { - expr t = infer_type(e); + expr ensure_type(expr const & e, constraint_seq & cs) { + expr t = infer_type(e, cs); if (is_sort(t)) return e; - t = whnf(t); + t = whnf(t, cs); if (is_sort(t)) return e; if (has_metavar(t)) { - t = whnf(t); + t = whnf(t, cs); if (is_sort(t)) return e; if (is_meta(t)) { // let type checker add constraint - m_tc[m_relax_main_opaque]->ensure_sort(t, e); + m_tc[m_relax_main_opaque]->ensure_sort(t, e, cs); return e; } } @@ -1040,14 +1015,14 @@ public: }); } - expr visit_binding(expr e, expr_kind k) { + expr visit_binding(expr e, expr_kind k, constraint_seq & cs) { scope_ctx scope(*this); buffer ds, ls, es; while (e.kind() == k) { es.push_back(e); expr d = binding_domain(e); d = instantiate_rev_locals(d, ls.size(), ls.data()); - d = ensure_type(visit_expecting_type(d)); + d = ensure_type(visit_expecting_type(d, cs), cs); ds.push_back(d); expr l = mk_local(binding_name(e), d, binding_info(e)); if (binding_info(e).is_contextual()) @@ -1058,7 +1033,7 @@ public: } lean_assert(ls.size() == es.size() && ls.size() == ds.size()); e = instantiate_rev_locals(e, ls.size(), ls.data()); - e = (k == expr_kind::Pi) ? ensure_type(visit_expecting_type(e)) : visit(e); + e = (k == expr_kind::Pi) ? ensure_type(visit_expecting_type(e, cs), cs) : visit(e, cs); e = abstract_locals(e, ls.size(), ls.data()); unsigned i = ls.size(); while (i > 0) { @@ -1067,19 +1042,19 @@ public: } return e; } - expr visit_pi(expr const & e) { return visit_binding(e, expr_kind::Pi); } - expr visit_lambda(expr const & e) { return visit_binding(e, expr_kind::Lambda); } + expr visit_pi(expr const & e, constraint_seq & cs) { return visit_binding(e, expr_kind::Pi, cs); } + expr visit_lambda(expr const & e, constraint_seq & cs) { return visit_binding(e, expr_kind::Lambda, cs); } - expr visit_core(expr const & e) { + expr visit_core(expr const & e, constraint_seq & cs) { if (is_placeholder(e)) { - return visit_placeholder(e); + return visit_placeholder(e, cs); } else if (is_choice(e)) { - return visit_choice(e, none_expr()); + return visit_choice(e, none_expr(), cs); } else if (is_by(e)) { - return visit_by(e, none_expr()); + return visit_by(e, none_expr(), cs); } else if (is_noinfo(e)) { flet let(m_noinfo, true); - return visit(get_annotation_arg(e)); + return visit(get_annotation_arg(e), cs); } else { switch (e.kind()) { case expr_kind::Local: return e; @@ -1087,53 +1062,59 @@ public: case expr_kind::Sort: return visit_sort(e); case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Constant: return visit_constant(e); - case expr_kind::Macro: return visit_macro(e); - case expr_kind::Lambda: return visit_lambda(e); - case expr_kind::Pi: return visit_pi(e); - case expr_kind::App: return visit_app(e); + case expr_kind::Macro: return visit_macro(e, cs); + case expr_kind::Lambda: return visit_lambda(e, cs); + case expr_kind::Pi: return visit_pi(e, cs); + case expr_kind::App: return visit_app(e, cs); } lean_unreachable(); // LCOV_EXCL_LINE } } - expr visit(expr const & e) { + pair visit(expr const & e) { expr r; expr b = e; + constraint_seq cs; if (is_explicit(e)) { b = get_explicit_arg(e); - r = visit_core(get_explicit_arg(e)); + r = visit_core(get_explicit_arg(e), cs); } else if (is_explicit(get_app_fn(e))) { - r = visit_core(e); + r = visit_core(e, cs); } else { if (is_implicit(e)) { r = get_implicit_arg(e); if (is_explicit(r)) r = get_explicit_arg(r); b = r; - r = visit_core(r); + r = visit_core(r, cs); } else { - r = visit_core(e); + r = visit_core(e, cs); } if (!is_lambda(r)) { - tag g = e.get_tag(); - expr r_type = whnf(infer_type(r)); + tag g = e.get_tag(); + expr r_type = whnf(infer_type(r, cs), cs); expr imp_arg; + bool is_strict = false; while (is_pi(r_type) && binding_info(r_type).is_implicit()) { - imp_arg = mk_placeholder_meta(some_expr(binding_domain(r_type)), g); + imp_arg = mk_placeholder_meta(some_expr(binding_domain(r_type)), g, is_strict, cs); r = mk_app(r, imp_arg, g); - r_type = whnf(instantiate(binding_body(r_type), imp_arg)); + r_type = whnf(instantiate(binding_body(r_type), imp_arg), cs); } } } save_info_data(b, r); - return r; + return mk_pair(r, cs); } - lazy_list solve() { - consume_tc_cnstrs(); - buffer cs; - cs.append(m_constraints); - m_constraints.clear(); - return unify(env(), cs.size(), cs.data(), m_ngen.mk_child(), true, ios().get_options()); + expr visit(expr const & e, constraint_seq & cs) { + auto r = visit(e); + cs += r.second; + return r.first; + } + + lazy_list solve(constraint_seq const & cs) { + buffer tmp; + cs.linearize(tmp); + return unify(env(), tmp.size(), tmp.data(), m_ngen.mk_child(), true, ios().get_options()); } static void collect_metavars(expr const & e, buffer & mvars) { @@ -1245,7 +1226,8 @@ public: if (!meta) return; meta = instantiate_meta(*meta, subst); - expr type = m_tc[m_relax_main_opaque]->infer(*meta); + // TODO(Leo): we are discarding constraints here + expr type = m_tc[m_relax_main_opaque]->infer(*meta).first; // first solve unassigned metavariables in type type = solve_unassigned_mvars(subst, type, visited); proof_state ps(goals(goal(*meta, type)), subst, m_ngen.mk_child()); @@ -1279,7 +1261,7 @@ public: return has_metavar(e); if (auto it = m_mvar2meta.find(mlocal_name(e))) { expr meta = tmp_s.instantiate(*it); - expr meta_type = tmp_s.instantiate(type_checker(env()).infer(meta)); + expr meta_type = tmp_s.instantiate(type_checker(env()).infer(meta).first); goal g(meta, meta_type); display_unsolved_proof_state(e, proof_state(goals(g), substitution(), m_ngen), "don't know how to synthesize it"); @@ -1306,20 +1288,13 @@ public: return std::make_tuple(r, to_list(new_ps.begin(), new_ps.end())); } - void copy_info_to_manager(substitution s) { - if (!infom()) - return; - m_pre_info_data.instantiate(s); - infom()->merge(m_pre_info_data); - m_pre_info_data.clear(); - } - std::tuple operator()(expr const & e, bool _ensure_type, bool relax_main_opaque) { flet set_relax(m_relax_main_opaque, relax_main_opaque && !get_hide_main_opaque(env())); - expr r = visit(e); + constraint_seq cs; + expr r = visit(e, cs); if (_ensure_type) - r = ensure_type(r); - auto p = solve().pull(); + r = ensure_type(r, cs); + auto p = solve(cs).pull(); lean_assert(p); substitution s = p->first; auto result = apply(s, r); @@ -1329,17 +1304,21 @@ public: std::tuple operator()(expr const & t, expr const & v, name const & n, bool is_opaque) { lean_assert(!has_local(t)); lean_assert(!has_local(v)); - expr r_t = ensure_type(visit(t)); + constraint_seq t_cs; + expr r_t = ensure_type(visit(t, t_cs), t_cs); // Opaque definitions in the main module may treat other opaque definitions (in the main module) as transparent. flet set_relax(m_relax_main_opaque, is_opaque && !get_hide_main_opaque(env())); - expr r_v = visit(v); - expr r_v_type = infer_type(r_v); + constraint_seq v_cs; + expr r_v = visit(v, v_cs); + expr r_v_type = infer_type(r_v, v_cs); justification j = mk_justification(r_v, [=](formatter const & fmt, substitution const & subst) { substitution s(subst); return pp_def_type_mismatch(fmt, n, s.instantiate(r_t), s.instantiate(r_v_type)); }); - r_v = ensure_type(r_v, r_v_type, r_t, j, is_opaque); - auto p = solve().pull(); + pair r_v_cs = ensure_has_type(r_v, r_v_type, r_t, j, is_opaque); + r_v = r_v_cs.first; + constraint_seq cs = t_cs + r_v_cs.second + v_cs; + auto p = solve(cs).pull(); lean_assert(p); substitution s = p->first; name_set univ_params = collect_univ_params(r_v, collect_univ_params(r_t)); diff --git a/src/frontends/lean/inductive_cmd.cpp b/src/frontends/lean/inductive_cmd.cpp index f8b41bb41a..ece4aa78af 100644 --- a/src/frontends/lean/inductive_cmd.cpp +++ b/src/frontends/lean/inductive_cmd.cpp @@ -174,9 +174,9 @@ struct inductive_cmd_fn { /** \brief Return the universe level of the given type, if it is not a sort, then raise an exception. */ level get_datatype_result_level(expr d_type) { - d_type = m_tc->whnf(d_type); + d_type = m_tc->whnf(d_type).first; while (is_pi(d_type)) { - d_type = m_tc->whnf(binding_body(d_type)); + d_type = m_tc->whnf(binding_body(d_type)).first; } if (!is_sort(d_type)) throw_error(sstream() << "invalid inductive datatype, resultant type is not a sort"); @@ -185,7 +185,7 @@ struct inductive_cmd_fn { /** \brief Update the result sort of the given type */ expr update_result_sort(expr t, level const & l) { - t = m_tc->whnf(t); + t = m_tc->whnf(t).first; if (is_pi(t)) { return update_binding(t, binding_domain(t), update_result_sort(binding_body(t), l)); } else if (is_sort(t)) { @@ -215,7 +215,7 @@ struct inductive_cmd_fn { /** \brief Check if the parameters of \c d_type and \c first_d_type are equal. */ void check_params(expr d_type, expr first_d_type) { for (unsigned i = 0; i < m_num_params; i++) { - if (!m_tc->is_def_eq(binding_domain(d_type), binding_domain(first_d_type))) + if (!m_tc->is_def_eq(binding_domain(d_type), binding_domain(first_d_type)).first) throw_error(sstream() << "invalid parameter #" << (i+1) << " in mutually recursive inductive declaration, " << "all inductive types must have equivalent parameters"); expr l = mk_local_for(d_type); @@ -418,7 +418,7 @@ struct inductive_cmd_fn { unsigned i = 0; while (is_pi(intro_type)) { if (i >= m_num_params) { - expr s = m_tc->ensure_type(binding_domain(intro_type)); + expr s = m_tc->ensure_type(binding_domain(intro_type)).first; level l = sort_level(s); if (l == m_u) { // ignore, this is the auxiliary level diff --git a/src/frontends/lean/pp.cpp b/src/frontends/lean/pp.cpp index 0c7f38a8c4..9856ebb3a1 100644 --- a/src/frontends/lean/pp.cpp +++ b/src/frontends/lean/pp.cpp @@ -131,7 +131,7 @@ bool pretty_fn::is_implicit(expr const & f) { if (m_implict) return false; // showing implicit arguments try { - binder_info bi = binding_info(m_tc.ensure_pi(m_tc.infer(f))); + binder_info bi = binding_info(m_tc.ensure_pi(m_tc.infer(f).first).first); return bi.is_implicit() || bi.is_strict_implicit(); } catch (...) { return false; @@ -140,7 +140,7 @@ bool pretty_fn::is_implicit(expr const & f) { bool pretty_fn::is_prop(expr const & e) { try { - return m_env.impredicative() && m_tc.is_prop(e); + return m_env.impredicative() && m_tc.is_prop(e).first; } catch (...) { return false; } diff --git a/src/frontends/lean/proof_qed_ext.cpp b/src/frontends/lean/proof_qed_ext.cpp index 36dbbab613..d344198a29 100644 --- a/src/frontends/lean/proof_qed_ext.cpp +++ b/src/frontends/lean/proof_qed_ext.cpp @@ -66,7 +66,7 @@ typedef scoped_ext proof_qed_ext; static void check_valid_tactic(environment const & env, expr const & pre_tac) { type_checker tc(env); - if (!tc.is_def_eq(tc.infer(pre_tac), get_tactic_type())) + if (!tc.is_def_eq(tc.infer(pre_tac).first, get_tactic_type()).first) throw exception("invalid proof-qed pre-tactic update, argument is not a tactic"); } diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index bd09ca6a7e..5897b49752 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -239,7 +239,7 @@ struct structure_cmd_fn { unsigned i = 0; while (is_pi(intro_type)) { if (i >= num_params) { - expr s = tc->ensure_type(binding_domain(intro_type)); + expr s = tc->ensure_type(binding_domain(intro_type)).first; level l = sort_level(s); if (l == m_u) { // ignore, this is the auxiliary level diff --git a/src/kernel/constraint.cpp b/src/kernel/constraint.cpp index bf8adf691c..2aba36bd68 100644 --- a/src/kernel/constraint.cpp +++ b/src/kernel/constraint.cpp @@ -5,7 +5,11 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "util/rc.h" +#include "kernel/expr.h" +#include "kernel/justification.h" +#include "kernel/metavar.h" #include "kernel/constraint.h" + namespace lean { struct constraint_cell { void dealloc(); @@ -104,6 +108,12 @@ constraint update_justification(constraint const & c, justification const & j) { lean_unreachable(); // LCOV_EXCL_LINE } +void to_buffer(constraint_seq const & cs, justification const & j, buffer & r) { + return cs.for_each([&](constraint const & c) { + r.push_back(update_justification(c, mk_composite1(c.get_justification(), j))); + }); +} + std::ostream & operator<<(std::ostream & out, constraint const & c) { switch (c.kind()) { case constraint_kind::Eq: diff --git a/src/kernel/constraint.h b/src/kernel/constraint.h index f88b089741..4af0d564d8 100644 --- a/src/kernel/constraint.h +++ b/src/kernel/constraint.h @@ -9,11 +9,13 @@ Author: Leonardo de Moura #include "util/lazy_list.h" #include "util/list.h" #include "util/name_generator.h" -#include "kernel/expr.h" -#include "kernel/justification.h" -#include "kernel/metavar.h" +#include "util/sequence.h" +#include "kernel/level.h" namespace lean { +class expr; +class justification; +class substitution; /** \brief The lean kernel type checker produces two kinds of constraints: @@ -117,6 +119,11 @@ unsigned cnstr_delay_factor(constraint const & c); /** \brief Return true iff the given choice constraints owns the right to assign the metavariable in \c c. */ bool cnstr_is_owner(constraint const & c); +typedef sequence constraint_seq; +inline constraint_seq empty_cs() { return constraint_seq(); } +/** \brief Copy constraints in cs to r, and append justification j to them. */ +void to_buffer(constraint_seq const & cs, justification const & j, buffer & r); + /** \brief Printer for debugging purposes */ std::ostream & operator<<(std::ostream & out, constraint const & c); } diff --git a/src/kernel/converter.cpp b/src/kernel/converter.cpp index 4b54fe5dc5..c7cf8399cc 100644 --- a/src/kernel/converter.cpp +++ b/src/kernel/converter.cpp @@ -43,7 +43,8 @@ bool is_opaque(declaration const & d, name_set const & extra_opaque, optional is_delta_core(environment const & env, expr const & e, name_set const & extra_opaque, optional const & mod_idx) { +static optional is_delta_core(environment const & env, expr const & e, name_set const & extra_opaque, + optional const & mod_idx) { if (is_constant(e)) { if (auto d = env.find(const_name(e))) if (d->is_definition() && !is_opaque(*d, extra_opaque, mod_idx)) @@ -66,14 +67,18 @@ optional is_delta(environment const & env, expr const & e, name_set } static no_delayed_justification g_no_delayed_jst; -bool converter::is_def_eq(expr const & t, expr const & s, type_checker & c) { +pair converter::is_def_eq(expr const & t, expr const & s, type_checker & c) { return is_def_eq(t, s, c, g_no_delayed_jst); } /** \brief Do nothing converter */ struct dummy_converter : public converter { - virtual expr whnf(expr const & e, type_checker &) { return e; } - virtual bool is_def_eq(expr const &, expr const &, type_checker &, delayed_justification &) { return true; } + virtual pair whnf(expr const & e, type_checker &) { + return mk_pair(e, constraint_seq()); + } + virtual pair is_def_eq(expr const &, expr const &, type_checker &, delayed_justification &) { + return mk_pair(true, constraint_seq()); + } virtual optional get_module_idx() const { return optional(); } }; @@ -82,18 +87,17 @@ std::unique_ptr mk_dummy_converter() { } name converter::mk_fresh_name(type_checker & tc) { return tc.mk_fresh_name(); } -expr converter::infer_type(type_checker & tc, expr const & e) { return tc.infer_type(e); } -void converter::add_cnstr(type_checker & tc, constraint const & c) { return tc.add_cnstr(c); } +pair converter::infer_type(type_checker & tc, expr const & e) { return tc.infer_type(e); } extension_context & converter::get_extension(type_checker & tc) { return tc.get_extension(); } static expr g_dont_care(Const("dontcare")); struct default_converter : public converter { - environment m_env; - optional m_module_idx; - bool m_memoize; - name_set m_extra_opaque; - expr_struct_map m_whnf_core_cache; - expr_struct_map m_whnf_cache; + environment m_env; + optional m_module_idx; + bool m_memoize; + name_set m_extra_opaque; + expr_struct_map m_whnf_core_cache; + expr_struct_map> m_whnf_cache; default_converter(environment const & env, optional mod_idx, bool memoize, name_set const & extra_opaque): m_env(env), m_module_idx(mod_idx), m_memoize(memoize), m_extra_opaque(extra_opaque) { @@ -109,10 +113,19 @@ struct default_converter : public converter { } /** \brief Apply normalizer extensions to \c e. */ - optional norm_ext(expr const & e, type_checker & c) { + optional> norm_ext(expr const & e, type_checker & c) { return m_env.norm_ext()(e, get_extension(c)); } + optional d_norm_ext(expr const & e, type_checker & c, constraint_seq & cs) { + if (auto r = norm_ext(e, c)) { + cs = cs + r->second; + return some_expr(r->first); + } else { + return none_expr(); + } + } + /** \brief Return true if \c e may be reduced later after metavariables are instantiated. */ bool may_reduce_later(expr const & e, type_checker & c) { return m_env.norm_ext().may_reduce_later(e, get_extension(c)); @@ -260,11 +273,11 @@ struct default_converter : public converter { } /** \brief Put expression \c t in weak head normal form */ - virtual expr whnf(expr const & e_prime, type_checker & c) { + virtual pair whnf(expr const & e_prime, type_checker & c) { // Do not cache easy cases switch (e_prime.kind()) { case expr_kind::Var: case expr_kind::Sort: case expr_kind::Meta: case expr_kind::Local: case expr_kind::Pi: - return e_prime; + return to_ecs(e_prime); case expr_kind::Lambda: case expr_kind::Macro: case expr_kind::App: case expr_kind::Constant: break; } @@ -278,19 +291,30 @@ struct default_converter : public converter { } expr t = e; + constraint_seq cs; while (true) { expr t1 = whnf_core(t, 0, c); - auto new_t = norm_ext(t1, c); - if (new_t) { - t = *new_t; + if (auto new_t = d_norm_ext(t1, c, cs)) { + t = *new_t; } else { + auto r = mk_pair(t1, cs); if (m_memoize) - m_whnf_cache.insert(mk_pair(e, t1)); - return t1; + m_whnf_cache.insert(mk_pair(e, r)); + return r; } } } + expr whnf(expr const & e_prime, type_checker & c, constraint_seq & cs) { + auto r = whnf(e_prime, c); + cs = cs + r.second; + return r.first; + } + + pair to_bcs(bool b) { return mk_pair(b, constraint_seq()); } + pair to_bcs(bool b, constraint const & c) { return mk_pair(b, constraint_seq(c)); } + pair to_bcs(bool b, constraint_seq const & cs) { return mk_pair(b, cs); } + /** \brief Given lambda/Pi expressions \c t and \c s, return true iff \c t is def eq to \c s. @@ -300,7 +324,7 @@ struct default_converter : public converter { and body(t) is definitionally equal to body(s) */ - bool is_def_eq_binding(expr t, expr s, type_checker & c, delayed_justification & jst) { + bool is_def_eq_binding(expr t, expr s, type_checker & c, delayed_justification & jst, constraint_seq & cs) { lean_assert(t.kind() == s.kind()); lean_assert(is_binding(t)); expr_kind k = t.kind(); @@ -310,7 +334,7 @@ struct default_converter : public converter { if (binding_domain(t) != binding_domain(s)) { var_s_type = instantiate_rev(binding_domain(s), subst.size(), subst.data()); expr var_t_type = instantiate_rev(binding_domain(t), subst.size(), subst.data()); - if (!is_def_eq(var_t_type, *var_s_type, c, jst)) + if (!is_def_eq(var_t_type, *var_s_type, c, jst, cs)) return false; } if (!closed(binding_body(t)) || !closed(binding_body(s))) { @@ -325,44 +349,53 @@ struct default_converter : public converter { s = binding_body(s); } while (t.kind() == k && s.kind() == k); return is_def_eq(instantiate_rev(t, subst.size(), subst.data()), - instantiate_rev(s, subst.size(), subst.data()), c, jst); + instantiate_rev(s, subst.size(), subst.data()), c, jst, cs); } - bool is_def_eq(level const & l1, level const & l2, type_checker & c, delayed_justification & jst) { + bool is_def_eq(level const & l1, level const & l2, delayed_justification & jst, constraint_seq & cs) { if (is_equivalent(l1, l2)) { return true; } else if (has_meta(l1) || has_meta(l2)) { - add_cnstr(c, mk_level_eq_cnstr(l1, l2, jst.get())); + cs = cs + constraint_seq(mk_level_eq_cnstr(l1, l2, jst.get())); return true; } else { return false; } } - bool is_def_eq(levels const & ls1, levels const & ls2, type_checker & c, delayed_justification & jst) { - if (is_nil(ls1) && is_nil(ls2)) + bool is_def_eq(levels const & ls1, levels const & ls2, type_checker & c, delayed_justification & jst, constraint_seq & cs) { + if (is_nil(ls1) && is_nil(ls2)) { return true; - else if (!is_nil(ls1) && !is_nil(ls2)) - return is_def_eq(head(ls1), head(ls2), c, jst) && is_def_eq(tail(ls1), tail(ls2), c, jst); - else + } else if (!is_nil(ls1) && !is_nil(ls2)) { + return + is_def_eq(head(ls1), head(ls2), jst, cs) && + is_def_eq(tail(ls1), tail(ls2), c, jst, cs); + } else { return false; + } + } + + static pair to_lbcs(lbool l) { return mk_pair(l, constraint_seq()); } + static pair to_lbcs(lbool l, constraint const & c) { return mk_pair(l, constraint_seq(c)); } + static pair to_lbcs(pair const & bcs) { + return mk_pair(to_lbool(bcs.first), bcs.second); } /** \brief This is an auxiliary method for is_def_eq. It handles the "easy cases". */ - lbool quick_is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & jst) { + lbool quick_is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & jst, constraint_seq & cs) { if (t == s) return l_true; // t and s are structurally equal if (is_meta(t) || is_meta(s)) { // if t or s is a metavariable (or the application of a metavariable), then add constraint - add_cnstr(c, mk_eq_cnstr(t, s, jst.get())); + cs = cs + constraint_seq(mk_eq_cnstr(t, s, jst.get())); return l_true; } if (t.kind() == s.kind()) { switch (t.kind()) { case expr_kind::Lambda: case expr_kind::Pi: - return to_lbool(is_def_eq_binding(t, s, c, jst)); + return to_lbool(is_def_eq_binding(t, s, c, jst, cs)); case expr_kind::Sort: - return to_lbool(is_def_eq(sort_level(t), sort_level(s), c, jst)); + return to_lbool(is_def_eq(sort_level(t), sort_level(s), c, jst, cs)); case expr_kind::Meta: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Var: case expr_kind::Local: case expr_kind::App: @@ -378,9 +411,9 @@ struct default_converter : public converter { \brief Return true if arguments of \c t are definitionally equal to arguments of \c s. This method is used to implement an optimization in the method \c is_def_eq. */ - bool is_def_eq_args(expr t, expr s, type_checker & c, delayed_justification & jst) { + bool is_def_eq_args(expr t, expr s, type_checker & c, delayed_justification & jst, constraint_seq & cs) { while (is_app(t) && is_app(s)) { - if (!is_def_eq(app_arg(t), app_arg(s), c, jst)) + if (!is_def_eq(app_arg(t), app_arg(s), c, jst, cs)) return false; t = app_fn(t); s = app_fn(s); @@ -395,33 +428,47 @@ struct default_converter : public converter { } /** \brief Try to solve (fun (x : A), B) =?= s by trying eta-expansion on s */ - bool try_eta_expansion(expr const & t, expr const & s, type_checker & c, delayed_justification & jst) { + bool try_eta_expansion(expr const & t, expr const & s, type_checker & c, delayed_justification & jst, constraint_seq & cs) { if (is_lambda(t) && !is_lambda(s)) { - type_checker::scope scope(c); - expr s_type = whnf(infer_type(c, s), c); + auto tcs = infer_type(c, s); + auto wcs = whnf(tcs.first, c); + expr s_type = wcs.first; if (!is_pi(s_type)) return false; - expr new_s = mk_lambda(binding_name(s_type), binding_domain(s_type), mk_app(s, Var(0)), binding_info(s_type)); - bool r = is_def_eq(t, new_s, c, jst); - if (r) scope.keep(); - return r; + expr new_s = mk_lambda(binding_name(s_type), binding_domain(s_type), mk_app(s, Var(0)), binding_info(s_type)); + auto dcs = is_def_eq(t, new_s, c, jst); + if (!dcs.first) + return false; + cs = cs + dcs.second + wcs.second + tcs.second; + return true; + } else { + return false; + } + } + + bool is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & jst, constraint_seq & cs) { + auto bcs = is_def_eq(t, s, c, jst); + if (bcs.first) { + cs = cs + bcs.second; + return true; } else { return false; } } /** Return true iff t is definitionally equal to s. */ - virtual bool is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & jst) { + virtual pair is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & jst) { check_system("is_definitionally_equal"); - lbool r = quick_is_def_eq(t, s, c, jst); - if (r != l_undef) return r == l_true; + constraint_seq cs; + lbool r = quick_is_def_eq(t, s, c, jst, cs); + if (r != l_undef) return to_bcs(r == l_true, cs); // apply whnf (without using delta-reduction or normalizer extensions) expr t_n = whnf_core(t, c); expr s_n = whnf_core(s, c); if (!is_eqp(t_n, t) || !is_eqp(s_n, s)) { - r = quick_is_def_eq(t_n, s_n, c, jst); - if (r != l_undef) return r == l_true; + r = quick_is_def_eq(t_n, s_n, c, jst, cs); + if (r != l_undef) return to_bcs(r == l_true, cs); } // lazy delta-reduction and then normalizer extensions @@ -447,6 +494,7 @@ struct default_converter : public converter { if (has_expr_metavar(t_n) || has_expr_metavar(s_n)) { // We let the unifier deal with cases such as // (f ...) =?= (f ...) + // when t_n or s_n contains metavariables break; } else { // Optimization: @@ -454,41 +502,37 @@ struct default_converter : public converter { // If they are, then t_n and s_n must be definitionally equal, and we can // skip the delta-reduction step. // If the flag use_conv_opt() is not true, then we skip this optimization - if (!is_opaque(*d_t) && d_t->use_conv_opt()) { - type_checker::scope scope(c); - if (is_def_eq_args(t_n, s_n, c, jst)) { - scope.keep(); - return true; - } - } + if (!is_opaque(*d_t) && d_t->use_conv_opt() && + is_def_eq_args(t_n, s_n, c, jst, cs)) + return to_bcs(true, cs); } } t_n = whnf_core(unfold_names(t_n, d_t->get_weight() - 1), c); s_n = whnf_core(unfold_names(s_n, d_s->get_weight() - 1), c); } - r = quick_is_def_eq(t_n, s_n, c, jst); - if (r != l_undef) return r == l_true; + r = quick_is_def_eq(t_n, s_n, c, jst, cs); + if (r != l_undef) return to_bcs(r == l_true, cs); } // try normalizer extensions - auto new_t_n = norm_ext(t_n, c); - auto new_s_n = norm_ext(s_n, c); + auto new_t_n = d_norm_ext(t_n, c, cs); + auto new_s_n = d_norm_ext(s_n, c, cs); if (!new_t_n && !new_s_n) break; // t_n and s_n are in weak head normal form if (new_t_n) t_n = whnf_core(*new_t_n, c); if (new_s_n) s_n = whnf_core(*new_s_n, c); - r = quick_is_def_eq(t_n, s_n, c, jst); - if (r != l_undef) return r == l_true; + r = quick_is_def_eq(t_n, s_n, c, jst, cs); + if (r != l_undef) return to_bcs(r == l_true, cs); } if (is_constant(t_n) && is_constant(s_n) && const_name(t_n) == const_name(s_n) && - is_def_eq(const_levels(t_n), const_levels(s_n), c, jst)) - return true; + is_def_eq(const_levels(t_n), const_levels(s_n), c, jst, cs)) + return to_bcs(true, cs); if (is_local(t_n) && is_local(s_n) && mlocal_name(t_n) == mlocal_name(s_n) && - is_def_eq(mlocal_type(t_n), mlocal_type(s_n), c, jst)) - return true; + is_def_eq(mlocal_type(t_n), mlocal_type(s_n), c, jst, cs)) + return to_bcs(true, cs); optional d_t, d_s; bool delay_check = false; @@ -503,66 +547,69 @@ struct default_converter : public converter { // At this point, t_n and s_n are in weak head normal form (modulo meta-variables and proof irrelevance) if (!delay_check && is_app(t_n) && is_app(s_n)) { - type_checker::scope scope(c); buffer t_args; buffer s_args; expr t_fn = get_app_args(t_n, t_args); expr s_fn = get_app_args(s_n, s_args); - if (is_def_eq(t_fn, s_fn, c, jst) && t_args.size() == s_args.size()) { + constraint_seq cs_prime = cs; + if (is_def_eq(t_fn, s_fn, c, jst, cs_prime) && t_args.size() == s_args.size()) { unsigned i = 0; for (; i < t_args.size(); i++) { - if (!is_def_eq(t_args[i], s_args[i], c, jst)) + if (!is_def_eq(t_args[i], s_args[i], c, jst, cs_prime)) break; } if (i == t_args.size()) { - scope.keep(); - return true; + return to_bcs(true, cs_prime); } } } - if (try_eta_expansion(t_n, s_n, c, jst) || - try_eta_expansion(s_n, t_n, c, jst)) - return true; + if (try_eta_expansion(t_n, s_n, c, jst, cs) || + try_eta_expansion(s_n, t_n, c, jst, cs)) + return to_bcs(true, cs); if (m_env.prop_proof_irrel()) { // Proof irrelevance support for Prop (aka Type.{0}) - type_checker::scope scope(c); - expr t_type = infer_type(c, t); - if (is_prop(t_type, c) && is_def_eq(t_type, infer_type(c, s), c, jst)) { - scope.keep(); - return true; + auto tcs = infer_type(c, t); + expr t_type = tcs.first; + auto pcs = is_prop(t_type, c); + if (pcs.first) { + auto scs = infer_type(c, s); + auto dcs = is_def_eq(t_type, scs.first, c, jst); + if (dcs.first) + return to_bcs(true, dcs.second + scs.second + pcs.second + tcs.second); } } list const & cls_proof_irrel = m_env.cls_proof_irrel(); if (!is_nil(cls_proof_irrel)) { // Proof irrelevance support for classes - type_checker::scope scope(c); - expr t_type = whnf(infer_type(c, t), c); - if (std::any_of(cls_proof_irrel.begin(), cls_proof_irrel.end(), - [&](name const & cls_name) { return is_app_of(t_type, cls_name); }) && - is_def_eq(t_type, infer_type(c, s), c, jst)) { - scope.keep(); - return true; + auto tcs = infer_type(c, t); + auto wcs = whnf(tcs.first, c); + expr t_type = wcs.first; + if (std::any_of(cls_proof_irrel.begin(), cls_proof_irrel.end(), [&](name const & cls_name) { return is_app_of(t_type, cls_name); })) { + auto ccs = infer_type(c, s); + auto cs_prime = tcs.second + wcs.second + ccs.second; + if (is_def_eq(t_type, ccs.first, c, jst, cs_prime)) + return to_bcs(true, cs_prime); } } - if (may_reduce_later(t_n, c) || may_reduce_later(s_n, c)) { - add_cnstr(c, mk_eq_cnstr(t_n, s_n, jst.get())); - return true; + if (may_reduce_later(t_n, c) || may_reduce_later(s_n, c) || delay_check) { + cs = cs + constraint_seq(mk_eq_cnstr(t_n, s_n, jst.get())); + return to_bcs(true, cs); } - if (delay_check) { - add_cnstr(c, mk_eq_cnstr(t_n, s_n, jst.get())); - return true; - } - - return false; + return to_bcs(false); } - bool is_prop(expr const & e, type_checker & c) { - return whnf(infer_type(c, e), c) == Prop; + pair is_prop(expr const & e, type_checker & c) { + auto tcs = infer_type(c, e); + auto wcs = whnf(tcs.first, c); + if (wcs.first == Prop) + return to_bcs(true, wcs.second + tcs.second); + else + return to_bcs(false); } virtual optional get_module_idx() const { diff --git a/src/kernel/converter.h b/src/kernel/converter.h index 1ebd8b42a0..5cbea013a0 100644 --- a/src/kernel/converter.h +++ b/src/kernel/converter.h @@ -13,15 +13,14 @@ class type_checker; class converter { protected: name mk_fresh_name(type_checker & tc); - expr infer_type(type_checker & tc, expr const & e); - void add_cnstr(type_checker & tc, constraint const & c); + pair infer_type(type_checker & tc, expr const & e); extension_context & get_extension(type_checker & tc); public: virtual ~converter() {} - virtual expr whnf(expr const & e, type_checker & c) = 0; - virtual bool is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & j) = 0; + virtual pair whnf(expr const & e, type_checker & c) = 0; + virtual pair is_def_eq(expr const & t, expr const & s, type_checker & c, delayed_justification & j) = 0; virtual optional get_module_idx() const = 0; - bool is_def_eq(expr const & t, expr const & s, type_checker & c); + pair is_def_eq(expr const & t, expr const & s, type_checker & c); }; std::unique_ptr mk_dummy_converter(); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 454ceca973..631d9eeb1d 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -358,7 +358,9 @@ public: macro_definition & operator=(macro_definition && s); name get_name() const { return m_ptr->get_name(); } - expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const { return m_ptr->get_type(m, arg_types, ctx); } + expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const { + return m_ptr->get_type(m, arg_types, ctx); + } optional expand(expr const & m, extension_context & ctx) const { return m_ptr->expand(m, ctx); } optional expand1(expr const & m, extension_context & ctx) const { return m_ptr->expand1(m, ctx); } unsigned trust_level() const { return m_ptr->trust_level(); } diff --git a/src/kernel/extension_context.h b/src/kernel/extension_context.h index 0bbfa60ff0..cfbce95b1f 100644 --- a/src/kernel/extension_context.h +++ b/src/kernel/extension_context.h @@ -6,10 +6,9 @@ Author: Leonardo de Moura */ #pragma once #include "util/name.h" +#include "kernel/constraint.h" namespace lean { -class expr; -class constraint; class environment; class delayed_justification; @@ -26,10 +25,9 @@ class extension_context { public: virtual ~extension_context() {} virtual environment const & env() const = 0; - virtual expr whnf(expr const & e) = 0; - virtual bool is_def_eq(expr const & e1, expr const & e2, delayed_justification & j) = 0; - virtual expr infer_type(expr const & e) = 0; + virtual pair whnf(expr const & e) = 0; + virtual pair is_def_eq(expr const & e1, expr const & e2, delayed_justification & j) = 0; + virtual pair infer_type(expr const & e) = 0; virtual name mk_fresh_name() = 0; - virtual void add_cnstr(constraint const & c) = 0; }; } diff --git a/src/kernel/inductive/inductive.cpp b/src/kernel/inductive/inductive.cpp index 8230a7a65d..21a66f43bd 100644 --- a/src/kernel/inductive/inductive.cpp +++ b/src/kernel/inductive/inductive.cpp @@ -215,6 +215,9 @@ struct add_inductive_fn { } type_checker & tc() { return *(m_tc.get()); } + bool is_def_eq(expr const & t, expr const & s) { return tc().is_def_eq(t, s).first; } + expr whnf(expr const & e) { return tc().whnf(e).first; } + expr ensure_type(expr const & e) { return tc().ensure_type(e).first; } /** \brief Return a fresh name. */ name mk_fresh_name() { return m_ngen.next(); } @@ -246,7 +249,7 @@ struct add_inductive_fn { m_param_consts.push_back(l); t = instantiate(binding_body(t), l); } else { - if (!tc().is_def_eq(binding_domain(t), get_param_type(i))) + if (!is_def_eq(binding_domain(t), get_param_type(i))) throw kernel_exception(m_env, "parameters of all inductive datatypes must match"); t = instantiate(binding_body(t), m_param_consts[i]); } @@ -258,7 +261,7 @@ struct add_inductive_fn { } if (i != m_num_params) throw kernel_exception(m_env, "number of parameters mismatch in inductive datatype declaration"); - t = tc().ensure_sort(t); + t = tc().ensure_sort(t).first; if (m_env.impredicative()) { // if the environment is impredicative, then the resultant universe is 0 (Prop), // or is never zero (under any parameter assignment). @@ -301,7 +304,7 @@ struct add_inductive_fn { bool is_valid_it_app(expr const & t, unsigned d_idx) { buffer args; expr I = get_app_args(t, args); - if (!tc().is_def_eq(I, m_it_consts[d_idx]) || args.size() != m_it_num_args[d_idx]) + if (!is_def_eq(I, m_it_consts[d_idx]) || args.size() != m_it_num_args[d_idx]) return false; for (unsigned i = 0; i < m_num_params; i++) { if (m_param_consts[i] != args[i]) @@ -336,15 +339,15 @@ struct add_inductive_fn { Return none otherwise. */ optional is_rec_argument(expr t) { - t = tc().whnf(t); + t = whnf(t); while (is_pi(t)) - t = tc().whnf(instantiate(binding_body(t), mk_local_for(t))); + t = whnf(instantiate(binding_body(t), mk_local_for(t))); return is_valid_it_app(t); } /** \brief Check if \c t contains only positive occurrences of the inductive datatypes being declared. */ void check_positivity(expr t, name const & intro_name, int arg_idx) { - t = tc().whnf(t); + t = whnf(t); if (!has_it_occ(t)) { // nonrecursive argument } else if (is_pi(t)) { @@ -373,12 +376,12 @@ struct add_inductive_fn { bool found_rec = false; while (is_pi(t)) { if (i < m_num_params) { - if (!tc().is_def_eq(binding_domain(t), get_param_type(i))) + if (!is_def_eq(binding_domain(t), get_param_type(i))) throw kernel_exception(m_env, sstream() << "arg #" << (i + 1) << " of '" << n << "' " << "does not match inductive datatypes parameters'"); t = instantiate(binding_body(t), m_param_consts[i]); } else { - expr s = tc().ensure_type(binding_domain(t)); + expr s = ensure_type(binding_domain(t)); // the sort is ok IF // 1- its level is <= inductive datatype level, OR // 2- m_env is impredicative and inductive datatype is at level 0 @@ -457,7 +460,7 @@ struct add_inductive_fn { unsigned i = 0; while (is_pi(t)) { if (i >= m_num_params) { - expr s = tc().ensure_type(binding_domain(t)); + expr s = ensure_type(binding_domain(t)); if (!is_zero(sort_level(s))) return true; } @@ -569,12 +572,12 @@ struct add_inductive_fn { // populate v using u for (unsigned i = 0; i < u.size(); i++) { expr u_i = u[i]; - expr u_i_ty = tc().whnf(mlocal_type(u_i)); + expr u_i_ty = whnf(mlocal_type(u_i)); buffer xs; while (is_pi(u_i_ty)) { expr x = mk_local_for(u_i_ty); xs.push_back(x); - u_i_ty = tc().whnf(instantiate(binding_body(u_i_ty), x)); + u_i_ty = whnf(instantiate(binding_body(u_i_ty), x)); } buffer it_indices; unsigned it_idx = get_I_indices(u_i_ty, it_indices); @@ -707,12 +710,12 @@ struct add_inductive_fn { if (m_dep_elim) { for (unsigned i = 0; i < u.size(); i++) { expr u_i = u[i]; - expr u_i_ty = tc().whnf(mlocal_type(u_i)); + expr u_i_ty = whnf(mlocal_type(u_i)); buffer xs; while (is_pi(u_i_ty)) { expr x = mk_local_for(u_i_ty); xs.push_back(x); - u_i_ty = tc().whnf(instantiate(binding_body(u_i_ty), x)); + u_i_ty = whnf(instantiate(binding_body(u_i_ty), x)); } buffer it_indices; unsigned it_idx = get_I_indices(u_i_ty, it_indices); @@ -758,46 +761,50 @@ bool inductive_normalizer_extension::supports(name const & feature) const { return feature == g_inductive_extension; } -optional inductive_normalizer_extension::operator()(expr const & e, extension_context & ctx) const { +optional> inductive_normalizer_extension::operator()(expr const & e, extension_context & ctx) const { // Reduce terms \c e of the form // elim_k A C e p[A,b] (intro_k_i A b u) inductive_env_ext const & ext = get_extension(ctx.env()); expr const & elim_fn = get_app_fn(e); if (!is_constant(elim_fn)) - return none_expr(); + return none_ecs(); auto it1 = ext.m_elim_info.find(const_name(elim_fn)); if (!it1) - return none_expr(); // it is not an eliminator + return none_ecs(); // it is not an eliminator buffer elim_args; get_app_args(e, elim_args); unsigned major_idx = it1->m_num_ACe + it1->m_num_indices; if (elim_args.size() < major_idx + 1) - return none_expr(); // major premise is missing - expr intro_app = ctx.whnf(elim_args[major_idx]); + return none_ecs(); // major premise is missing + auto intro_app_cs = ctx.whnf(elim_args[major_idx]); + expr intro_app = intro_app_cs.first; + constraint_seq cs = intro_app_cs.second; expr const & intro_fn = get_app_fn(intro_app); // Last argument must be a constant and an application of a constant. if (!is_constant(intro_fn)) - return none_expr(); + return none_ecs(); // Check if intro_fn is an introduction rule matching elim_fn auto it2 = ext.m_comp_rules.find(const_name(intro_fn)); if (!it2 || it2->m_elim_name != const_name(elim_fn)) - return none_expr(); + return none_ecs(); buffer intro_args; get_app_args(intro_app, intro_args); // Check intro num_args if (intro_args.size() != it1->m_num_params + it2->m_num_bu) - return none_expr(); + return none_ecs(); if (it1->m_num_params > 0) { // Global parameters of elim and intro be definitionally equal simple_delayed_justification jst([=]() { return mk_justification("elim/intro global parameters must match", some_expr(e)); }); for (unsigned i = 0; i < it1->m_num_params; i++) { - if (!ctx.is_def_eq(elim_args[i], intro_args[i], jst)) - return none_expr(); + auto dcs = ctx.is_def_eq(elim_args[i], intro_args[i], jst); + if (!dcs.first) + return none_ecs(); + cs = cs + dcs.second; } } // Number of universe levels must match. if (length(const_levels(elim_fn)) != length(it1->m_level_names)) - return none_expr(); + return none_ecs(); buffer ACebu; for (unsigned i = 0; i < it1->m_num_ACe; i++) ACebu.push_back(elim_args[i]); @@ -810,7 +817,7 @@ optional inductive_normalizer_extension::operator()(expr const & e, extens unsigned num_args = elim_args.size() - major_idx - 1; r = mk_app(r, num_args, elim_args.data() + major_idx + 1); } - return some_expr(r); + return some_ecs(r, cs); } template @@ -827,7 +834,7 @@ bool is_elim_meta_app_core(Ctx & ctx, expr const & e) { unsigned major_idx = it1->m_num_ACe + it1->m_num_indices; if (elim_args.size() < major_idx + 1) return false; - expr intro_app = ctx.whnf(elim_args[major_idx]); + expr intro_app = ctx.whnf(elim_args[major_idx]).first; return has_expr_metavar_strict(intro_app); } diff --git a/src/kernel/inductive/inductive.h b/src/kernel/inductive/inductive.h index a33760a67b..4a17cc7d3a 100644 --- a/src/kernel/inductive/inductive.h +++ b/src/kernel/inductive/inductive.h @@ -16,7 +16,7 @@ namespace inductive { /** \brief Normalizer extension for applying inductive datatype computational rules. */ class inductive_normalizer_extension : public normalizer_extension { public: - virtual optional operator()(expr const & e, extension_context & ctx) const; + virtual optional> operator()(expr const & e, extension_context & ctx) const; virtual bool may_reduce_later(expr const & e, extension_context & ctx) const; virtual bool supports(name const & feature) const; }; diff --git a/src/kernel/normalizer_extension.cpp b/src/kernel/normalizer_extension.cpp index e460a71e95..1532fc927b 100644 --- a/src/kernel/normalizer_extension.cpp +++ b/src/kernel/normalizer_extension.cpp @@ -9,7 +9,9 @@ Author: Leonardo de Moura namespace lean { class id_normalizer_extension : public normalizer_extension { public: - virtual optional operator()(expr const &, extension_context &) const { return none_expr(); } + virtual optional> operator()(expr const &, extension_context &) const { + return optional>(); + } virtual bool may_reduce_later(expr const &, extension_context &) const { return false; } virtual bool supports(name const &) const { return false; } }; @@ -25,7 +27,7 @@ public: comp_normalizer_extension(std::unique_ptr && ext1, std::unique_ptr && ext2): m_ext1(std::move(ext1)), m_ext2(std::move(ext2)) {} - virtual optional operator()(expr const & e, extension_context & ctx) const { + virtual optional> operator()(expr const & e, extension_context & ctx) const { if (auto r = (*m_ext1)(e, ctx)) return r; else diff --git a/src/kernel/normalizer_extension.h b/src/kernel/normalizer_extension.h index f50c3b9c54..4f5b3bfd05 100644 --- a/src/kernel/normalizer_extension.h +++ b/src/kernel/normalizer_extension.h @@ -17,7 +17,7 @@ namespace lean { class normalizer_extension { public: virtual ~normalizer_extension() {} - virtual optional operator()(expr const & e, extension_context & ctx) const = 0; + virtual optional> operator()(expr const & e, extension_context & ctx) const = 0; /** \brief Return true if the extension may reduce \c e after metavariables are instantiated. */ virtual bool may_reduce_later(expr const & e, extension_context & ctx) const = 0; /** \brief Return true iff the extension supports a feature with the given name, @@ -25,6 +25,11 @@ public: virtual bool supports(name const & feature) const = 0; }; +inline optional> none_ecs() { return optional>(); } +inline optional> some_ecs(expr const & e, constraint_seq const & cs) { + return optional>(e, cs); +} + /** \brief Create the do-nothing normalizer extension */ std::unique_ptr mk_id_normalizer_extension(); diff --git a/src/kernel/record/record.cpp b/src/kernel/record/record.cpp index 2c24916e5f..4343929dad 100644 --- a/src/kernel/record/record.cpp +++ b/src/kernel/record/record.cpp @@ -52,8 +52,8 @@ environment add_record(environment const & env, level_param_names const & level_ return add_record_fn(env, level_params, rec_name, rec_type, intro_name, intro_type)(); } -optional record_normalizer_extension::operator()(expr const &, extension_context &) const { - return optional(); +optional> record_normalizer_extension::operator()(expr const &, extension_context &) const { + return none_ecs(); } bool record_normalizer_extension::may_reduce_later(expr const &, extension_context &) const { diff --git a/src/kernel/record/record.h b/src/kernel/record/record.h index 4c5e9f9722..46e4d6c729 100644 --- a/src/kernel/record/record.h +++ b/src/kernel/record/record.h @@ -27,7 +27,7 @@ environment add_record(environment const & env, /** \brief Normalizer extension for applying record computational rules. */ class record_normalizer_extension : public normalizer_extension { public: - virtual optional operator()(expr const & e, extension_context & ctx) const; + virtual optional> operator()(expr const & e, extension_context & ctx) const; virtual bool may_reduce_later(expr const & e, extension_context & ctx) const; virtual bool supports(name const & feature) const; }; diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 881b9beb90..c39c41cc20 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -71,21 +71,6 @@ expr mk_pi_for(name_generator & ngen, expr const & meta) { return mk_pi(ngen.next(), A, B); } -type_checker::scope::scope(type_checker & tc): - m_tc(tc), m_keep(false) { - m_tc.push(); -} -type_checker::scope::~scope() { - if (m_keep) - m_tc.keep(); - else - m_tc.pop(); -} - -void type_checker::scope::keep() { - m_keep = true; -} - optional type_checker::expand_macro(expr const & m) { lean_assert(is_macro(m)); return macro_def(m).expand(m, m_tc_ctx); @@ -100,25 +85,10 @@ pair type_checker::open_binding_body(expr const & e) { return mk_pair(instantiate(binding_body(e), local), local); } -/** \brief Add given constraint using m_add_cnstr_fn. */ -void type_checker::add_cnstr(constraint const & c) { - m_cs.push_back(c); -} - constraint type_checker::mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j) { return ::lean::mk_eq_cnstr(lhs, rhs, j, static_cast(m_conv->get_module_idx())); } -optional type_checker::next_cnstr() { - if (m_cs_qhead < m_cs.size()) { - constraint c = m_cs[m_cs_qhead]; - m_cs_qhead++; - return optional(c); - } else { - return optional(); - } -} - /** \brief Make sure \c e "is" a sort, and return the corresponding sort. If \c e is not a sort, then the whnf procedure is invoked. Then, there are @@ -128,39 +98,37 @@ optional type_checker::next_cnstr() { \remark \c s is used to extract position (line number information) when an error message is produced */ -expr type_checker::ensure_sort_core(expr e, expr const & s) { +pair type_checker::ensure_sort_core(expr e, expr const & s) { if (is_sort(e)) - return e; - e = whnf(e); - if (is_sort(e)) { - return e; - } else if (is_meta(e)) { + return to_ecs(e); + auto ecs = whnf(e); + if (is_sort(ecs.first)) { + return ecs; + } else if (is_meta(ecs.first)) { expr r = mk_sort(mk_meta_univ(m_gen.next())); justification j = mk_justification(s, [=](formatter const & fmt, substitution const & subst) { return pp_type_expected(fmt, substitution(subst).instantiate(s)); }); - add_cnstr(mk_eq_cnstr(e, r, j)); - return r; + return to_ecs(r, mk_eq_cnstr(ecs.first, r, j), ecs.second); } else { throw_kernel_exception(m_env, s, [=](formatter const & fmt) { return pp_type_expected(fmt, s); }); } } /** \brief Similar to \c ensure_sort, but makes sure \c e "is" a Pi. */ -expr type_checker::ensure_pi_core(expr e, expr const & s) { +pair type_checker::ensure_pi_core(expr e, expr const & s) { if (is_pi(e)) - return e; - e = whnf(e); - if (is_pi(e)) { - return e; - } else if (is_meta(e)) { - expr r = mk_pi_for(m_gen, e); + return to_ecs(e); + auto ecs = whnf(e); + if (is_pi(ecs.first)) { + return ecs; + } else if (is_meta(ecs.first)) { + expr r = mk_pi_for(m_gen, ecs.first); justification j = mk_justification(s, [=](formatter const & fmt, substitution const & subst) { return pp_function_expected(fmt, substitution(subst).instantiate(s)); }); - add_cnstr(mk_eq_cnstr(e, r, j)); - return r; + return to_ecs(r, mk_eq_cnstr(ecs.first, r, j), ecs.second); } else { throw_kernel_exception(m_env, s, [=](formatter const & fmt) { return pp_function_expected(fmt, s); }); } @@ -217,26 +185,34 @@ expr type_checker::infer_constant(expr const & e, bool infer_only) { return instantiate_type_univ_params(d, ls); } -expr type_checker::infer_macro(expr const & e, bool infer_only) { +pair type_checker::infer_macro(expr const & e, bool infer_only) { buffer arg_types; - for (unsigned i = 0; i < macro_num_args(e); i++) - arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only)); - expr r = macro_def(e).get_type(e, arg_types.data(), m_tc_ctx); - if (!infer_only && macro_def(e).trust_level() >= m_env.trust_lvl()) { + constraint_seq cs; + for (unsigned i = 0; i < macro_num_args(e); i++) { + arg_types.push_back(infer_type_core(macro_arg(e, i), infer_only, cs)); + } + auto def = macro_def(e); + expr t = def.get_type(e, arg_types.data(), m_tc_ctx); + if (!infer_only && def.trust_level() >= m_env.trust_lvl()) { optional m = expand_macro(e); if (!m) throw_kernel_exception(m_env, "failed to expand macro", e); - expr t = infer_type_core(*m, infer_only); + pair tmcs = infer_type_core(*m, infer_only); + cs = cs + tmcs.second; simple_delayed_justification jst([=]() { return mk_macro_jst(e); }); - if (!is_def_eq(r, t, jst)) + pair bcs = is_def_eq(t, tmcs.first, jst); + if (!bcs.first) throw_kernel_exception(m_env, g_macro_error_msg, e); + return mk_pair(t, bcs.second + cs); + } else { + return mk_pair(t, cs); } - return r; } -expr type_checker::infer_lambda(expr const & _e, bool infer_only) { +pair type_checker::infer_lambda(expr const & _e, bool infer_only) { buffer es, ds, ls; expr e = _e; + constraint_seq cs; while (is_lambda(e)) { es.push_back(e); ds.push_back(binding_domain(e)); @@ -244,82 +220,107 @@ expr type_checker::infer_lambda(expr const & _e, bool infer_only) { expr l = mk_local(m_gen.next(), binding_name(e), d, binding_info(e)); ls.push_back(l); if (!infer_only) { - expr t = infer_type_core(d, infer_only); - ensure_sort_core(t, d); + pair dtcs = infer_type_core(d, infer_only); + pair scs = ensure_sort_core(dtcs.first, d); + cs = cs + scs.second + dtcs.second; } e = binding_body(e); } - expr r = infer_type_core(instantiate_rev(e, ls.size(), ls.data()), infer_only); - r = abstract_locals(r, ls.size(), ls.data()); + pair rcs = infer_type_core(instantiate_rev(e, ls.size(), ls.data()), infer_only); + cs = cs + rcs.second; + expr r = abstract_locals(rcs.first, ls.size(), ls.data()); unsigned i = es.size(); while (i > 0) { --i; r = mk_pi(binding_name(es[i]), ds[i], r, binding_info(es[i])); } - return r; + return mk_pair(r, cs); } -expr type_checker::infer_pi(expr const & _e, bool infer_only) { +pair type_checker::infer_pi(expr const & _e, bool infer_only) { buffer ls; buffer us; expr e = _e; + constraint_seq cs; while (is_pi(e)) { - expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); - expr t1 = ensure_sort_core(infer_type_core(d, infer_only), d); + expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); + pair dtcs = infer_type_core(d, infer_only); + pair scs = ensure_sort_core(dtcs.first, d); + cs = cs + scs.second + dtcs.second; + expr t1 = scs.first; us.push_back(sort_level(t1)); expr l = mk_local(m_gen.next(), binding_name(e), d, binding_info(e)); ls.push_back(l); e = binding_body(e); } e = instantiate_rev(e, ls.size(), ls.data()); - level r = sort_level(ensure_sort_core(infer_type_core(e, infer_only), e)); + pair etcs = infer_type_core(e, infer_only); + pair scs = ensure_sort_core(etcs.first, e); + cs = cs + scs.second + etcs.second; + level r = sort_level(scs.first); unsigned i = ls.size(); while (i > 0) { --i; r = m_env.impredicative() ? mk_imax(us[i], r) : mk_max(us[i], r); } - return mk_sort(r); + return mk_pair(mk_sort(r), cs); } -expr type_checker::infer_app(expr const & e, bool infer_only) { +pair type_checker::infer_app(expr const & e, bool infer_only) { if (!infer_only) { - expr f_type = ensure_pi_core(infer_type_core(app_fn(e), infer_only), e); - expr a_type = infer_type_core(app_arg(e), infer_only); + pair ftcs = infer_type_core(app_fn(e), infer_only); + pair pics = ensure_pi_core(ftcs.first, e); + expr f_type = pics.first; + pair acs = infer_type_core(app_arg(e), infer_only); + expr a_type = acs.first; app_delayed_justification jst(e, app_arg(e), f_type, a_type); - if (!is_def_eq(a_type, binding_domain(f_type), jst)) { + expr d_type = binding_domain(pics.first); + pair dcs = is_def_eq(a_type, d_type, jst); + if (!dcs.first) { throw_kernel_exception(m_env, e, [=](formatter const & fmt) { - return pp_app_type_mismatch(fmt, e, app_arg(e), binding_domain(f_type), a_type); + return pp_app_type_mismatch(fmt, e, app_arg(e), d_type, a_type); }); } - return instantiate(binding_body(f_type), app_arg(e)); + return mk_pair(instantiate(binding_body(pics.first), app_arg(e)), + pics.second + ftcs.second + dcs.second + acs.second); } else { buffer args; expr const & f = get_app_args(e, args); - expr f_type = infer_type_core(f, true); - unsigned j = 0; - unsigned nargs = args.size(); + pair ftcs = infer_type_core(f, true); + expr f_type = ftcs.first; + constraint_seq cs = ftcs.second; + unsigned j = 0; + unsigned nargs = args.size(); for (unsigned i = 0; i < nargs; i++) { if (is_pi(f_type)) { f_type = binding_body(f_type); } else { f_type = instantiate_rev(f_type, i-j, args.data()+j); - f_type = ensure_pi_core(f_type, e); + pair pics = ensure_pi_core(f_type, e); + f_type = pics.first; + cs = pics.second + cs; f_type = binding_body(f_type); j = i; } } expr r = instantiate_rev(f_type, nargs-j, args.data()+j); - return r; + return mk_pair(r, cs); } } +expr type_checker::infer_type_core(expr const & e, bool infer_only, constraint_seq & cs) { + auto r = infer_type_core(e, infer_only); + cs = cs + r.second; + return r.first; +} + /** \brief Return type of expression \c e, if \c infer_only is false, then it also check whether \c e is type correct or not. \pre closed(e) */ -expr type_checker::infer_type_core(expr const & e, bool infer_only) { +pair type_checker::infer_type_core(expr const & e, bool infer_only) { if (is_var(e)) throw_kernel_exception(m_env, "type checker does not support free variables, replace them with local constants before invoking it", e); @@ -332,20 +333,20 @@ expr type_checker::infer_type_core(expr const & e, bool infer_only) { return it->second; } - expr r; + pair r; switch (e.kind()) { - case expr_kind::Local: case expr_kind::Meta: r = mlocal_type(e); break; + case expr_kind::Local: case expr_kind::Meta: r.first = mlocal_type(e); break; case expr_kind::Var: lean_unreachable(); // LCOV_EXCL_LINE case expr_kind::Sort: if (!infer_only) check_level(sort_level(e), e); - r = mk_sort(mk_succ(sort_level(e))); + r.first = mk_sort(mk_succ(sort_level(e))); break; - case expr_kind::Constant: r = infer_constant(e, infer_only); break; - case expr_kind::Macro: r = infer_macro(e, infer_only); break; - case expr_kind::Lambda: r = infer_lambda(e, infer_only); break; - case expr_kind::Pi: r = infer_pi(e, infer_only); break; - case expr_kind::App: r = infer_app(e, infer_only); break; + case expr_kind::Constant: r.first = infer_constant(e, infer_only); break; + case expr_kind::Macro: r = infer_macro(e, infer_only); break; + case expr_kind::Lambda: r = infer_lambda(e, infer_only); break; + case expr_kind::Pi: r = infer_pi(e, infer_only); break; + case expr_kind::App: r = infer_app(e, infer_only); break; } if (m_memoize) @@ -354,151 +355,63 @@ expr type_checker::infer_type_core(expr const & e, bool infer_only) { return r; } -expr type_checker::infer_type(expr const & e) { - scope mk_scope(*this); - expr r = infer_type_core(e, true); - mk_scope.keep(); - return r; +pair type_checker::infer_type(expr const & e) { + return infer_type_core(e, true); } -void type_checker::copy_constraints(unsigned qhead, buffer & new_cnstrs) { - for (unsigned i = qhead; i < m_cs.size(); i++) - new_cnstrs.push_back(m_cs[i]); -} - -expr type_checker::infer(expr const & e, buffer & new_cnstrs) { - scope mk_scope(*this); - unsigned cs_qhead = m_cs.size(); - expr r = infer_type_core(e, true); - copy_constraints(cs_qhead, new_cnstrs); - return r; -} - -expr type_checker::check(expr const & e, level_param_names const & ps) { - scope mk_scope(*this); +pair type_checker::check(expr const & e, level_param_names const & ps) { flet updt(m_params, &ps); - expr r = infer_type_core(e, false); - mk_scope.keep(); - return r; + return infer_type_core(e, false); } -expr type_checker::check(expr const & e, buffer & new_cnstrs) { - scope mk_scope(*this); - unsigned cs_qhead = m_cs.size(); - expr r = infer_type_core(e, false); - copy_constraints(cs_qhead, new_cnstrs); - return r; +pair type_checker::ensure_sort(expr const & e, expr const & s) { + return ensure_sort_core(e, s); } -expr type_checker::ensure_sort(expr const & e, expr const & s) { - scope mk_scope(*this); - expr r = ensure_sort_core(e, s); - mk_scope.keep(); - return r; -} - -expr type_checker::ensure_pi(expr const & e, expr const & s) { - scope mk_scope(*this); - expr r = ensure_pi_core(e, s); - mk_scope.keep(); - return r; +pair type_checker::ensure_pi(expr const & e, expr const & s) { + return ensure_pi_core(e, s); } /** \brief Return true iff \c t and \c s are definitionally equal */ -bool type_checker::is_def_eq(expr const & t, expr const & s, delayed_justification & jst) { - scope mk_scope(*this); - bool r = m_conv->is_def_eq(t, s, *this, jst); - if (r) mk_scope.keep(); - return r; +pair type_checker::is_def_eq(expr const & t, expr const & s, delayed_justification & jst) { + return m_conv->is_def_eq(t, s, *this, jst); } -bool type_checker::is_def_eq(expr const & t, expr const & s) { - scope mk_scope(*this); - bool r = m_conv->is_def_eq(t, s, *this); - if (r) mk_scope.keep(); - return r; +pair type_checker::is_def_eq(expr const & t, expr const & s) { + return m_conv->is_def_eq(t, s, *this); } -bool type_checker::is_def_eq(expr const & t, expr const & s, justification const & j) { +pair type_checker::is_def_eq(expr const & t, expr const & s, justification const & j) { as_delayed_justification djst(j); return is_def_eq(t, s, djst); } -bool type_checker::is_def_eq(expr const & t, expr const & s, justification const & j, buffer & new_cnstrs) { - unsigned cs_qhead = m_cs.size(); - scope mk_scope(*this); +pair type_checker::is_def_eq_types(expr const & t, expr const & s, justification const & j) { + auto tcs1 = infer_type_core(t, true); + auto tcs2 = infer_type_core(s, true); as_delayed_justification djst(j); - if (m_conv->is_def_eq(t, s, *this, djst)) { - copy_constraints(cs_qhead, new_cnstrs); - return true; - } else { - return false; - } -} - -bool type_checker::is_def_eq_types(expr const & t, expr const & s, justification const & j, buffer & new_cnstrs) { - scope mk_scope(*this); - unsigned cs_qhead = m_cs.size(); - expr r1 = infer_type_core(t, true); - expr r2 = infer_type_core(s, true); - as_delayed_justification djst(j); - if (m_conv->is_def_eq(r1, r2, *this, djst)) { - copy_constraints(cs_qhead, new_cnstrs); - return true; - } else { - return false; - } + auto bcs = m_conv->is_def_eq(tcs1.first, tcs2.first, *this, djst); + return mk_pair(bcs.first, bcs.first ? bcs.second + tcs1.second + tcs2.second : constraint_seq()); } /** \brief Return true iff \c e is a proposition */ -bool type_checker::is_prop(expr const & e) { - scope mk_scope(*this); - bool r = whnf(infer_type(e)) == Prop; - if (r) mk_scope.keep(); - return r; +pair type_checker::is_prop(expr const & e) { + auto tcs = infer_type(e); + auto wtcs = whnf(tcs.first); + bool r = wtcs.first == Prop; + if (r) + return mk_pair(true, tcs.second + wtcs.second); + else + return mk_pair(false, constraint_seq()); } -expr type_checker::whnf(expr const & t) { +pair type_checker::whnf(expr const & t) { return m_conv->whnf(t, *this); } -expr type_checker::whnf(expr const & t, buffer & new_cnstrs) { - scope mk_scope(*this); - unsigned cs_qhead = m_cs.size(); - expr r = m_conv->whnf(t, *this); - copy_constraints(cs_qhead, new_cnstrs); - return r; -} - -void type_checker::push() { - m_infer_type_cache[0].push(); - m_infer_type_cache[1].push(); - m_trail.emplace_back(m_cs.size(), m_cs_qhead); -} - -void type_checker::pop() { - m_infer_type_cache[0].pop(); - m_infer_type_cache[1].pop(); - m_cs.shrink(m_trail.back().first); - m_cs_qhead = m_trail.back().second; - m_trail.pop_back(); -} - -void type_checker::keep() { - m_infer_type_cache[0].keep(); - m_infer_type_cache[1].keep(); - m_trail.pop_back(); -} - -unsigned type_checker::num_scopes() const { - lean_assert(m_infer_type_cache[0].num_scopes() == m_infer_type_cache[1].num_scopes()); - return m_infer_type_cache[0].num_scopes(); -} - type_checker::type_checker(environment const & env, name_generator const & g, std::unique_ptr && conv, bool memoize): m_env(env), m_gen(g), m_conv(std::move(conv)), m_tc_ctx(*this), m_memoize(memoize), m_params(nullptr) { - m_cs_qhead = 0; } static name g_tmp_prefix = name::mk_internal_unique_name(); @@ -547,15 +460,15 @@ certified_declaration check(environment const & env, declaration const & d, name check_name(env, d.get_name()); check_duplicated_params(env, d); type_checker checker1(env, g, mk_default_converter(env, optional(), memoize, extra_opaque)); - expr sort = checker1.check(d.get_type(), d.get_univ_params()); + expr sort = checker1.check(d.get_type(), d.get_univ_params()).first; checker1.ensure_sort(sort, d.get_type()); if (d.is_definition()) { optional midx; if (d.is_opaque()) midx = optional(d.get_module_idx()); type_checker checker2(env, g, mk_default_converter(env, midx, memoize, extra_opaque)); - expr val_type = checker2.check(d.get_value(), d.get_univ_params()); - if (!checker2.is_def_eq(val_type, d.get_type())) { + expr val_type = checker2.check(d.get_value(), d.get_univ_params()).first; + if (!checker2.is_def_eq(val_type, d.get_type()).first) { throw_kernel_exception(env, d.get_value(), [=](formatter const & fmt) { return pp_def_type_mismatch(fmt, d.get_name(), d.get_type(), val_type); }); diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index 18673aedb9..5a320e8bda 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -10,13 +10,21 @@ Author: Leonardo de Moura #include #include "util/name_generator.h" #include "util/name_set.h" -#include "util/scoped_map.h" #include "kernel/environment.h" #include "kernel/constraint.h" +#include "kernel/justification.h" #include "kernel/converter.h" +#include "kernel/expr_maps.h" namespace lean { +inline pair to_ecs(expr const & e) { return mk_pair(e, empty_cs()); } +inline pair to_ecs(expr const & e, constraint const & c, constraint_seq const & cs) { + return mk_pair(e, constraint_seq(constraint_seq(c), cs)); +} +inline pair to_ecs(expr const & e, constraint const & c) { return mk_pair(e, constraint_seq(c)); } +inline pair to_ecs(expr const & e, constraint_seq const & cs) { return mk_pair(e, cs); } + /** \brief Given \c type of the form (Pi ctx, r), return (Pi ctx, new_range) */ expr replace_range(expr const & type, expr const & new_range); @@ -56,7 +64,7 @@ expr mk_pi_for(name_generator & ngen, expr const & meta); The type checker produces constraints, and they are sent to the constraint handler. */ class type_checker { - typedef scoped_map cache; + typedef expr_bi_struct_map> cache; /** \brief Interface type_checker <-> macro & normalizer_extension */ class type_checker_context : public extension_context { @@ -64,11 +72,12 @@ class type_checker { public: type_checker_context(type_checker & tc):m_tc(tc) {} virtual environment const & env() const { return m_tc.m_env; } - virtual expr whnf(expr const & e) { return m_tc.whnf(e); } - virtual bool is_def_eq(expr const & e1, expr const & e2, delayed_justification & j) { return m_tc.is_def_eq(e1, e2, j); } - virtual expr infer_type(expr const & e) { return m_tc.infer_type(e); } + virtual pair whnf(expr const & e) { return m_tc.whnf(e); } + virtual pair is_def_eq(expr const & e1, expr const & e2, delayed_justification & j) { + return m_tc.is_def_eq(e1, e2, j); + } + virtual pair infer_type(expr const & e) { return m_tc.infer_type(e); } virtual name mk_fresh_name() { return m_tc.m_gen.next(); } - virtual void add_cnstr(constraint const & c) { m_tc.add_cnstr(c); } }; environment m_env; @@ -83,27 +92,24 @@ class type_checker { bool m_memoize; // temp flag level_param_names const * m_params; - buffer m_cs; // temporary cache of constraints - unsigned m_cs_qhead; - buffer> m_trail; friend class converter; // allow converter to access the following methods name mk_fresh_name() { return m_gen.next(); } optional expand_macro(expr const & m); pair open_binding_body(expr const & e); - void add_cnstr(constraint const & c); - expr ensure_sort_core(expr e, expr const & s); - expr ensure_pi_core(expr e, expr const & s); + pair ensure_sort_core(expr e, expr const & s); + pair ensure_pi_core(expr e, expr const & s); justification mk_macro_jst(expr const & e); void check_level(level const & l, expr const & s); expr infer_constant(expr const & e, bool infer_only); - expr infer_macro(expr const & e, bool infer_only); - expr infer_lambda(expr const & e, bool infer_only); - expr infer_pi(expr const & e, bool infer_only); - expr infer_app(expr const & e, bool infer_only); - expr infer_type_core(expr const & e, bool infer_only); - expr infer_type(expr const & e); - void copy_constraints(unsigned qhead, buffer & new_cnstrs); + pair infer_macro(expr const & e, bool infer_only); + pair infer_lambda(expr const & e, bool infer_only); + pair infer_pi(expr const & e, bool infer_only); + pair infer_app(expr const & e, bool infer_only); + pair infer_type_core(expr const & e, bool infer_only); + pair infer_type(expr const & e); + expr infer_type_core(expr const & e, bool infer_only, constraint_seq & cs); + extension_context & get_extension() { return m_tc_ctx; } constraint mk_eq_cnstr(expr const & lhs, expr const & rhs, justification const & j); public: @@ -130,71 +136,69 @@ public: type is correct. Throw an exception if a type error is found. - The result is meaningful only if the constraints sent to the - constraint handler can be solved. + The result is meaningful only if the generated constraints can be solved. */ - expr infer(expr const & t) { return infer_type(t); } - /** \brief Infer \c t type and copy constraints associated with type inference to \c new_cnstrs */ - expr infer(expr const & t, buffer & new_cnstrs); + pair infer(expr const & t) { return infer_type(t); } /** \brief Type check the given expression, and return the type of \c t. Throw an exception if a type error is found. - The result is meaningful only if the constraints sent to the - constraint handler can be solved. + The result is meaningful only if the generated constraints can be solved. */ - expr check(expr const & t, level_param_names const & ps = level_param_names()); - expr check(expr const & t, buffer & new_cnstrs); + pair check(expr const & t, level_param_names const & ps = level_param_names()); + /** \brief Return true iff t is definitionally equal to s. */ - bool is_def_eq(expr const & t, expr const & s); - bool is_def_eq(expr const & t, expr const & s, justification const & j); - bool is_def_eq(expr const & t, expr const & s, delayed_justification & jst); - /** \brief Return true iff \c t and \c s are (may be) definitionally equal (module constraints) - New constraints associated with test are store in \c new_cnstrs. - */ - bool is_def_eq(expr const & t, expr const & s, justification const & j, buffer & new_cnstrs); - /** \brief Return true iff types of \c t and \c s are (may be) definitionally equal (modulo constraints) - New constraints associated with test are store in \c new_cnstrs. - */ - bool is_def_eq_types(expr const & t, expr const & s, justification const & j, buffer & new_cnstrs); + pair is_def_eq(expr const & t, expr const & s); + pair is_def_eq(expr const & t, expr const & s, justification const & j); + pair is_def_eq(expr const & t, expr const & s, delayed_justification & jst); + /** \brief Return true iff types of \c t and \c s are (may be) definitionally equal (modulo constraints) */ + pair is_def_eq_types(expr const & t, expr const & s, justification const & j); /** \brief Return true iff t is a proposition. */ - bool is_prop(expr const & t); + pair is_prop(expr const & t); /** \brief Return the weak head normal form of \c t. */ - expr whnf(expr const & t); - /** \brief Similar to the previous method, but it also returns the new constraints created in the process. */ - expr whnf(expr const & t, buffer & new_cnstrs); + pair whnf(expr const & t); /** \brief Return a Pi if \c t is convertible to a Pi type. Throw an exception otherwise. The argument \c s is used when reporting errors */ - expr ensure_pi(expr const & t, expr const & s); - expr ensure_pi(expr const & t) { return ensure_pi(t, t); } + pair ensure_pi(expr const & t, expr const & s); + pair ensure_pi(expr const & t) { return ensure_pi(t, t); } /** \brief Mare sure type of \c e is a Pi, and return it. Throw an exception otherwise. */ - expr ensure_fun(expr const & e) { return ensure_pi(infer(e), e); } + pair ensure_fun(expr const & e) { + auto tcs = infer(e); + auto pics = ensure_pi(tcs.first, e); + return mk_pair(pics.first, pics.second + tcs.second); + } /** \brief Return a Sort if \c t is convertible to Sort. Throw an exception otherwise. The argument \c s is used when reporting errors. */ - expr ensure_sort(expr const & t, expr const & s); + pair ensure_sort(expr const & t, expr const & s); /** \brief Return a Sort if \c t is convertible to Sort. Throw an exception otherwise. */ - expr ensure_sort(expr const & t) { return ensure_sort(t, t); } + pair ensure_sort(expr const & t) { return ensure_sort(t, t); } /** \brief Mare sure type of \c e is a sort, and return it. Throw an exception otherwise. */ - expr ensure_type(expr const & e) { return ensure_sort(infer(e), e); } + pair ensure_type(expr const & e) { + auto tcs = infer(e); + auto scs = ensure_sort(tcs.first, e); + return mk_pair(scs.first, scs.second + tcs.second); + } - /** \brief Return the number of backtracking points. */ - unsigned num_scopes() const; - /** \brief Consume next constraint in the produced constraint queue */ - optional next_cnstr(); + expr whnf(expr const & e, constraint_seq & cs) { auto r = whnf(e); cs += r.second; return r.first; } + expr infer(expr const & e, constraint_seq & cs) { auto r = infer(e); cs += r.second; return r.first; } + expr ensure_pi(expr const & e, constraint_seq & cs) { auto r = ensure_pi(e); cs += r.second; return r.first; } + expr ensure_pi(expr const & e, expr const & s, constraint_seq & cs) { auto r = ensure_pi(e, s); cs += r.second; return r.first; } + expr ensure_sort(expr const & t, expr const & s, constraint_seq & cs) { auto r = ensure_sort(t, s); cs += r.second; return r.first; } - void push(); - void pop(); - void keep(); + bool is_def_eq(expr const & t, expr const & s, justification const & j, constraint_seq & cs) { + auto r = is_def_eq(t, s, j); + if (r.first) + cs = r.second + cs; + return r.first; + } - class scope { - type_checker & m_tc; - bool m_keep; - public: - scope(type_checker & tc); - ~scope(); - void keep(); - }; + bool is_def_eq_types(expr const & t, expr const & s, justification const & j, constraint_seq & cs) { + auto r = is_def_eq_types(t, s, j); + if (r.first) + cs = r.second + cs; + return r.first; + } }; typedef std::shared_ptr type_checker_ref; diff --git a/src/library/inductive_unifier_plugin.cpp b/src/library/inductive_unifier_plugin.cpp index f0411f06b6..44268964f9 100644 --- a/src/library/inductive_unifier_plugin.cpp +++ b/src/library/inductive_unifier_plugin.cpp @@ -45,7 +45,7 @@ class inductive_unifier_plugin_cell : public unifier_plugin_cell { */ lazy_list add_elim_meta_cnstrs(type_checker & tc, name_generator ngen, inductive::inductive_decl const & decl, expr const & elim, buffer & args, expr const & t, justification const & j, - buffer & tc_cnstr_buffer, bool relax) const { + constraint_seq cs, bool relax) const { lean_assert(is_constant(elim)); environment const & env = tc.env(); levels elim_lvls = const_levels(elim); @@ -55,12 +55,10 @@ class inductive_unifier_plugin_cell : public unifier_plugin_cell { lean_assert(has_expr_metavar_strict(meta)); buffer margs; expr const & m = get_app_args(meta, margs); - expr mtype = tc.infer(m, tc_cnstr_buffer); - lean_assert(!tc.next_cnstr()); - unsigned buff_sz = tc_cnstr_buffer.size(); + expr mtype = tc.infer(m, cs); buffer alts; for (auto const & intro : inductive::inductive_decl_intros(decl)) { - tc_cnstr_buffer.shrink(buff_sz); + constraint_seq cs_intro = cs; name const & intro_name = inductive::intro_rule_name(intro); declaration intro_decl = env.get(intro_name); levels intro_lvls; @@ -72,33 +70,30 @@ class inductive_unifier_plugin_cell : public unifier_plugin_cell { } expr intro_fn = mk_constant(inductive::intro_rule_name(intro), intro_lvls); expr hint = intro_fn; - expr intro_type = tc.whnf(inductive::intro_rule_type(intro), tc_cnstr_buffer); + expr intro_type = tc.whnf(inductive::intro_rule_type(intro), cs_intro); while (is_pi(intro_type)) { hint = mk_app(hint, mk_app(mk_aux_metavar_for(ngen, mtype), margs)); - intro_type = tc.whnf(binding_body(intro_type), tc_cnstr_buffer); - lean_assert(!tc.next_cnstr()); + intro_type = tc.whnf(binding_body(intro_type), cs_intro); } constraint c1 = mk_eq_cnstr(meta, hint, j, relax); args[major_idx] = hint; - lean_assert(!tc.next_cnstr()); - expr reduce_elim = tc.whnf(mk_app(elim, args), tc_cnstr_buffer); - lean_assert(!tc.next_cnstr()); + expr reduce_elim = tc.whnf(mk_app(elim, args), cs_intro); constraint c2 = mk_eq_cnstr(reduce_elim, t, j, relax); - list tc_cnstrs = to_list(tc_cnstr_buffer.begin(), tc_cnstr_buffer.end()); - alts.push_back(cons(c1, cons(c2, tc_cnstrs))); + cs_intro = constraint_seq(c1) + constraint_seq(c2) + cs_intro; + buffer cs_buffer; + cs_intro.linearize(cs_buffer); + alts.push_back(to_list(cs_buffer.begin(), cs_buffer.end())); } - lean_assert(!tc.next_cnstr()); return to_lazy(to_list(alts.begin(), alts.end())); } lazy_list process_elim_meta_core(type_checker & tc, name_generator const & ngen, expr const & lhs, expr const & rhs, justification const & j, bool relax) const { lean_assert(inductive::is_elim_meta_app(tc, lhs)); - buffer tc_cnstr_buffer; - lean_assert(!tc.next_cnstr()); - if (!tc.is_def_eq_types(lhs, rhs, j, tc_cnstr_buffer)) + auto dcs = tc.is_def_eq_types(lhs, rhs, j); + if (!dcs.first) return lazy_list(); - lean_assert(!tc.next_cnstr()); + constraint_seq cs = dcs.second; buffer args; expr const & elim = get_app_args(lhs, args); environment const & env = tc.env(); @@ -106,7 +101,7 @@ class inductive_unifier_plugin_cell : public unifier_plugin_cell { auto decls = *inductive::is_inductive_decl(env, it_name); for (auto const & d : std::get<2>(decls)) { if (inductive::inductive_decl_name(d) == it_name) - return add_elim_meta_cnstrs(tc, ngen, d, elim, args, rhs, j, tc_cnstr_buffer, relax); + return add_elim_meta_cnstrs(tc, ngen, d, elim, args, rhs, j, cs, relax); } lean_unreachable(); // LCOV_EXCL_LINE } diff --git a/src/library/kernel_bindings.cpp b/src/library/kernel_bindings.cpp index 8186a92fad..8414ad20d1 100644 --- a/src/library/kernel_bindings.cpp +++ b/src/library/kernel_bindings.cpp @@ -951,6 +951,19 @@ static void open_formatter(lua_State * L) { SET_GLOBAL_FUN(set_formatter_factory, "set_formatter_factory"); } +// Helper function for push pair expr, constraint_seq +int push_ecs(lua_State * L, pair const & p) { + push_expr(L, p.first); + push_constraint_seq(L, p.second); + return 2; +} + +int push_bcs(lua_State * L, pair const & p) { + push_boolean(L, p.first); + push_constraint_seq(L, p.second); + return 2; +} + // Environment_id DECL_UDATA(environment_id) static int environment_id_descendant(lua_State * L) { return push_boolean(L, to_environment_id(L, 1).is_descendant(to_environment_id(L, 2))); } @@ -1032,10 +1045,10 @@ static int mk_environment(lua_State * L) { return push_environment(L, mk_en static int mk_hott_environment(lua_State * L) { return push_environment(L, mk_hott_environment(get_trust_lvl(L, 1))); } static int environment_forget(lua_State * L) { return push_environment(L, to_environment(L, 1).forget()); } -static int environment_whnf(lua_State * L) { return push_expr(L, type_checker(to_environment(L, 1)).whnf(to_expr(L, 2))); } +static int environment_whnf(lua_State * L) { return push_ecs(L, type_checker(to_environment(L, 1)).whnf(to_expr(L, 2))); } static int environment_normalize(lua_State * L) { return push_expr(L, normalize(to_environment(L, 1), to_expr(L, 2))); } -static int environment_infer_type(lua_State * L) { return push_expr(L, type_checker(to_environment(L, 1)).infer(to_expr(L, 2))); } -static int environment_type_check(lua_State * L) { return push_expr(L, type_checker(to_environment(L, 1)).check(to_expr(L, 2))); } +static int environment_infer_type(lua_State * L) { return push_ecs(L, type_checker(to_environment(L, 1)).infer(to_expr(L, 2))); } +static int environment_type_check(lua_State * L) { return push_ecs(L, type_checker(to_environment(L, 1)).check(to_expr(L, 2))); } static int environment_for_each_decl(lua_State * L) { environment const & env = to_environment(L, 1); luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun @@ -1577,12 +1590,57 @@ static const struct luaL_Reg constraint_m[] = { {0, 0} }; +// Constraint sequences +DECL_UDATA(constraint_seq) + +static int constraint_seq_mk(lua_State * L) { + unsigned nargs = lua_gettop(L); + constraint_seq cs; + for (unsigned i = 0; i < nargs; i++) { + cs += to_constraint(L, i); + } + return push_constraint_seq(L, cs); +} + +static int constraint_seq_concat(lua_State * L) { + if (is_constraint_seq(L, 1) && is_constraint(L, 2)) + return push_constraint_seq(L, to_constraint_seq(L, 1) + to_constraint(L, 2)); + else + return push_constraint_seq(L, to_constraint_seq(L, 1) + to_constraint_seq(L, 2)); +} + +static int constraint_seq_linearize(lua_State * L) { + buffer tmp; + to_constraint_seq(L, 1).linearize(tmp); + lua_newtable(L); + int i = 1; + for (constraint const & c : tmp) { + push_constraint(L, c); + lua_rawseti(L, -2, i); + i++; + } + return 1; +} + +static const struct luaL_Reg constraint_seq_m[] = { + {"__gc", constraint_seq_gc}, + {"__concat", constraint_seq_concat}, + {"concat", constraint_seq_concat}, + {"linearize", constraint_seq_linearize}, + {0, 0} +}; + static void open_constraint(lua_State * L) { luaL_newmetatable(L, constraint_mt); lua_pushvalue(L, -1); lua_setfield(L, -2, "__index"); setfuncs(L, constraint_m, 0); + luaL_newmetatable(L, constraint_seq_mt); + lua_pushvalue(L, -1); + lua_setfield(L, -2, "__index"); + setfuncs(L, constraint_seq_m, 0); + SET_GLOBAL_FUN(constraint_pred, "is_constraint"); SET_GLOBAL_FUN(mk_eq_cnstr, "mk_eq_cnstr"); SET_GLOBAL_FUN(mk_level_eq_cnstr, "mk_level_eq_cnstr"); @@ -1593,6 +1651,9 @@ static void open_constraint(lua_State * L) { SET_ENUM("LevelEq", constraint_kind::LevelEq); SET_ENUM("Choice", constraint_kind::Choice); lua_setglobal(L, "constraint_kind"); + + SET_GLOBAL_FUN(constraint_seq_pred, "is_constraint_seq"); + SET_GLOBAL_FUN(constraint_seq_mk, "constraint_seq"); } // Substitution @@ -1787,44 +1848,30 @@ static int mk_type_checker(lua_State * L) { return push_type_checker_ref(L, t); } } -static int type_checker_whnf(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); } +static int type_checker_whnf(lua_State * L) { return push_ecs(L, to_type_checker_ref(L, 1)->whnf(to_expr(L, 2))); } static int type_checker_ensure_pi(lua_State * L) { if (lua_gettop(L) == 2) - return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2))); + return push_ecs(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2))); else - return push_expr(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2), to_expr(L, 3))); + return push_ecs(L, to_type_checker_ref(L, 1)->ensure_pi(to_expr(L, 2), to_expr(L, 3))); } static int type_checker_ensure_sort(lua_State * L) { if (lua_gettop(L) == 2) - return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2))); + return push_ecs(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2))); else - return push_expr(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2), to_expr(L, 3))); + return push_ecs(L, to_type_checker_ref(L, 1)->ensure_sort(to_expr(L, 2), to_expr(L, 3))); } static int type_checker_check(lua_State * L) { int nargs = lua_gettop(L); if (nargs <= 2) - return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), level_param_names())); + return push_ecs(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), level_param_names())); else - return push_expr(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), to_level_param_names(L, 3))); + return push_ecs(L, to_type_checker_ref(L, 1)->check(to_expr(L, 2), to_level_param_names(L, 3))); } -static int type_checker_infer(lua_State * L) { return push_expr(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); } -static int type_checker_is_def_eq(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); } -static int type_checker_is_prop(lua_State * L) { return push_boolean(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); } -static int type_checker_push(lua_State * L) { to_type_checker_ref(L, 1)->push(); return 0; } -static int type_checker_pop(lua_State * L) { - if (to_type_checker_ref(L, 1)->num_scopes() == 0) - throw exception("invalid pop method, type_checker does not have backtracking points"); - to_type_checker_ref(L, 1)->pop(); - return 0; -} -static int type_checker_keep(lua_State * L) { - if (to_type_checker_ref(L, 1)->num_scopes() == 0) - throw exception("invalid pop method, type_checker does not have backtracking points"); - to_type_checker_ref(L, 1)->keep(); - return 0; -} -static int type_checker_num_scopes(lua_State * L) { return push_integer(L, to_type_checker_ref(L, 1)->num_scopes()); } -static int type_checker_next_cnstr(lua_State * L) { return push_optional_constraint(L, to_type_checker_ref(L, 1)->next_cnstr()); } +static int type_checker_infer(lua_State * L) { return push_ecs(L, to_type_checker_ref(L, 1)->infer(to_expr(L, 2))); } + +static int type_checker_is_def_eq(lua_State * L) { return push_bcs(L, to_type_checker_ref(L, 1)->is_def_eq(to_expr(L, 2), to_expr(L, 3))); } +static int type_checker_is_prop(lua_State * L) { return push_bcs(L, to_type_checker_ref(L, 1)->is_prop(to_expr(L, 2))); } static name g_tmp_prefix = name::mk_internal_unique_name(); @@ -1851,11 +1898,6 @@ static const struct luaL_Reg type_checker_ref_m[] = { {"infer", safe_function}, {"is_def_eq", safe_function}, {"is_prop", safe_function}, - {"push", safe_function}, - {"pop", safe_function}, - {"keep", safe_function}, - {"next_cnstr", safe_function}, - {"num_scopes", safe_function}, {0, 0} }; diff --git a/src/library/kernel_bindings.h b/src/library/kernel_bindings.h index 5880221bbd..4aaf12b571 100644 --- a/src/library/kernel_bindings.h +++ b/src/library/kernel_bindings.h @@ -20,6 +20,7 @@ UDATA_DEFS(environment) UDATA_DEFS(substitution) UDATA_DEFS(justification) UDATA_DEFS(constraint) +UDATA_DEFS_CORE(constraint_seq) UDATA_DEFS(substitution) UDATA_DEFS(io_state) UDATA_DEFS_CORE(type_checker_ref) diff --git a/src/library/match.cpp b/src/library/match.cpp index d6991413ff..d61ce7fcb1 100644 --- a/src/library/match.cpp +++ b/src/library/match.cpp @@ -347,10 +347,10 @@ bool match(expr const & p, expr const & t, buffer> & esubst, buff match_plugin mk_whnf_match_plugin(std::shared_ptr const & tc) { return [=](expr const & p, expr const & t, match_context & ctx) { // NOLINT try { - buffer cs; + constraint_seq cs; expr p1 = tc->whnf(p, cs); expr t1 = tc->whnf(t, cs); - return cs.empty() && (p1 != p || t1 != t) && ctx.match(p1, t1); + return !cs && (p1 != p || t1 != t) && ctx.match(p1, t1); } catch (exception&) { return false; } diff --git a/src/library/normalize.cpp b/src/library/normalize.cpp index 0f3f53c85d..1a0c3b4af2 100644 --- a/src/library/normalize.cpp +++ b/src/library/normalize.cpp @@ -32,7 +32,7 @@ class normalize_fn { } expr normalize(expr e) { - e = m_tc.whnf(e); + e = m_tc.whnf(e).first; switch (e.kind()) { case expr_kind::Var: case expr_kind::Constant: case expr_kind::Sort: case expr_kind::Meta: case expr_kind::Local: case expr_kind::Macro: diff --git a/src/library/num.cpp b/src/library/num.cpp index 128cbd2e8a..231ada991b 100644 --- a/src/library/num.cpp +++ b/src/library/num.cpp @@ -20,11 +20,11 @@ bool has_num_decls(environment const & env) { try { type_checker tc(env); return - tc.infer(g_zero) == g_num && - tc.infer(g_pos) == mk_arrow(g_pos_num, g_num) && - tc.infer(g_one) == g_pos_num && - tc.infer(g_bit0) == mk_arrow(g_pos_num, g_pos_num) && - tc.infer(g_bit1) == mk_arrow(g_pos_num, g_pos_num); + tc.infer(g_zero).first == g_num && + tc.infer(g_pos).first == mk_arrow(g_pos_num, g_num) && + tc.infer(g_one).first == g_pos_num && + tc.infer(g_bit0).first == mk_arrow(g_pos_num, g_pos_num) && + tc.infer(g_bit1).first == mk_arrow(g_pos_num, g_pos_num); } catch (...) { return false; } diff --git a/src/library/resolve_macro.cpp b/src/library/resolve_macro.cpp index ebcf18e2bc..d6e1d02739 100644 --- a/src/library/resolve_macro.cpp +++ b/src/library/resolve_macro.cpp @@ -84,7 +84,24 @@ public: // Begin of resolve_macro get_type implementation // This section of code is trusted when the environment has trust_level == 1 - bool is_def_eq(expr const & l1, expr const & l2, extension_context & ctx) const { return ctx.is_def_eq(l1, l2, jst()); } + bool is_def_eq(expr const & l1, expr const & l2, extension_context & ctx) const { + auto r = ctx.is_def_eq(l1, l2, jst()); + return r.first && !r.second; + } + + expr whnf(expr const & e, extension_context & ctx) const { + auto r = ctx.whnf(e); + if (r.second) + throw_kernel_exception(ctx.env(), "invalid resolve macro, constraints were generated while computing whnf", e); + return r.first; + } + + expr infer_type(expr const & e, extension_context & ctx) const { + auto r = ctx.infer_type(e); + if (r.second) + throw_kernel_exception(ctx.env(), "invalid resolve macro, constraints were generated while inferring type", e); + return r.first; + } /** \brief Return true if \c ls already contains a literal that is definitionally equal to \c l */ bool already_contains(expr const & l, buffer const & ls, extension_context & ctx) const { @@ -109,7 +126,7 @@ public: if (is_or(cls, lhs, rhs)) { return collect(lhs, rhs, l, R, ctx); } else { - cls = ctx.whnf(cls); + cls = whnf(cls, ctx); if (is_or(cls, lhs, rhs)) { return collect(lhs, rhs, l, R, ctx); } else if (is_def_eq(cls, l, ctx)) { @@ -125,8 +142,8 @@ public: virtual expr get_type(expr const & m, expr const * arg_types, extension_context & ctx) const { environment const & env = ctx.env(); check_num_args(env, m); - expr l = ctx.whnf(macro_arg(m, 0)); - expr not_l = ctx.whnf(g_not(l)); + expr l = whnf(macro_arg(m, 0), ctx); + expr not_l = whnf(g_not(l), ctx); expr C1 = arg_types[1]; expr C2 = arg_types[2]; buffer R; // resolvent @@ -148,12 +165,12 @@ public: virtual optional expand(expr const & m, extension_context & ctx) const { environment const & env = ctx.env(); check_num_args(env, m); - expr l = ctx.whnf(macro_arg(m, 0)); - expr not_l = ctx.whnf(g_not(l)); + expr l = whnf(macro_arg(m, 0), ctx); + expr not_l = whnf(g_not(l), ctx); expr H1 = macro_arg(m, 1); expr H2 = macro_arg(m, 2); - expr C1 = ctx.infer_type(H1); - expr C2 = ctx.infer_type(H2); + expr C1 = infer_type(H1, ctx); + expr C2 = infer_type(H2, ctx); expr arg_types[3] = { expr() /* get_type() does not use first argument */, C1, C2 }; expr R = get_type(m, arg_types, ctx); return some_expr(mk_or_elim_tree1(l, not_l, C1, H1, C2, H2, R, ctx)); @@ -198,7 +215,7 @@ public: if (is_or(C1, lhs, rhs)) { return mk_or_elim_tree1(l, not_l, lhs, rhs, H1, C2, H2, R, ctx); } else { - C1 = ctx.whnf(C1); + C1 = whnf(C1, ctx); if (is_or(C1, lhs, rhs)) { return mk_or_elim_tree1(l, not_l, lhs, rhs, H1, C2, H2, R, ctx); } else if (is_def_eq(C1, l, ctx)) { @@ -243,7 +260,7 @@ public: if (is_or(C2, lhs, rhs)) { return mk_or_elim_tree2(l, H, not_l, lhs, rhs, H2, R, ctx); } else { - C2 = ctx.whnf(C2); + C2 = whnf(C2, ctx); if (is_or(C2, lhs, rhs)) { return mk_or_elim_tree2(l, H, not_l, lhs, rhs, H2, R, ctx); } else if (is_def_eq(C2, not_l, ctx)) { diff --git a/src/library/string.cpp b/src/library/string.cpp index f7f895764f..48ac0a4d7e 100644 --- a/src/library/string.cpp +++ b/src/library/string.cpp @@ -101,11 +101,11 @@ bool has_string_decls(environment const & env) { try { type_checker tc(env); return - tc.infer(g_ff) == g_bool && - tc.infer(g_tt) == g_bool && - tc.infer(g_ascii) == g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> g_char))))))) && - tc.infer(g_empty) == g_string && - tc.infer(g_str) == g_char >> (g_string >> g_string); + tc.infer(g_ff).first == g_bool && + tc.infer(g_tt).first == g_bool && + tc.infer(g_ascii).first == g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> (g_bool >> g_char))))))) && + tc.infer(g_empty).first == g_string && + tc.infer(g_str).first == g_char >> (g_string >> g_string); } catch (...) { return false; } diff --git a/src/library/tactic/apply_tactic.cpp b/src/library/tactic/apply_tactic.cpp index fd6e7317e6..e0b8947155 100644 --- a/src/library/tactic/apply_tactic.cpp +++ b/src/library/tactic/apply_tactic.cpp @@ -44,7 +44,7 @@ bool collect_simple_metas(expr const & e, buffer & result) { unsigned get_expect_num_args(type_checker & tc, expr e) { unsigned r = 0; while (true) { - e = tc.whnf(e); + e = tc.whnf(e).first; if (!is_pi(e)) return r; e = binding_body(e); @@ -93,6 +93,7 @@ static void remove_redundant_metas(buffer & metas) { proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, proof_state const & s, expr const & _e, bool add_meta, bool add_subgoals, bool relax_main_opaque) { + // TODO(Leo): we are ignoring constraints produces by type checker goals const & gs = s.get_goals(); if (empty(gs)) return proof_state_seq(); @@ -102,7 +103,7 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, goals tail_gs = tail(gs); expr t = g.get_type(); expr e = _e; - expr e_t = tc.infer(e); + expr e_t = tc.infer(e).first; buffer metas; collect_simple_meta(e, metas); if (add_meta) { @@ -111,7 +112,7 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, if (num_t > num_e_t) return proof_state_seq(); // no hope to unify then for (unsigned i = 0; i < num_e_t - num_t; i++) { - e_t = tc.whnf(e_t); + e_t = tc.whnf(e_t).first; expr meta = g.mk_meta(ngen.next(), binding_domain(e_t)); e = mk_app(e, meta); e_t = instantiate(binding_body(e_t), meta); @@ -139,7 +140,7 @@ proof_state_seq apply_tactic_core(environment const & env, io_state const & ios, unsigned i = metas.size(); while (i > 0) { --i; - new_gs = cons(goal(metas[i], new_subst.instantiate_all(tc.infer(metas[i]))), new_gs); + new_gs = cons(goal(metas[i], new_subst.instantiate_all(tc.infer(metas[i]).first)), new_gs); } } return proof_state(new_gs, new_subst, new_ngen); diff --git a/src/library/tactic/expr_to_tactic.cpp b/src/library/tactic/expr_to_tactic.cpp index b6eadc09de..3240a4ff89 100644 --- a/src/library/tactic/expr_to_tactic.cpp +++ b/src/library/tactic/expr_to_tactic.cpp @@ -45,10 +45,10 @@ bool has_tactic_decls(environment const & env) { try { type_checker tc(env); return - tc.infer(g_builtin_tac) == g_tac_type && - tc.infer(g_and_then_tac_fn) == g_tac_type >> (g_tac_type >> g_tac_type) && - tc.infer(g_or_else_tac_fn) == g_tac_type >> (g_tac_type >> g_tac_type) && - tc.infer(g_repeat_tac_fn) == g_tac_type >> g_tac_type; + tc.infer(g_builtin_tac).first == g_tac_type && + tc.infer(g_and_then_tac_fn).first == g_tac_type >> (g_tac_type >> g_tac_type) && + tc.infer(g_or_else_tac_fn).first == g_tac_type >> (g_tac_type >> g_tac_type) && + tc.infer(g_repeat_tac_fn).first == g_tac_type >> g_tac_type; } catch (...) { return false; } @@ -69,7 +69,7 @@ static bool is_builtin_tactic(expr const & v) { } tactic expr_to_tactic(type_checker & tc, expr e, pos_info_provider const * p) { - e = tc.whnf(e); + e = tc.whnf(e).first; expr f = get_app_fn(e); if (!is_constant(f)) throw_failed(e); @@ -165,7 +165,7 @@ register_unary_num_tac::register_unary_num_tac(name const & n, std::function k = to_num(args[1]); if (!k) - k = to_num(tc.whnf(args[1])); + k = to_num(tc.whnf(args[1]).first); if (!k) throw expr_to_tactic_exception(e, "invalid tactic, second argument must be a numeral"); if (!k->is_unsigned_int()) @@ -199,7 +199,7 @@ static register_tac reg_trace(name(g_tac, "trace"), [](type_checker & tc, expr c throw expr_to_tactic_exception(e, "invalid trace tactic, argument expected"); if (auto str = to_string(args[0])) return trace_tactic(*str); - else if (auto str = to_string(tc.whnf(args[0]))) + else if (auto str = to_string(tc.whnf(args[0]).first)) return trace_tactic(*str); else throw expr_to_tactic_exception(e, "invalid trace tactic, string value expected"); @@ -224,7 +224,7 @@ static register_unary_num_tac reg_try_for(name(g_tac, "try_for"), [](tactic cons static register_tac reg_fixpoint(g_fixpoint_name, [](type_checker & tc, expr const & e, pos_info_provider const *) { if (!is_constant(app_fn(e))) throw expr_to_tactic_exception(e, "invalid fixpoint tactic, it must have one argument"); - expr r = tc.whnf(mk_app(app_arg(e), e)); + expr r = tc.whnf(mk_app(app_arg(e), e)).first; return fixpoint(r); }); diff --git a/src/library/tactic/goal.cpp b/src/library/tactic/goal.cpp index 67dc9821fb..46df614d8a 100644 --- a/src/library/tactic/goal.cpp +++ b/src/library/tactic/goal.cpp @@ -86,7 +86,7 @@ bool goal::validate_locals() const { bool goal::validate(environment const & env) const { if (validate_locals()) { type_checker tc(env); - return tc.is_def_eq(tc.check(m_meta), m_type); + return tc.is_def_eq(tc.check(m_meta).first, m_type).first; } else { return false; } diff --git a/src/library/tactic/proof_state.h b/src/library/tactic/proof_state.h index 8f34485b71..b23ef934cc 100644 --- a/src/library/tactic/proof_state.h +++ b/src/library/tactic/proof_state.h @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "util/lua.h" #include "util/optional.h" #include "util/name_set.h" +#include "kernel/metavar.h" #include "library/tactic/goal.h" namespace lean { diff --git a/src/library/tactic/tactic.cpp b/src/library/tactic/tactic.cpp index 1462f84305..43e94b515f 100644 --- a/src/library/tactic/tactic.cpp +++ b/src/library/tactic/tactic.cpp @@ -248,9 +248,11 @@ tactic exact_tactic(expr const & _e) { goals const & gs = s.get_goals(); goal const & g = head(gs); expr e = subst.instantiate(_e); - expr e_t = subst.instantiate(tc.infer(e)); + auto e_t_cs = tc.infer(e); + expr e_t = subst.instantiate(e_t_cs.first); expr t = subst.instantiate(g.get_type()); - if (tc.is_def_eq(e_t, t) && !tc.next_cnstr()) { + auto dcs = tc.is_def_eq(e_t, t); + if (dcs.first && !dcs.second && !e_t_cs.second) { expr new_p = g.abstract(e); check_has_no_local(new_p, _e, "exact"); subst.assign(g.get_name(), new_p); diff --git a/src/library/unifier.cpp b/src/library/unifier.cpp index 3558dbccca..d843facced 100644 --- a/src/library/unifier.cpp +++ b/src/library/unifier.cpp @@ -310,17 +310,11 @@ struct unifier_fn { m_assumption_idx(u.m_next_assumption_idx), m_jst(j), m_subst(u.m_subst), m_cnstrs(u.m_cnstrs), m_mvar_occs(u.m_mvar_occs), m_owned_map(u.m_owned_map), m_pattern(u.m_pattern) { u.m_next_assumption_idx++; - u.m_tc[0]->push(); - u.m_tc[1]->push(); } /** \brief Restore unifier's state with saved values, and update m_assumption_idx and m_failed_justifications. */ void restore_state(unifier_fn & u) { lean_assert(u.in_conflict()); - u.m_tc[0]->pop(); // restore type checker state - u.m_tc[1]->pop(); // restore type checker state - u.m_tc[0]->push(); - u.m_tc[1]->push(); u.m_subst = m_subst; u.m_cnstrs = m_cnstrs; u.m_mvar_occs = m_mvar_occs; @@ -438,54 +432,76 @@ struct unifier_fn { \remark If relax is true then opaque definitions from the main module are treated as transparent. */ bool is_def_eq(expr const & t1, expr const & t2, justification const & j, bool relax) { - if (m_tc[relax]->is_def_eq(t1, t2, j)) { - return true; - } else { + auto dcs = m_tc[relax]->is_def_eq(t1, t2, j); + if (!dcs.first) { // std::cout << "conflict: " << t1 << " =?= " << t2 << "\n"; set_conflict(j); return false; + } else { + return process_constraints(dcs.second); } } + /** \brief Process the given constraints. Return true iff no conflict was detected. */ + bool process_constraints(constraint_seq const & cs) { + return cs.all_of([&](constraint const & c) { return process_constraint(c); }); + } + + bool process_constraints(buffer const & cs) { + for (auto const & c : cs) { + if (!process_constraint(c)) + return false; + } + return true; + } + + + /** \brief Process constraints in \c cs, and append justification \c j to them. */ + bool process_constraints(constraint_seq const & cs, justification const & j) { + return cs.all_of([&](constraint const & c) { + return process_constraint(update_justification(c, mk_composite1(c.get_justification(), j))); + }); + } + + template + bool process_constraints(Constraints const & cs, justification const & j) { + for (auto const & c : cs) { + if (!process_constraint(update_justification(c, mk_composite1(c.get_justification(), j)))) + return false; + } + return true; + } + /** \brief Put \c e in weak head normal form. \remark If relax is true then opaque definitions from the main module are treated as transparent. - \remark Constraints generated in the process are stored in \c cs. The justification \c j is composed with them. + \remark Constraints generated in the process are stored in \c cs. */ - expr whnf(expr const & e, justification const & j, bool relax, buffer & cs) { - unsigned cs_sz = cs.size(); - expr r = m_tc[relax]->whnf(e, cs); - for (unsigned i = cs_sz; i < cs.size(); i++) - cs[i] = update_justification(cs[i], mk_composite1(j, cs[i].get_justification())); - return r; - } - - /** \brief Process the given constraints. Return true iff no conflict was detected. */ - bool process_constraints(buffer & cs) { - for (auto const & c : cs) - if (!process_constraint(c)) - return false; - return true; + expr whnf(expr const & e, bool relax, constraint_seq & cs) { + return m_tc[relax]->whnf(e, cs); } /** \brief Infer \c e type. \remark Return none if an exception was throw when inferring the type. \remark If relax is true then opaque definitions from the main module are treated as transparent. - \remark Constraints generated in the process are stored in \c cs. The justification \c j is composed with them. + \remark Constraints generated in the process are stored in \c cs. */ - optional infer(expr const & e, justification const & j, bool relax, buffer & cs) { + optional infer(expr const & e, bool relax, constraint_seq & cs) { try { - unsigned cs_sz = cs.size(); - expr r = m_tc[relax]->infer(e, cs); - for (unsigned i = cs_sz; i < cs.size(); i++) - cs[i] = update_justification(cs[i], mk_composite1(j, cs[i].get_justification())); - return some_expr(r); + return some_expr(m_tc[relax]->infer(e, cs)); } catch (exception &) { return none_expr(); } } + expr whnf(expr const & e, justification const & j, bool relax, buffer & cs) { + constraint_seq _cs; + expr r = whnf(e, relax, _cs); + to_buffer(_cs, j, cs); + return r; + } + justification mk_assign_justification(expr const & m, expr const & m_type, expr const & v_type, justification const & j) { auto r = j.get_main_expr(); if (!r) r = m; @@ -524,28 +540,11 @@ struct unifier_fn { lean_assert(is_metavar(m)); lean_assert(!in_conflict()); m_subst.assign(m, v, j); - #if 0 - expr m_type = mlocal_type(m); - expr v_type; - buffer cs; - if (auto type = infer(v, j, relax, cs)) { - v_type = *type; - if (!process_constraints(cs)) - return false; - } else { - set_conflict(j); - return false; - } - lean_assert(!in_conflict()); - justification new_j = mk_assign_justification(m, m_type, v_type, j); - if (!is_def_eq(m_type, v_type, new_j, relax)) - return false; - #else - buffer cs; - auto lhs_type = infer(lhs, j, relax, cs); - auto rhs_type = infer(rhs, j, relax, cs); + constraint_seq cs; + auto lhs_type = infer(lhs, relax, cs); + auto rhs_type = infer(rhs, relax, cs); if (lhs_type && rhs_type) { - if (!process_constraints(cs)) + if (!process_constraints(cs, j)) return false; justification new_j = mk_assign_justification(m, *lhs_type, *rhs_type, j); if (!is_def_eq(*lhs_type, *rhs_type, new_j, relax)) @@ -554,7 +553,6 @@ struct unifier_fn { set_conflict(j); return false; } - #endif auto it = m_mvar_occs.find(mlocal_name(m)); if (it) { cnstr_idx_set s = *it; @@ -943,8 +941,6 @@ struct unifier_fn { } void pop_case_split() { - m_tc[0]->pop(); - m_tc[1]->pop(); m_case_splits.pop_back(); } @@ -974,13 +970,6 @@ struct unifier_fn { return optional(); } - /** \brief Process constraints in \c cs, and append justification \c j to them. */ - bool process_constraints(constraints const & cs, justification const & j) { - for (constraint const & c : cs) - process_constraint(update_justification(c, mk_composite1(c.get_justification(), j))); - return !in_conflict(); - } - bool next_lazy_constraints_case_split(lazy_constraints_case_split & cs) { auto r = cs.m_tail.pull(); if (r) { @@ -1029,29 +1018,30 @@ struct unifier_fn { if (!is_constant(f_lhs) || !is_constant(f_rhs) || const_name(f_lhs) != const_name(f_rhs)) return lazy_list(); justification const & j = c.get_justification(); - buffer cs; + constraint_seq cs; bool relax = relax_main_opaque(c); - lean_assert(!m_tc[relax]->next_cnstr()); - if (!m_tc[relax]->is_def_eq(f_lhs, f_rhs, j, cs)) + auto fcs = m_tc[relax]->is_def_eq(f_lhs, f_rhs, j); + if (!fcs.first) return lazy_list(); + cs = fcs.second; buffer args_lhs, args_rhs; get_app_args(lhs, args_lhs); get_app_args(rhs, args_rhs); if (args_lhs.size() != args_rhs.size()) return lazy_list(); - lean_assert(!m_tc[relax]->next_cnstr()); - for (unsigned i = 0; i < args_lhs.size(); i++) - if (!m_tc[relax]->is_def_eq(args_lhs[i], args_rhs[i], j, cs)) + for (unsigned i = 0; i < args_lhs.size(); i++) { + auto acs = m_tc[relax]->is_def_eq(args_lhs[i], args_rhs[i], j); + if (!acs.first) return lazy_list(); - return lazy_list(to_list(cs.begin(), cs.end())); + cs = acs.second + cs; + } + return lazy_list(cs.to_list()); } bool process_plugin_constraint(constraint const & c) { bool relax = relax_main_opaque(c); lean_assert(!is_choice_cnstr(c)); - lean_assert(!m_tc[relax]->next_cnstr()); lazy_list alts = m_plugin->solve(*m_tc[relax], c, m_ngen.mk_child()); - lean_assert(!m_tc[relax]->next_cnstr()); alts = append(alts, process_const_const_cnstr(c)); return process_lazy_constraints(alts, c.get_justification()); } @@ -1068,8 +1058,8 @@ struct unifier_fn { expr m_type; bool relax = relax_main_opaque(c); - buffer cs; - if (auto type = infer(m, c.get_justification(), relax, cs)) { + constraint_seq cs; + if (auto type = infer(m, relax, cs)) { m_type = *type; if (!process_constraints(cs)) return false; @@ -1128,11 +1118,11 @@ struct unifier_fn { expr t = apply_beta(lhs_fn_val, lhs_args.size(), lhs_args.data()); expr s = apply_beta(rhs_fn_val, rhs_args.size(), rhs_args.data()); bool relax = relax_main_opaque(c); - buffer cs2; - if (m_tc[relax]->is_def_eq(t, s, j, cs2)) { + auto dcs = m_tc[relax]->is_def_eq(t, s, j); + if (dcs.first) { // create a case split a = mk_assumption_justification(m_next_assumption_idx); - add_case_split(std::unique_ptr(new simple_case_split(*this, j, to_list(cs2.begin(), cs2.end())))); + add_case_split(std::unique_ptr(new simple_case_split(*this, j, dcs.second.to_list()))); } // process first case @@ -1231,26 +1221,20 @@ struct unifier_fn { return true; } - /** \brief Copy pending constraints in u.m_tc[relax] to cs and append justification j to them */ - void copy_pending_constraints(buffer & cs) { - while (auto c = u.m_tc[relax]->next_cnstr()) - cs.push_back(update_justification(*c, mk_composite1(c->get_justification(), j))); - } - /** \see ensure_sufficient_args */ - expr ensure_sufficient_args_core(expr mtype, unsigned i) { + expr ensure_sufficient_args_core(expr mtype, unsigned i, constraint_seq & cs) { if (i == margs.size()) return mtype; - mtype = u.m_tc[relax]->ensure_pi(mtype); + mtype = u.m_tc[relax]->ensure_pi(mtype, cs); expr local = u.mk_local_for(mtype); expr body = instantiate(binding_body(mtype), local); - return Pi(local, ensure_sufficient_args_core(body, i+1)); + return Pi(local, ensure_sufficient_args_core(body, i+1, cs)); } /** \brief Make sure mtype is a Pi of size at least margs.size(). If it is not, we use ensure_pi and (potentially) add new constaints to enforce it. */ - expr ensure_sufficient_args(expr const & mtype, buffer & cs) { + expr ensure_sufficient_args(expr const & mtype, constraint_seq & cs) { expr t = mtype; unsigned num = 0; while (is_pi(t)) { @@ -1259,12 +1243,7 @@ struct unifier_fn { } if (num == margs.size()) return mtype; - lean_assert(!u.m_tc[relax]->next_cnstr()); // make sure there are no pending constraints - // We must create a scope to make sure no constraints "leak" into the current state. - type_checker::scope scope(*u.m_tc[relax]); - auto new_mtype = ensure_sufficient_args_core(mtype, 0); - copy_pending_constraints(cs); - return new_mtype; + return ensure_sufficient_args_core(mtype, 0, cs); } /** @@ -1278,10 +1257,10 @@ struct unifier_fn { lean_assert(is_metavar(m)); lean_assert(is_sort(rhs) || is_constant(rhs)); expr const & mtype = mlocal_type(m); - buffer cs; - auto new_mtype = ensure_sufficient_args(mtype, cs); - cs.push_back(mk_eq_cnstr(m, mk_lambda_for(new_mtype, rhs), j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); + constraint_seq cs; + expr new_mtype = ensure_sufficient_args(mtype, cs); + cs = cs + mk_eq_cnstr(m, mk_lambda_for(new_mtype, rhs), j, relax); + alts.push_back(cs.to_list()); } /** @@ -1302,15 +1281,15 @@ struct unifier_fn { expr const & mtype = mlocal_type(m); unsigned vidx = margs.size() - i - 1; expr const & marg = margs[i]; - buffer cs; + constraint_seq cs; auto new_mtype = ensure_sufficient_args(mtype, cs); // Remark: we should not use mk_eq_cnstr(marg, rhs, j) since is_def_eq may be able to reduce them. // The unifier assumes the eq constraints are reduced. if (u.m_tc[relax]->is_def_eq_types(marg, rhs, j, cs) && u.m_tc[relax]->is_def_eq(marg, rhs, j, cs)) { expr v = mk_lambda_for(new_mtype, mk_var(vidx)); - cs.push_back(mk_eq_cnstr(m, v, j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); + cs = cs + mk_eq_cnstr(m, v, j, relax); + alts.push_back(cs.to_list()); } } @@ -1344,11 +1323,11 @@ struct unifier_fn { mk_simple_nonlocal_projection(i); } else if (is_local(marg) && is_local(rhs) && mlocal_name(marg) == mlocal_name(rhs)) { // if the argument is local, and rhs is equal to it, then we also add a projection - buffer cs; + constraint_seq cs; auto new_mtype = ensure_sufficient_args(mtype, cs); expr v = mk_lambda_for(new_mtype, mk_var(vidx)); - cs.push_back(mk_eq_cnstr(m, v, j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); + cs = cs + mk_eq_cnstr(m, v, j, relax); + alts.push_back(cs.to_list()); } } } @@ -1386,30 +1365,6 @@ struct unifier_fn { return v; } - /** \brief Check if term \c e (produced by an imitation step) is - type correct, and store generated constraints in \c cs. - Include \c j in all generated constraints */ - bool check_imitation(expr e, buffer & cs) { - buffer ls; - while (is_lambda(e)) { - expr d = instantiate_rev(binding_domain(e), ls.size(), ls.data()); - expr l = mk_local(u.m_ngen.next(), binding_name(e), d, binding_info(e)); - ls.push_back(l); - e = binding_body(e); - } - e = instantiate_rev(e, ls.size(), ls.data());; - try { - buffer aux; - u.m_tc[relax]->check(e, aux); - for (auto c : aux) { - cs.push_back(update_justification(c, mk_composite1(j, c.get_justification()))); - } - return true; - } catch (exception&) { - return false; - } - } - void mk_app_projections() { lean_assert(is_metavar(m)); lean_assert(is_app(rhs)); @@ -1433,7 +1388,7 @@ struct unifier_fn { /** \brief Create the local context \c locals for the imitiation step. */ - void mk_local_context(buffer & locals, buffer & cs) { + void mk_local_context(buffer & locals, constraint_seq & cs) { expr mtype = mlocal_type(m); unsigned nargs = margs.size(); mtype = ensure_sufficient_args(mtype, cs); @@ -1456,7 +1411,7 @@ struct unifier_fn { } expr mk_imitiation_arg(expr const & arg, expr const & type, buffer const & locals, - buffer & cs) { + constraint_seq & cs) { if (!has_meta_args() && is_local(arg) && contains_local(arg, locals)) { return arg; } else { @@ -1464,39 +1419,36 @@ struct unifier_fn { if (context_check(type, locals)) { expr maux = mk_metavar(u.m_ngen.next(), Pi(locals, type)); // std::cout << " >> " << maux << " : " << mlocal_type(maux) << "\n"; - cs.push_back(mk_eq_cnstr(mk_app(maux, margs), arg, j, relax)); + cs = mk_eq_cnstr(mk_app(maux, margs), arg, j, relax) + cs; return mk_app(maux, locals); } else { expr maux_type = mk_metavar(u.m_ngen.next(), Pi(locals, mk_sort(mk_meta_univ(u.m_ngen.next())))); expr maux = mk_metavar(u.m_ngen.next(), Pi(locals, mk_app(maux_type, locals))); - cs.push_back(mk_eq_cnstr(mk_app(maux, margs), arg, j, relax)); + cs = mk_eq_cnstr(mk_app(maux, margs), arg, j, relax) + cs; return mk_app(maux, locals); } } } - void mk_app_imitation_core(expr const & f, buffer const & locals, buffer & cs) { + void mk_app_imitation_core(expr const & f, buffer const & locals, constraint_seq & cs) { buffer rargs; get_app_args(rhs, rargs); buffer sargs; try { - // create a scope to make sure no constraints "leak" into the current state - type_checker::scope scope(*u.m_tc[relax]); - expr f_type = u.m_tc[relax]->infer(f); + expr f_type = u.m_tc[relax]->infer(f, cs); for (expr const & rarg : rargs) { - f_type = u.m_tc[relax]->ensure_pi(f_type); + f_type = u.m_tc[relax]->ensure_pi(f_type, cs); expr d_type = binding_domain(f_type); expr sarg = mk_imitiation_arg(rarg, d_type, locals, cs); sargs.push_back(sarg); f_type = instantiate(binding_body(f_type), sarg); } - copy_pending_constraints(cs); } catch (exception&) {} expr v = Fun(locals, mk_app(f, sargs)); // std::cout << " >> app imitation, v: " << v << "\n"; lean_assert(!has_local(v)); - cs.push_back(mk_eq_cnstr(m, v, j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); + cs = cs + mk_eq_cnstr(m, v, j, relax); + alts.push_back(cs.to_list()); } /** @@ -1519,22 +1471,20 @@ struct unifier_fn { void mk_app_imitation() { lean_assert(is_metavar(m)); lean_assert(is_app(rhs)); - lean_assert(!u.m_tc[relax]->next_cnstr()); // make sure there are no pending constraints buffer locals; - buffer cs; + constraint_seq cs; flet let(j, j); // save j value mk_local_context(locals, cs); lean_assert(margs.size() == locals.size()); expr const & f = get_app_fn(rhs); lean_assert(is_constant(f) || is_local(f)); if (is_local(f)) { - unsigned cs_sz = cs.size(); unsigned i = margs.size(); while (i > 0) { --i; if (is_local(margs[i]) && mlocal_name(margs[i]) == mlocal_name(f)) { - cs.shrink(cs_sz); - mk_app_imitation_core(locals[i], locals, cs); + constraint_seq new_cs = cs; + mk_app_imitation_core(locals[i], locals, new_cs); } } } else { @@ -1556,30 +1506,27 @@ struct unifier_fn { void mk_bindings_imitation() { lean_assert(is_metavar(m)); lean_assert(is_binding(rhs)); - lean_assert(!u.m_tc[relax]->next_cnstr()); // make sure there are no pending constraints - buffer cs; + constraint_seq cs; buffer locals; flet let(j, j); // save j value mk_local_context(locals, cs); lean_assert(margs.size() == locals.size()); try { // create a scope to make sure no constraints "leak" into the current state - type_checker::scope scope(*u.m_tc[relax]); expr rhs_A = binding_domain(rhs); - expr A_type = u.m_tc[relax]->infer(rhs_A); + expr A_type = u.m_tc[relax]->infer(rhs_A, cs); expr A = mk_imitiation_arg(rhs_A, A_type, locals, cs); expr local = mk_local(u.m_ngen.next(), binding_name(rhs), A, binding_info(rhs)); locals.push_back(local); margs.push_back(local); expr rhs_B = instantiate(binding_body(rhs), local); - expr B_type = u.m_tc[relax]->infer(rhs_B); + expr B_type = u.m_tc[relax]->infer(rhs_B, cs); expr B = mk_imitiation_arg(rhs_B, B_type, locals, cs); expr binding = is_pi(rhs) ? Pi(local, B) : Fun(local, B); locals.pop_back(); expr v = Fun(locals, binding); - copy_pending_constraints(cs); - cs.push_back(mk_eq_cnstr(m, v, j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); + cs = cs + mk_eq_cnstr(m, v, j, relax); + alts.push_back(cs.to_list()); } catch (exception&) {} margs.pop_back(); } @@ -1596,24 +1543,27 @@ struct unifier_fn { ?m =?= fun (x_1 ... x_k), M((?m_1 x_1 ... x_k) ... (?m_n x_1 ... x_k)) */ void mk_macro_imitation() { + // TODO(Leo): use same approach used in mk_app_imitation lean_assert(is_metavar(m)); lean_assert(is_macro(rhs)); - buffer cs; + constraint_seq cs; expr mtype = mlocal_type(m); mtype = ensure_sufficient_args(mtype, cs); // create an auxiliary metavariable for each macro argument buffer sargs; for (unsigned i = 0; i < macro_num_args(rhs); i++) { expr maux = mk_aux_metavar_for(u.m_ngen, mtype); - cs.push_back(mk_eq_cnstr(mk_app(maux, margs), macro_arg(rhs, i), j, relax)); + cs = mk_eq_cnstr(mk_app(maux, margs), macro_arg(rhs, i), j, relax) + cs; sargs.push_back(mk_app_vars(maux, margs.size())); } expr v = mk_macro(macro_def(rhs), sargs.size(), sargs.data()); v = mk_lambda_for(mtype, v); - if (check_imitation(v, cs)) { - cs.push_back(mk_eq_cnstr(m, v, j, relax)); - alts.push_back(to_list(cs.begin(), cs.end())); - } + // if (check_imitation(v, cs)) { + // cs.push_back(mk_eq_cnstr(m, v, j, relax)); + // alts.push_back(to_list(cs.begin(), cs.end())); + // } + cs = cs + mk_eq_cnstr(m, v, j, relax); + alts.push_back(cs.to_list()); } public: @@ -1689,10 +1639,10 @@ struct unifier_fn { if (is_app(rhs)) { expr const & f = get_app_fn(rhs); if (!is_local(f) && !is_constant(f)) { - buffer cs; - expr new_rhs = whnf(rhs, j, relax, cs); + constraint_seq cs; + expr new_rhs = whnf(rhs, relax, cs); lean_assert(new_rhs != rhs); - if (!process_constraints(cs)) + if (!process_constraints(cs, j)) return false; return is_def_eq(lhs, new_rhs, j, relax); } @@ -1770,20 +1720,6 @@ struct unifier_fn { return true; } - void consume_tc_cnstrs() { - while (true) { - if (in_conflict()) { - return; - } else if (auto c = m_tc[0]->next_cnstr()) { - process_constraint(*c); - } else if (auto c = m_tc[1]->next_cnstr()) { - process_constraint(*c); - } else { - break; - } - } - } - /** \brief Process the following constraints 1. (max l1 l2) =?= 0 OR @@ -1857,8 +1793,6 @@ struct unifier_fn { /** \brief Process the next constraint in the constraint queue m_cnstrs */ bool process_next() { lean_assert(!m_cnstrs.empty()); - lean_assert(!m_tc[0]->next_cnstr()); - lean_assert(!m_tc[1]->next_cnstr()); auto const * p = m_cnstrs.min(); unsigned cidx = p->second; if (!m_expensive && cidx >= get_group_first_index(cnstr_group::ClassInstance)) @@ -1871,8 +1805,6 @@ struct unifier_fn { } else { auto r = instantiate_metavars(c); c = r.first; - lean_assert(!m_tc[0]->next_cnstr()); - lean_assert(!m_tc[1]->next_cnstr()); bool modified = r.second; if (is_level_eq_cnstr(c)) { if (modified) @@ -1933,7 +1865,6 @@ struct unifier_fn { return optional(); } while (true) { - consume_tc_cnstrs(); if (!in_conflict()) { if (m_cnstrs.empty()) break; @@ -1942,8 +1873,6 @@ struct unifier_fn { if (in_conflict() && !resolve_conflict()) return failure(); } - lean_assert(!m_tc[0]->next_cnstr()); - lean_assert(!m_tc[1]->next_cnstr()); lean_assert(!in_conflict()); lean_assert(m_cnstrs.empty()); substitution s = m_subst; @@ -1983,10 +1912,12 @@ lazy_list unify(environment const & env, expr const & lhs, expr co expr _lhs = new_s.instantiate(lhs); expr _rhs = new_s.instantiate(rhs); auto u = std::make_shared(env, 0, nullptr, ngen, new_s, false, max_steps, expensive); - if (!u->m_tc[relax]->is_def_eq(_lhs, _rhs)) + constraint_seq cs; + if (!u->m_tc[relax]->is_def_eq(_lhs, _rhs, justification(), cs) || !u->process_constraints(cs)) { return lazy_list(); - else + } else { return unify(u); + } } lazy_list unify(environment const & env, expr const & lhs, expr const & rhs, name_generator const & ngen, diff --git a/src/tests/kernel/environment.cpp b/src/tests/kernel/environment.cpp index f2be70692a..402817b84f 100644 --- a/src/tests/kernel/environment.cpp +++ b/src/tests/kernel/environment.cpp @@ -67,12 +67,12 @@ static void tst1() { expr c = mk_local("c", Prop); expr id = Const("id"); type_checker checker(env3, name_generator("tmp")); - lean_assert(checker.check(id(Prop)) == Prop >> Prop); - lean_assert(checker.whnf(id(Prop, c)) == c); - lean_assert(checker.whnf(id(Prop, id(Prop, id(Prop, c)))) == c); + lean_assert(checker.check(id(Prop)).first == Prop >> Prop); + lean_assert(checker.whnf(id(Prop, c)).first == c); + lean_assert(checker.whnf(id(Prop, id(Prop, id(Prop, c)))).first == c); type_checker checker2(env2, name_generator("tmp")); - lean_assert(checker2.whnf(id(Prop, id(Prop, id(Prop, c)))) == id(Prop, id(Prop, id(Prop, c)))); + lean_assert(checker2.whnf(id(Prop, id(Prop, id(Prop, c)))).first == id(Prop, id(Prop, id(Prop, c)))); } static void tst2() { @@ -99,34 +99,34 @@ static void tst2() { expr c1 = mk_local("c1", Prop); expr c2 = mk_local("c2", Prop); expr id = Const("id"); - std::cout << checker.whnf(f3(c1, c2)) << "\n"; + std::cout << checker.whnf(f3(c1, c2)).first << "\n"; lean_assert_eq(env.find(name(base, 98))->get_weight(), 98); - lean_assert(checker.is_def_eq(f98(c1, c2), f97(f97(c1, c2), f97(c2, c1)))); - lean_assert(checker.is_def_eq(f98(c1, id(Prop, id(Prop, c2))), f97(f97(c1, id(Prop, c2)), f97(c2, c1)))); + lean_assert(checker.is_def_eq(f98(c1, c2), f97(f97(c1, c2), f97(c2, c1))).first); + lean_assert(checker.is_def_eq(f98(c1, id(Prop, id(Prop, c2))), f97(f97(c1, id(Prop, c2)), f97(c2, c1))).first); name_set s; s.insert(name(base, 96)); type_checker checker2(env, name_generator("tmp"), mk_default_converter(env, optional(), true, s)); - lean_assert_eq(checker2.whnf(f98(c1, c2)), + lean_assert_eq(checker2.whnf(f98(c1, c2)).first, f96(f96(f97(c1, c2), f97(c2, c1)), f96(f97(c2, c1), f97(c1, c2)))); } class normalizer_extension_tst : public normalizer_extension { public: - virtual optional operator()(expr const & e, extension_context & ctx) const { + virtual optional> operator()(expr const & e, extension_context & ctx) const { if (!is_app(e)) - return none_expr(); + return optional>(); expr const & f = app_fn(e); expr const & a = app_arg(e); if (!is_constant(f) || const_name(f) != name("proj1")) - return none_expr(); - expr a_n = ctx.whnf(a); + return optional>(); + expr a_n = ctx.whnf(a).first; if (!is_app(a_n) || !is_app(app_fn(a_n)) || !is_constant(app_fn(app_fn(a_n)))) - return none_expr(); + return optional>(); expr const & mk = app_fn(app_fn(a_n)); if (const_name(mk) != name("mk")) - return none_expr(); + return optional>(); // In a real implementation, we must check if proj1 and mk were defined in the environment. - return some_expr(app_arg(app_fn(a_n))); + return optional>(app_arg(app_fn(a_n)), constraint_seq()); } virtual bool may_reduce_later(expr const &, extension_context &) const { return false; } virtual bool supports(name const &) const { return false; } @@ -145,7 +145,7 @@ static void tst3() { expr a = Const("a"); expr b = Const("b"); type_checker checker(env, name_generator("tmp")); - lean_assert_eq(checker.whnf(proj1(proj1(mk(id(A, mk(a, b)), b)))), a); + lean_assert_eq(checker.whnf(proj1(proj1(mk(id(A, mk(a, b)), b)))).first, a); } class dummy_ext : public environment_extension {}; diff --git a/src/util/sequence.h b/src/util/sequence.h index 9a0e7ecc72..f6b5e3eb08 100644 --- a/src/util/sequence.h +++ b/src/util/sequence.h @@ -10,6 +10,7 @@ Author: Leonardo de Moura #include "util/buffer.h" #include "util/optional.h" #include "util/memory_pool.h" +#include "util/list.h" namespace lean { /** \brief Sequence datastructure with O(1) concatenation operation */ @@ -76,9 +77,13 @@ public: friend bool is_eqp(sequence const & s1, sequence const & s2) { return s1.m_node.raw() == s2.m_node.raw(); } friend sequence operator+(sequence const & s1, sequence const & s2) { return sequence(s1, s2); } + friend sequence operator+(sequence const & s, T const & v) { return s + sequence(v); } + friend sequence operator+(T const & v, sequence const & s) { return sequence(v) + s; } + sequence & operator+=(T const & v) { *this = *this + v; return *this; } + sequence & operator+=(sequence const & s) { *this = *this + s; return *this; } - /** \brief Store sequence elements in \c r */ - void linearize(buffer & r) const { + template + bool all_of(F && f) const { buffer todo; if (m_node) todo.push_back(m_node.raw()); while (!todo.empty()) { @@ -88,9 +93,23 @@ public: todo.push_back(static_cast(c)->m_second.raw()); todo.push_back(static_cast(c)->m_first.raw()); } else { - r.push_back(static_cast(c)->m_value); + if (!f(static_cast(c)->m_value)) + return false; } } + return true; + } + + template + void for_each(F && f) const { all_of([&](T const & v) { f(v); return true; }); } + + /** \brief Store sequence elements in \c r */ + void linearize(buffer & r) const { for_each([&](T const & v) { r.push_back(v); }); } + + list to_list() const { + buffer tmp; + linearize(tmp); + return ::lean::to_list(tmp.begin(), tmp.end()); } }; diff --git a/tests/lua/env5.lua b/tests/lua/env5.lua index 35a69f0843..75746d8d14 100644 --- a/tests/lua/env5.lua +++ b/tests/lua/env5.lua @@ -41,8 +41,11 @@ assert(not env:is_descendant(env2)) local tc2 = type_checker(env2) id_u = Const("id", {mk_global_univ("u")}) print(tc2:check(id_u)) -print(tc2:check(tc2:check(id_u))) -print(tc2:check(tc2:check(tc2:check(id_u)))) +local tmp = tc2:check(id_u) +print(tc2:check(tmp)) +local tmp1 = tc2:check(id_u) +local tmp2 = tc2:check(tmp1) +print(tc2:check(tmp2)) env2 = add_decl(env2, mk_var_decl("a", Type)) local tc2 = type_checker(env2) local a = Const("a") diff --git a/tests/lua/tc1.lua b/tests/lua/tc1.lua index f57a284955..f152bd1bb2 100644 --- a/tests/lua/tc1.lua +++ b/tests/lua/tc1.lua @@ -21,6 +21,3 @@ print("check(t): ") print(tc2:check(t)) print("check(t2): ") print(tc2:check(t2)) -assert(tc2:next_cnstr()) -assert(tc2:next_cnstr()) -assert(not tc2:next_cnstr()) diff --git a/tests/lua/tc5.lua b/tests/lua/tc5.lua index 9fd02ba2e5..846925defc 100644 --- a/tests/lua/tc5.lua +++ b/tests/lua/tc5.lua @@ -10,7 +10,6 @@ assert(tc:is_prop(Const("C"))) assert(not tc:is_prop(Const("T"))) assert(not tc:is_prop(Const("B2"))) print(tc:check(mk_lambda("x", mk_metavar("m", mk_metavar("t", mk_sort(mk_meta_univ("l")))), Var(0)))) -assert(tc:next_cnstr()) print(tc:ensure_sort(Const("B2"))) assert(not pcall(function() print(tc:ensure_sort(Const("A"))) diff --git a/tests/lua/tc8.lua b/tests/lua/tc8.lua index f7c1d157c9..1735819df6 100644 --- a/tests/lua/tc8.lua +++ b/tests/lua/tc8.lua @@ -8,17 +8,16 @@ local a = Const("a") local m1 = mk_metavar("m1", mk_metavar("m2", mk_sort(mk_meta_univ("l")))) local ngen = name_generator("tst") local tc = type_checker(env, ngen) -assert(tc:num_scopes() == 0) -tc:push() -assert(tc:num_scopes() == 1) -print(tc:check(f(m1))) -assert(tc:next_cnstr()) -assert(not tc:next_cnstr()) -print(tc:check(f(f(m1)))) -assert(not tc:next_cnstr()) -- New constraint is not generated -tc:pop() -- forget that we checked f(m1) -print(tc:check(f(m1))) --- constraint is generated again -assert(tc:next_cnstr()) -assert(not tc:next_cnstr()) -check_error(function() tc:pop() end) + +function test_check(e) + t, cs = tc:check(e) + print(tostring(e) .. " : " .. tostring(t)) + cs = cs:linearize() + for i = 1, #cs do + print(" >> " .. tostring(cs[i])) + end +end + +test_check(f(m1)) +test_check(f(f(m1))) +test_check(f(m1)) diff --git a/tests/lua/tc_bug1.lua b/tests/lua/tc_bug1.lua index ef7b06e738..cd6d99db96 100644 --- a/tests/lua/tc_bug1.lua +++ b/tests/lua/tc_bug1.lua @@ -34,27 +34,26 @@ local ng = name_generator("foo") local tc = type_checker(env, ng) local m1 = mk_metavar("m1", Prop) print("before is_def_eq") -assert(not tc:next_cnstr()) local tc = type_checker(env, ng) -assert(tc:is_def_eq(foo_intro(m1, q(a), q(a), Ax1), foo_intro(q(a), q(a), q(a), Ax2))) -local c = tc:next_cnstr() -assert(c) -assert(not tc:next_cnstr()) -assert(c:lhs() == m1) -assert(c:rhs() == q(a)) +r, cs = tc:is_def_eq(foo_intro(m1, q(a), q(a), Ax1), foo_intro(q(a), q(a), q(a), Ax2)) +assert(r) +cs = cs:linearize() +assert(#cs == 1) +assert(cs[1]:lhs() == m1) +assert(cs[1]:rhs() == q(a)) local tc = type_checker(env, ng) -assert(not tc:next_cnstr()) print(tostring(foo_intro) .. " : " .. tostring(tc:check(foo_intro))) print(tostring(foo_intro2) .. " : " .. tostring(tc:check(foo_intro2))) assert(tc:is_def_eq(foo_intro, foo_intro2)) print("before is_def_eq2") -assert(tc:is_def_eq(foo_intro(m1, q(a), q(b), Ax1), foo_intro2(q(a), q(a), q(a), Ax2))) -assert(not tc:next_cnstr()) +local r, cs = tc:is_def_eq(foo_intro(m1, q(a), q(b), Ax1), foo_intro2(q(a), q(a), q(a), Ax2)) +assert(r) +cs = cs:linearize() +assert(#cs == 0) local tc = type_checker(env, ng) print("before failure") assert(not pcall(function() print(tc:check(and_intro(m1, q(a), Ax1, Ax3))) end)) -assert(not tc:next_cnstr()) print("before success") -print(tc:check(and_intro(m1, q(a), Ax1, Ax2))) -assert(tc:next_cnstr()) -assert(not tc:next_cnstr()) +local t, cs = tc:check(and_intro(m1, q(a), Ax1, Ax2)) +cs = cs:linearize() +assert(#cs == 1)