diff --git a/src/library/equations_compiler/elim_match.cpp b/src/library/equations_compiler/elim_match.cpp index d01edc4742..2a14e51467 100644 --- a/src/library/equations_compiler/elim_match.cpp +++ b/src/library/equations_compiler/elim_match.cpp @@ -272,13 +272,17 @@ struct elim_match_fn { return false; } - bool is_invertible_app(expr const & e) { + bool is_transport_app(expr const & e) { if (!is_app(e)) return false; expr const & fn = get_app_fn(e); if (!is_constant(fn)) return false; optional info = has_inverse(m_env, const_name(fn)); - if (!info) return false; - return info->m_arity == get_app_num_args(e); + if (!info || info->m_arity != get_app_num_args(e)) return false; + optional inv_info = has_inverse(m_env, info->m_inv); + return + inv_info && + info->m_arity == inv_info->m_inv_arity && + inv_info->m_arity == info->m_inv_arity; } unsigned get_inductive_num_params(name const & n) const { return *inductive::get_num_params(m_env, n); } @@ -303,7 +307,7 @@ struct elim_match_fn { return e; } else { return ctx.whnf_pred(e, [&](expr const & e) { - return !is_constructor_app(e) && !is_value(e) && !is_invertible_app(e); + return !is_constructor_app(e) && !is_value(e) && !is_transport_app(e); }); } } @@ -441,11 +445,11 @@ struct elim_match_fn { } /* Return true iff the next pattern in all equations is the same invertible function. */ - bool is_invertible_transition(problem const & P) { + bool is_transport_transition(problem const & P) { if (!P.m_equations) return false; optional fn_name; return all_next_pattern(P, [&](expr const & p) { - if (!is_invertible_app(p)) return false; + if (!is_transport_app(p)) return false; name const & curr_name = const_name(get_app_fn(p)); if (fn_name) { return *fn_name == curr_name; @@ -912,39 +916,129 @@ struct elim_match_fn { return process(new_P); } - list process_invertible(problem const & P) { - trace_match(tout() << "step: invertible function\n";); - type_context ctx = mk_type_context(P); - expr const & x = head(P.m_var_stack); - /* make inverse */ - expr const & p = head(head(P.m_equations).m_patterns); - lean_assert(is_invertible_app(p)); - expr const & fn = get_app_fn(p); - inverse_info info = *has_inverse(m_env, const_name(fn)); + /* Create (f ... x) with the given arity, where the other arguments are inferred using + type inference */ + expr mk_app_with_arity(type_context & ctx, name const & f, unsigned arity, expr const & x) { buffer mask; - mask.resize(info.m_inv_arity - 1, false); + mask.resize(arity - 1, false); mask.push_back(true); - expr inv; try { - inv = mk_app(ctx, info.m_inv, mask.size(), mask.data(), &x); - } catch (exception &) { - throw_error(sstream() << "equation compiler failed, when trying to inverse function " - << "'" << const_name(fn) << "' (use 'set_option trace.eqn_compiler.elim_match true' " + return mk_app(ctx, f, mask.size(), mask.data(), &x); + } catch (app_builder_exception &) { + throw_error(sstream() << "equation compiler failed, when trying to build " + << "'" << f << "' application (use 'set_option trace.eqn_compiler.elim_match true' " << "for additional details)"); } - expr inv_type = ctx.infer(inv); - local_context lctx = ctx.lctx(); - name y_name("_y"); - expr y = lctx.mk_local_decl(y_name, inv_type, inv); - expr goal_type = ctx.infer(P.m_goal); - expr new_goal = ctx.mk_metavar_decl(lctx, goal_type); - expr new_val = mk_let(y_name, inv_type, inv, mk_delayed_abstraction(new_goal, mlocal_name(y))); - m_mctx = ctx.mctx(); - m_mctx.assign(P.m_goal, new_val); + } + + /* + This step is applied to problems of the form + + goal: ... (x : A) ... |- C x + var_stack: [x, ...] + equations: + (f y_1) ... := rhs_1 + (f y_2) ... := rhs_2 + ... + (f y_n) ... := rhs_n + + where f (B -> A) is not a constructor, and there is a function g : A -> B s.t. + g_f_eq : forall y, g (f y) = y + f_g_eq : forall x, f (g x) = x + + Steps: + + 1) Revert x and obtain goal + + M_1 ... |- forall (x : A), C' x + + 2) Create new goal + + M_2 ... |- forall (y : B), C' (f y) + + Solution for M_1 is + + fun x : A, @@eq.rec (fun x, C' x) (M_2 (g x)) (f_g_eq x) + + We need the eq.rec because (M_2 (g x)) has type (C' (f (g x))) + + 3) Create new problem by reintroducing all variables inverted in + step 1, replacing x with y in the var stack, and using the new set of equations + + y_1 ... := rhs_1 + y_2 ... := rhs_2 + ... + y_n ... := rhs_n + + Remark: the lemma g_f_eq is used when we are trying to prove the equation lemmas. + */ + list process_transport(problem const & P) { + trace_match(tout() << "step: transport function\n";); + expr x = head(P.m_var_stack); + expr p = head(head(P.m_equations).m_patterns); + lean_assert(is_transport_app(p)); + expr f = get_app_fn(p); + name f_name = const_name(f); + inverse_info info = *has_inverse(m_env, f_name); + unsigned f_arity = info.m_arity; + name g_name = info.m_inv; + unsigned g_arity = info.m_inv_arity; + inverse_info info_inv = *has_inverse(m_env, g_name); + name f_g_eq_name = info_inv.m_lemma; + + /* Step 1 */ + buffer to_revert; + to_revert.push_back(x); + expr M_1 = revert(m_env, m_opts, m_mctx, P.m_goal, to_revert); + + /* Step 2 */ + type_context ctx1 = mk_type_context(M_1); + expr M_1_type = ctx1.relaxed_whnf(ctx1.infer(M_1)); + lean_assert(is_pi(M_1_type)); + expr x1 = ctx1.push_local(binding_name(M_1_type), binding_domain(M_1_type)); + expr g_x1 = mk_app_with_arity(ctx1, g_name, g_arity, x1); + expr B = ctx1.infer(g_x1); + expr y1 = ctx1.push_local("_y", B); + expr f_y1 = mk_app_with_arity(ctx1, f_name, f_arity, y1); + expr C_x1 = instantiate(binding_body(M_1_type), x1); + expr C_f_y1 = replace_local(C_x1, x1, f_y1); + expr M_2_type = ctx1.mk_pi(y1, C_f_y1); + expr M_2 = ctx1.mk_metavar_decl(get_local_context(M_1), M_2_type); + expr eqrec; + try { + expr eqrec_motive = ctx1.mk_lambda(x1, C_x1); + expr eqrec_minor = mk_app(M_2, g_x1); + expr eqrec_major = mk_app(ctx1, f_g_eq_name, x1); + eqrec = mk_eq_rec(ctx1, eqrec_motive, eqrec_minor, eqrec_major); + } catch (app_builder_exception &) { + throw_error("equation compiler failed, when trying to build " + "'eq.rec'-application for transport step (use 'set_option trace.eqn_compiler.elim_match true' " + "for additional details)"); + } + expr M_1_val = ctx1.mk_lambda(x1, eqrec); + /* M_1_val is (fun x1 : A, @@eq.rec (fun x1, C x1) (M_2 (g x1)) (f_g_eq x1)) */ + m_mctx = ctx1.mctx(); + m_mctx.assign(M_1, M_1_val); + + /* Step 3 */ + buffer new_H_names; + optional M_3 = intron(m_env, m_opts, m_mctx, M_2, to_revert.size(), new_H_names); + if (!M_3) { + throw_error("equation compiler failed, when reintroducing reverted variables " + "(use 'set_option trace.eqn_compiler.elim_match true' " + "for additional details)"); + } + local_context lctx3 = get_local_context(*M_3); + buffer new_Hs; + for (name const & H_name : new_H_names) { + new_Hs.push_back(lctx3.get_local(H_name)); + } + lean_assert(to_revert.size() == new_Hs.size()); problem new_P; - new_P.m_fn_name = name(P.m_fn_name, "_inv"); - new_P.m_var_stack = cons(y, tail(P.m_var_stack)); - new_P.m_goal = new_goal; + new_P.m_fn_name = name(P.m_fn_name, "_transport"); + new_P.m_goal = *M_3; + new_P.m_var_stack = map(P.m_var_stack, + [&](expr const & x) { return replace_locals(x, to_revert, new_Hs); }); buffer new_eqns; for (equation const & eqn : P.m_equations) { equation new_eqn = eqn; @@ -997,8 +1091,8 @@ struct elim_match_fn { return process_value(P); } else if (is_constructor_transition(P)) { return process_constructor(P); - } else if (is_invertible_transition(P)) { - return process_invertible(P); + } else if (is_transport_transition(P)) { + return process_transport(P); } else if (is_inaccessible_transition(P)) { return process_inaccessible(P); } else if (some_inaccessible(P)) {