lean4-htt/src/library/equations_compiler/util.cpp

194 lines
6.6 KiB
C++

/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "kernel/find_fn.h"
#include "kernel/inductive/inductive.h"
#include "library/equations_compiler/equations.h"
#include "library/equations_compiler/util.h"
namespace lean {
[[ noreturn ]] void throw_ill_formed_eqns() {
throw exception("ill-formed match/equations expression");
}
static optional<pair<expr, unsigned>> get_eqn_fn_and_arity(expr e) {
while (is_lambda(e))
e = binding_body(e);
if (!is_equation(e) && !is_no_equation(e)) throw_ill_formed_eqns();
if (is_no_equation(e)) {
return optional<pair<expr, unsigned>>();
} else {
expr const & lhs = equation_lhs(e);
expr const & fn = get_app_fn(lhs);
lean_assert(is_local(fn));
return optional<pair<expr, unsigned>>(fn, get_app_num_args(lhs));
}
}
static expr consume_fn_prefix(expr eq, buffer<expr> const & fns) {
for (unsigned i = 0; i < fns.size(); i++) {
if (!is_lambda(eq)) throw_ill_formed_eqns();
eq = binding_body(eq);
}
return instantiate_rev(eq, fns);
}
unpack_eqns::unpack_eqns(type_context & ctx, expr const & e):
m_locals(ctx) {
lean_assert(is_equations(e));
m_src = e;
buffer<expr> eqs;
unsigned num_fns = equations_num_fns(e);
to_equations(e, eqs);
/* Extract functions. */
lean_assert(eqs.size() > 0);
expr eq = eqs[0];
for (unsigned i = 0; i < num_fns; i++) {
if (!is_lambda(eq)) throw_ill_formed_eqns();
if (!closed(binding_domain(eq))) throw_ill_formed_eqns();
m_fns.push_back(m_locals.push_local(binding_name(eq), binding_domain(eq)));
eq = binding_body(eq);
}
/* Extract equations */
unsigned eqidx = 0;
for (unsigned fidx = 0; fidx < num_fns; fidx++) {
m_eqs.push_back(buffer<expr>());
buffer<expr> & fn_eqs = m_eqs.back();
if (eqidx >= eqs.size()) throw_ill_formed_eqns();
expr eq = consume_fn_prefix(eqs[eqidx], m_fns);
fn_eqs.push_back(eq);
eqidx++;
if (auto p = get_eqn_fn_and_arity(eq)) {
if (p->first != m_fns[fidx]) throw_ill_formed_eqns();
unsigned arity = p->second;
m_arity.push_back(arity);
while (eqidx < eqs.size()) {
expr eq = consume_fn_prefix(eqs[eqidx], m_fns);
if (auto p = get_eqn_fn_and_arity(eq)) {
if (p->first == m_fns[fidx]) {
if (p->second != arity) throw_ill_formed_eqns();
fn_eqs.push_back(eq);
eqidx++;
} else {
break;
}
} else {
break;
}
}
} else {
/* noequation, guess arity using type of function */
expr type = mlocal_type(m_fns[fidx]);
unsigned arity = 0;
while (is_pi(type))
type = binding_body(type);
if (arity == 0) throw_ill_formed_eqns();
m_arity.push_back(arity);
}
}
if (eqs.size() != eqidx) throw_ill_formed_eqns();
lean_assert(m_arity.size() == m_fns.size());
lean_assert(m_eqs.size() == m_fns.size());
}
expr unpack_eqns::update_fn_type(unsigned fidx, expr const & type) {
expr new_fn = m_locals.push_local(local_pp_name(m_fns[fidx]), type);
m_fns[fidx] = new_fn;
return new_fn;
}
expr unpack_eqns::repack() {
buffer<expr> new_eqs;
for (buffer<expr> const & fn_eqs : m_eqs) {
for (expr const & eq : fn_eqs) {
new_eqs.push_back(m_locals.ctx().mk_lambda(m_fns, eq));
}
}
return update_equations(m_src, new_eqs);
}
unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn):
m_src(eqn), m_locals(ctx) {
expr it = eqn;
while (is_lambda(it)) {
expr d = instantiate_rev(binding_domain(it), m_locals.as_buffer().size(), m_locals.as_buffer().data());
m_vars.push_back(m_locals.push_local(binding_name(it), d, binding_info(it)));
it = binding_body(it);
}
it = instantiate_rev(it, m_locals.as_buffer().size(), m_locals.as_buffer().data());
if (!is_equation(it)) throw_ill_formed_eqns();
m_nested_src = it;
m_lhs = equation_lhs(it);
m_rhs = equation_rhs(it);
}
expr unpack_eqn::add_var(name const & n, expr const & type) {
m_modified_vars = true;
m_vars.push_back(m_locals.push_local(n, type));
return m_vars.back();
}
expr unpack_eqn::repack() {
if (!m_modified_vars &&
equation_lhs(m_nested_src) == m_lhs &&
equation_rhs(m_nested_src) == m_rhs) return m_src;
expr new_eq = copy_tag(m_nested_src, mk_equation(m_lhs, m_rhs));
return copy_tag(m_src, m_locals.ctx().mk_lambda(m_vars, new_eq));
}
bool eqns_env_interface::is_inductive(name const & n) const {
return static_cast<bool>(inductive::is_inductive_decl(m_env, n));
}
bool eqns_env_interface::is_inductive(expr const & e) const {
if (!is_constant(e)) return false;
return is_inductive(const_name(e));
}
optional<name> eqns_env_interface::is_constructor(expr const & e) const {
if (!is_constant(e)) return optional<name>();
return inductive::is_intro_rule(m_env, const_name(e));
}
unsigned eqns_env_interface::get_inductive_num_params(name const & n) const {
lean_assert(is_inductive(n));
return *inductive::get_num_params(m_env, n);
}
unsigned eqns_env_interface::get_inductive_num_indices(name const & n) const {
lean_assert(is_inductive(n));
return *inductive::get_num_indices(m_env, n);
}
bool is_recursive_eqns(type_context & ctx, expr const & e) {
unpack_eqns ues(ctx, e);
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
buffer<expr> const & eqns = ues.get_eqns_of(fidx);
for (expr const & eqn : eqns) {
expr it = eqn;
while (is_lambda(it)) {
it = binding_body(it);
}
if (!is_equation(it)) throw_ill_formed_eqns();
expr const & rhs = equation_rhs(it);
if (find(rhs, [&](expr const & e, unsigned) {
if (is_local(e)) {
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
if (mlocal_name(e) == mlocal_name(ues.get_fn(fidx)))
return true;
}
}
return false;
})) {
return true;
}
}
}
return false;
}
}