diff --git a/src/frontends/lean/decl_cmds.cpp b/src/frontends/lean/decl_cmds.cpp index 56d2e66f0e..9f89b83788 100644 --- a/src/frontends/lean/decl_cmds.cpp +++ b/src/frontends/lean/decl_cmds.cpp @@ -8,6 +8,7 @@ Author: Leonardo de Moura #include "util/sstream.h" #include "kernel/type_checker.h" #include "kernel/abstract.h" +#include "kernel/replace_fn.h" #include "kernel/for_each_fn.h" #include "library/scoped_ext.h" #include "library/aliases.h" @@ -315,13 +316,87 @@ static bool is_curr_with_or_comma(parser & p) { return p.curr_is_token(get_with_tk()) || p.curr_is_token(get_comma_tk()); } -expr parse_equations(parser & p, name const & n, expr const & type, buffer & auxs) { +/** + For convenience, the left-hand-side of a recursive equation may contain + undeclared variables. + We use parser::undef_id_to_local_scope to force the parser to create a local constant for + each undefined identifier. + + This method validates occurrences of these variables. They can only occur as an application + or macro argument. +*/ +static void validate_equation_lhs(parser const & p, expr const & lhs, buffer const & locals) { + if (is_app(lhs)) { + validate_equation_lhs(p, app_fn(lhs), locals); + validate_equation_lhs(p, app_arg(lhs), locals); + } else if (is_macro(lhs)) { + for (unsigned i = 0; i < macro_num_args(lhs); i++) + validate_equation_lhs(p, macro_arg(lhs, i), locals); + } else if (!is_local(lhs)) { + for_each(lhs, [&](expr const & e, unsigned) { + if (is_local(e) && + std::any_of(locals.begin(), locals.end(), [&](expr const & local) { + return mlocal_name(e) == mlocal_name(local); + })) { + throw parser_error(sstream() << "invalid occurrence of variable '" << mlocal_name(lhs) << + "' in the left-hand-side of recursive equation", p.pos_of(lhs)); + } + return has_local(e); + }); + } +} + +/** + \brief Merge multiple occurrences of a variable in the left-hand-side of a recursive equation. + + \see validate_equation_lhs +*/ +static expr merge_equation_lhs_vars(expr const & lhs, buffer & locals) { + expr_map m; + unsigned j = 0; + for (unsigned i = 0; i < locals.size(); i++) { + unsigned k; + for (k = 0; k < i; k++) { + if (mlocal_name(locals[k]) == mlocal_name(locals[i])) { + m.insert(mk_pair(locals[i], locals[k])); + break; + } + } + if (k == i) { + locals[j] = locals[i]; + j++; + } + } + if (j == locals.size()) + return lhs; + locals.shrink(j); + return replace(lhs, [&](expr const & e) { + if (!has_local(e)) + return some_expr(e); + if (is_local(e)) { + auto it = m.find(e); + if (it != m.end()) + return some_expr(it->second); + } + return none_expr(); + }); +} + +static void throw_invalid_equation_lhs(name const & n, pos_info const & p) { + throw parser_error(sstream() << "invalid recursive equation, head symbol '" + << n << "' in the left-hand-side does not correspond to function(s) being defined", p); +} + +expr parse_equations(parser & p, name const & n, expr const & type, buffer & auxs, + optional const & lenv, buffer const & ps) { buffer eqns; + buffer fns; { - parser::local_scope scope1(p); - parser::undef_id_to_local_scope scope2(p); + parser::local_scope scope1(p, lenv); + for (expr const & param : ps) + p.add_local(param); lean_assert(is_curr_with_or_comma(p)); - expr f = mk_local(n, type); + fns.push_back(mk_local(n, type)); if (p.curr_is_token(get_with_tk())) { while (p.curr_is_token(get_with_tk())) { p.next(); @@ -330,28 +405,57 @@ expr parse_equations(parser & p, name const & n, expr const & type, buffer expr g_type = p.parse_expr(); expr g = mk_local(g_name, g_type); auxs.push_back(g); + fns.push_back(g); } } p.check_token_next(get_comma_tk(), "invalid declaration, ',' expected"); - p.add_local(f); - for (expr const & g : auxs) - p.add_local(g); + for (expr const & fn : fns) + p.add_local(fn); while (true) { - expr lhs = p.parse_expr(); + expr lhs; + unsigned prev_num_undef_ids = p.get_num_undef_ids(); + buffer locals; + { + parser::undef_id_to_local_scope scope2(p); + auto lhs_pos = p.pos(); + lhs = p.parse_expr(); + expr lhs_fn = get_app_fn(lhs); + if (is_explicit(lhs_fn)) + lhs_fn = get_explicit_arg(lhs_fn); + if (is_constant(lhs_fn)) + throw_invalid_equation_lhs(const_name(lhs_fn), lhs_pos); + if (is_local(lhs_fn) && std::all_of(fns.begin(), fns.end(), [&](expr const & fn) { return fn != lhs_fn; })) + throw_invalid_equation_lhs(local_pp_name(lhs_fn), lhs_pos); + if (!is_local(lhs_fn)) + throw parser_error("invalid recursive equation, head symbol in left-hand-side is not a constant", lhs_pos); + unsigned num_undef_ids = p.get_num_undef_ids(); + for (unsigned i = prev_num_undef_ids; i < num_undef_ids; i++) { + locals.push_back(p.get_undef_id(i)); + } + } + validate_equation_lhs(p, lhs, locals); + lhs = merge_equation_lhs_vars(lhs, locals); p.check_token_next(get_assign_tk(), "invalid declaration, ':=' expected"); - expr rhs = p.parse_expr(); - eqns.push_back(mk_equation(lhs, rhs)); + { + parser::local_scope scope2(p); + for (expr const & local : locals) + p.add_local(local); + expr rhs = p.parse_expr(); + eqns.push_back(Fun(fns, Fun(locals, mk_equation(lhs, rhs), p))); + } if (!p.curr_is_token(get_comma_tk())) break; p.next(); } } if (p.curr_is_token(get_wf_tk())) { + auto pos = p.pos(); p.next(); + expr R = p.save_pos(mk_expr_placeholder(), pos); expr Hwf = p.parse_expr(); - return mk_equations(eqns.size(), eqns.data(), Hwf); + return mk_equations(fns.size(), eqns.size(), eqns.data(), R, Hwf); } else { - return mk_equations(eqns.size(), eqns.data()); + return mk_equations(fns.size(), eqns.size(), eqns.data()); } } @@ -409,7 +513,7 @@ class definition_cmd_fn { auto pos = m_p.pos(); m_type = m_p.parse_expr(); if (is_curr_with_or_comma(m_p)) { - m_value = parse_equations(m_p, m_name, m_type, m_aux_decls); + m_value = parse_equations(m_p, m_name, m_type, m_aux_decls, optional(), buffer()); } else if (!is_definition() && !m_p.curr_is_token(get_assign_tk())) { check_end_of_theorem(m_p); m_value = m_p.save_pos(mk_expr_placeholder(), pos); @@ -427,7 +531,7 @@ class definition_cmd_fn { m_p.next(); m_type = m_p.parse_scoped_expr(ps, *lenv); if (is_curr_with_or_comma(m_p)) { - m_value = parse_equations(m_p, m_name, m_type, m_aux_decls); + m_value = parse_equations(m_p, m_name, m_type, m_aux_decls, lenv, ps); } else if (!is_definition() && !m_p.curr_is_token(get_assign_tk())) { check_end_of_theorem(m_p); m_value = m_p.save_pos(mk_expr_placeholder(), pos); diff --git a/src/frontends/lean/elaborator.cpp b/src/frontends/lean/elaborator.cpp index 1cd707db85..1df7c78e22 100644 --- a/src/frontends/lean/elaborator.cpp +++ b/src/frontends/lean/elaborator.cpp @@ -33,6 +33,7 @@ Author: Leonardo de Moura #include "library/local_context.h" #include "library/tactic/expr_to_tactic.h" #include "library/error_handling/error_handling.h" +#include "library/definitional/equations.h" #include "frontends/lean/local_decls.h" #include "frontends/lean/class.h" #include "frontends/lean/tactic_hint.h" @@ -99,10 +100,11 @@ elaborator::elaborator(elaborator_context & ctx, name_generator const & ngen, bo m_has_sorry = has_sorry(m_ctx.m_env); m_relax_main_opaque = false; m_use_tactic_hints = true; - m_no_info = false; - m_tc[0] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), false); - m_tc[1] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), true); - m_nice_mvar_names = nice_mvar_names; + m_no_info = false; + m_in_equation_lhs = false; + m_tc[0] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), false); + m_tc[1] = mk_type_checker(ctx.m_env, m_ngen.mk_child(), true); + m_nice_mvar_names = nice_mvar_names; } expr elaborator::mk_local(name const & n, expr const & t, binder_info const & bi) { @@ -812,6 +814,155 @@ expr elaborator::visit_sorry(expr const & e) { return mk_app(update_constant(e, to_list(u)), m, e.get_tag()); } +expr const & elaborator::get_equation_fn(expr const & eq) const { + expr it = eq; + while (is_lambda(it)) + it = binding_body(it); + if (!is_equation(it)) + throw_elaborator_exception(env(), "ill-formed equation", eq); + expr const & fn = get_app_fn(equation_lhs(it)); + if (!is_local(fn)) + throw_elaborator_exception(env(), "ill-formed equation", eq); + return fn; +} + +static expr copy_domain(unsigned num, expr const & source, expr const & target) { + if (num == 0) { + return target; + } else { + lean_assert(is_binding(source) && is_binding(target)); + return update_binding(source, mk_as_is(binding_domain(source)), copy_domain(num-1, binding_body(source), binding_body(target))); + } +} + +static constraint mk_equations_cnstr(environment const & env, io_state const & ios, expr const & m, expr const & eqns) { + justification j = mk_failed_to_synthesize_jst(env, m); + auto choice_fn = [=](expr const & , expr const &, substitution const & s, + name_generator const &) { + expr new_eqns = substitution(s).instantiate(eqns); + regular(env, ios) << "Equations:\n" << new_eqns << "\n\n"; + // TODO(Leo); + return lazy_list(constraints()); + }; + bool owner = true; + bool relax = false; + return mk_choice_cnstr(m, choice_fn, to_delay_factor(cnstr_group::MaxDelayed), owner, j, relax); +} + +expr elaborator::visit_equations(expr const & eqns, constraint_seq & cs) { + buffer eqs; + buffer new_eqs; + optional new_R; + optional new_Hwf; + + to_equations(eqns, eqs); + + if (eqs.empty()) + throw_elaborator_exception(env(), "invalid empty set of recursive equations", eqns); + + if (is_wf_equations(eqns)) { + new_R = visit(equations_wf_rel(eqns), cs); + new_Hwf = visit(equations_wf_proof(eqns), cs); + expr Hwf_type = infer_type(*new_Hwf, cs); + expr wf = visit(mk_constant("well_founded"), cs); + wf = ::lean::mk_app(wf, *new_R); + justification j = mk_type_mismatch_jst(*new_Hwf, Hwf_type, wf, equations_wf_proof(eqns)); + auto new_Hwf_cs = ensure_has_type(*new_Hwf, Hwf_type, wf, j, m_relax_main_opaque); + new_Hwf = new_Hwf_cs.first; + cs += new_Hwf_cs.second; + } + + flet> set1(m_equation_R, new_R); + unsigned num_fns = equations_num_fns(eqns); + + optional first_eq; + for (expr const & eq : eqs) { + expr new_eq; + if (first_eq) { + // Replace first num_fns domains of eq with the ones in first_eq. + // This is a trick/hack to ensure the fns in each equation have + // the same elaborated type. + new_eq = visit(copy_domain(num_fns, *first_eq, eq), cs); + } else { + new_eq = visit(eq, cs); + first_eq = new_eq; + } + new_eqs.push_back(new_eq); + } + + expr new_eqns; + if (new_R) { + new_eqns = mk_equations(num_fns, new_eqs.size(), new_eqs.data(), *new_R, *new_Hwf); + } else { + new_eqns = mk_equations(num_fns, new_eqs.size(), new_eqs.data()); + } + + lean_assert(first_eq && is_lambda(*first_eq)); + expr type = binding_domain(*first_eq); + expr m = m_full_context.mk_meta(m_ngen, some_expr(type), eqns.get_tag()); + register_meta(m); + constraint c = mk_equations_cnstr(env(), ios(), m, new_eqns); + cs += c; + return m; +} + +expr elaborator::visit_equation(expr const & eq, constraint_seq & cs) { + expr const & lhs = equation_lhs(eq); + expr const & rhs = equation_rhs(eq); + expr lhs_fn = get_app_fn(lhs); + if (is_explicit(lhs_fn)) + lhs_fn = get_explicit_arg(lhs_fn); + if (!is_local(lhs_fn)) + throw exception("ill-formed equation"); + expr new_lhs, new_rhs; + { + flet set(m_in_equation_lhs, true); + new_lhs = visit(lhs, cs); + } + { + optional some_new_lhs(new_lhs); + flet> set1(m_equation_lhs, some_new_lhs); + new_rhs = visit(rhs, cs); + } + + expr lhs_type = infer_type(new_lhs, cs); + expr rhs_type = infer_type(new_rhs, cs); + justification j = mk_justification(eq, [=](formatter const & fmt, substitution const & subst) { + substitution s(subst); + return pp_def_type_mismatch(fmt, local_pp_name(lhs_fn), s.instantiate(lhs_type), s.instantiate(rhs_type)); + }); + pair new_rhs_cs = ensure_has_type(new_rhs, rhs_type, lhs_type, j, m_relax_main_opaque); + new_rhs = new_rhs_cs.first; + cs += new_rhs_cs.second; + return mk_equation(new_lhs, new_rhs); +} + +expr elaborator::visit_inaccessible(expr const & e, constraint_seq & cs) { + if (!m_in_equation_lhs) + throw_elaborator_exception(env(), "invalid occurrence of 'inaccessible' annotation, it must only occur in the " + "left-hand-side of recursive equations", e); + return mk_inaccessible(visit(get_annotation_arg(e), cs)); +} + +expr elaborator::visit_decreasing(expr const & e, constraint_seq & cs) { + if (!m_equation_lhs) + throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, it must only occur in " + "the right-hand-side of recursive equations", e); + if (!m_equation_R) + throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, it can only be used when " + "recursive equations are being defined by well-founded recursion", e); + expr const & lhs_fn = get_app_fn(*m_equation_lhs); + if (get_app_fn(decreasing_app(e)) != lhs_fn) + throw_elaborator_exception(env(), "invalid occurrence of 'decreasing' annotation, expression must be an " + "application of the recursive function being defined", e); + expr dec_app = visit(decreasing_app(e), cs); + expr dec_proof = visit(decreasing_proof(e), cs); + // Remark: perhaps we should enforce the type of dec_proof here. + // We may have enough information to wrap the arguments in a sigma type (reason: the type of the function being elaborated has holes). + // Possible solution: create a constraint that enforces the type as soon the type of function has been elaborated. + return mk_decreasing(dec_app, dec_proof); +} + expr elaborator::visit_core(expr const & e, constraint_seq & cs) { if (is_placeholder(e)) { return visit_placeholder(e, cs); @@ -841,6 +992,14 @@ expr elaborator::visit_core(expr const & e, constraint_seq & cs) { return visit_core(get_explicit_arg(e), cs); } else if (is_sorry(e)) { return visit_sorry(e); + } else if (is_equations(e)) { + lean_unreachable(); + } else if (is_equation(e)) { + return visit_equation(e, cs); + } else if (is_inaccessible(e)) { + return visit_inaccessible(e, cs); + } else if (is_decreasing(e)) { + return visit_decreasing(e, cs); } else { switch (e.kind()) { case expr_kind::Local: return e; @@ -882,6 +1041,8 @@ pair elaborator::visit(expr const & e) { } else { r = visit_core(b, cs); } + } else if (is_equations(e)) { + r = visit_equations(e, cs); } else if (is_explicit(get_app_fn(e))) { r = visit_core(e, cs); } else { diff --git a/src/frontends/lean/elaborator.h b/src/frontends/lean/elaborator.h index 3196022c3e..6fae783fb5 100644 --- a/src/frontends/lean/elaborator.h +++ b/src/frontends/lean/elaborator.h @@ -53,6 +53,14 @@ class elaborator : public coercion_info_manager { // if m_no_info is true, we do not collect information when true, // we set is to true whenever we find no_info annotation. bool m_no_info; + // if m_in_equation_lhs is true, we are processing the left-hand-side of an equation + // and inaccessible expressions are allowed + bool m_in_equation_lhs; + // if m_equation_lhs is not none, we are processing the right-hand-side of an equation + // and decreasing expressions are allowed + optional m_equation_lhs; + // if m_equation_R is not none when elaborator is processing recursive equation using the well-founded relation R. + optional m_equation_R; bool m_use_tactic_hints; info_manager m_pre_info_data; bool m_has_sorry; @@ -151,6 +159,13 @@ class elaborator : public coercion_info_manager { std::tuple apply(substitution & s, expr const & e); pair elaborate_nested(list const & g, expr const & e, bool relax, bool use_tactic_hints, bool report_unassigned); + + expr const & get_equation_fn(expr const & eq) const; + expr visit_equations(expr const & eqns, constraint_seq & cs); + expr visit_equation(expr const & e, constraint_seq & cs); + expr visit_inaccessible(expr const & e, constraint_seq & cs); + expr visit_decreasing(expr const & e, constraint_seq & cs); + public: elaborator(elaborator_context & ctx, name_generator const & ngen, bool nice_mvar_names = false); std::tuple operator()(list const & ctx, expr const & e, bool _ensure_type, diff --git a/src/frontends/lean/parser.cpp b/src/frontends/lean/parser.cpp index ff51bb9db8..3569061f9a 100644 --- a/src/frontends/lean/parser.cpp +++ b/src/frontends/lean/parser.cpp @@ -76,6 +76,12 @@ parser::local_scope::local_scope(parser & p, environment const & env): m_p.m_env = env; m_p.push_local_scope(); } +parser::local_scope::local_scope(parser & p, optional const & env): + m_p(p), m_env(p.env()) { + if (env) + m_p.m_env = *env; + m_p.push_local_scope(); +} parser::local_scope::~local_scope() { m_p.pop_local_scope(); m_p.m_env = m_env; @@ -362,8 +368,11 @@ expr parser::propagate_levels(expr const & e, levels const & ls) { } } -pos_info parser::pos_of(expr const & e, pos_info default_pos) { - if (auto it = m_pos_table.find(get_tag(e))) +pos_info parser::pos_of(expr const & e, pos_info default_pos) const { + tag t = e.get_tag(); + if (t == nulltag) + return default_pos; + if (auto it = m_pos_table.find(t)) return *it; else return default_pos; @@ -432,7 +441,7 @@ void parser::push_local_scope(bool save_options) { optional opts; if (save_options) opts = m_ios.get_options(); - m_parser_scope_stack = cons(parser_scope_stack_elem(opts, m_level_variables, m_variables, m_include_vars), + m_parser_scope_stack = cons(parser_scope_stack_elem(opts, m_level_variables, m_variables, m_include_vars, m_undef_ids.size()), m_parser_scope_stack); } @@ -451,6 +460,7 @@ void parser::pop_local_scope() { m_level_variables = s.m_level_variables; m_variables = s.m_variables; m_include_vars = s.m_include_vars; + m_undef_ids.shrink(s.m_num_undef_ids); m_parser_scope_stack = tail(m_parser_scope_stack); } @@ -1111,7 +1121,9 @@ expr parser::id_to_expr(name const & id, pos_info const & p) { if (m_undef_id_behavior == undef_id_behavior::AssumeConstant) { r = save_pos(mk_constant(get_namespace(m_env) + id, ls), p); } else if (m_undef_id_behavior == undef_id_behavior::AssumeLocal) { - r = save_pos(mk_local(id, mk_expr_placeholder()), p); + expr local = mk_local(id, mk_expr_placeholder()); + m_undef_ids.push_back(local); + r = save_pos(local, p); } } if (!r) diff --git a/src/frontends/lean/parser.h b/src/frontends/lean/parser.h index 86b6336c5c..9ae0350e2f 100644 --- a/src/frontends/lean/parser.h +++ b/src/frontends/lean/parser.h @@ -52,8 +52,10 @@ struct parser_scope_stack_elem { name_set m_level_variables; name_set m_variables; name_set m_include_vars; - parser_scope_stack_elem(optional const & o, name_set const & lvs, name_set const & vs, name_set const & ivs): - m_options(o), m_level_variables(lvs), m_variables(vs), m_include_vars(ivs) {} + unsigned m_num_undef_ids; + parser_scope_stack_elem(optional const & o, name_set const & lvs, name_set const & vs, name_set const & ivs, + unsigned num_undef_ids): + m_options(o), m_level_variables(lvs), m_variables(vs), m_include_vars(ivs), m_num_undef_ids(num_undef_ids) {} }; typedef list parser_scope_stack; @@ -130,6 +132,8 @@ class parser { // curr command token name m_cmd_token; + buffer m_undef_ids; + void display_warning_pos(unsigned line, unsigned pos); void display_warning_pos(pos_info p); void display_error_pos(unsigned line, unsigned pos); @@ -255,8 +259,8 @@ public: pos_info pos() const { return pos_info(m_scanner.get_line(), m_scanner.get_pos()); } expr save_pos(expr e, pos_info p); expr rec_save_pos(expr const & e, pos_info p); - pos_info pos_of(expr const & e, pos_info default_pos); - pos_info pos_of(expr const & e) { return pos_of(e, pos()); } + pos_info pos_of(expr const & e, pos_info default_pos) const; + pos_info pos_of(expr const & e) const { return pos_of(e, pos()); } pos_info cmd_pos() const { return m_last_cmd_pos; } name const & get_cmd_token() const { return m_cmd_token; } void set_line(unsigned p) { return m_scanner.set_line(p); } @@ -359,7 +363,10 @@ public: expr parse_scoped_expr(buffer const & ps, unsigned rbp = 0) { return parse_scoped_expr(ps.size(), ps.data(), rbp); } struct local_scope { parser & m_p; environment m_env; - local_scope(parser & p, bool save_options = false); local_scope(parser & p, environment const & env); ~local_scope(); + local_scope(parser & p, bool save_options = false); + local_scope(parser & p, environment const & env); + local_scope(parser & p, optional const & env); + ~local_scope(); }; bool has_locals() const { return !m_local_decls.empty() || !m_local_level_decls.empty(); } void add_local_level(name const & n, level const & l, bool is_variable = false); @@ -395,6 +402,11 @@ public: struct undef_id_to_const_scope : public flet { undef_id_to_const_scope(parser & p); }; struct undef_id_to_local_scope : public flet { undef_id_to_local_scope(parser &); }; + /** \brief Return the size of the stack of undefined local constants */ + unsigned get_num_undef_ids() const { return m_undef_ids.size(); } + /** \brief Return the i-th undefined local constants */ + expr const & get_undef_id(unsigned i) const { return m_undef_ids[i]; } + /** \brief Elaborate \c e, and tolerate metavariables in the result. */ std::tuple elaborate_relaxed(expr const & e, list const & ctx = list()); /** \brief Elaborate \c e, and ensure it is a type. */ diff --git a/src/library/definitional/equations.cpp b/src/library/definitional/equations.cpp index 3eae9603ff..8e23acc02e 100644 --- a/src/library/definitional/equations.cpp +++ b/src/library/definitional/equations.cpp @@ -22,18 +22,27 @@ static std::string * g_decreasing_opcode = nullptr; [[ noreturn ]] static void throw_eq_ex() { throw exception("unexpected occurrence of 'equation' expression"); } class equations_macro_cell : public macro_definition_cell { + unsigned m_num_fns; public: + equations_macro_cell(unsigned num_fns):m_num_fns(num_fns) {} virtual name get_name() const { return *g_equations_name; } virtual pair get_type(expr const &, extension_context &) const { throw_eqs_ex(); } virtual optional expand(expr const &, extension_context &) const { throw_eqs_ex(); } - virtual void write(serializer & s) const { s.write_string(*g_equations_opcode); } + virtual void write(serializer & s) const { s << *g_equations_opcode << m_num_fns; } + unsigned get_num_fns() const { return m_num_fns; } }; class equation_macro_cell : public macro_definition_cell { public: virtual name get_name() const { return *g_equation_name; } - virtual pair get_type(expr const &, extension_context &) const { throw_eq_ex(); } - virtual optional expand(expr const &, extension_context &) const { throw_eq_ex(); } + virtual pair get_type(expr const &, extension_context &) const { + expr dummy = mk_Prop(); + return mk_pair(dummy, constraint_seq()); + } + virtual optional expand(expr const &, extension_context &) const { + expr dummy = mk_Type(); + return some_expr(dummy); + } virtual void write(serializer & s) const { s.write_string(*g_equation_opcode); } }; @@ -56,11 +65,18 @@ public: virtual void write(serializer & s) const { s.write_string(*g_decreasing_opcode); } }; -static macro_definition * g_equations = nullptr; static macro_definition * g_equation = nullptr; static macro_definition * g_decreasing = nullptr; bool is_equation(expr const & e) { return is_macro(e) && macro_def(e) == *g_equation; } + +bool is_lambda_equation(expr const & e) { + if (is_lambda(e)) + return is_lambda_equation(binding_body(e)); + else + return is_equation(e); +} + expr const & equation_lhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 0); } expr const & equation_rhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 1); } expr mk_equation(expr const & lhs, expr const & rhs) { @@ -76,40 +92,54 @@ expr mk_decreasing(expr const & t, expr const & H) { return mk_macro(*g_decreasing, 2, args); } -bool is_equations(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations; } +bool is_equations(expr const & e) { return is_macro(e) && macro_def(e).get_name() == *g_equations_name; } bool is_wf_equations_core(expr const & e) { lean_assert(is_equations(e)); - return !is_equation(macro_arg(e, macro_num_args(e) - 1)); + return macro_num_args(e) >= 3 && !is_lambda_equation(macro_arg(e, macro_num_args(e) - 1)); } bool is_wf_equations(expr const & e) { return is_equations(e) && is_wf_equations_core(e); } unsigned equations_size(expr const & e) { + lean_assert(is_equations(e)); if (is_wf_equations_core(e)) - return macro_num_args(e) - 1; + return macro_num_args(e) - 2; else return macro_num_args(e); } -void to_equations(expr const & e, buffer & eqns) { - lean_assert(is_equation(e)); - unsigned sz = equations_size(e); - for (unsigned i = 0; i < sz; i++) - eqns.push_back(macro_arg(e, i)); +unsigned equations_num_fns(expr const & e) { + lean_assert(is_equations(e)); + return static_cast(macro_def(e).raw())->get_num_fns(); } expr const & equations_wf_proof(expr const & e) { lean_assert(is_wf_equations(e)); return macro_arg(e, macro_num_args(e) - 1); } -expr mk_equations(unsigned num, expr const * eqns) { - lean_assert(std::all_of(eqns, eqns+num, is_equation)); - lean_assert(num > 0); - return mk_macro(*g_equations, num, eqns); +expr const & equations_wf_rel(expr const & e) { + lean_assert(is_wf_equations(e)); + return macro_arg(e, macro_num_args(e) - 2); } -expr mk_equations(unsigned num, expr const * eqns, expr const & Hwf) { - lean_assert(std::all_of(eqns, eqns+num, is_equation)); - lean_assert(num > 0); +void to_equations(expr const & e, buffer & eqns) { + lean_assert(is_equations(e)); + unsigned sz = equations_size(e); + for (unsigned i = 0; i < sz; i++) + eqns.push_back(macro_arg(e, i)); +} +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs) { + lean_assert(num_fns > 0); + lean_assert(num_eqs > 0); + lean_assert(std::all_of(eqs, eqs+num_eqs, is_lambda_equation)); + macro_definition def(new equations_macro_cell(num_fns)); + return mk_macro(def, num_eqs, eqs); +} +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf) { + lean_assert(num_fns > 0); + lean_assert(num_eqs > 0); + lean_assert(std::all_of(eqs, eqs+num_eqs, is_lambda_equation)); buffer args; - args.append(num, eqns); + args.append(num_eqs, eqs); + args.push_back(R); args.push_back(Hwf); - return mk_macro(*g_equations, args.size(), args.data()); + macro_definition def(new equations_macro_cell(num_fns)); + return mk_macro(def, args.size(), args.data()); } expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); } @@ -120,7 +150,6 @@ void initialize_equations() { g_equation_name = new name("equation"); g_decreasing_name = new name("decreasing"); g_inaccessible_name = new name("innaccessible"); - g_equations = new macro_definition(new equations_macro_cell()); g_equation = new macro_definition(new equation_macro_cell()); g_decreasing = new macro_definition(new decreasing_macro_cell()); g_equations_opcode = new std::string("Eqns"); @@ -128,15 +157,17 @@ void initialize_equations() { g_decreasing_opcode = new std::string("Decr"); register_annotation(*g_inaccessible_name); register_macro_deserializer(*g_equations_opcode, - [](deserializer &, unsigned num, expr const * args) { - if (num == 0) + [](deserializer & d, unsigned num, expr const * args) { + unsigned num_fns; + d >> num_fns; + if (num == 0 || num_fns == 0) throw corrupted_stream_exception(); - if (!is_equation(args[num-1])) { - if (num == 1) + if (!is_lambda_equation(args[num-1])) { + if (num <= 2) throw corrupted_stream_exception(); - return mk_equations(num-1, args, args[num-1]); + return mk_equations(num_fns, num-2, args, args[num-2], args[num-1]); } else { - return mk_equations(num, args); + return mk_equations(num_fns, num, args); } }); register_macro_deserializer(*g_equation_opcode, @@ -157,7 +188,6 @@ void finalize_equations() { delete g_equation_opcode; delete g_equations_opcode; delete g_decreasing_opcode; - delete g_equations; delete g_equation; delete g_decreasing; delete g_equations_name; diff --git a/src/library/definitional/equations.h b/src/library/definitional/equations.h index e6d0bbf70f..0c992d0263 100644 --- a/src/library/definitional/equations.h +++ b/src/library/definitional/equations.h @@ -12,6 +12,8 @@ bool is_equation(expr const & e); expr const & equation_lhs(expr const & e); expr const & equation_rhs(expr const & e); expr mk_equation(expr const & lhs, expr const & rhs); +/** \brief Return true if e is of the form fun a_1 ... a_n, equation */ +bool is_lambda_equation(expr const & e); bool is_decreasing(expr const & e); expr const & decreasing_app(expr const & e); @@ -21,10 +23,12 @@ expr mk_decreasing(expr const & t, expr const & H); bool is_equations(expr const & e); bool is_wf_equations(expr const & e); unsigned equations_size(expr const & e); +unsigned equations_num_fns(expr const & e); void to_equations(expr const & e, buffer & eqns); expr const & equations_wf_proof(expr const & e); -expr mk_equations(unsigned num, expr const * eqns); -expr mk_equations(unsigned num, expr const * eqns, expr const & Hwf); +expr const & equations_wf_rel(expr const & e); +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs); +expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf); expr mk_inaccessible(expr const & e); bool is_inaccessible(expr const & e); diff --git a/tests/lean/extra/rec.lean b/tests/lean/extra/rec.lean new file mode 100644 index 0000000000..da24a61b50 --- /dev/null +++ b/tests/lean/extra/rec.lean @@ -0,0 +1,31 @@ +import data.vector +open nat vector + +check lt.base +set_option pp.implicit true + +definition add : nat → nat → nat, +add zero b := b, +add (succ a) b := succ (add a b) + +definition map {A B C : Type} (f : A → B → C) : Π {n}, vector A n → vector B n → vector C n, +map nil nil := nil, +map (a :: va) (b :: vb) := f a b :: map va vb + +definition fib : nat → nat, +fib 0 := 1, +fib 1 := 1, +fib (a+2) := (fib a ↓ lt.step (lt.base a)) + (fib (a+1) ↓ lt.base (a+1)) +[wf] lt.wf + +definition half : nat → nat, +half 0 := 0, +half 1 := 0, +half (x+2) := half x + 1 + +variables {A B : Type} +inductive image_of (f : A → B) : B → Type := +mk : Π a, image_of f (f a) + +definition inv {f : A → B} : Π b, image_of f b → A, +inv ⌞f a⌟ (image_of.mk f a) := a