feat(library/equations_compiler/elim_match): add variable/constructor transitions

This commit is contained in:
Leonardo de Moura 2016-08-20 09:55:44 -07:00
parent 6aa2ab6538
commit 67dc68b24d

View file

@ -11,13 +11,19 @@ Author: Leonardo de Moura
#include "library/string.h"
#include "library/pp_options.h"
#include "library/generic_exception.h"
#include "library/util.h"
#include "library/locals.h"
#include "library/app_builder.h"
#include "library/tactic/tactic_state.h"
#include "library/tactic/revert_tactic.h"
#include "library/tactic/cases_tactic.h"
#include "library/tactic/intro_tactic.h"
#include "library/equations_compiler/equations.h"
#include "library/equations_compiler/util.h"
namespace lean {
#define trace_match(Code) lean_trace(name({"eqn_compiler", "elim_match"}), Code)
#define trace_match_detail(Code) lean_trace(name({"eqn_compiler", "elim_match_detail"}), Code)
struct elim_match_fn {
environment m_env;
@ -27,29 +33,53 @@ struct elim_match_fn {
expr m_ref;
unsigned m_depth{0};
buffer<bool> m_used_eqns;
bool m_lemmas{true};
elim_match_fn(environment const & env, options const & opts,
metavar_context const & mctx):
m_env(env), m_opts(opts), m_mctx(mctx) {}
struct equation {
list<pair<name, name>> m_renames;
list<pair<expr, expr>> m_renames;
local_context m_lctx;
list<expr> m_patterns;
expr m_rhs;
expr m_ref; /* for reporting errors */
unsigned m_idx;
equation() {}
equation(equation const & eqn, list<expr> const & new_patterns):
m_renames(eqn.m_renames), m_lctx(eqn.m_lctx), m_patterns(new_patterns),
m_rhs(eqn.m_rhs), m_ref(eqn.m_ref), m_idx(eqn.m_idx) {}
};
struct program {
name m_fn_name; /* for debugging purposes */
/* Metavariable containing the context for the program */
expr m_goal;
/* Variables that still need to be matched/processed */
list<name> m_var_stack;
/* Number of variables that still need to be matched/processed */
unsigned m_nvars;
list<equation> m_equations;
};
struct lemma {
list<expr> m_vars;
expr m_eqn; /* equation (it might be conditional) */
expr m_proof;
lemma() {}
lemma(list<expr> const & vars, expr const & eqn, expr const & proof):
m_vars(vars), m_eqn(eqn), m_proof(proof) {}
};
/** Result for the compilation procedure. */
struct result {
/* m_code is the expression that implements a program. */
expr m_code;
/* List of equation lemmas that hold for m_code, and their proofs */
list<lemma> m_lemmas;
result() {}
result(expr const & c):m_code(c) {}
};
[[ noreturn ]] void throw_error(char const * msg) {
throw_generic_exception(msg, m_ref);
}
@ -58,19 +88,43 @@ struct elim_match_fn {
throw_generic_exception(strm, m_ref);
}
local_context get_local_context(expr const & mvar) {
lean_assert(is_metavar(mvar));
metavar_decl mdecl = *m_mctx.get_metavar_decl(mvar);
return mdecl.get_context();
}
local_context get_local_context(program const & P) {
return get_local_context(P.m_goal);
}
type_context mk_type_context(local_context const & lctx) {
return mk_type_context_for(m_env, m_opts, m_mctx, lctx);
}
type_context mk_type_context(program const & P) {
return mk_type_context(get_local_context(P));
}
std::function<format(expr const &)> mk_pp_ctx(local_context const & lctx) {
options opts = m_opts.update(get_pp_beta_name(), false);
type_context ctx = mk_type_context_for(m_env, opts, m_mctx, lctx);
return ::lean::mk_pp_ctx(ctx);
}
std::function<format(expr const &)> mk_pp_ctx(program const & P) {
return mk_pp_ctx(get_local_context(P));
}
format nest(format const & fmt) const {
return ::lean::nest(get_pp_indent(m_opts), fmt);
}
unsigned get_arity(local_context const & lctx, expr const & e) {
unsigned get_eqns_arity(local_context const & lctx, expr const & eqns) {
/* Naive way to retrieve the arity of the function being defined */
lean_assert(is_equations(e));
lean_assert(is_equations(eqns));
type_context ctx = mk_type_context(lctx);
unpack_eqns ues(ctx, e);
unpack_eqns ues(ctx, eqns);
return ues.get_arity_of(0);
}
@ -82,15 +136,26 @@ struct elim_match_fn {
return is_constructor(get_app_fn(e));
}
bool is_inductive(expr const & e) const {
return static_cast<bool>(eqns_env_interface(m_env).is_inductive(e));
}
bool is_inductive_app(expr const & e) const {
return is_inductive(get_app_fn(e));
}
bool is_value(expr const & e) const {
return to_num(e) || to_char(e) || to_string(e) || is_constructor(e);
}
/* Normalize until head is constructor or value */
expr whnf_pattern(type_context & ctx, expr const & e) {
return ctx.whnf_pred(e, [&](expr const & e) {
return !is_constructor_app(e) && !is_value(e);
});
if (is_inaccessible(e))
return e;
else
return ctx.whnf_pred(e, [&](expr const & e) {
return !is_constructor_app(e) && !is_value(e);
});
}
/* Normalize until head is constructor */
@ -100,21 +165,23 @@ struct elim_match_fn {
});
}
pair<expr, list<name>>
mk_main_goal(local_context lctx, expr fn_type, unsigned arity) {
type_context ctx = mk_type_context(lctx);
buffer<name> vars;
name x("_x");
for (unsigned i = 0; i < arity; i++) {
fn_type = ctx.whnf(fn_type);
if (!is_pi(fn_type)) throw_ill_formed_eqns();
expr var = ctx.push_local(x.append_after(i+1), binding_domain(fn_type));
vars.push_back(mlocal_name(var));
fn_type = instantiate(binding_body(fn_type), var);
}
m_mctx = ctx.mctx();
expr m = m_mctx.mk_metavar_decl(ctx.lctx(), fn_type);
return mk_pair(m, to_list(vars));
/* Normalize until head is an inductive datatype */
expr whnf_inductive(type_context & ctx, expr const & e) {
return ctx.whnf_pred(e, [&](expr const & e) {
return !is_inductive_app(e);
});
}
/* Store in args the parameters of the inductive datatype I */
levels get_inductive_levels_and_params(type_context & ctx, expr const & I, buffer<expr> & params) {
expr I1 = whnf_inductive(ctx, I);
buffer<expr> args;
expr const & Ifn = get_app_args(I1, args);
unsigned nparams = eqns_env_interface(m_env).get_inductive_num_params(const_name(Ifn));
lean_assert(nparams <= args.size());
for (unsigned i = 0; i < nparams; i++)
params.push_back(args[i]);
return const_levels(Ifn);
}
optional<equation> mk_equation(local_context const & lctx, expr const & eqn, unsigned idx) {
@ -169,18 +236,19 @@ struct elim_match_fn {
lean_assert(is_equations(e));
buffer<expr> eqns;
to_equations(e, eqns);
unsigned arity = get_arity(lctx, e);
unsigned arity = get_eqns_arity(lctx, e);
program P;
P.m_fn_name = binding_name(eqns[0]);
expr fn_type = binding_domain(eqns[0]);
std::tie(P.m_goal, P.m_var_stack) = mk_main_goal(lctx, fn_type, arity);
P.m_goal = m_mctx.mk_metavar_decl(lctx, fn_type);
P.m_nvars = arity;
P.m_equations = mk_equations(lctx, eqns);
return P;
}
format pp_equation(equation const & eqn) {
format r;
auto pp = mk_pp_ctx(m_env, m_opts, m_mctx, eqn.m_lctx);
auto pp = mk_pp_ctx(eqn.m_lctx);
bool first = true;
for (expr const & p : eqn.m_patterns) {
if (first) first = false; else r += format(" ");
@ -192,16 +260,7 @@ struct elim_match_fn {
format pp_program(program const & P) {
format r;
r += format("program") + space() + format(P.m_fn_name);
metavar_decl mdecl = *m_mctx.get_metavar_decl(P.m_goal);
local_context lctx = mdecl.get_context();
auto pp = mk_pp_ctx(m_env, m_opts, m_mctx, lctx);
format vstack;
for (name const & x : P.m_var_stack) {
local_decl x_decl = *lctx.get_local_decl(x);
vstack += line() + paren(format(x_decl.get_pp_name()) + space() + colon() + space() + pp(x_decl.get_type()));
}
r += group(nest(vstack));
r += format("program") + space() + format(P.m_fn_name) + space() + format("#") + format(P.m_nvars);
for (equation const & eqn : P.m_equations) {
r += nest(line() + pp_equation(eqn));
}
@ -287,13 +346,54 @@ struct elim_match_fn {
return found_inaccessible && found_not_inaccessible;
}
/** Result for the compilation procedure. */
struct result {
/* m_code is the expression that implements a program. */
expr m_code;
/* List of equation lemmas that hold for m_code, and their proofs */
list<pair<expr, expr>> m_eqns_proofs;
};
/* See update_eqn_lhs */
template<typename F>
expr update_eqn_lhs_core(expr const & lhs, unsigned arity, F && updt) {
buffer<expr> args;
auto it = lhs;
for (unsigned i = 0; i < arity; i++) {
lean_assert(is_app(it));
args.push_back(app_arg(it));
it = app_fn(it);
}
return updt(args);
}
/* Auxiliary method for updating the function in the left-hand-side of the given (conditional) equation.
The method assumes the left-hand-side is of the form:
(f a_1 ... a_n)
where n == arity.
The function updt must construct the new left-hand-side.
It take a buffer containing [a_n, ..., a_1]. */
template<typename F>
expr update_eqn_lhs(expr const & eqn, unsigned arity, F && updt) {
if (is_pi(eqn)) {
return update_binding(eqn, binding_domain(eqn), update_eqn_lhs(binding_body(eqn), arity, updt));
} else {
lean_assert(is_eq(eqn));
buffer<expr> eqn_args;
expr const & eq_fn = get_app_args(eqn, eqn_args);
lean_assert(eqn_args.size() == 3);
eqn_args[1] = update_eqn_lhs_core(eqn_args[1], arity, updt);
return mk_app(eq_fn, eqn_args);
}
}
/* Helper method for tracing intermediate lemmas produced during the compilation process. */
void trace_lemmas(program const & P, char const * header, buffer<lemma> const & lemmas) {
trace_match_detail({
tout() << "[" << m_depth << "] " << header << " lemmas:\n";
auto pp_fn = mk_pp_ctx(P);
for (lemma const & L : lemmas) {
/* Replace function with its name. */
expr tmp_eqn = update_eqn_lhs(L.m_eqn, P.m_nvars,
[&](buffer<expr> const & args) {
return mk_rev_app(mk_constant(P.m_fn_name), args);
});
tout() << " " << ::lean::nest(4, pp_fn(tmp_eqn)) << "\n";
}});
}
result compile_no_equation(program const & P) {
trace_match(tout() << "no equation transition\n";);
@ -305,14 +405,259 @@ struct elim_match_fn {
lean_unreachable();
}
/* Update the equation left hand side
(f a_1 ... a_n)
where n == arity, with
(new_fn x a_1 ... a_n) */
expr update_eqn_for_variable_transition(expr const & eqn, unsigned arity, expr const & new_fn, expr const & x) {
return update_eqn_lhs(eqn, arity, [&](buffer<expr> & args) {
args.push_back(x);
return mk_rev_app(new_fn, args);
});
}
result compile_variable(program const & P) {
lean_assert(is_variable_transition(P));
trace_match(tout() << "variable transition\n";);
lean_unreachable();
program new_P;
new_P.m_fn_name = P.m_fn_name;
buffer<name> new_names;
optional<expr> new_goal = intron(m_env, m_opts, m_mctx, P.m_goal, 1, new_names);
if (!new_goal) throw_ill_formed_eqns();
lean_assert(new_names.size() == 1);
new_P.m_goal = *new_goal;
new_P.m_nvars = P.m_nvars - 1;
name x_name = new_names[0];
expr x = get_local_context(new_P).get_local_decl(x_name)->mk_ref();
buffer<equation> new_eqns;
for (equation const & eqn : P.m_equations) {
equation new_eqn = eqn;
new_eqn.m_patterns = tail(eqn.m_patterns);
new_eqn.m_renames = cons(mk_pair(head(eqn.m_patterns), x), eqn.m_renames);
new_eqns.push_back(new_eqn);
}
new_P.m_equations = to_list(new_eqns);
result R = compile_core(new_P);
result new_R;
type_context ctx = mk_type_context(P);
new_R.m_code = m_mctx.instantiate_mvars(P.m_goal);
if (m_lemmas) {
buffer<lemma> new_lemmas;
for (lemma const & L : R.m_lemmas) {
lemma new_L;
new_L.m_vars = cons(x, L.m_vars);
new_L.m_eqn = update_eqn_for_variable_transition(L.m_eqn, new_P.m_nvars, new_R.m_code, x);
new_L.m_proof = L.m_proof;
new_lemmas.push_back(new_L);
}
trace_lemmas(P, "variable transition", new_lemmas);
new_R.m_lemmas = to_list(new_lemmas);
}
return new_R;
}
/* Populate R with the given equations. The equations are also updated by replacing the current
pattern (a constructor) with its arguments. Note that R[i].first is the name of the constructor.
Example: suppose the input eqns contains the equations
nil L_1 := R_1
(cons a b) L_2 := R_2
(cons c d) L_3 := R_3
nil L_4 := R_4
Then, R will contain the pairs
(nil, L_1 := R_1)
(cons, (cons a b) L_2 := R_2)
(cons, (cons c d) L_3 := R_3)
(nil L_4 := R_4)
*/
void distribute_constructor_equations(list<equation> const & eqns, buffer<pair<name, equation>> & R) {
for (equation const & eqn : eqns) {
lean_assert(eqn.m_patterns);
type_context ctx = mk_type_context(eqn.m_lctx);
expr pattern = whnf_constructor(ctx, head(eqn.m_patterns));
if (!is_constructor_app(pattern)) {
throw_error("equation compiler failed, pattern is not a constructor "
"(use 'set_option trace.eqn_compiler.elim_match true' for additional details)");
}
list<expr> new_patterns = cons(pattern, tail(eqn.m_patterns));
expr const & C = get_app_fn(pattern);
R.emplace_back(const_name(C), equation(eqn, new_patterns));
}
}
/* eqns is the data-structured returned by distribute_constructor_equations.
This method selects the ones such that eqns[i].first == C.
It also updates eqns[i].second.m_renames using \c renaming.
It also "replaces" the next pattern (a constructor) with its fields.
The map \c renaming is produced by the `cases` tactic.
It is needed because the `cases` tactic may revert and reintroduce hypothesis that
depend on the hypothesis being destructed.
The parameter \c field should be interpreted as a bit-mask here.
It says which constructor fields should be used. That is, "some" value means the field
should be considered.
*/
list<equation> get_equations_for(name const & C, list<optional<name>> const & fields, name_map<name> const & renaming,
local_context const & lctx, buffer<pair<name, equation>> const & eqns) {
buffer<equation> R;
for (auto p : eqns) {
if (p.first == C) {
equation eqn = p.second;
/* Update renames */
eqn.m_renames = map(eqn.m_renames, [&](pair<expr, expr> const & p) {
if (auto new_name = renaming.find(mlocal_name(p.second))) {
return mk_pair(p.first, lctx.get_local_decl(*new_name)->mk_ref());
} else {
return p;
}
});
/* Update patterns */
type_context ctx = mk_type_context(eqn.m_lctx);
lean_assert(eqn.m_patterns);
expr pattern = head(eqn.m_patterns);
buffer<expr> pattern_args;
DEBUG_CODE(expr const & C2 =) get_app_args(pattern, pattern_args);
lean_assert(const_name(C2) == C);
/* The inductive datatype parameters are always ignored. */
name I = *eqns_env_interface(m_env).is_constructor(C);
unsigned I_nparams = eqns_env_interface(m_env).get_inductive_num_params(I);
lean_assert(I_nparams <= pattern_args.size());
lean_assert(I_nparams + length(fields) == pattern_args.size());
buffer<expr> new_patterns;
auto it_fields = fields;
for (unsigned i = I_nparams; i < pattern_args.size(); i++) {
if (head(it_fields)) {
new_patterns.push_back(whnf_pattern(ctx, pattern_args[i]));
}
it_fields = tail(it_fields);
}
eqn.m_patterns = to_list(new_patterns.begin(), new_patterns.end(), tail(eqn.m_patterns));
R.push_back(eqn);
}
}
return to_list(R);
}
/* Store in R the local_decl_refs for ilist by using the local context of the metavariable mvar. */
void to_buffer_local(expr const & mvar, list<optional<name>> const & ilist, buffer<expr> & R) {
local_context lctx = get_local_context(mvar);
for (optional<name> const & x_name : ilist) {
if (x_name)
R.push_back(lctx.get_local_decl(*x_name)->mk_ref());
}
}
/* Update an equation left-hand-side of the form
(f a_1 ... a_n b_1 ... b_m)
where n == nfields and n+m == arity, with
(new_fn (c a_1 ... a_n) b_1 ... b_m) */
expr update_eqn_for_constructor_transition(expr const & eqn, unsigned arity, expr const & new_fn, expr const & c, unsigned nfields) {
return update_eqn_lhs(eqn, arity, [&](buffer<expr> & args) {
lean_assert(args.size() >= nfields);
expr c_app = mk_rev_app(c, nfields, args.end() - nfields);
args.shrink(args.size() - nfields);
args.push_back(c_app);
return mk_rev_app(new_fn, args);
});
}
result compile_constructor(program const & P) {
trace_match(tout() << "constructor transition\n";);
lean_unreachable();
lean_assert(is_constructor_transition(P));
buffer<name> new_names;
optional<expr> aux_mvar1 = intron(m_env, m_opts, m_mctx, P.m_goal, 1, new_names);
if (!aux_mvar1) throw_ill_formed_eqns();
expr x = get_local_context(*aux_mvar1).get_local_decl(new_names[0])->mk_ref();
cintros_list ilist;
renaming_list rlist;
list<expr> new_goals; list<name> new_goal_cnames;
try {
list<name> ids;
std::tie(new_goals, new_goal_cnames) =
cases(m_env, m_opts, transparency_mode::Semireducible, m_mctx,
*aux_mvar1, x, ids, &ilist, &rlist);
lean_assert(length(new_goals) == length(new_goal_cnames));
lean_assert(length(new_goals) == length(ilist));
lean_assert(length(new_goals) == length(rlist));
} catch (exception &) {
trace_match(tout() << "dependent pattern matching step failed\n";);
throw_error("equation compiler failed (use 'set_option trace.eqn_compiler.elim_match true' "
"for additional details)");
}
if (empty(new_goals)) {
return result(m_mctx.instantiate_mvars(P.m_goal));
} else {
buffer<pair<name, equation>> equations_by_constructor;
distribute_constructor_equations(P.m_equations, equations_by_constructor);
/* For each (reachable) case, we invoke compile recursively, and we store
- name of the constructor
- number of fields of this constructor
- "arity" of the auxiliary program being used in the recursive call
- result of the compilation for this auxiliary function. */
buffer<std::tuple<name, unsigned, unsigned, result>> result_by_constructor;
while (new_goals) {
lean_assert(new_goal_cnames && ilist && rlist);
program new_P;
new_P.m_fn_name = name(P.m_fn_name, head(new_goal_cnames).get_string());
expr new_goal = head(new_goals);
/* Revert constructor fields (which have not been eliminated by dependent pattern matching). */
buffer<expr> to_revert;
to_buffer_local(new_goal, head(ilist), to_revert);
unsigned to_revert_size = to_revert.size();
unsigned nfields = to_revert_size;
expr aux_mvar2 = revert(m_env, m_opts, m_mctx, head(new_goals), to_revert);
lean_assert(to_revert.size() == to_revert_size);
new_P.m_goal = aux_mvar2;
/* The arity of the auxiliary program is the arity of the original program
- 1 (we consumed one argument in this step) and + nfields (we added nfields new arguments). */
new_P.m_nvars = P.m_nvars - 1 + nfields;
new_P.m_equations = get_equations_for(head(new_goal_cnames), head(ilist), head(rlist),
get_local_context(aux_mvar2), equations_by_constructor);
result new_R = compile_core(new_P);
result_by_constructor.emplace_back(head(new_goal_cnames), nfields, new_P.m_nvars, new_R);
new_goals = tail(new_goals);
new_goal_cnames = tail(new_goal_cnames);
ilist = tail(ilist);
rlist = tail(rlist);
}
result new_R;
new_R.m_code = m_mctx.instantiate_mvars(P.m_goal);
if (m_lemmas) {
type_context ctx = mk_type_context(get_local_context(*aux_mvar1));
expr I = ctx.infer(x);
buffer<expr> I_params;
levels I_lvls = get_inductive_levels_and_params(ctx, I, I_params);
buffer<lemma> new_lemmas;
for (std::tuple<name, unsigned, unsigned, result> const & entry : result_by_constructor) {
name const & cname = std::get<0>(entry); /* constructor name */
unsigned nfields = std::get<1>(entry);
unsigned arity = std::get<2>(entry);
result const & Rc = std::get<3>(entry);
expr c = mk_app(mk_constant(cname, I_lvls), I_params);
for (lemma const & L : Rc.m_lemmas) {
lemma new_L;
new_L.m_vars = L.m_vars;
new_L.m_eqn = update_eqn_for_constructor_transition(L.m_eqn, arity, new_R.m_code, c, nfields);
new_L.m_proof = L.m_proof;
new_lemmas.push_back(new_L);
}
}
trace_lemmas(P, "constructor transition", new_lemmas);
new_R.m_lemmas = to_list(new_lemmas);
}
return new_R;
}
}
result compile_value(program const & P) {
@ -326,13 +671,34 @@ struct elim_match_fn {
}
result compile_leaf(program const & P) {
lean_unreachable();
if (!P.m_equations) {
throw_error("invalid non-exhaustive set of equations (use 'set_option trace.eqn_compiler.elim_match true' "
"for additional details)");
}
equation const & eqn = head(P.m_equations);
m_used_eqns[eqn.m_idx] = true;
buffer<expr> from, to;
for (pair<expr, expr> const & p : eqn.m_renames) {
from.push_back(p.first);
to.push_back(p.second);
}
expr rhs = replace_locals(eqn.m_rhs, from, to);
m_mctx.assign(P.m_goal, rhs);
result R;
R.m_code = rhs;
if (m_lemmas) {
type_context ctx = mk_type_context(get_local_context(P));
expr eq = mk_eq(ctx, rhs, rhs);
expr H = mk_eq_refl(ctx, rhs);
R.m_lemmas = to_list(lemma(list<expr>(), eq, H));
}
return R;
}
result compile_core(program const & P) {
flet<unsigned> inc_depth(m_depth, m_depth+1);
trace_match(tout() << "depth [" << m_depth << "]\n" << pp_program(P) << "\n";);
if (P.m_var_stack) {
if (P.m_nvars > 0) {
if (!P.m_equations) {
return compile_no_equation(P);
} else if (is_inaccessible_transition(P)) {
@ -393,6 +759,7 @@ expr elim_match(environment & env, options const & opts, metavar_context & mctx,
void initialize_elim_match() {
register_trace_class({"eqn_compiler", "elim_match"});
register_trace_class({"eqn_compiler", "elim_match_detail"});
}
void finalize_elim_match() {
}