fix(library/equations_compiler/elim_match): support dependent functions handling invertible functions

This commit is contained in:
Leonardo de Moura 2016-09-07 16:04:08 -07:00
parent 159653f253
commit f7699d8719

View file

@ -272,13 +272,17 @@ struct elim_match_fn {
return false;
}
bool is_invertible_app(expr const & e) {
bool is_transport_app(expr const & e) {
if (!is_app(e)) return false;
expr const & fn = get_app_fn(e);
if (!is_constant(fn)) return false;
optional<inverse_info> info = has_inverse(m_env, const_name(fn));
if (!info) return false;
return info->m_arity == get_app_num_args(e);
if (!info || info->m_arity != get_app_num_args(e)) return false;
optional<inverse_info> inv_info = has_inverse(m_env, info->m_inv);
return
inv_info &&
info->m_arity == inv_info->m_inv_arity &&
inv_info->m_arity == info->m_inv_arity;
}
unsigned get_inductive_num_params(name const & n) const { return *inductive::get_num_params(m_env, n); }
@ -303,7 +307,7 @@ struct elim_match_fn {
return e;
} else {
return ctx.whnf_pred(e, [&](expr const & e) {
return !is_constructor_app(e) && !is_value(e) && !is_invertible_app(e);
return !is_constructor_app(e) && !is_value(e) && !is_transport_app(e);
});
}
}
@ -441,11 +445,11 @@ struct elim_match_fn {
}
/* Return true iff the next pattern in all equations is the same invertible function. */
bool is_invertible_transition(problem const & P) {
bool is_transport_transition(problem const & P) {
if (!P.m_equations) return false;
optional<name> fn_name;
return all_next_pattern(P, [&](expr const & p) {
if (!is_invertible_app(p)) return false;
if (!is_transport_app(p)) return false;
name const & curr_name = const_name(get_app_fn(p));
if (fn_name) {
return *fn_name == curr_name;
@ -912,39 +916,129 @@ struct elim_match_fn {
return process(new_P);
}
list<lemma> process_invertible(problem const & P) {
trace_match(tout() << "step: invertible function\n";);
type_context ctx = mk_type_context(P);
expr const & x = head(P.m_var_stack);
/* make inverse */
expr const & p = head(head(P.m_equations).m_patterns);
lean_assert(is_invertible_app(p));
expr const & fn = get_app_fn(p);
inverse_info info = *has_inverse(m_env, const_name(fn));
/* Create (f ... x) with the given arity, where the other arguments are inferred using
type inference */
expr mk_app_with_arity(type_context & ctx, name const & f, unsigned arity, expr const & x) {
buffer<bool> mask;
mask.resize(info.m_inv_arity - 1, false);
mask.resize(arity - 1, false);
mask.push_back(true);
expr inv;
try {
inv = mk_app(ctx, info.m_inv, mask.size(), mask.data(), &x);
} catch (exception &) {
throw_error(sstream() << "equation compiler failed, when trying to inverse function "
<< "'" << const_name(fn) << "' (use 'set_option trace.eqn_compiler.elim_match true' "
return mk_app(ctx, f, mask.size(), mask.data(), &x);
} catch (app_builder_exception &) {
throw_error(sstream() << "equation compiler failed, when trying to build "
<< "'" << f << "' application (use 'set_option trace.eqn_compiler.elim_match true' "
<< "for additional details)");
}
expr inv_type = ctx.infer(inv);
local_context lctx = ctx.lctx();
name y_name("_y");
expr y = lctx.mk_local_decl(y_name, inv_type, inv);
expr goal_type = ctx.infer(P.m_goal);
expr new_goal = ctx.mk_metavar_decl(lctx, goal_type);
expr new_val = mk_let(y_name, inv_type, inv, mk_delayed_abstraction(new_goal, mlocal_name(y)));
m_mctx = ctx.mctx();
m_mctx.assign(P.m_goal, new_val);
}
/*
This step is applied to problems of the form
goal: ... (x : A) ... |- C x
var_stack: [x, ...]
equations:
(f y_1) ... := rhs_1
(f y_2) ... := rhs_2
...
(f y_n) ... := rhs_n
where f (B -> A) is not a constructor, and there is a function g : A -> B s.t.
g_f_eq : forall y, g (f y) = y
f_g_eq : forall x, f (g x) = x
Steps:
1) Revert x and obtain goal
M_1 ... |- forall (x : A), C' x
2) Create new goal
M_2 ... |- forall (y : B), C' (f y)
Solution for M_1 is
fun x : A, @@eq.rec (fun x, C' x) (M_2 (g x)) (f_g_eq x)
We need the eq.rec because (M_2 (g x)) has type (C' (f (g x)))
3) Create new problem by reintroducing all variables inverted in
step 1, replacing x with y in the var stack, and using the new set of equations
y_1 ... := rhs_1
y_2 ... := rhs_2
...
y_n ... := rhs_n
Remark: the lemma g_f_eq is used when we are trying to prove the equation lemmas.
*/
list<lemma> process_transport(problem const & P) {
trace_match(tout() << "step: transport function\n";);
expr x = head(P.m_var_stack);
expr p = head(head(P.m_equations).m_patterns);
lean_assert(is_transport_app(p));
expr f = get_app_fn(p);
name f_name = const_name(f);
inverse_info info = *has_inverse(m_env, f_name);
unsigned f_arity = info.m_arity;
name g_name = info.m_inv;
unsigned g_arity = info.m_inv_arity;
inverse_info info_inv = *has_inverse(m_env, g_name);
name f_g_eq_name = info_inv.m_lemma;
/* Step 1 */
buffer<expr> to_revert;
to_revert.push_back(x);
expr M_1 = revert(m_env, m_opts, m_mctx, P.m_goal, to_revert);
/* Step 2 */
type_context ctx1 = mk_type_context(M_1);
expr M_1_type = ctx1.relaxed_whnf(ctx1.infer(M_1));
lean_assert(is_pi(M_1_type));
expr x1 = ctx1.push_local(binding_name(M_1_type), binding_domain(M_1_type));
expr g_x1 = mk_app_with_arity(ctx1, g_name, g_arity, x1);
expr B = ctx1.infer(g_x1);
expr y1 = ctx1.push_local("_y", B);
expr f_y1 = mk_app_with_arity(ctx1, f_name, f_arity, y1);
expr C_x1 = instantiate(binding_body(M_1_type), x1);
expr C_f_y1 = replace_local(C_x1, x1, f_y1);
expr M_2_type = ctx1.mk_pi(y1, C_f_y1);
expr M_2 = ctx1.mk_metavar_decl(get_local_context(M_1), M_2_type);
expr eqrec;
try {
expr eqrec_motive = ctx1.mk_lambda(x1, C_x1);
expr eqrec_minor = mk_app(M_2, g_x1);
expr eqrec_major = mk_app(ctx1, f_g_eq_name, x1);
eqrec = mk_eq_rec(ctx1, eqrec_motive, eqrec_minor, eqrec_major);
} catch (app_builder_exception &) {
throw_error("equation compiler failed, when trying to build "
"'eq.rec'-application for transport step (use 'set_option trace.eqn_compiler.elim_match true' "
"for additional details)");
}
expr M_1_val = ctx1.mk_lambda(x1, eqrec);
/* M_1_val is (fun x1 : A, @@eq.rec (fun x1, C x1) (M_2 (g x1)) (f_g_eq x1)) */
m_mctx = ctx1.mctx();
m_mctx.assign(M_1, M_1_val);
/* Step 3 */
buffer<name> new_H_names;
optional<expr> M_3 = intron(m_env, m_opts, m_mctx, M_2, to_revert.size(), new_H_names);
if (!M_3) {
throw_error("equation compiler failed, when reintroducing reverted variables "
"(use 'set_option trace.eqn_compiler.elim_match true' "
"for additional details)");
}
local_context lctx3 = get_local_context(*M_3);
buffer<expr> new_Hs;
for (name const & H_name : new_H_names) {
new_Hs.push_back(lctx3.get_local(H_name));
}
lean_assert(to_revert.size() == new_Hs.size());
problem new_P;
new_P.m_fn_name = name(P.m_fn_name, "_inv");
new_P.m_var_stack = cons(y, tail(P.m_var_stack));
new_P.m_goal = new_goal;
new_P.m_fn_name = name(P.m_fn_name, "_transport");
new_P.m_goal = *M_3;
new_P.m_var_stack = map(P.m_var_stack,
[&](expr const & x) { return replace_locals(x, to_revert, new_Hs); });
buffer<equation> new_eqns;
for (equation const & eqn : P.m_equations) {
equation new_eqn = eqn;
@ -997,8 +1091,8 @@ struct elim_match_fn {
return process_value(P);
} else if (is_constructor_transition(P)) {
return process_constructor(P);
} else if (is_invertible_transition(P)) {
return process_invertible(P);
} else if (is_transport_transition(P)) {
return process_transport(P);
} else if (is_inaccessible_transition(P)) {
return process_inaccessible(P);
} else if (some_inaccessible(P)) {