lean4-htt/src/library/equations_compiler/pack_domain.cpp
Leonardo de Moura e9f843ddf6 refactor(kernel/expr): remove mlocal_* functions
The constructors `mvar` and `fvar` have different memory layouts.
2018-06-22 14:25:31 -07:00

160 lines
6.3 KiB
C++

/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/instantiate.h"
#include "library/constants.h"
#include "library/trace.h"
#include "library/app_builder.h"
#include "library/type_context.h"
#include "library/locals.h"
#include "library/replace_visitor_with_tc.h"
#include "library/equations_compiler/equations.h"
#include "library/equations_compiler/util.h"
namespace lean {
struct sigma_packer_fn {
type_context_old & m_ctx;
sigma_packer_fn(type_context_old & ctx):m_ctx(ctx) {}
expr_pair mk_sigma_domain(expr const & pi_type, buffer<expr> & out_locals, unsigned n) {
expr type = pi_type;
if (!is_pi(type)) type = m_ctx.relaxed_whnf(type);
if (!is_pi(type)) throw_ill_formed_eqns();
expr const & A = binding_domain(type);
type_context_old::tmp_locals locals(m_ctx);
expr a = locals.push_local_from_binding(type);
out_locals.push_back(a);
expr next_pi_type = instantiate(binding_body(type), a);
if (n == 1) return mk_pair(A, next_pi_type);
expr B, codomain;
std::tie(B, codomain) = mk_sigma_domain(next_pi_type, out_locals, n-1);
B = locals.mk_lambda(B);
return mk_pair(mk_app(m_ctx, get_psigma_name(), A, B), codomain);
}
expr mk_codomain(expr const & codomain, expr p, buffer<expr> const & locals, unsigned n) {
buffer<expr> terms;
for (unsigned i = 0; i < n - 1; i++) {
terms.push_back(mk_app(m_ctx, get_psigma_fst_name(), p));
p = mk_app(m_ctx, get_psigma_snd_name(), p);
}
terms.push_back(p);
return replace_locals(codomain, locals, terms);
}
expr pack_as_unary(expr const & pi_type, unsigned n) {
buffer<expr> locals;
expr domain, pre_codomain;
std::tie(domain, pre_codomain) = mk_sigma_domain(pi_type, locals, n);
type_context_old::tmp_locals plocal(m_ctx);
expr p = plocal.push_local("_p", domain);
expr codomain = mk_codomain(pre_codomain, p, locals, n);
return plocal.mk_pi(codomain);
}
class update_apps_fn : public replace_visitor_with_tc {
buffer<expr> const & m_old_fns;
unpack_eqns const & m_ues;
optional<unsigned> get_fn_idx(expr const & fn) {
if (!is_local(fn)) return optional<unsigned>();
for (unsigned fnidx = 0; fnidx < m_old_fns.size(); fnidx++) {
if (local_name(fn) == local_name(m_old_fns[fnidx]))
return optional<unsigned>(fnidx);
}
return optional<unsigned>();
}
expr pack(unsigned i, unsigned arity, buffer<expr> const & args, expr const & type) {
lean_assert(arity > 0);
if (i == arity - 1) {
return args[i];
} else {
lean_assert(is_constant(get_app_fn(type), get_psigma_name()));
expr a = args[i];
expr A = app_arg(app_fn(type));
expr B = app_arg(type);
lean_assert(is_lambda(B));
expr new_type = instantiate(binding_body(B), a);
expr b = pack(i+1, arity, args, new_type);
bool mask[2] = {true, true};
expr AB[2] = {A, B};
return mk_app(mk_app(m_ctx, get_psigma_mk_name(), 2, mask, AB), a, b);
}
}
virtual expr visit_app(expr const & e) override {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
for (expr & arg : args)
arg = visit(arg);
auto fnidx = get_fn_idx(fn);
if (!fnidx) return replace_visitor_with_tc::visit_app(e);
expr new_fn = m_ues.get_fn(*fnidx);
if (fn == new_fn) return replace_visitor_with_tc::visit_app(e);
unsigned arity = m_ues.get_arity_of(*fnidx);
if (args.size() < arity) {
expr new_e = m_ctx.eta_expand(e);
if (!is_lambda(new_e)) throw_ill_formed_eqns();
return visit(new_e);
}
expr new_fn_type = m_ctx.infer(new_fn);
expr sigma_type = binding_domain(new_fn_type);
expr arg = pack(0, arity, args, sigma_type);
expr r = mk_app(new_fn, arg);
return copy_pos(e, mk_app(r, args.size() - arity, args.data() + arity));
}
virtual expr visit_local(expr const & e) override {
auto fnidx = get_fn_idx(e);
if (!fnidx) return replace_visitor_with_tc::visit_local(e);
expr new_fn = m_ues.get_fn(*fnidx);
if (e == new_fn) return replace_visitor_with_tc::visit_local(e);
unsigned arity = m_ues.get_arity_of(*fnidx);
if (0 < arity) {
expr new_e = m_ctx.eta_expand(e);
if (!is_lambda(new_e)) throw_ill_formed_eqns();
return visit(new_e);
}
return new_fn;
}
public:
update_apps_fn(type_context_old & ctx, buffer<expr> const & old_fns, unpack_eqns const & ues):
replace_visitor_with_tc(ctx), m_old_fns(old_fns), m_ues(ues) {}
};
expr operator()(expr const & e) {
unpack_eqns ues(m_ctx, e);
buffer<expr> old_fns;
bool modified = false;
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
expr const & fn = ues.get_fn(fidx);
old_fns.push_back(fn);
unsigned arity = ues.get_arity_of(fidx);
if (arity > 1) {
expr new_type = pack_as_unary(m_ctx.infer(fn), arity);
ues.update_fn_type(fidx, new_type);
modified = true;
}
}
if (!modified) return e;
update_apps_fn updt(m_ctx, old_fns, ues);
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
buffer<expr> & eqs = ues.get_eqns_of(fidx);
for (expr & eq : eqs)
eq = updt(eq);
}
expr r = ues.repack();
lean_trace("eqn_compiler", tout() << "making function(s) unary:\n" << r << "\n";);
return r;
}
};
expr pack_domain(type_context_old & ctx, expr const & e) {
return sigma_packer_fn(ctx)(e);
}
}