lean4-htt/src/library/equations_compiler/wf_rec.cpp
Leonardo de Moura 4e496b78d5 feat(library/equations_compiler): unpack auxiliary definition
We still need to unpack auxiliary lemmas, and propagate information in
the frontend.
2017-05-20 20:34:18 -07:00

391 lines
16 KiB
C++

/*
Copyright (c) 2017 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/instantiate.h"
#include "library/type_context.h"
#include "library/trace.h"
#include "library/constants.h"
#include "library/pp_options.h"
#include "library/app_builder.h"
#include "library/sorry.h" // remove after we add tactic for proving recursive calls are decreasing
#include "library/replace_visitor_with_tc.h"
#include "library/equations_compiler/pack_domain.h"
#include "library/equations_compiler/pack_mutual.h"
#include "library/equations_compiler/elim_match.h"
#include "library/equations_compiler/util.h"
namespace lean {
#define trace_wf(Code) lean_trace(name({"eqn_compiler", "wf_rec"}), type_context ctx = mk_type_context(); scope_trace_env _scope1(m_env, ctx); Code)
#define trace_debug_wf(Code) lean_trace(name({"debug", "eqn_compiler", "wf_rec"}), type_context ctx = mk_type_context(); scope_trace_env _scope1(m_env, ctx); Code)
#define trace_debug_wf_aux(Code) lean_trace(name({"debug", "eqn_compiler", "wf_rec"}), scope_trace_env _scope1(m_env, ctx); Code)
struct wf_rec_fn {
environment m_env;
options m_opts;
metavar_context m_mctx;
local_context m_lctx;
expr m_ref;
equations_header m_header;
expr m_R;
expr m_R_wf;
wf_rec_fn(environment const & env, options const & opts,
metavar_context const & mctx, local_context const & lctx):
m_env(env), m_opts(opts), m_mctx(mctx), m_lctx(lctx) {
}
type_context mk_type_context(local_context const & lctx) {
return type_context(m_env, m_opts, m_mctx, lctx, transparency_mode::Semireducible);
}
type_context mk_type_context() {
return mk_type_context(m_lctx);
}
expr pack_domain(expr const & eqns) {
type_context ctx = mk_type_context();
expr r = ::lean::pack_domain(ctx, eqns);
m_env = ctx.env();
m_mctx = ctx.mctx();
return r;
}
expr pack_mutual(expr const & eqns) {
type_context ctx = mk_type_context();
expr r = ::lean::pack_mutual(ctx, eqns);
m_env = ctx.env();
m_mctx = ctx.mctx();
return r;
}
expr_pair mk_wf_relation(expr const & eqns) {
lean_assert(get_equations_header(eqns).m_num_fns == 1);
type_context ctx = mk_type_context();
unpack_eqns ues(ctx, eqns);
try {
expr fn_type = ctx.relaxed_whnf(ctx.infer(ues.get_fn(0)));
lean_assert(is_pi(fn_type));
expr d = binding_domain(fn_type);
expr wf = mk_app(ctx, get_has_well_founded_name(), d);
if (auto inst = ctx.mk_class_instance(wf)) {
bool mask[2] = {true, true};
expr args[2] = {d, *inst};
expr r = mk_app(ctx, get_has_well_founded_r_name(), 2, mask, args);
expr wf = mk_app(ctx, get_has_well_founded_wf_name(), 2, mask, args);
return expr_pair(r, wf);
}
} catch (exception & ex) {
throw nested_exception(some_expr(m_ref),
"failed to create well founded relation using type class resolution",
ex);
}
throw generic_exception(m_ref, "failed to create well founded relation using type class resolution");
}
/* Return the type of the functional. */
expr mk_new_fn_type(type_context & ctx, unpack_eqns const & ues) {
type_context::tmp_locals locals(ctx);
expr fn = ues.get_fn(0);
expr fn_type = ctx.relaxed_whnf(ctx.infer(fn));
lean_assert(ues.get_arity_of(0) == 1);
expr x = locals.push_local("_x", binding_domain(fn_type));
expr y = locals.push_local("_y", binding_domain(fn_type));
expr hlt = mk_app(m_R, y, x);
expr Cy = instantiate(binding_body(fn_type), y);
expr F_type = ctx.mk_pi(y, mk_arrow(hlt, Cy));
expr F = locals.push_local("_F", F_type);
expr Cx = instantiate(binding_body(fn_type), x);
return ctx.mk_pi(x, ctx.mk_pi(F, Cx));
}
struct elim_rec_apps_fn : public replace_visitor_with_tc {
expr m_fn;
expr m_R;
expr m_x;
expr m_F;
elim_rec_apps_fn(type_context & ctx, expr const & fn, expr const & R, expr const & x, expr const & F):
replace_visitor_with_tc(ctx), m_fn(fn), m_R(R), m_x(x), m_F(F) {}
virtual expr visit_local(expr const & e) {
if (mlocal_name(e) == mlocal_name(m_fn)) {
/* unexpected occurrence of recursive function */
throw generic_exception(e, "unexpected occurrence of recursive function\n");
}
return e;
}
/* Prove that y < x */
expr mk_dec_proof(expr const & y) {
expr y_R_x = mk_app(m_R, y, m_x);
// TODO(Leo): invoke tactic, we use sorry for now
return mk_sorry(y_R_x);
}
virtual expr visit_app(expr const & e) {
expr const & fn = app_fn(e);
if (is_local(fn) && mlocal_name(fn) == mlocal_name(m_fn)) {
expr y = visit(app_arg(e));
expr hlt = mk_dec_proof(y);
return mk_app(m_F, y, hlt);
} else {
return replace_visitor_with_tc::visit_app(e);
}
}
};
void update_eqs(type_context & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) {
buffer<expr> & eqns = ues.get_eqns_of(0);
buffer<expr> new_eqns;
for (expr const & eqn : eqns) {
unpack_eqn ue(ctx, eqn);
expr lhs = ue.lhs();
expr rhs = ue.rhs();
buffer<expr> lhs_args;
get_app_args(lhs, lhs_args);
lean_assert(lhs_args.size() == 1);
expr new_lhs = mk_app(new_fn, lhs_args);
expr type = ctx.whnf(ctx.infer(new_lhs));
lean_assert(is_pi(type));
ue.lhs() = new_lhs;
type_context::tmp_locals locals(ctx);
expr F = locals.push_local_from_binding(type);
ue.rhs() = ctx.mk_lambda(F, elim_rec_apps_fn(ctx, fn, m_R, lhs_args[0], F)(rhs));
new_eqns.push_back(ue.repack());
}
eqns = new_eqns;
}
expr elim_recursion(expr const & eqns) {
type_context ctx = mk_type_context();
unpack_eqns ues(ctx, eqns);
lean_assert(ues.get_num_fns() == 1);
expr fn = ues.get_fn(0);
expr fn_type = ctx.infer(fn);
expr new_fn_type = mk_new_fn_type(ctx, ues);
trace_debug_wf(tout() << "\n"; tout() << "new function type: " << new_fn_type << "\n";);
expr new_fn = ues.update_fn_type(0, new_fn_type);
update_eqs(ctx, ues, fn, new_fn);
expr new_eqns = ues.repack();
trace_debug_wf(tout() << "after well_founded elim_recursion:\n" << new_eqns << "\n";);
m_mctx = ctx.mctx();
return new_eqns;
}
expr mk_fix(expr const & aux_fn) {
type_context ctx = mk_type_context();
type_context::tmp_locals locals(ctx);
buffer<expr> fn_args;
expr it = ctx.relaxed_whnf(ctx.infer(aux_fn));
lean_assert(is_pi(it));
expr x_ty = binding_domain(it);
expr x = locals.push_local("_x", x_ty);
it = ctx.relaxed_whnf(instantiate(binding_body(it), x));
lean_assert(is_pi(it));
expr Cx = binding_body(it);
lean_assert(closed(it));
expr C = ctx.mk_lambda(x, Cx);
level u_1 = get_level(ctx, x_ty);
optional<level> dec_u_1 = dec_level(u_1);
if (!dec_u_1)
throw generic_exception(m_ref, "equation compiler failed to compute universe level parameter");
level u_2 = get_level(ctx, Cx);
expr fix = mk_app({mk_constant(get_well_founded_fix_name(), {*dec_u_1, u_2}), x_ty, C, m_R, m_R_wf, aux_fn, x});
return ctx.mk_lambda(x, fix);
}
expr mk_fix_aux_function(equations_header const & header, expr fn) {
type_context ctx = mk_type_context();
fn = mk_fix(fn);
expr fn_type = ctx.infer(fn);
expr r;
std::tie(m_env, r) = mk_aux_definition(m_env, m_opts, m_mctx, m_lctx, header,
head(header.m_fn_names), fn_type, fn);
return r;
}
struct mk_lemma_rhs_fn : public replace_visitor_with_tc {
expr m_fn;
expr m_F;
mk_lemma_rhs_fn(type_context & ctx, expr const & fn, expr const & F):
replace_visitor_with_tc(ctx), m_fn(fn), m_F(F) {}
virtual expr visit_local(expr const & e) override {
if (e == m_F) {
throw exception("equation compiler failed when generation equational lemmas");
} else {
return e;
}
}
virtual expr visit_app(expr const & e) override {
if (is_app(app_fn(e)) && app_fn(app_fn(e)) == m_F) {
return mk_app(m_fn, visit(app_arg(app_fn(e))));
} else {
return replace_visitor_with_tc::visit_app(e);
}
}
};
expr mk_lemma_rhs(type_context & ctx, expr const & fn, expr rhs) {
rhs = ctx.relaxed_whnf(rhs);
lean_assert(is_lambda(rhs));
type_context::tmp_locals locals(ctx);
expr F = locals.push_local_from_binding(rhs);
rhs = instantiate(binding_body(rhs), F);
return mk_lemma_rhs_fn(ctx, fn, F)(rhs);
}
void mk_lemmas(expr const & fn, list<expr> const & lemmas) {
name const & fn_name = const_name(get_app_fn(fn));
unsigned eqn_idx = 1;
type_context ctx = mk_type_context();
for (expr type : lemmas) {
type_context::tmp_locals locals(ctx);
type = ctx.relaxed_whnf(type);
while (is_pi(type)) {
expr local = locals.push_local_from_binding(type);
type = instantiate(binding_body(type), local);
}
lean_assert(is_eq(type));
expr lhs = app_arg(app_fn(type));
expr rhs = app_arg(type);
expr new_lhs = mk_app(fn, app_arg(lhs));
expr new_rhs = mk_lemma_rhs(ctx, fn, rhs);
trace_debug_wf_aux(tout() << "aux equation [" << eqn_idx << "]:\n" << new_lhs << "\n=\n" << new_rhs << "\n";);
m_env = mk_equation_lemma(m_env, m_opts, m_mctx, ctx.lctx(), fn_name,
eqn_idx, m_header.m_is_private, locals.as_buffer(), new_lhs, new_rhs);
eqn_idx++;
}
m_mctx = ctx.mctx();
}
expr_pair mk_sigma(type_context & ctx, unsigned i, buffer<expr> const & args) {
lean_assert(args.size() > 0);
if (i == args.size() - 1) {
return mk_pair(args[i], ctx.infer(args[i]));
} else {
expr as, as_type;
std::tie(as, as_type) = mk_sigma(ctx, i+1, args);
expr a = args[i];
lean_assert(is_local(a));
expr a_type = ctx.infer(a);
level a_lvl = get_level(ctx, a_type);
level as_lvl = get_level(ctx, as_type);
as_type = ctx.mk_lambda(a, as_type);
expr r_type = mk_app(mk_constant(get_psigma_name(), {a_lvl, as_lvl}), a_type, as_type);
expr r = mk_app(mk_constant(get_psigma_mk_name(), {a_lvl, as_lvl}),
a_type, as_type, a, as);
return mk_pair(r, r_type);
}
}
expr unpack(expr const & packed_fn, expr const & eqns_before_pack) {
equations_header const & header = get_equations_header(eqns_before_pack);
list<name> fn_names = header.m_fn_names;
type_context ctx = mk_type_context();
buffer<expr> result_fns;
expr packed_fn_type = ctx.relaxed_whnf(ctx.infer(packed_fn));
expr packed_domain = binding_domain(packed_fn_type);
unpack_eqns ues(ctx, eqns_before_pack);
unsigned num_fns = ues.get_num_fns();
for (unsigned fidx = 0; fidx < num_fns; fidx++) {
unsigned arity = ues.get_arity_of(fidx);
expr fn_type = ctx.infer(ues.get_fn(fidx));
type_context::tmp_locals args(ctx);
expr it = fn_type;
for (unsigned i = 0; i < arity; i++) {
it = ctx.relaxed_whnf(it);
lean_assert(is_pi(it));
expr arg = args.push_local_from_binding(it);
it = instantiate(binding_body(it), arg);
}
expr sigma_mk = mk_sigma(ctx, 0, args.as_buffer()).first;
expr packed_arg = mk_mutual_arg(ctx, sigma_mk, fidx, num_fns, packed_domain);
expr fn_val = args.mk_lambda(mk_app(packed_fn, packed_arg));
name fn_name = head(fn_names);
fn_names = tail(fn_names);
trace_debug_wf(tout() << fn_name << " := " << fn_val << "\n";);
expr r;
std::tie(m_env, r) = mk_aux_definition(m_env, m_opts, m_mctx, m_lctx, header, fn_name, fn_type, fn_val);
result_fns.push_back(r);
/* TODO(Leo): unpack equations */
}
return mk_equations_result(result_fns.size(), result_fns.data());
}
expr operator()(expr eqns) {
m_ref = eqns;
m_header = get_equations_header(eqns);
/* Make sure all functions are unary */
expr before_pack = eqns;
eqns = pack_domain(eqns);
trace_debug_wf(tout() << "after pack_domain\n" << eqns << "\n";);
/* Make sure we have only one function */
expr before_mutual = eqns;
equations_header const & header = get_equations_header(eqns);
if (header.m_num_fns > 1) {
eqns = pack_mutual(eqns);
}
/* Retrieve well founded relation */
if (is_wf_equations(eqns)) {
m_R = equations_wf_rel(eqns);
m_R_wf = equations_wf_proof(eqns);
} else {
std::tie(m_R, m_R_wf) = mk_wf_relation(eqns);
}
{
lean_trace_init_bool(name({"eqn_compiler", "wf_rec"}), get_pp_implicit_name(), true);
trace_wf(tout() << "using well_founded relation\n" << m_R << " :\n "
<< mk_type_context().infer(m_R) << "\n";);
}
/* Eliminate recursion using functional. */
eqns = elim_recursion(eqns);
trace_debug_wf(tout() << "after elim_recursion\n" << eqns << "\n";);
/* Eliminate pattern matching */
elim_match_result r = elim_match(m_env, m_opts, m_mctx, m_lctx, eqns);
expr fn = mk_fix_aux_function(get_equations_header(eqns), r.m_fn);
trace_debug_wf(tout() << "after mk_fix\n" << fn << " :\n " << mk_type_context().infer(fn) << "\n";);
if (m_header.m_aux_lemmas) {
lean_assert(!m_header.m_is_meta);
mk_lemmas(fn, r.m_lemmas);
}
return unpack(fn, before_pack);
}
};
/** \brief (Try to) eliminate "recursive calls" in the equations \c eqns by using well founded recursion.
If successful, elim_match is used to compile pattern matching. */
expr wf_rec(environment & env, options const & opts,
metavar_context & mctx, local_context const & lctx,
expr const & eqns) {
wf_rec_fn proc(env, opts, mctx, lctx);
expr r = proc(eqns);
env = proc.m_env;
mctx = proc.m_mctx;
return r;
}
void initialize_wf_rec() {
register_trace_class({"eqn_compiler", "wf_rec"});
register_trace_class({"debug", "eqn_compiler", "wf_rec"});
}
void finalize_wf_rec() {
}
}