lean4-htt/src/library/equations_compiler/compiler.cpp

244 lines
8.7 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <algorithm>
#include "kernel/find_fn.h"
#include "kernel/instantiate.h"
#include "library/trace.h"
#include "library/locals.h"
#include "library/replace_visitor.h"
#include "library/equations_compiler/compiler.h"
#include "library/equations_compiler/util.h"
#include "library/equations_compiler/pack_domain.h"
#include "library/equations_compiler/structural_rec.h"
#include "library/equations_compiler/unbounded_rec.h"
#include "library/equations_compiler/elim_match.h"
namespace lean {
#define trace_compiler(Code) lean_trace("eqn_compiler", scope_trace_env _scope1(ctx.env(), ctx); Code)
static bool has_nested_rec(expr const & eqns) {
return static_cast<bool>(find(eqns, [&](expr const & e, unsigned) {
return is_local(e) && local_info(e).is_rec();
}));
}
static expr compile_equations_core(environment & env, options const & opts,
metavar_context & mctx, local_context const & lctx,
expr const & eqns) {
type_context ctx(env, opts, mctx, lctx, transparency_mode::Semireducible);
trace_compiler(tout() << "compiling\n" << eqns << "\n";);
trace_compiler(tout() << "recursive: " << is_recursive_eqns(ctx, eqns) << "\n";);
trace_compiler(tout() << "nested recursion: " << has_nested_rec(eqns) << "\n";);
equations_header const & header = get_equations_header(eqns);
lean_assert(header.m_is_meta || !has_nested_rec(eqns));
if (header.m_is_meta) {
return mk_equations_result(unbounded_rec(env, opts, mctx, lctx, eqns));
}
if (header.m_num_fns == 1) {
if (!is_recursive_eqns(ctx, eqns)) {
return mk_equations_result(mk_nonrec(env, opts, mctx, lctx, eqns));
} else if (optional<expr> r = try_structural_rec(env, opts, mctx, lctx, eqns)) {
return mk_equations_result(*r);
}
}
throw exception("support for well-founded recursion has not been implemented yet, "
"use 'set_option trace.eqn_compiler true' for additional information");
// test pack_domain
// pack_domain(ctx, eqns);
}
/** Auxiliary class for pulling nested recursive calls.
Example: given
definition f : nat → (nat × nat) → nat
| 0 m := m.1
| (n+1) m :=
match m with
| (a, b) := f n (b, a + 1)
end
when we compile
match m with
| (a, b) := f n (b, a + 1)
end
we consinder (f n (b, a + 1)) to be a nested recursive call.
Then, we transform the expression to
(fun g,
match m with
| (a, b) := g a b
end) (fun a b, f n (b, a + 1))
Compile the match, and then beta-reduce.
*/
struct pull_nested_rec_fn : public replace_visitor {
environment & m_env;
options m_opts;
metavar_context & m_mctx;
buffer<local_context> m_lctx_stack;
buffer<expr> m_new_locals;
buffer<expr> m_new_values;
pull_nested_rec_fn(environment & env, options const & opts, metavar_context & mctx, local_context const & lctx):
m_env(env), m_opts(opts), m_mctx(mctx) {
m_lctx_stack.push_back(lctx);
}
local_context & base_lctx() { return m_lctx_stack[0]; }
local_context & lctx() { return m_lctx_stack.back(); }
type_context mk_type_context(local_context const & lctx) {
return type_context(m_env, m_opts, m_mctx, lctx, transparency_mode::Semireducible);
}
expr visit_lambda_pi_let(bool is_lam, expr const & e) {
buffer<expr> locals;
m_lctx_stack.push_back(m_lctx_stack.back());
expr t = e;
while (true) {
if ((is_lam && is_lambda(t)) || (!is_lam && is_pi(t))) {
expr d = instantiate_rev(binding_domain(t), locals.size(), locals.data());
d = visit(d);
expr x = lctx().mk_local_decl(binding_name(t), d, binding_info(t));
locals.push_back(x);
t = binding_body(t);
} else if (is_let(t)) {
expr d = instantiate_rev(let_type(t), locals.size(), locals.data());
expr v = instantiate_rev(let_value(t), locals.size(), locals.data());
d = visit(d);
v = visit(v);
expr x = lctx().mk_local_decl(let_name(t), d, v);
locals.push_back(x);
t = let_body(t);
} else {
break;
}
}
t = instantiate_rev(t, locals.size(), locals.data());
t = visit(t);
type_context ctx = mk_type_context(lctx());
t = is_lam ? ctx.mk_lambda(locals, t) : ctx.mk_pi(locals, t);
m_mctx = ctx.mctx();
m_lctx_stack.pop_back();
return t;
}
virtual expr visit_lambda(expr const & e) override {
return visit_lambda_pi_let(true, e);
}
virtual expr visit_let(expr const & e) override {
return visit_lambda_pi_let(true, e);
}
virtual expr visit_pi(expr const & e) override {
return visit_lambda_pi_let(false, e);
}
expr default_visit_app(expr const & e, expr const & fn, buffer<expr> & args) {
expr new_fn = visit(fn);
bool modified = !is_eqp(fn, new_fn);
for (expr & arg : args) {
expr new_arg = visit(arg);
if (!is_eqp(new_arg, arg))
modified = true;
arg = new_arg;
}
if (!modified)
return e;
else
return mk_app(new_fn, args);
}
void collect_locals_core(expr const & e, name_set & found, buffer<expr> & R) {
for_each(e, [&](expr const & e, unsigned) {
if (is_local(e) && !base_lctx().get_local_decl(e) && !found.contains(mlocal_name(e))) {
found.insert(mlocal_name(e));
R.push_back(e);
}
return true;
});
}
void collect_locals(expr const & e, buffer<expr> & R) {
name_set found;
collect_locals_core(e, found, R);
for (unsigned i = 0; i < R.size(); i++) {
expr const & x = R[i];
collect_locals_core(lctx().get_local_decl(x)->get_type(), found, R);
}
std::sort(R.begin(), R.end(), [&](expr const & x1, expr const & x2) {
return lctx().get_local_decl(x1)->get_idx() < lctx().get_local_decl(x2)->get_idx();
});
}
expr declare_new_local(name const & uid, name const & n, expr const & type) {
lean_assert(m_lctx_stack.size() > 1);
expr r;
for (unsigned i = 0; i < m_lctx_stack.size(); i++) {
local_context & lctx = m_lctx_stack[i];
if (i == 0) {
r = lctx.mk_local_decl(uid, n, type);
} else {
DEBUG_CODE(expr r2 =) lctx.mk_local_decl(uid, n, type);
lean_assert(r == r2);
}
}
return r;
}
virtual expr visit_app(expr const & e) override {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
if (is_local(fn) && local_info(fn).is_rec() && base_lctx().get_local_decl(fn)) {
buffer<expr> local_deps;
collect_locals(e, local_deps);
type_context ctx = mk_type_context(lctx());
expr val = ctx.mk_lambda(local_deps, e);
expr val_type = ctx.infer(val);
name fn_aux = name("_f").append_after(m_new_locals.size() + 1);
name uid = mk_local_decl_name();
expr g = declare_new_local(uid, fn_aux, val_type);
m_new_locals.push_back(g);
m_new_values.push_back(val);
return mk_app(g, local_deps);
} else {
return default_visit_app(e, fn, args);
}
}
expr operator()(expr const & e) {
expr new_e = visit(e);
lean_assert(m_lctx_stack.size() == 1);
local_context new_lctx = m_lctx_stack[0];
expr r = compile_equations_core(m_env, m_opts, m_mctx, new_lctx, new_e);
type_context ctx = mk_type_context(new_lctx);
expr new_r = replace_locals(r, m_new_locals, m_new_values);
m_mctx = ctx.mctx();
return new_r;
}
};
expr compile_equations(environment & env, options const & opts, metavar_context & mctx, local_context const & lctx,
expr const & eqns) {
if (!get_equations_header(eqns).m_is_meta && has_nested_rec(eqns)) {
return pull_nested_rec_fn(env, opts, mctx, lctx)(eqns);
} else {
return compile_equations_core(env, opts, mctx, lctx, eqns);
}
}
void initialize_compiler() {
}
void finalize_compiler() {
}
}