194 lines
6.6 KiB
C++
194 lines
6.6 KiB
C++
/*
|
|
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
|
|
Author: Leonardo de Moura
|
|
*/
|
|
#include "kernel/instantiate.h"
|
|
#include "kernel/abstract.h"
|
|
#include "kernel/find_fn.h"
|
|
#include "kernel/inductive/inductive.h"
|
|
#include "library/equations_compiler/equations.h"
|
|
#include "library/equations_compiler/util.h"
|
|
|
|
namespace lean {
|
|
[[ noreturn ]] void throw_ill_formed_eqns() {
|
|
throw exception("ill-formed match/equations expression");
|
|
}
|
|
|
|
static optional<pair<expr, unsigned>> get_eqn_fn_and_arity(expr e) {
|
|
while (is_lambda(e))
|
|
e = binding_body(e);
|
|
if (!is_equation(e) && !is_no_equation(e)) throw_ill_formed_eqns();
|
|
if (is_no_equation(e)) {
|
|
return optional<pair<expr, unsigned>>();
|
|
} else {
|
|
expr const & lhs = equation_lhs(e);
|
|
expr const & fn = get_app_fn(lhs);
|
|
lean_assert(is_local(fn));
|
|
return optional<pair<expr, unsigned>>(fn, get_app_num_args(lhs));
|
|
}
|
|
}
|
|
|
|
static expr consume_fn_prefix(expr eq, buffer<expr> const & fns) {
|
|
for (unsigned i = 0; i < fns.size(); i++) {
|
|
if (!is_lambda(eq)) throw_ill_formed_eqns();
|
|
eq = binding_body(eq);
|
|
}
|
|
return instantiate_rev(eq, fns);
|
|
}
|
|
|
|
unpack_eqns::unpack_eqns(type_context & ctx, expr const & e):
|
|
m_locals(ctx) {
|
|
lean_assert(is_equations(e));
|
|
m_src = e;
|
|
buffer<expr> eqs;
|
|
unsigned num_fns = equations_num_fns(e);
|
|
to_equations(e, eqs);
|
|
/* Extract functions. */
|
|
lean_assert(eqs.size() > 0);
|
|
expr eq = eqs[0];
|
|
for (unsigned i = 0; i < num_fns; i++) {
|
|
if (!is_lambda(eq)) throw_ill_formed_eqns();
|
|
if (!closed(binding_domain(eq))) throw_ill_formed_eqns();
|
|
m_fns.push_back(m_locals.push_local(binding_name(eq), binding_domain(eq)));
|
|
eq = binding_body(eq);
|
|
}
|
|
/* Extract equations */
|
|
unsigned eqidx = 0;
|
|
for (unsigned fidx = 0; fidx < num_fns; fidx++) {
|
|
m_eqs.push_back(buffer<expr>());
|
|
buffer<expr> & fn_eqs = m_eqs.back();
|
|
if (eqidx >= eqs.size()) throw_ill_formed_eqns();
|
|
expr eq = consume_fn_prefix(eqs[eqidx], m_fns);
|
|
fn_eqs.push_back(eq);
|
|
eqidx++;
|
|
if (auto p = get_eqn_fn_and_arity(eq)) {
|
|
if (p->first != m_fns[fidx]) throw_ill_formed_eqns();
|
|
unsigned arity = p->second;
|
|
m_arity.push_back(arity);
|
|
while (eqidx < eqs.size()) {
|
|
expr eq = consume_fn_prefix(eqs[eqidx], m_fns);
|
|
if (auto p = get_eqn_fn_and_arity(eq)) {
|
|
if (p->first == m_fns[fidx]) {
|
|
if (p->second != arity) throw_ill_formed_eqns();
|
|
fn_eqs.push_back(eq);
|
|
eqidx++;
|
|
} else {
|
|
break;
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
/* noequation, guess arity using type of function */
|
|
expr type = mlocal_type(m_fns[fidx]);
|
|
unsigned arity = 0;
|
|
while (is_pi(type))
|
|
type = binding_body(type);
|
|
if (arity == 0) throw_ill_formed_eqns();
|
|
m_arity.push_back(arity);
|
|
}
|
|
}
|
|
if (eqs.size() != eqidx) throw_ill_formed_eqns();
|
|
lean_assert(m_arity.size() == m_fns.size());
|
|
lean_assert(m_eqs.size() == m_fns.size());
|
|
}
|
|
|
|
expr unpack_eqns::update_fn_type(unsigned fidx, expr const & type) {
|
|
expr new_fn = m_locals.push_local(local_pp_name(m_fns[fidx]), type);
|
|
m_fns[fidx] = new_fn;
|
|
return new_fn;
|
|
}
|
|
|
|
expr unpack_eqns::repack() {
|
|
buffer<expr> new_eqs;
|
|
for (buffer<expr> const & fn_eqs : m_eqs) {
|
|
for (expr const & eq : fn_eqs) {
|
|
new_eqs.push_back(m_locals.ctx().mk_lambda(m_fns, eq));
|
|
}
|
|
}
|
|
return update_equations(m_src, new_eqs);
|
|
}
|
|
|
|
unpack_eqn::unpack_eqn(type_context & ctx, expr const & eqn):
|
|
m_src(eqn), m_locals(ctx) {
|
|
expr it = eqn;
|
|
while (is_lambda(it)) {
|
|
expr d = instantiate_rev(binding_domain(it), m_locals.as_buffer().size(), m_locals.as_buffer().data());
|
|
m_vars.push_back(m_locals.push_local(binding_name(it), d, binding_info(it)));
|
|
it = binding_body(it);
|
|
}
|
|
it = instantiate_rev(it, m_locals.as_buffer().size(), m_locals.as_buffer().data());
|
|
if (!is_equation(it)) throw_ill_formed_eqns();
|
|
m_nested_src = it;
|
|
m_lhs = equation_lhs(it);
|
|
m_rhs = equation_rhs(it);
|
|
}
|
|
|
|
expr unpack_eqn::add_var(name const & n, expr const & type) {
|
|
m_modified_vars = true;
|
|
m_vars.push_back(m_locals.push_local(n, type));
|
|
return m_vars.back();
|
|
}
|
|
|
|
expr unpack_eqn::repack() {
|
|
if (!m_modified_vars &&
|
|
equation_lhs(m_nested_src) == m_lhs &&
|
|
equation_rhs(m_nested_src) == m_rhs) return m_src;
|
|
expr new_eq = copy_tag(m_nested_src, mk_equation(m_lhs, m_rhs));
|
|
return copy_tag(m_src, m_locals.ctx().mk_lambda(m_vars, new_eq));
|
|
}
|
|
|
|
bool eqns_env_interface::is_inductive(name const & n) const {
|
|
return static_cast<bool>(inductive::is_inductive_decl(m_env, n));
|
|
}
|
|
|
|
bool eqns_env_interface::is_inductive(expr const & e) const {
|
|
if (!is_constant(e)) return false;
|
|
return is_inductive(const_name(e));
|
|
}
|
|
|
|
optional<name> eqns_env_interface::is_constructor(expr const & e) const {
|
|
if (!is_constant(e)) return optional<name>();
|
|
return inductive::is_intro_rule(m_env, const_name(e));
|
|
}
|
|
|
|
unsigned eqns_env_interface::get_inductive_num_params(name const & n) const {
|
|
lean_assert(is_inductive(n));
|
|
return *inductive::get_num_params(m_env, n);
|
|
}
|
|
|
|
unsigned eqns_env_interface::get_inductive_num_indices(name const & n) const {
|
|
lean_assert(is_inductive(n));
|
|
return *inductive::get_num_indices(m_env, n);
|
|
}
|
|
|
|
bool is_recursive_eqns(type_context & ctx, expr const & e) {
|
|
unpack_eqns ues(ctx, e);
|
|
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
|
|
buffer<expr> const & eqns = ues.get_eqns_of(fidx);
|
|
for (expr const & eqn : eqns) {
|
|
expr it = eqn;
|
|
while (is_lambda(it)) {
|
|
it = binding_body(it);
|
|
}
|
|
if (!is_equation(it)) throw_ill_formed_eqns();
|
|
expr const & rhs = equation_rhs(it);
|
|
if (find(rhs, [&](expr const & e, unsigned) {
|
|
if (is_local(e)) {
|
|
for (unsigned fidx = 0; fidx < ues.get_num_fns(); fidx++) {
|
|
if (mlocal_name(e) == mlocal_name(ues.get_fn(fidx)))
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
})) {
|
|
return true;
|
|
}
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
}
|