diff --git a/src/library/equations_compiler/elim_match.cpp b/src/library/equations_compiler/elim_match.cpp index e8fabd383e..1594d35608 100644 --- a/src/library/equations_compiler/elim_match.cpp +++ b/src/library/equations_compiler/elim_match.cpp @@ -16,6 +16,7 @@ Author: Leonardo de Moura #include "library/app_builder.h" #include "library/tactic/tactic_state.h" #include "library/tactic/revert_tactic.h" +#include "library/tactic/clear_tactic.h" #include "library/tactic/cases_tactic.h" #include "library/tactic/intro_tactic.h" #include "library/equations_compiler/equations.h" @@ -62,12 +63,13 @@ struct elim_match_fn { }; struct lemma { - list m_vars; - expr m_eqn; /* equation (it might be conditional) */ - expr m_proof; + local_context m_lctx; + list m_vars; + expr m_eqn; /* equation (it might be conditional) */ + expr m_proof; lemma() {} - lemma(list const & vars, expr const & eqn, expr const & proof): - m_vars(vars), m_eqn(eqn), m_proof(proof) {} + lemma(local_context const & lctx, list const & vars, expr const & eqn, expr const & proof): + m_lctx(lctx), m_vars(vars), m_eqn(eqn), m_proof(proof) {} }; /** Result for the compilation procedure. */ @@ -384,9 +386,9 @@ struct elim_match_fn { void trace_lemmas(program const & P, char const * header, buffer const & lemmas) { trace_match_detail({ tout() << "[" << m_depth << "] " << header << " lemmas:\n"; - auto pp_fn = mk_pp_ctx(P); for (lemma const & L : lemmas) { /* Replace function with its name. */ + auto pp_fn = mk_pp_ctx(L.m_lctx); expr tmp_eqn = update_eqn_lhs(L.m_eqn, P.m_nvars, [&](buffer const & args) { return mk_rev_app(mk_constant(P.m_fn_name), args); @@ -402,7 +404,30 @@ struct elim_match_fn { result compile_skip(program const & P) { trace_match(tout() << "skip transition\n";); - lean_unreachable(); + program new_P; + new_P.m_fn_name = P.m_fn_name; + buffer new_names; + optional aux_goal = intron(m_env, m_opts, m_mctx, P.m_goal, 1, new_names); + if (!aux_goal) throw_ill_formed_eqns(); + lean_assert(new_names.size() == 1); + expr H = m_mctx.get_metavar_decl(*aux_goal)->get_context().get_local_decl(new_names[0])->mk_ref(); + new_P.m_goal = *aux_goal; // clear(m_mctx, *aux_goal, H); + new_P.m_nvars = P.m_nvars - 1; + buffer new_eqns; + for (equation const & eqn : P.m_equations) { + equation new_eqn = eqn; + new_eqn.m_patterns = tail(eqn.m_patterns); + new_eqns.push_back(new_eqn); + } + new_P.m_equations = to_list(new_eqns); + result R = compile_core(new_P); + result new_R; + type_context ctx = mk_type_context(P); + if (m_lemmas) { + // TODO(Leo) + + } + return new_R; } /* Update the equation left hand side @@ -448,6 +473,7 @@ struct elim_match_fn { buffer new_lemmas; for (lemma const & L : R.m_lemmas) { lemma new_L; + new_L.m_lctx = L.m_lctx; new_L.m_vars = cons(x, L.m_vars); new_L.m_eqn = update_eqn_for_variable_transition(L.m_eqn, new_P.m_nvars, new_R.m_code, x); new_L.m_proof = L.m_proof; @@ -474,8 +500,7 @@ struct elim_match_fn { (nil, L_1 := R_1) (cons, (cons a b) L_2 := R_2) (cons, (cons c d) L_3 := R_3) - (nil L_4 := R_4) - */ + (nil L_4 := R_4) */ void distribute_constructor_equations(list const & eqns, buffer> & R) { for (equation const & eqn : eqns) { lean_assert(eqn.m_patterns); @@ -502,8 +527,7 @@ struct elim_match_fn { The parameter \c field should be interpreted as a bit-mask here. It says which constructor fields should be used. That is, "some" value means the field - should be considered. - */ + should be considered. */ list get_equations_for(name const & C, list> const & fields, name_map const & renaming, local_context const & lctx, buffer> const & eqns) { buffer R; @@ -554,20 +578,44 @@ struct elim_match_fn { } } + static list to_bitmask(list> const & ilist) { + return map2(ilist, [](optional const & ilist) { return static_cast(ilist); }); + } + /* Update an equation left-hand-side of the form (f a_1 ... a_n b_1 ... b_m) - where n == nfields and n+m == arity, with + where n == number of true entries in mask, and n+m == arity, with - (new_fn (c a_1 ... a_n) b_1 ... b_m) */ - expr update_eqn_for_constructor_transition(expr const & eqn, unsigned arity, expr const & new_fn, expr const & c, unsigned nfields) { - return update_eqn_lhs(eqn, arity, [&](buffer & args) { - lean_assert(args.size() >= nfields); - expr c_app = mk_rev_app(c, nfields, args.end() - nfields); - args.shrink(args.size() - nfields); - args.push_back(c_app); - return mk_rev_app(new_fn, args); + (new_fn (c I_params i ... a_1 ... a_n) b_1 ... b_m) + + if there are false entries in mask, we need to infer any missing arguments 'i'. */ + expr update_eqn_for_constructor_transition(lemma const & L, list const & mask, + unsigned arity, expr const & new_fn, + name const & c_name, buffer const & I_params) { + type_context ctx = mk_type_context(L.m_lctx); + return update_eqn_lhs(L.m_eqn, arity, [&](buffer & args) { + std::reverse(args.begin(), args.end()); + buffer c_mask; + buffer c_args; + /* Add I_params */ + for (expr const & p : I_params) { + c_mask.push_back(true); + c_args.push_back(p); + } + /* Add constructor fields */ + unsigned i = 0; + for (bool b : mask) { + /* Remark: b is false only for indexed families. */ + c_mask.push_back(b); + if (b) { + c_args.push_back(args[i]); + i++; + } + } + expr c_app = mk_app(ctx, c_name, c_mask.size(), c_mask.data(), c_args.data()); + return mk_app(mk_app(new_fn, c_app), args.size() - i, args.data() + i); }); } @@ -601,10 +649,11 @@ struct elim_match_fn { distribute_constructor_equations(P.m_equations, equations_by_constructor); /* For each (reachable) case, we invoke compile recursively, and we store - name of the constructor - - number of fields of this constructor + - bitmask for which fields were introduced. The lenght of the bitmask is equal + to the head(ilist). The value is true iff the corresponding element in head(ilist) is not none. - "arity" of the auxiliary program being used in the recursive call - result of the compilation for this auxiliary function. */ - buffer> result_by_constructor; + buffer, unsigned, result>> result_by_constructor; while (new_goals) { lean_assert(new_goal_cnames && ilist && rlist); program new_P; @@ -613,18 +662,18 @@ struct elim_match_fn { /* Revert constructor fields (which have not been eliminated by dependent pattern matching). */ buffer to_revert; to_buffer_local(new_goal, head(ilist), to_revert); - unsigned to_revert_size = to_revert.size(); - unsigned nfields = to_revert_size; - expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert); + unsigned to_revert_size = to_revert.size(); + unsigned num_intro_fields = to_revert_size; + expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert); lean_assert(to_revert.size() == to_revert_size); new_P.m_goal = aux_mvar2; /* The arity of the auxiliary program is the arity of the original program - - 1 (we consumed one argument in this step) and + nfields (we added nfields new arguments). */ - new_P.m_nvars = P.m_nvars - 1 + nfields; + - 1 (we consumed one argument in this step) and + number of introduced constructor fields. */ + new_P.m_nvars = P.m_nvars - 1 + num_intro_fields; new_P.m_equations = get_equations_for(head(new_goal_cnames), head(ilist), head(rlist), get_local_context(aux_mvar2), equations_by_constructor); result new_R = compile_core(new_P); - result_by_constructor.emplace_back(head(new_goal_cnames), nfields, new_P.m_nvars, new_R); + result_by_constructor.emplace_back(head(new_goal_cnames), to_bitmask(head(ilist)), new_P.m_nvars, new_R); new_goals = tail(new_goals); new_goal_cnames = tail(new_goal_cnames); @@ -639,16 +688,16 @@ struct elim_match_fn { buffer I_params; levels I_lvls = get_inductive_levels_and_params(ctx, I, I_params); buffer new_lemmas; - for (std::tuple const & entry : result_by_constructor) { + for (std::tuple, unsigned, result> const & entry : result_by_constructor) { name const & cname = std::get<0>(entry); /* constructor name */ - unsigned nfields = std::get<1>(entry); + list mask = std::get<1>(entry); /* bitmask indicating which constructor fields have been introduced by cases-tactic. */ unsigned arity = std::get<2>(entry); result const & Rc = std::get<3>(entry); - expr c = mk_app(mk_constant(cname, I_lvls), I_params); for (lemma const & L : Rc.m_lemmas) { lemma new_L; + new_L.m_lctx = L.m_lctx; new_L.m_vars = L.m_vars; - new_L.m_eqn = update_eqn_for_constructor_transition(L.m_eqn, arity, new_R.m_code, c, nfields); + new_L.m_eqn = update_eqn_for_constructor_transition(L, mask, arity, new_R.m_code, cname, I_params); new_L.m_proof = L.m_proof; new_lemmas.push_back(new_L); } @@ -690,7 +739,7 @@ struct elim_match_fn { type_context ctx = mk_type_context(get_local_context(P)); expr eq = mk_eq(ctx, rhs, rhs); expr H = mk_eq_refl(ctx, rhs); - R.m_lemmas = to_list(lemma(list(), eq, H)); + R.m_lemmas = to_list(lemma(ctx.lctx(), list(), eq, H)); } return R; } @@ -738,6 +787,17 @@ struct elim_match_fn { return R; } + void abstract_eqns_vars(list const & Ls, buffer & R) { + for (lemma const & L : Ls) { + type_context ctx = mk_type_context(L.m_lctx); + buffer vars; + to_buffer(L.m_vars, vars); + expr e = ctx.mk_pi(vars, L.m_eqn); + expr H = ctx.mk_lambda(vars, L.m_proof); + R.emplace_back(e, H); + } + } + expr operator()(local_context const & lctx, expr const & eqns) { lean_assert(equations_num_fns(eqns) == 1); DEBUG_CODE({ @@ -748,6 +808,13 @@ struct elim_match_fn { program P = mk_program(lctx, eqns); result R = compile(P); +#if 0 + if (m_lemmas) { + buffer Hs; + abstract_eqns_vars(R.m_lemmas, Hs); + } +#endif + lean_unreachable(); } };