1772 lines
78 KiB
C++
1772 lines
78 KiB
C++
/*
|
|
Copyright (c) 2014 Microsoft Corporation. All rights reserved.
|
|
Released under Apache 2.0 license as described in the file LICENSE.
|
|
|
|
Author: Leonardo de Moura
|
|
*/
|
|
#include <algorithm>
|
|
#include <string>
|
|
#include "util/sstream.h"
|
|
#include "util/list_fn.h"
|
|
#include "util/fresh_name.h"
|
|
#include "kernel/expr.h"
|
|
#include "kernel/type_checker.h"
|
|
#include "kernel/abstract.h"
|
|
#include "kernel/instantiate.h"
|
|
#include "kernel/error_msgs.h"
|
|
#include "kernel/for_each_fn.h"
|
|
#include "kernel/find_fn.h"
|
|
#include "kernel/replace_fn.h"
|
|
#include "library/generic_exception.h"
|
|
#include "library/kernel_serializer.h"
|
|
#include "library/io_state_stream.h"
|
|
#include "library/annotation.h"
|
|
#include "library/util.h"
|
|
#include "library/old_util.h"
|
|
#include "library/locals.h"
|
|
#include "library/constants.h"
|
|
#include "library/normalize.h"
|
|
#include "library/pp_options.h"
|
|
#include "library/definitional/old_inversion.h"
|
|
|
|
namespace lean {
|
|
static name * g_equations_name = nullptr;
|
|
static name * g_equation_name = nullptr;
|
|
static name * g_no_equation_name = nullptr;
|
|
static name * g_inaccessible_name = nullptr;
|
|
static name * g_equations_result_name = nullptr;
|
|
static std::string * g_equations_opcode = nullptr;
|
|
static std::string * g_equation_opcode = nullptr;
|
|
static std::string * g_no_equation_opcode = nullptr;
|
|
static std::string * g_equations_result_opcode = nullptr;
|
|
|
|
[[ noreturn ]] static void throw_eqs_ex() { throw exception("unexpected occurrence of 'equations' expression"); }
|
|
|
|
class equations_macro_cell : public macro_definition_cell {
|
|
unsigned m_num_fns;
|
|
public:
|
|
equations_macro_cell(unsigned num_fns):m_num_fns(num_fns) {}
|
|
virtual name get_name() const { return *g_equations_name; }
|
|
virtual expr check_type(expr const &, abstract_type_context &, bool) const { throw_eqs_ex(); }
|
|
virtual optional<expr> expand(expr const &, abstract_type_context &) const { throw_eqs_ex(); }
|
|
virtual void write(serializer & s) const { s << *g_equations_opcode << m_num_fns; }
|
|
unsigned get_num_fns() const { return m_num_fns; }
|
|
};
|
|
|
|
class equation_base_macro_cell : public macro_definition_cell {
|
|
public:
|
|
virtual expr check_type(expr const &, abstract_type_context &, bool) const {
|
|
expr dummy = mk_Prop();
|
|
return dummy;
|
|
}
|
|
virtual optional<expr> expand(expr const &, abstract_type_context &) const {
|
|
expr dummy = mk_Type();
|
|
return some_expr(dummy);
|
|
}
|
|
};
|
|
|
|
class equation_macro_cell : public equation_base_macro_cell {
|
|
public:
|
|
virtual name get_name() const { return *g_equation_name; }
|
|
virtual void write(serializer & s) const { s.write_string(*g_equation_opcode); }
|
|
};
|
|
|
|
// This is just a placeholder to indicate no equations were provided
|
|
class no_equation_macro_cell : public equation_base_macro_cell {
|
|
public:
|
|
virtual name get_name() const { return *g_no_equation_name; }
|
|
virtual void write(serializer & s) const { s.write_string(*g_no_equation_opcode); }
|
|
};
|
|
|
|
static macro_definition * g_equation = nullptr;
|
|
static macro_definition * g_no_equation = nullptr;
|
|
|
|
bool is_equation(expr const & e) { return is_macro(e) && macro_def(e) == *g_equation; }
|
|
|
|
bool is_lambda_equation(expr const & e) {
|
|
if (is_lambda(e))
|
|
return is_lambda_equation(binding_body(e));
|
|
else
|
|
return is_equation(e);
|
|
}
|
|
|
|
expr const & equation_lhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 0); }
|
|
expr const & equation_rhs(expr const & e) { lean_assert(is_equation(e)); return macro_arg(e, 1); }
|
|
expr mk_equation(expr const & lhs, expr const & rhs) {
|
|
expr args[2] = { lhs, rhs };
|
|
return mk_macro(*g_equation, 2, args);
|
|
}
|
|
expr mk_no_equation() { return mk_macro(*g_no_equation); }
|
|
bool is_no_equation(expr const & e) { return is_macro(e) && macro_def(e) == *g_no_equation; }
|
|
|
|
bool is_lambda_no_equation(expr const & e) {
|
|
if (is_lambda(e))
|
|
return is_lambda_no_equation(binding_body(e));
|
|
else
|
|
return is_no_equation(e);
|
|
}
|
|
|
|
bool is_equations(expr const & e) { return is_macro(e) && macro_def(e).get_name() == *g_equations_name; }
|
|
bool is_wf_equations_core(expr const & e) {
|
|
lean_assert(is_equations(e));
|
|
return macro_num_args(e) >= 3 && !is_lambda_equation(macro_arg(e, macro_num_args(e) - 1));
|
|
}
|
|
bool is_wf_equations(expr const & e) { return is_equations(e) && is_wf_equations_core(e); }
|
|
unsigned equations_size(expr const & e) {
|
|
lean_assert(is_equations(e));
|
|
if (is_wf_equations_core(e))
|
|
return macro_num_args(e) - 2;
|
|
else
|
|
return macro_num_args(e);
|
|
}
|
|
unsigned equations_num_fns(expr const & e) {
|
|
lean_assert(is_equations(e));
|
|
return static_cast<equations_macro_cell const*>(macro_def(e).raw())->get_num_fns();
|
|
}
|
|
expr const & equations_wf_proof(expr const & e) {
|
|
lean_assert(is_wf_equations(e));
|
|
return macro_arg(e, macro_num_args(e) - 1);
|
|
}
|
|
expr const & equations_wf_rel(expr const & e) {
|
|
lean_assert(is_wf_equations(e));
|
|
return macro_arg(e, macro_num_args(e) - 2);
|
|
}
|
|
void to_equations(expr const & e, buffer<expr> & eqns) {
|
|
lean_assert(is_equations(e));
|
|
unsigned sz = equations_size(e);
|
|
for (unsigned i = 0; i < sz; i++)
|
|
eqns.push_back(macro_arg(e, i));
|
|
}
|
|
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs) {
|
|
lean_assert(num_fns > 0);
|
|
lean_assert(num_eqs > 0);
|
|
lean_assert(std::all_of(eqs, eqs+num_eqs, [](expr const & e) {
|
|
return is_lambda_equation(e) || is_lambda_no_equation(e);
|
|
}));
|
|
macro_definition def(new equations_macro_cell(num_fns));
|
|
return mk_macro(def, num_eqs, eqs);
|
|
}
|
|
expr mk_equations(unsigned num_fns, unsigned num_eqs, expr const * eqs, expr const & R, expr const & Hwf) {
|
|
lean_assert(num_fns > 0);
|
|
lean_assert(num_eqs > 0);
|
|
lean_assert(std::all_of(eqs, eqs+num_eqs, is_lambda_equation));
|
|
buffer<expr> args;
|
|
args.append(num_eqs, eqs);
|
|
args.push_back(R);
|
|
args.push_back(Hwf);
|
|
macro_definition def(new equations_macro_cell(num_fns));
|
|
return mk_macro(def, args.size(), args.data());
|
|
}
|
|
|
|
expr update_equations(expr const & eqns, buffer<expr> const & new_eqs) {
|
|
lean_assert(is_equations(eqns));
|
|
lean_assert(!new_eqs.empty());
|
|
if (is_wf_equations(eqns)) {
|
|
return mk_equations(equations_num_fns(eqns), new_eqs.size(), new_eqs.data(),
|
|
equations_wf_rel(eqns), equations_wf_proof(eqns));
|
|
} else {
|
|
return mk_equations(equations_num_fns(eqns), new_eqs.size(), new_eqs.data());
|
|
}
|
|
}
|
|
|
|
expr mk_inaccessible(expr const & e) { return mk_annotation(*g_inaccessible_name, e); }
|
|
bool is_inaccessible(expr const & e) { return is_annotation(e, *g_inaccessible_name); }
|
|
|
|
// Auxiliary macro used to store the result of a set of equations defining a mutually recursive
|
|
// definition.
|
|
class equations_result_macro_cell : public macro_definition_cell {
|
|
public:
|
|
virtual name get_name() const { return *g_equations_result_name; }
|
|
virtual expr check_type(expr const & m, abstract_type_context & ctx, bool infer_only) const {
|
|
return ctx.check(macro_arg(m, 0), infer_only);
|
|
}
|
|
virtual optional<expr> expand(expr const & m, abstract_type_context &) const {
|
|
return some_expr(macro_arg(m, 0));
|
|
}
|
|
virtual void write(serializer & s) const { s << *g_equations_result_opcode; }
|
|
};
|
|
|
|
static macro_definition * g_equations_result = nullptr;
|
|
|
|
static expr mk_equations_result(unsigned n, expr const * rs) {
|
|
return mk_macro(*g_equations_result, n, rs);
|
|
}
|
|
|
|
bool is_equations_result(expr const & e) { return is_macro(e) && macro_def(e) == *g_equations_result; }
|
|
unsigned get_equations_result_size(expr const & e) { return macro_num_args(e); }
|
|
expr const & get_equations_result(expr const & e, unsigned i) { return macro_arg(e, i); }
|
|
|
|
void initialize_equations() {
|
|
g_equations_name = new name("equations");
|
|
g_equation_name = new name("equation");
|
|
g_no_equation_name = new name("no_equation");
|
|
g_inaccessible_name = new name("innaccessible");
|
|
g_equations_result_name = new name("equations_result");
|
|
g_equation = new macro_definition(new equation_macro_cell());
|
|
g_no_equation = new macro_definition(new no_equation_macro_cell());
|
|
g_equations_result = new macro_definition(new equations_result_macro_cell());
|
|
g_equations_opcode = new std::string("Eqns");
|
|
g_equation_opcode = new std::string("Eqn");
|
|
g_no_equation_opcode = new std::string("NEqn");
|
|
g_equations_result_opcode = new std::string("EqnR");
|
|
register_annotation(*g_inaccessible_name);
|
|
register_macro_deserializer(*g_equations_opcode,
|
|
[](deserializer & d, unsigned num, expr const * args) {
|
|
unsigned num_fns;
|
|
d >> num_fns;
|
|
if (num == 0 || num_fns == 0)
|
|
throw corrupted_stream_exception();
|
|
if (!is_lambda_equation(args[num-1]) && !is_lambda_no_equation(args[num-1])) {
|
|
if (num <= 2)
|
|
throw corrupted_stream_exception();
|
|
return mk_equations(num_fns, num-2, args, args[num-2], args[num-1]);
|
|
} else {
|
|
return mk_equations(num_fns, num, args);
|
|
}
|
|
});
|
|
register_macro_deserializer(*g_equation_opcode,
|
|
[](deserializer &, unsigned num, expr const * args) {
|
|
if (num != 2)
|
|
throw corrupted_stream_exception();
|
|
return mk_equation(args[0], args[1]);
|
|
});
|
|
register_macro_deserializer(*g_no_equation_opcode,
|
|
[](deserializer &, unsigned num, expr const *) {
|
|
if (num != 0)
|
|
throw corrupted_stream_exception();
|
|
return mk_no_equation();
|
|
});
|
|
register_macro_deserializer(*g_equations_result_opcode,
|
|
[](deserializer &, unsigned num, expr const * args) {
|
|
return mk_equations_result(num, args);
|
|
});
|
|
}
|
|
|
|
void finalize_equations() {
|
|
delete g_equations_result_opcode;
|
|
delete g_equation_opcode;
|
|
delete g_no_equation_opcode;
|
|
delete g_equations_opcode;
|
|
delete g_equations_result;
|
|
delete g_equation;
|
|
delete g_no_equation;
|
|
delete g_equations_result_name;
|
|
delete g_equations_name;
|
|
delete g_equation_name;
|
|
delete g_no_equation_name;
|
|
delete g_inaccessible_name;
|
|
}
|
|
|
|
class equation_compiler_fn {
|
|
old_type_checker & m_tc;
|
|
io_state const & m_ios;
|
|
expr m_meta;
|
|
expr m_meta_type;
|
|
buffer<expr> m_global_context;
|
|
// The additional context is used to store inductive datatype parameters that occur as arguments in recursive equations.
|
|
// For example, suppose the user writes
|
|
//
|
|
// definition append : Π (A : Type), list A → list A → list A,
|
|
// append A nil l := l,
|
|
// append A (h :: t) l := h :: (append t l)
|
|
//
|
|
// instead of
|
|
//
|
|
// definition append (A : Type) : list A → list A → list A,
|
|
// append nil l := l,
|
|
// append (h :: t) l := h :: (append t l)
|
|
//
|
|
// In this case, we move the parameter (A : Type) to m_additional_context and simplify the recursive equations.
|
|
// The simplification is necessary when we are translating the recursive applications into a brec_on recursor.
|
|
buffer<expr> m_additional_context;
|
|
buffer<expr> m_fns; // functions being defined
|
|
|
|
environment const & env() const { return m_tc.env(); }
|
|
io_state const & ios() const { return m_ios; }
|
|
io_state_stream out() const { return regular(env(), ios(), m_tc.get_type_context()); }
|
|
expr whnf(expr const & e) { return m_tc.whnf(e).first; }
|
|
expr infer_type(expr const & e) { return m_tc.infer(e).first; }
|
|
bool is_def_eq(expr const & e1, expr const & e2) { return m_tc.is_def_eq(e1, e2).first; }
|
|
bool is_proof_irrelevant() const { return m_tc.env().prop_proof_irrel(); }
|
|
|
|
optional<name> is_constructor(expr const & e) const {
|
|
if (!is_constant(e))
|
|
return optional<name>();
|
|
return inductive::is_intro_rule(env(), const_name(e));
|
|
}
|
|
|
|
expr to_telescope(expr const & e, buffer<expr> & tele) {
|
|
return ::lean::to_telescope(e, tele, optional<binder_info>());
|
|
}
|
|
|
|
expr fun_to_telescope(expr const & e, buffer<expr> & tele) {
|
|
return ::lean::fun_to_telescope(e, tele, optional<binder_info>());
|
|
}
|
|
|
|
// Similar to to_telescope, but uses normalization
|
|
expr to_telescope_ext(expr const & e, buffer<expr> & tele) {
|
|
return ::lean::to_telescope(m_tc, e, tele, optional<binder_info>());
|
|
}
|
|
|
|
[[ noreturn ]] static void throw_error(char const * msg, expr const & src) { throw_generic_exception(msg, src); }
|
|
[[ noreturn ]] static void throw_error(sstream const & ss, expr const & src) { throw_generic_exception(ss, src); }
|
|
[[ noreturn ]] static void throw_error(expr const & src, pp_fn const & fn) { throw_generic_exception(src, fn); }
|
|
[[ noreturn ]] void throw_error(sstream const & ss) const { throw_generic_exception(ss, m_meta); }
|
|
[[ noreturn ]] void throw_error(expr const & src, sstream const & ss) const { throw_generic_exception(ss, src); }
|
|
|
|
void check_limitations(expr const & eqns) const {
|
|
if (is_wf_equations(eqns) && equations_num_fns(eqns) != 1)
|
|
throw_error("mutually recursive equations do not support well-founded recursion yet", eqns);
|
|
}
|
|
|
|
#ifdef LEAN_DEBUG
|
|
static bool disjoint(list<expr> const & l1, list<expr> const & l2) {
|
|
for (expr const & e1 : l1) {
|
|
for (expr const & e2 : l2) {
|
|
lean_assert(mlocal_name(e1) != mlocal_name(e2));
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Return true iff all names in s1 are names of local constants in s2.
|
|
static bool contained(list<optional<name>> const & s1, list<expr> const & s2) {
|
|
return std::all_of(s1.begin(), s1.end(), [&](optional<name> const & n) {
|
|
return
|
|
!n ||
|
|
std::any_of(s2.begin(), s2.end(), [&](expr const & e) {
|
|
return mlocal_name(e) == *n;
|
|
});
|
|
});
|
|
}
|
|
#endif
|
|
|
|
struct eqn {
|
|
// The local context for an equation is of additional
|
|
// local constants occurring in m_patterns and m_rhs
|
|
// which are not in m_global_context or
|
|
// in the function containing the equation.
|
|
// Remark: each function/program contains its own m_context.
|
|
// So, the variables occurring in m_patterns and m_rhs
|
|
// are in m_global_context, m_context, or m_local_context,
|
|
// or is one of the functions being defined.
|
|
// We say an equation is in "compiled" form
|
|
// if m_local_context and m_patterns are empty.
|
|
list<expr> m_local_context;
|
|
list<expr> m_patterns; // patterns to be processed
|
|
expr m_rhs; // right-hand-side
|
|
eqn(list<expr> const & c, list<expr> const & p, expr const & r):
|
|
m_local_context(c), m_patterns(p), m_rhs(r) {}
|
|
eqn(eqn const & e, list<expr> const & c, list<expr> const & p):
|
|
eqn(c, p, e.m_rhs) {}
|
|
};
|
|
|
|
// Data-structure used to store for compiling pattern matching.
|
|
// We create a program object for each function being defined
|
|
struct program {
|
|
expr m_fn; // function being defined
|
|
list<expr> m_context; // local constants
|
|
list<optional<name>> m_var_stack; // variables that must be matched with the patterns it is a "subset" of m_context.
|
|
list<eqn> m_eqns; // equations
|
|
expr m_type; // result type
|
|
|
|
// Due to dependent pattern matching some elements in m_var_stack are "none", and are skipped
|
|
// during dependent pattern matching.
|
|
|
|
// The goal of the compiler is to process all variables in m_var_stack
|
|
program(expr const & fn, list<expr> const & ctx, list<optional<name>> const & s, list<eqn> const & e, expr const & t):
|
|
m_fn(fn), m_context(ctx), m_var_stack(s), m_eqns(e), m_type(t) {
|
|
lean_assert(contained(m_var_stack, m_context));
|
|
}
|
|
program(program const & p, list<expr> const & ctx, list<optional<name>> const & new_s, list<eqn> const & new_e):
|
|
program(p.m_fn, ctx, new_s, new_e, p.m_type) {}
|
|
program(program const & p, list<optional<name>> const & new_s, list<eqn> const & new_e):
|
|
program(p.m_fn, p.m_context, new_s, new_e, p.m_type) {}
|
|
program(program const & p, list<expr> const & ctx):
|
|
program(p.m_fn, ctx, p.m_var_stack, p.m_eqns, p.m_type) {}
|
|
program(program const & p, list<eqn> const & new_e):
|
|
program(p, p.m_var_stack, new_e) {}
|
|
program() {}
|
|
expr const & get_var(name const & n) const {
|
|
for (expr const & v : m_context) {
|
|
if (mlocal_name(v) == n)
|
|
return v;
|
|
}
|
|
lean_unreachable();
|
|
}
|
|
};
|
|
|
|
// Auxiliary fields for producing error messages
|
|
buffer<program> m_init_prgs;
|
|
unsigned m_prg_idx; // current program index being compiled
|
|
|
|
#ifdef LEAN_DEBUG
|
|
// For debugging purposes: checks whether all local constants occurring in \c e
|
|
// are in local_ctx or m_global_context
|
|
bool check_ctx(expr const & e, list<expr> const & context, list<expr> const & local_context) const {
|
|
for_each(e, [&](expr const & e, unsigned) {
|
|
if (is_local(e)) {
|
|
if (!(contains_local(e, local_context) ||
|
|
contains_local(e, context) ||
|
|
contains_local(e, m_additional_context) ||
|
|
contains_local(e, m_global_context) ||
|
|
contains_local(e, m_fns))) {
|
|
lean_unreachable();
|
|
}
|
|
return false; // do not visit type
|
|
}
|
|
if (is_metavar(e))
|
|
return false; // do not visit type
|
|
return true;
|
|
});
|
|
return true;
|
|
}
|
|
|
|
// For debugging purposes: check if the program is well-formed
|
|
bool check_program(program const & s) const {
|
|
unsigned sz = length(s.m_var_stack);
|
|
lean_assert(contained(s.m_var_stack, s.m_context));
|
|
for (eqn const & e : s.m_eqns) {
|
|
// the number of patterns in each equation is equal to the variable stack size
|
|
if (length(e.m_patterns) != sz) {
|
|
lean_unreachable();
|
|
return false;
|
|
}
|
|
check_ctx(e.m_rhs, s.m_context, e.m_local_context);
|
|
for (expr const & p : e.m_patterns)
|
|
check_ctx(p, s.m_context, e.m_local_context);
|
|
lean_assert(disjoint(e.m_local_context, s.m_context));
|
|
}
|
|
return true;
|
|
}
|
|
#endif
|
|
|
|
// Initialize m_fns (the vector of functions to be compiled)
|
|
void initialize_fns(expr const & eqns) {
|
|
lean_assert(is_equations(eqns));
|
|
unsigned num_fns = equations_num_fns(eqns);
|
|
buffer<expr> eqs;
|
|
to_equations(eqns, eqs);
|
|
expr eq = eqs[0];
|
|
for (unsigned i = 0; i < num_fns; i++) {
|
|
expr fn = mk_local(mk_fresh_name(), binding_name(eq), binding_domain(eq), binder_info());
|
|
m_fns.push_back(fn);
|
|
eq = instantiate(binding_body(eq), fn);
|
|
}
|
|
}
|
|
|
|
// Store in \c arities the number of arguments of each function being defined.
|
|
// This procedure also makes sure that two different equations for the same function
|
|
// contain the same number of arguments in the left-hand-side.
|
|
// Remark: after executing this procedure the arity of m_fns[i] is stored in arities[i]
|
|
// if there is at least one equation for m_fns[i].
|
|
void initialize_arities(expr const & eqns, buffer<optional<unsigned>> & arities) {
|
|
lean_assert(arities.empty());
|
|
buffer<expr> eqs;
|
|
to_equations(eqns, eqs);
|
|
lean_assert(!eqs.empty());
|
|
arities.resize(m_fns.size());
|
|
for (expr eq : eqs) {
|
|
if (is_lambda_equation(eq)) {
|
|
for (expr const & fn : m_fns)
|
|
eq = instantiate(binding_body(eq), fn);
|
|
while (is_lambda(eq))
|
|
eq = binding_body(eq);
|
|
lean_assert(is_equation(eq));
|
|
expr const & lhs = equation_lhs(eq);
|
|
buffer<expr> lhs_args;
|
|
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
|
if (!is_local(lhs_fn))
|
|
throw_error(sstream() << "invalid equation, "
|
|
<< "left-hand-side is not one of the functions being defined", eq);
|
|
unsigned i = 0;
|
|
for (; i < m_fns.size(); i++) {
|
|
if (lhs_fn == m_fns[i]) {
|
|
if (arities[i] && *arities[i] != lhs_args.size())
|
|
throw_error(sstream() << "invalid equation for '" << lhs_fn << "' "
|
|
<< "left-hand-side of different equations have different number of arguments", eq);
|
|
arities[i] = lhs_args.size();
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Initialize the variable stack for each function that needs
|
|
// to be compiled.
|
|
// This method assumes m_fns has been already initialized.
|
|
// This method also initialized the buffer prg, but the eqns
|
|
// field of each program is not initialized by it.
|
|
//
|
|
// See initialize_arities for an explanation for \c arities.
|
|
void initialize_var_stack(buffer<program> & prgs, buffer<optional<unsigned>> const & arities) {
|
|
lean_assert(!m_fns.empty());
|
|
lean_assert(prgs.empty());
|
|
for (unsigned i = 0; i < m_fns.size(); i++) {
|
|
expr const & fn = m_fns[i];
|
|
buffer<expr> args;
|
|
expr r_type = to_telescope(mlocal_type(fn), args);
|
|
for (expr & arg : args)
|
|
arg = update_mlocal(arg, whnf(mlocal_type(arg)));
|
|
if (arities[i]) {
|
|
unsigned arity = *arities[i];
|
|
if (args.size() > arity) {
|
|
r_type = Pi(args.size() - arity, args.data() + arity, r_type);
|
|
args.shrink(arity);
|
|
}
|
|
}
|
|
list<expr> ctx = to_list(args);
|
|
list<optional<name>> vstack = map2<optional<name>>(ctx, [](expr const & e) {
|
|
return optional<name>(mlocal_name(e));
|
|
});
|
|
prgs.push_back(program(fn, ctx, vstack, list<eqn>(), r_type));
|
|
}
|
|
}
|
|
|
|
struct validate_exception {
|
|
expr m_expr;
|
|
validate_exception(expr const & e):m_expr(e) {}
|
|
};
|
|
|
|
void check_in_local_ctx(expr const & e, buffer<expr> const & local_ctx) {
|
|
if (!contains_local(e, local_ctx))
|
|
throw_error(e, sstream() << "invalid equation, variable '" << e
|
|
<< "' has the same name of a variable in an outer-scope (solution: rename this variable)");
|
|
}
|
|
|
|
// Validate/normalize the given pattern.
|
|
// It stores in reachable_vars any variable that does not occur
|
|
// in inaccessible terms.
|
|
expr validate_pattern(expr pat, buffer<expr> const & local_ctx, name_set & reachable_vars) {
|
|
if (is_inaccessible(pat))
|
|
return pat;
|
|
if (is_local(pat)) {
|
|
reachable_vars.insert(mlocal_name(pat));
|
|
check_in_local_ctx(pat, local_ctx);
|
|
return pat;
|
|
}
|
|
expr new_pat = whnf(pat);
|
|
if (is_local(new_pat)) {
|
|
reachable_vars.insert(mlocal_name(new_pat));
|
|
check_in_local_ctx(new_pat, local_ctx);
|
|
return new_pat;
|
|
}
|
|
buffer<expr> pat_args;
|
|
expr const & fn = get_app_args(new_pat, pat_args);
|
|
if (auto in = is_constructor(fn)) {
|
|
unsigned num_params = *inductive::get_num_params(env(), *in);
|
|
for (unsigned i = num_params; i < pat_args.size(); i++)
|
|
pat_args[i] = validate_pattern(pat_args[i], local_ctx, reachable_vars);
|
|
return mk_app(fn, pat_args, pat.get_tag());
|
|
} else {
|
|
throw validate_exception(pat);
|
|
}
|
|
}
|
|
|
|
// Validate/normalize the patterns associated with the given lhs.
|
|
// The lhs is only used to report errors.
|
|
// It stores in reachable_vars any variable that does not occur
|
|
// in inaccessible terms.
|
|
void validate_patterns(expr const & lhs, buffer<expr> const & local_ctx, buffer<expr> & patterns, name_set & reachable_vars) {
|
|
for (expr & pat : patterns) {
|
|
try {
|
|
pat = validate_pattern(pat, local_ctx, reachable_vars);
|
|
} catch (validate_exception & ex) {
|
|
expr problem_expr = ex.m_expr;
|
|
throw_error(lhs, [=](formatter const & fmt) {
|
|
format r("invalid argument, it is not a constructor, variable, "
|
|
"nor it is marked as an inaccessible pattern");
|
|
r += pp_indent_expr(fmt, problem_expr);
|
|
r += compose(line(), format("in the following equation left-hand-side"));
|
|
r += pp_indent_expr(fmt, lhs);
|
|
return r;
|
|
});
|
|
}
|
|
}
|
|
}
|
|
|
|
// Create initial program state for each function being defined.
|
|
void initialize(expr const & eqns, buffer<program> & prg) {
|
|
lean_assert(is_equations(eqns));
|
|
buffer<optional<unsigned>> arities;
|
|
initialize_fns(eqns);
|
|
initialize_arities(eqns, arities);
|
|
initialize_var_stack(prg, arities);
|
|
buffer<expr> eqs;
|
|
to_equations(eqns, eqs);
|
|
buffer<buffer<eqn>> res_eqns;
|
|
res_eqns.resize(m_fns.size());
|
|
for (expr eq : eqs) {
|
|
if (is_lambda_no_equation(eq))
|
|
continue; // skip marker
|
|
for (expr const & fn : m_fns)
|
|
eq = instantiate(binding_body(eq), fn);
|
|
buffer<expr> local_ctx;
|
|
eq = fun_to_telescope(eq, local_ctx);
|
|
expr const & lhs = equation_lhs(eq);
|
|
expr const & rhs = equation_rhs(eq);
|
|
buffer<expr> patterns;
|
|
expr const & fn = get_app_args(lhs, patterns);
|
|
name_set reachable_vars;
|
|
validate_patterns(lhs, local_ctx, patterns, reachable_vars);
|
|
for (expr const & v : local_ctx) {
|
|
// every variable in the local_ctx must be "reachable".
|
|
if (!reachable_vars.contains(mlocal_name(v))) {
|
|
throw_error(lhs, [=](formatter const & fmt) {
|
|
options o = fmt.get_options().update_if_undef(get_pp_implicit_name(), true);
|
|
formatter new_fmt = fmt.update_options(o);
|
|
format r("invalid equation left-hand-side, variable '");
|
|
r += format(local_pp_name(v));
|
|
r += format("' only occurs in inaccessible terms in the following equation left-hand-side");
|
|
r += pp_indent_expr(new_fmt, lhs);
|
|
return r;
|
|
});
|
|
}
|
|
}
|
|
for (unsigned i = 0; i < m_fns.size(); i++) {
|
|
if (mlocal_name(fn) == mlocal_name(m_fns[i])) {
|
|
if (patterns.size() != length(prg[i].m_var_stack))
|
|
throw_error("ill-formed equation, number of provided arguments does not match function type", eq);
|
|
res_eqns[i].push_back(eqn(to_list(local_ctx), to_list(patterns), rhs));
|
|
}
|
|
}
|
|
}
|
|
for (unsigned i = 0; i < m_fns.size(); i++) {
|
|
prg[i].m_eqns = to_list(res_eqns[i]);
|
|
lean_assert(check_program(prg[i]));
|
|
}
|
|
}
|
|
|
|
// For debugging purposes: display the context at m_ios
|
|
template<typename Ctx>
|
|
void display_ctx(Ctx const & ctx) const {
|
|
bool first = true;
|
|
for (expr const & e : ctx) {
|
|
out() << (first ? "" : ", ") << local_pp_name(e) << " : " << mlocal_type(e);
|
|
first = false;
|
|
}
|
|
}
|
|
|
|
// For debugging purposes: dump prg in m_ios
|
|
void display(program const & prg) const {
|
|
display_ctx(prg.m_context);
|
|
out() << " ;;";
|
|
for (optional<name> const & v : prg.m_var_stack) {
|
|
if (v)
|
|
out() << " " << local_pp_name(prg.get_var(*v));
|
|
else
|
|
out() << " <none>";
|
|
}
|
|
out() << " |- " << prg.m_type << "\n";
|
|
out() << "\n";
|
|
for (eqn const & e : prg.m_eqns) {
|
|
out() << "> ";
|
|
display_ctx(e.m_local_context);
|
|
out() << " |-";
|
|
for (expr const & p : e.m_patterns) {
|
|
if (is_atomic(p))
|
|
out() << " " << p;
|
|
else
|
|
out() << " (" << p << ")";
|
|
}
|
|
out() << " := " << e.m_rhs << "\n";
|
|
}
|
|
}
|
|
|
|
// Return true iff the next pattern in all equations is a variable or an inaccessible term
|
|
bool is_variable_transition(program const & p) const {
|
|
for (eqn const & e : p.m_eqns) {
|
|
lean_assert(e.m_patterns);
|
|
if (!is_local(head(e.m_patterns)) && !is_inaccessible(head(e.m_patterns)))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
// Return true iff the next pattern in all equations is a constructor.
|
|
bool is_constructor_transition(program const & p) const {
|
|
for (eqn const & e : p.m_eqns) {
|
|
lean_assert(e.m_patterns);
|
|
if (!is_constructor(get_app_fn(head(e.m_patterns))))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/** Return true if there are no equations, and the next variable is an inductive datatype.
|
|
In this case, it is worth trying the cases tactic, since this may be a conflicting state. */
|
|
bool is_no_equation_constructor_transition(program const & p) {
|
|
lean_assert(p.m_var_stack);
|
|
if (!p.m_eqns && head(p.m_var_stack)) {
|
|
expr const & x = p.get_var(*head(p.m_var_stack));
|
|
expr const & I = get_app_fn(mlocal_type(x));
|
|
return is_constant(I) && inductive::is_inductive_decl(env(), const_name(I));
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Return true iff the next pattern of every equation is a constructor or variable,
|
|
// and there are at least one equation where it is a variable and another where it is a
|
|
// constructor.
|
|
bool is_complete_transition(program const & p) const {
|
|
bool has_variable = false;
|
|
bool has_constructor = false;
|
|
for (eqn const & e : p.m_eqns) {
|
|
lean_assert(e.m_patterns);
|
|
expr const & p = head(e.m_patterns);
|
|
if (is_local(p))
|
|
has_variable = true;
|
|
else if (is_constructor(get_app_fn(p)))
|
|
has_constructor = true;
|
|
else
|
|
return false;
|
|
}
|
|
return has_variable && has_constructor;
|
|
}
|
|
|
|
// Remove variable from local context
|
|
static list<expr> remove(list<expr> const & local_ctx, expr const & l) {
|
|
if (!local_ctx)
|
|
return local_ctx;
|
|
else if (mlocal_name(head(local_ctx)) == mlocal_name(l))
|
|
return tail(local_ctx);
|
|
else
|
|
return cons(head(local_ctx), remove(tail(local_ctx), l));
|
|
}
|
|
|
|
static expr replace(expr const & e, buffer<expr> const & from_buffer, buffer<expr> const & to_buffer) {
|
|
lean_assert(from_buffer.size() == to_buffer.size());
|
|
return instantiate_rev(abstract_locals(e, from_buffer.size(), from_buffer.data()),
|
|
to_buffer.size(), to_buffer.data());
|
|
}
|
|
|
|
eqn replace(eqn const & e, expr const & from, expr const & to) {
|
|
buffer<expr> from_buffer; from_buffer.push_back(from);
|
|
buffer<expr> to_buffer; to_buffer.push_back(to);
|
|
buffer<expr> new_ctx;
|
|
for (expr const & l : e.m_local_context) {
|
|
expr new_l = replace(l, from_buffer, to_buffer);
|
|
if (new_l != l) {
|
|
from_buffer.push_back(l);
|
|
to_buffer.push_back(new_l);
|
|
}
|
|
new_ctx.push_back(new_l);
|
|
}
|
|
list<expr> new_patterns = map(e.m_patterns, [&](expr const & p) { return replace(p, from_buffer, to_buffer); });
|
|
expr new_rhs = replace(e.m_rhs, from_buffer, to_buffer);
|
|
return eqn(to_list(new_ctx), new_patterns, new_rhs);
|
|
}
|
|
|
|
expr compile_skip(program const & prg) {
|
|
lean_assert(!head(prg.m_var_stack));
|
|
auto new_stack = tail(prg.m_var_stack);
|
|
buffer<eqn> new_eqs;
|
|
for (eqn const & e : prg.m_eqns) {
|
|
auto new_patterns = tail(e.m_patterns);
|
|
new_eqs.emplace_back(e.m_local_context, new_patterns, e.m_rhs);
|
|
}
|
|
return compile_core(program(prg, new_stack, to_list(new_eqs)));
|
|
}
|
|
|
|
expr compile_variable(program const & prg) {
|
|
// The next pattern of every equation is a variable (or inaccessible term).
|
|
// Thus, we just rename them with the variable on
|
|
// the top of the variable stack.
|
|
// Remark: if the pattern is an inaccessible term, we just ignore it.
|
|
expr x = prg.get_var(*head(prg.m_var_stack));
|
|
auto new_stack = tail(prg.m_var_stack);
|
|
buffer<eqn> new_eqs;
|
|
for (eqn const & e : prg.m_eqns) {
|
|
expr p = head(e.m_patterns);
|
|
if (is_inaccessible(p)) {
|
|
new_eqs.emplace_back(e.m_local_context, tail(e.m_patterns), e.m_rhs);
|
|
} else {
|
|
lean_assert(is_local(p));
|
|
if (contains_local(p, e.m_local_context)) {
|
|
list<expr> new_local_ctx = remove(e.m_local_context, p);
|
|
new_eqs.push_back(replace(eqn(e, new_local_ctx, tail(e.m_patterns)), p, x));
|
|
} else {
|
|
new_eqs.emplace_back(eqn(e, e.m_local_context, tail(e.m_patterns)));
|
|
}
|
|
}
|
|
}
|
|
return compile_core(program(prg, new_stack, to_list(new_eqs)));
|
|
}
|
|
|
|
class implementation : public inversion::implementation {
|
|
eqn m_eqn;
|
|
public:
|
|
implementation(eqn const & e):m_eqn(e) {}
|
|
|
|
eqn const & get_eqn() const { return m_eqn; }
|
|
|
|
virtual name const & get_constructor_name() const {
|
|
return const_name(get_app_fn(head(m_eqn.m_patterns)));
|
|
}
|
|
|
|
virtual void update_exprs(std::function<expr(expr const &)> const & fn) {
|
|
m_eqn.m_local_context = map(m_eqn.m_local_context, fn);
|
|
m_eqn.m_patterns = map(m_eqn.m_patterns, fn);
|
|
m_eqn.m_rhs = fn(m_eqn.m_rhs);
|
|
}
|
|
};
|
|
|
|
// Wrap the equations from \c p as an "implementation_list" for the inversion package.
|
|
inversion::implementation_list to_implementation_list(program const & p) {
|
|
return map2<inversion::implementation_ptr>(p.m_eqns, [&](eqn const & e) {
|
|
return std::shared_ptr<inversion::implementation>(new implementation(e));
|
|
});
|
|
}
|
|
|
|
// Convert program into a goal. We need that to be able to invoke the inversion package.
|
|
old_goal to_goal(program const & p) {
|
|
buffer<expr> hyps;
|
|
to_buffer(p.m_context, hyps);
|
|
expr new_type = p.m_type;
|
|
expr new_meta = mk_app(mk_metavar(mk_fresh_name(), Pi(hyps, new_type)), hyps);
|
|
return old_goal(new_meta, new_type);
|
|
}
|
|
|
|
// Convert goal and implementation_list back into a program.
|
|
// - nvars is the number of new variables in the variable stack.
|
|
program to_program(expr const & fn, old_goal const & g, unsigned nvars, list<optional<name>> const & new_var_stack,
|
|
inversion::implementation_list const & imps) {
|
|
buffer<expr> new_context;
|
|
g.get_hyps(new_context);
|
|
expr new_type = g.get_type();
|
|
buffer<eqn> new_equations;
|
|
for (inversion::implementation_ptr const & imp : imps) {
|
|
eqn e = static_cast<implementation*>(imp.get())->get_eqn();
|
|
buffer<expr> pat_args;
|
|
get_app_args(head(e.m_patterns), pat_args);
|
|
lean_assert(pat_args.size() >= nvars);
|
|
list<expr> new_pats = to_list(pat_args.end() - nvars, pat_args.end(), tail(e.m_patterns));
|
|
new_equations.push_back(eqn(e.m_local_context, new_pats, e.m_rhs));
|
|
}
|
|
return program(fn, to_list(new_context), new_var_stack, to_list(new_equations), new_type);
|
|
}
|
|
|
|
/** \brief Compile constructor transition.
|
|
\remark if fail_if_subgoals is true, then it returns none if there are subgoals.
|
|
*/
|
|
optional<expr> compile_constructor_core(program const & p, bool fail_if_subgoals) {
|
|
expr h = p.get_var(*head(p.m_var_stack));
|
|
old_goal g = to_goal(p);
|
|
auto imps = to_implementation_list(p);
|
|
bool clear_elim = false;
|
|
if (auto r = apply(env(), ios(), m_tc, g, h, imps, clear_elim)) {
|
|
substitution subst = r->m_subst;
|
|
list<list<expr>> args = r->m_args;
|
|
list<rename_map> rn_maps = r->m_renames;
|
|
list<inversion::implementation_list> imps_list = r->m_implementation_lists;
|
|
if (fail_if_subgoals && r->m_goals)
|
|
return none_expr();
|
|
for (old_goal const & new_g : r->m_goals) {
|
|
list<optional<name>> new_vars = map2<optional<name>>(head(args),
|
|
[](expr const & a) {
|
|
if (is_local(a))
|
|
return optional<name>(mlocal_name(a));
|
|
else
|
|
return optional<name>();
|
|
});
|
|
rename_map const & rn = head(rn_maps);
|
|
list<optional<name>> new_var_stack = map(tail(p.m_var_stack),
|
|
[&](optional<name> const & n) -> optional<name> {
|
|
if (n)
|
|
return optional<name>(rn.find(*n));
|
|
else
|
|
return n;
|
|
});
|
|
list<optional<name>> new_case_stack = append(new_vars, new_var_stack);
|
|
program new_p = to_program(p.m_fn, new_g, length(new_vars), new_case_stack, head(imps_list));
|
|
args = tail(args);
|
|
imps_list = tail(imps_list);
|
|
rn_maps = tail(rn_maps);
|
|
expr t = compile_core(new_p);
|
|
assign(subst, new_g, t);
|
|
}
|
|
expr t = subst.instantiate_all(g.get_meta());
|
|
return some_expr(t);
|
|
} else {
|
|
throw_error(sstream() << "pattern matching failed");
|
|
}
|
|
}
|
|
|
|
expr compile_constructor(program const & p) {
|
|
bool fail_if_subgoals = false;
|
|
return *compile_constructor_core(p, fail_if_subgoals);
|
|
}
|
|
|
|
expr compile_no_equations(program const & p) {
|
|
lean_assert(head(p.m_var_stack));
|
|
expr const & x = p.get_var(*head(p.m_var_stack));
|
|
expr const & I = get_app_fn(mlocal_type(x));
|
|
lean_assert(is_constant(I) && inductive::is_inductive_decl(env(), const_name(I)));
|
|
/* If the head variable is a recursive datatype, then we want to fail if subgoals are generated.
|
|
Reason: avoid non-termination. */
|
|
bool fail_if_subgoals = is_recursive_datatype(env(), const_name(I));
|
|
if (auto r = compile_constructor_core(p, fail_if_subgoals))
|
|
return *r;
|
|
else
|
|
return compile_variable(p);
|
|
}
|
|
|
|
expr mk_constructor(name const & n, levels const & ls, buffer<expr> const & params, buffer<expr> & args) {
|
|
expr c = mk_app(mk_constant(n, ls), params);
|
|
expr t = infer_type(c);
|
|
to_telescope_ext(t, args);
|
|
return mk_app(c, args);
|
|
}
|
|
|
|
expr compile_complete(program const & prg) {
|
|
// The next pattern of every equation is a constructor or variable.
|
|
// We split the equations where the next pattern is a variable into cases.
|
|
// That is, we are reducing this case to the compile_constructor case.
|
|
buffer<eqn> new_eqns;
|
|
for (eqn const & e : prg.m_eqns) {
|
|
expr const & p = head(e.m_patterns);
|
|
if (is_local(p)) {
|
|
list<expr> rest_ctx = remove(e.m_local_context, p);
|
|
expr x = prg.get_var(*head(prg.m_var_stack));
|
|
expr x_type = whnf(mlocal_type(x));
|
|
buffer<expr> I_args;
|
|
expr const & I = get_app_args(x_type, I_args);
|
|
name const & I_name = const_name(I);
|
|
levels const & I_ls = const_levels(I);
|
|
unsigned nparams = *inductive::get_num_params(env(), I_name);
|
|
buffer<expr> I_params;
|
|
I_params.append(nparams, I_args.data());
|
|
buffer<name> constructor_names;
|
|
get_intro_rule_names(env(), I_name, constructor_names);
|
|
for (name const & c_name : constructor_names) {
|
|
buffer<expr> new_args;
|
|
expr c = mk_constructor(c_name, I_ls, I_params, new_args);
|
|
list<expr> new_ctx = to_list(new_args.begin(), new_args.end(), rest_ctx);
|
|
list<expr> new_patterns = cons(c, tail(e.m_patterns));
|
|
new_eqns.push_back(replace(eqn(e, new_ctx, new_patterns), p, c));
|
|
}
|
|
} else {
|
|
new_eqns.push_back(e);
|
|
}
|
|
}
|
|
return compile_core(program(prg, to_list(new_eqns)));
|
|
}
|
|
|
|
[[ noreturn ]] void throw_non_exhaustive() {
|
|
program prg = m_init_prgs[m_prg_idx];
|
|
throw_error(m_meta, [=](formatter const & _fmt) {
|
|
options opts = _fmt.get_options().update_if_undef(get_pp_implicit_name(), true);
|
|
opts = opts.update_if_undef(get_pp_purify_locals_name(), false);
|
|
formatter fmt = _fmt.update_options(opts);
|
|
format r = format("invalid non-exhaustive set of equations, "
|
|
"left-hand-side(s) after elaboration:");
|
|
for (eqn const & e : prg.m_eqns) {
|
|
expr lhs = prg.m_fn;
|
|
for (expr const & p : e.m_patterns) lhs = mk_app(lhs, p);
|
|
r += pp_indent_expr(fmt, lhs);
|
|
r += line();
|
|
}
|
|
return r;
|
|
});
|
|
}
|
|
|
|
expr compile_core(program const & p) {
|
|
lean_assert(check_program(p));
|
|
// out() << "compile_core step\n";
|
|
// display(p);
|
|
// out() << "------------------\n";
|
|
if (p.m_var_stack) {
|
|
if (!head(p.m_var_stack)) {
|
|
return compile_skip(p);
|
|
} else if (is_no_equation_constructor_transition(p)) {
|
|
return compile_no_equations(p);
|
|
} else if (is_variable_transition(p)) {
|
|
return compile_variable(p);
|
|
} else if (is_constructor_transition(p)) {
|
|
return compile_constructor(p);
|
|
} else if (is_complete_transition(p)) {
|
|
return compile_complete(p);
|
|
} else {
|
|
// In some equations the next pattern is an inaccessible term,
|
|
// and in others it is a constructor.
|
|
throw_error(sstream() << "invalid equations for '" << local_pp_name(p.m_fn)
|
|
<< "', inconsistent use of inaccessible term annotation, "
|
|
<< "in some equations a pattern is a constructor, and in another it is an inaccessible term");
|
|
}
|
|
} else {
|
|
if (p.m_eqns) {
|
|
// variable stack is empty
|
|
expr r = head(p.m_eqns).m_rhs;
|
|
lean_assert(is_def_eq(infer_type(r), p.m_type));
|
|
return r;
|
|
} else {
|
|
throw_non_exhaustive();
|
|
}
|
|
}
|
|
}
|
|
|
|
expr compile_pat_match(program const & p, unsigned i) {
|
|
flet<unsigned> set(m_prg_idx, i); // we only use m_prg_idx for producing error messages
|
|
buffer<expr> vars;
|
|
to_buffer(p.m_context, vars);
|
|
if (!is_proof_irrelevant()) {
|
|
// We have to include the global context because the proof relevant version
|
|
// uses the class-instance resolution, and must be able to "see" the complete
|
|
// context.
|
|
program new_p(p, append(to_list(m_global_context), p.m_context));
|
|
return Fun(vars, compile_core(new_p));
|
|
} else {
|
|
return Fun(vars, compile_core(p));
|
|
}
|
|
}
|
|
|
|
/** \brief Return true iff \c e is one of the functions being defined */
|
|
bool is_fn(expr const & e) const {
|
|
return is_local(e) && contains_local(e, m_fns);
|
|
}
|
|
|
|
/** \brief Return true iff the equations are recursive. */
|
|
bool is_recursive(buffer<program> const & prgs) const {
|
|
lean_assert(!prgs.empty());
|
|
for (program const & p : prgs) {
|
|
for (eqn const & e : p.m_eqns) {
|
|
if (find(e.m_rhs, [&](expr const & e, unsigned) { return is_fn(e); }))
|
|
return true;
|
|
}
|
|
}
|
|
if (prgs.size() > 1)
|
|
throw_error(sstream() << "mutual recursion is not needed when defining non-recursive functions");
|
|
return false;
|
|
}
|
|
|
|
/** \brief Return true if all locals are distinct local constants. */
|
|
static bool all_distinct_locals(unsigned num, expr const * locals) {
|
|
for (unsigned i = 0; i < num; i++) {
|
|
if (!is_local(locals[i]))
|
|
return false;
|
|
if (contains_local(locals[i], locals, locals + i))
|
|
return false;
|
|
}
|
|
return true;
|
|
}
|
|
|
|
/** \brief Return true iff \c t is an inductive datatype (I A j) which constains an associated brec_on definition,
|
|
and all indices of t are in ctx. */
|
|
bool is_rec_inductive(list<expr> const & ctx, expr const & t) const {
|
|
expr const & I = get_app_fn(t);
|
|
if (is_constant(I) && env().find(name{const_name(I), "brec_on"})) {
|
|
unsigned nindices = *inductive::get_num_indices(env(), const_name(I));
|
|
if (nindices > 0) {
|
|
buffer<expr> args;
|
|
get_app_args(t, args);
|
|
lean_assert(args.size() >= nindices);
|
|
return
|
|
all_distinct_locals(nindices, args.end() - nindices) &&
|
|
std::all_of(args.end() - nindices, args.end(),
|
|
[&](expr const & idx) { return contains_local(idx, ctx); });
|
|
} else {
|
|
return true;
|
|
}
|
|
} else {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
/** \brief Return true iff t1 and t2 are inductive datatypes of the same mutually inductive declaration. */
|
|
bool is_compatible_inductive(expr const & t1, expr const & t2) {
|
|
buffer<expr> args1, args2;
|
|
name const & I1 = const_name(get_app_args(t1, args1));
|
|
name const & I2 = const_name(get_app_args(t2, args2));
|
|
inductive::inductive_decls decls = *inductive::is_inductive_decl(env(), I1);
|
|
unsigned nparams = std::get<1>(decls);
|
|
for (auto decl : std::get<2>(decls)) {
|
|
if (inductive::inductive_decl_name(decl) == I2) {
|
|
// parameters must be definitionally equal
|
|
unsigned i = 0;
|
|
for (; i < nparams; i++) {
|
|
if (!is_def_eq(args1[i], args2[i]))
|
|
break;
|
|
}
|
|
if (i == nparams)
|
|
return true;
|
|
}
|
|
}
|
|
return false;
|
|
}
|
|
|
|
/** \brief Return true iff \c t1 and \c t2 are instances of the same inductive datatype */
|
|
static bool is_same_inductive(expr const & t1, expr const & t2) {
|
|
return const_name(get_app_fn(t1)) == const_name(get_app_fn(t2));
|
|
}
|
|
|
|
/** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */
|
|
bool is_le(expr const & s, expr const & t) {
|
|
return is_def_eq(s, t) || is_lt(s, t);
|
|
}
|
|
|
|
/** Return true iff \c s is structurally smaller than \c t */
|
|
bool is_lt(expr s, expr const & t) {
|
|
s = whnf(s);
|
|
if (is_app(s)) {
|
|
expr const & s_fn = get_app_fn(s);
|
|
if (!is_constructor(s_fn))
|
|
return is_lt(s_fn, t); // f < t ==> s := f a_1 ... a_n < t
|
|
}
|
|
buffer<expr> t_args;
|
|
expr const & t_fn = get_app_args(t, t_args);
|
|
if (!is_constructor(t_fn))
|
|
return false;
|
|
return std::any_of(t_args.begin(), t_args.end(), [&](expr const & t_arg) { return is_le(s, t_arg); });
|
|
}
|
|
|
|
/** \brief Auxiliary functional object for checking whether recursive application are structurally smaller or not */
|
|
struct check_rhs_fn {
|
|
equation_compiler_fn & m_main;
|
|
buffer<program> const & m_prgs;
|
|
buffer<unsigned> const & m_arg_pos;
|
|
|
|
check_rhs_fn(equation_compiler_fn & m, buffer<program> const & prgs, buffer<unsigned> const & arg_pos):
|
|
m_main(m), m_prgs(prgs), m_arg_pos(arg_pos) {}
|
|
|
|
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
|
|
bool check_rhs(expr const & e, expr const & arg) const {
|
|
switch (e.kind()) {
|
|
case expr_kind::Var: case expr_kind::Meta:
|
|
case expr_kind::Local: case expr_kind::Constant:
|
|
case expr_kind::Sort:
|
|
return true;
|
|
case expr_kind::Macro:
|
|
for (unsigned i = 0; i < macro_num_args(e); i++)
|
|
if (!check_rhs(macro_arg(e, i), arg))
|
|
return false;
|
|
return true;
|
|
case expr_kind::App: {
|
|
buffer<expr> args;
|
|
expr const & fn = get_app_args(e, args);
|
|
if (!check_rhs(fn, arg))
|
|
return false;
|
|
for (unsigned i = 0; i < args.size(); i++)
|
|
if (!check_rhs(args[i], arg))
|
|
return false;
|
|
if (is_local(fn)) {
|
|
for (unsigned j = 0; j < m_prgs.size(); j++) {
|
|
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
|
|
// it is a recusive application
|
|
unsigned pos_j = m_arg_pos[j];
|
|
if (pos_j < args.size()) {
|
|
expr const & arg_j = args[pos_j];
|
|
// arg_j must be structurally smaller than arg
|
|
if (!m_main.is_lt(arg_j, arg))
|
|
return false;
|
|
} else {
|
|
return false;
|
|
}
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
case expr_kind::Let:
|
|
// TODO(Leo): improve
|
|
return check_rhs(instantiate(let_body(e), let_value(e)), arg);
|
|
case expr_kind::Lambda:
|
|
case expr_kind::Pi:
|
|
if (!check_rhs(binding_domain(e), arg)) {
|
|
return false;
|
|
} else {
|
|
expr l = mk_local(mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
|
|
return check_rhs(instantiate(binding_body(e), l), arg);
|
|
}
|
|
}
|
|
lean_unreachable();
|
|
}
|
|
|
|
bool operator()(expr const & e, expr const & arg) const {
|
|
return check_rhs(e, arg);
|
|
}
|
|
};
|
|
|
|
/** \brief Return true iff the recursive equations in prgs are "admissible" with respect to
|
|
the following configuration of recursive arguments.
|
|
We say the equations are admissible when every recursive application of prgs[i]
|
|
is structurally smaller at arguments arg_pos[i].
|
|
*/
|
|
bool check_rec_args(buffer<program> const & prgs, buffer<unsigned> const & arg_pos) {
|
|
lean_assert(prgs.size() == arg_pos.size());
|
|
check_rhs_fn check_rhs(*this, prgs, arg_pos);
|
|
for (unsigned i = 0; i < prgs.size(); i++) {
|
|
program const & prg = prgs[i];
|
|
unsigned pos_i = arg_pos[i];
|
|
for (eqn const & e : prg.m_eqns) {
|
|
expr const & p_i = get_ith(e.m_patterns, pos_i);
|
|
if (!check_rhs(e.m_rhs, p_i))
|
|
return false;
|
|
}
|
|
}
|
|
return true;
|
|
}
|
|
|
|
bool find_rec_args(buffer<program> const & prgs, unsigned i, buffer<unsigned> & arg_pos, buffer<expr> & arg_types) {
|
|
if (i == prgs.size()) {
|
|
return check_rec_args(prgs, arg_pos);
|
|
} else {
|
|
program const & p = prgs[i];
|
|
unsigned j = 0;
|
|
for (optional<name> const & n : p.m_var_stack) {
|
|
lean_assert(n);
|
|
expr const & v = p.get_var(*n);
|
|
expr const & t = mlocal_type(v);
|
|
if (// argument must be an inductive datatype
|
|
is_rec_inductive(p.m_context, t) &&
|
|
// argument must be an inductive datatype different from the ones in arg_types
|
|
std::all_of(arg_types.begin(), arg_types.end(),
|
|
[&](expr const & prev_type) { return !is_same_inductive(t, prev_type); }) &&
|
|
// argument type must be in the same mutually recursive declaration of previous argument types
|
|
(arg_types.empty() || is_compatible_inductive(t, arg_types[0]))) {
|
|
// Found candidate
|
|
arg_pos.push_back(j);
|
|
arg_types.push_back(t);
|
|
if (find_rec_args(prgs, i+1, arg_pos, arg_types))
|
|
return true;
|
|
arg_pos.pop_back();
|
|
arg_types.pop_back();
|
|
}
|
|
j++;
|
|
}
|
|
return false;
|
|
}
|
|
}
|
|
|
|
bool find_rec_args(buffer<program> const & prgs, buffer<unsigned> & arg_pos) {
|
|
buffer<expr> arg_types;
|
|
return find_rec_args(prgs, 0, arg_pos, arg_types);
|
|
}
|
|
|
|
// Auxiliary function object used to eliminate recursive applications using "below" applications
|
|
struct elim_rec_apps_fn {
|
|
equation_compiler_fn & m_main;
|
|
buffer<program> const & m_prgs;
|
|
unsigned m_nparams;
|
|
buffer<expr> const & m_below_cnsts; // below constants
|
|
buffer<expr> const & m_Cs_locals; // auxiliary local constants representing the "motives"
|
|
buffer<unsigned> const & m_rec_arg_pos; // position of recursive argument for each program
|
|
buffer<buffer<unsigned>> const & m_rest_pos; // position of remaining arguments for each program
|
|
|
|
elim_rec_apps_fn(equation_compiler_fn & m, buffer<program> const & prgs, unsigned nparams,
|
|
buffer<expr> const & below_cnsts, buffer<expr> const & Cs_locals,
|
|
buffer<unsigned> const & rec_arg_pos, buffer<buffer<unsigned>> const & rest_pos):
|
|
m_main(m), m_prgs(prgs), m_nparams(nparams), m_below_cnsts(below_cnsts), m_Cs_locals(Cs_locals),
|
|
m_rec_arg_pos(rec_arg_pos), m_rest_pos(rest_pos) {}
|
|
|
|
bool is_below_type(expr const & t) const {
|
|
expr const & fn = get_app_fn(t);
|
|
return is_constant(fn) && std::find(m_below_cnsts.begin(), m_below_cnsts.end(), fn) != m_below_cnsts.end();
|
|
}
|
|
|
|
/** \brief Return the number of arguments in the left-hand-side of program prg_idx */
|
|
unsigned get_lhs_size(unsigned prg_idx) const { return length(m_prgs[prg_idx].m_context); }
|
|
|
|
expr whnf(expr const & e) {
|
|
return m_main.m_tc.whnf(e).first;
|
|
}
|
|
|
|
bool is_def_eq(expr const & a, expr const & b) {
|
|
return m_main.m_tc.is_def_eq(a, b).first;
|
|
}
|
|
|
|
/** \brief Retrieve \c a from the below dictionary \c d. \c d is a term made of products, and C's from (m_Cs_locals).
|
|
\c b is the below constant that was used to create the below dictionary \c d.
|
|
*/
|
|
optional<expr> to_below(expr const & d, expr const & a, expr const & b) {
|
|
expr const & fn = get_app_fn(d);
|
|
if (is_constant(fn) && const_name(fn) == get_prod_name()) {
|
|
expr d_arg1 = whnf(app_arg(app_fn(d)));
|
|
expr d_arg2 = whnf(app_arg(d));
|
|
if (auto r = to_below(d_arg1, a, mk_pr1(m_main.m_tc, b)))
|
|
return r;
|
|
else if (auto r = to_below(d_arg2, a, mk_pr2(m_main.m_tc, b)))
|
|
return r;
|
|
else
|
|
return none_expr();
|
|
} else if (is_constant(fn) && const_name(fn) == get_and_name()) {
|
|
// For ibelow, we use "and" instead of products
|
|
expr d_arg1 = whnf(app_arg(app_fn(d)));
|
|
expr d_arg2 = whnf(app_arg(d));
|
|
if (auto r = to_below(d_arg1, a, mk_and_elim_left(m_main.m_tc, b)))
|
|
return r;
|
|
else if (auto r = to_below(d_arg2, a, mk_and_elim_right(m_main.m_tc, b)))
|
|
return r;
|
|
else
|
|
return none_expr();
|
|
} else if (is_local(fn)) {
|
|
for (expr const & C : m_Cs_locals) {
|
|
if (mlocal_name(C) == mlocal_name(fn) && is_def_eq(app_arg(d), a))
|
|
return some_expr(b);
|
|
}
|
|
return none_expr();
|
|
} else if (is_pi(d)) {
|
|
if (is_app(a)) {
|
|
expr new_d = whnf(instantiate(binding_body(d), app_arg(a)));
|
|
return to_below(new_d, a, mk_app(b, app_arg(a)));
|
|
} else {
|
|
return none_expr();
|
|
}
|
|
} else {
|
|
return none_expr();
|
|
}
|
|
}
|
|
|
|
expr elim(unsigned prg_idx, buffer<expr> const & args, expr const & below, tag g) {
|
|
// Replace motives with abstract ones. We use the abstract motives (m_Cs_locals) as "markers"
|
|
buffer<expr> below_args;
|
|
expr const & below_cnst = get_app_args(mlocal_type(below), below_args);
|
|
buffer<expr> abst_below_args;
|
|
abst_below_args.append(m_nparams, below_args.data());
|
|
abst_below_args.append(m_Cs_locals);
|
|
for (unsigned i = m_nparams + m_Cs_locals.size(); i < below_args.size(); i++)
|
|
abst_below_args.push_back(below_args[i]);
|
|
expr abst_below = mk_app(below_cnst, abst_below_args);
|
|
expr below_dict = whnf(abst_below);
|
|
expr rec_arg = whnf(args[m_rec_arg_pos[prg_idx]]);
|
|
unsigned lhs_size = get_lhs_size(prg_idx);
|
|
if (optional<expr> b = to_below(below_dict, rec_arg, below)) {
|
|
expr r = *b;
|
|
for (unsigned rest_pos : m_rest_pos[prg_idx]) {
|
|
if (rest_pos < args.size())
|
|
r = mk_app(r, args[rest_pos], g);
|
|
}
|
|
for (unsigned i = lhs_size; i < args.size(); i++) {
|
|
r = mk_app(r, args[i], g);
|
|
}
|
|
return r;
|
|
} else {
|
|
m_main.throw_error(sstream() << "failed to compile equations using "
|
|
<< "brec_on approach (possible solution: use well-founded recursion)");
|
|
}
|
|
}
|
|
|
|
/** \brief Return true iff all recursive applications in \c e are structurally smaller than \c arg. */
|
|
expr elim(expr const & e, optional<expr> const & b) {
|
|
switch (e.kind()) {
|
|
case expr_kind::Var: case expr_kind::Meta:
|
|
case expr_kind::Local: case expr_kind::Constant:
|
|
case expr_kind::Sort:
|
|
return e;
|
|
case expr_kind::Macro: {
|
|
buffer<expr> new_args;
|
|
for (unsigned i = 0; i < macro_num_args(e); i++)
|
|
new_args.push_back(elim(macro_arg(e, i), b));
|
|
return update_macro(e, new_args.size(), new_args.data());
|
|
}
|
|
case expr_kind::App: {
|
|
buffer<expr> args;
|
|
expr const & fn = get_app_args(e, args);
|
|
expr new_fn = elim(fn, b);
|
|
buffer<expr> new_args;
|
|
for (expr const & arg : args)
|
|
new_args.push_back(elim(arg, b));
|
|
if (is_local(fn) && b) {
|
|
for (unsigned j = 0; j < m_prgs.size(); j++) {
|
|
if (mlocal_name(fn) == mlocal_name(m_prgs[j].m_fn)) {
|
|
return elim(j, new_args, *b, e.get_tag());
|
|
}
|
|
}
|
|
}
|
|
return mk_app(new_fn, new_args, e.get_tag());
|
|
}
|
|
case expr_kind::Lambda: {
|
|
expr local = mk_local(mk_fresh_name(), binding_name(e), binding_domain(e), binding_info(e));
|
|
expr body = instantiate(binding_body(e), local);
|
|
expr new_body;
|
|
if (is_below_type(binding_domain(e)))
|
|
new_body = elim(body, some_expr(local));
|
|
else
|
|
new_body = elim(body, b);
|
|
return copy_tag(e, Fun(local, new_body));
|
|
}
|
|
case expr_kind::Pi: {
|
|
expr new_domain = elim(binding_domain(e), b);
|
|
expr local = mk_local(mk_fresh_name(), binding_name(e), new_domain, binding_info(e));
|
|
expr new_body = elim(instantiate(binding_body(e), local), b);
|
|
return copy_tag(e, Pi(local, new_body));
|
|
}
|
|
case expr_kind::Let: {
|
|
// TODO(Leo): improve
|
|
return elim(instantiate(let_body(e), let_value(e)), b);
|
|
}}
|
|
lean_unreachable();
|
|
}
|
|
|
|
expr operator()(expr const & e) {
|
|
return elim(e, none_expr());
|
|
}
|
|
};
|
|
|
|
// Fix the i-th argument in the Pi-type t
|
|
expr fix_fn_type(expr const & t, unsigned i, expr const & p) {
|
|
if (!is_pi(t)) {
|
|
throw_error(sstream() << "invalid equation, failed to move parameter '" << p << "'");
|
|
} else if (i == 0) {
|
|
return instantiate(binding_body(t), p);
|
|
} else {
|
|
expr local = mk_local(mk_fresh_name(), binding_name(t), binding_domain(t), binding_info(t));
|
|
expr body = fix_fn_type(instantiate(binding_body(t), local), i-1, p);
|
|
return Pi(local, body);
|
|
}
|
|
}
|
|
|
|
// For each function application (fn ...) in e, replace it with (new_fn ...) and remove the i-th
|
|
// argument.
|
|
expr fix_rec_arg(expr const & fn, expr const & new_fn, unsigned i, expr const & e) {
|
|
return ::lean::replace(e, [&](expr const & e) {
|
|
if (is_app(e) && get_app_fn(e) == fn) {
|
|
buffer<expr> args;
|
|
get_app_args(e, args);
|
|
if (i < args.size())
|
|
args.erase(i);
|
|
return some_expr(mk_app(new_fn, args));
|
|
} else {
|
|
return none_expr();
|
|
}
|
|
});
|
|
}
|
|
|
|
// Move inductive datatype parameters occuring in prg to m_additional_context
|
|
pair<program, unsigned> move_params(program const & prg, unsigned arg_pos) {
|
|
expr const & a_type = mlocal_type(get_ith(prg.m_context, arg_pos));
|
|
buffer<expr> a_type_args;
|
|
expr const & I = get_app_args(a_type, a_type_args);
|
|
unsigned nparams = *inductive::get_num_params(env(), const_name(I));
|
|
buffer<expr> params;
|
|
params.append(nparams, a_type_args.data());
|
|
if (std::all_of(params.begin(), params.end(),
|
|
[&](expr const & p) { return !is_local(p) || contains_local(p, m_global_context); })) {
|
|
return mk_pair(prg, arg_pos);
|
|
} else {
|
|
list<expr> new_context = prg.m_context;
|
|
buffer<optional<name>> new_var_stack;
|
|
buffer<eqn> new_eqns;
|
|
to_buffer(prg.m_var_stack, new_var_stack);
|
|
to_buffer(prg.m_eqns, new_eqns);
|
|
expr new_fn = prg.m_fn;
|
|
for (expr const & param : params) {
|
|
if (contains_local(param, m_global_context))
|
|
continue; // parameter doesn't need to be moved
|
|
m_additional_context.push_back(param);
|
|
new_context = remove(new_context, param);
|
|
unsigned i = 0;
|
|
for (; i < new_var_stack.size(); i++) {
|
|
if (*new_var_stack[i] == mlocal_name(param))
|
|
break;
|
|
}
|
|
lean_assert(i < new_var_stack.size());
|
|
lean_assert(i != arg_pos);
|
|
expr new_fn_type = fix_fn_type(mlocal_type(new_fn), i, param);
|
|
expr new_new_fn = update_mlocal(new_fn, new_fn_type);
|
|
|
|
if (i < arg_pos)
|
|
arg_pos--;
|
|
new_var_stack.erase(i);
|
|
for (eqn & e : new_eqns) {
|
|
expr const & p = get_ith(e.m_patterns, i);
|
|
if (!is_local(p)) {
|
|
throw_error(sstream() << "invalid equations, "
|
|
<< "trying to pattern match inductive datatype parameter '" << p << "'");
|
|
} else {
|
|
list<expr> new_local_ctx = remove(e.m_local_context, p);
|
|
list<expr> new_patterns = remove_ith(e.m_patterns, i);
|
|
expr new_rhs = fix_rec_arg(new_fn, new_new_fn, i, e.m_rhs);
|
|
e = replace(eqn(new_local_ctx, new_patterns, new_rhs), p, param);
|
|
}
|
|
}
|
|
new_fn = new_new_fn;
|
|
}
|
|
return mk_pair(program(new_fn, new_context, to_list(new_var_stack), to_list(new_eqns), prg.m_type), arg_pos);
|
|
}
|
|
}
|
|
|
|
// Move inductive datatype parameters occuring in prg to m_additional_context
|
|
void move_params(buffer<program> & prgs, buffer<unsigned> & arg_pos) {
|
|
if (prgs.size() != 1) {
|
|
// The parameters already occur in m_global_context when there is more than one program.
|
|
return;
|
|
}
|
|
auto p = move_params(prgs[0], arg_pos[0]);
|
|
prgs[0] = p.first;
|
|
arg_pos[0] = p.second;
|
|
lean_assert(check_program(prgs[0]));
|
|
}
|
|
|
|
expr compile_brec_on_core(buffer<program> const & prgs, buffer<unsigned> const & arg_pos) {
|
|
// Return the recursive argument of the i-th program
|
|
auto get_rec_arg = [&](unsigned i) -> expr {
|
|
program const & pi = prgs[i];
|
|
return get_ith(pi.m_context, arg_pos[i]);
|
|
};
|
|
|
|
// Return the type of the recursive argument of the i-th program
|
|
auto get_rec_type = [&](unsigned i) -> expr {
|
|
return mlocal_type(get_rec_arg(i));
|
|
};
|
|
|
|
// Return the program associated with the inductive datatype named I_name.
|
|
// Return none if there isn't one.
|
|
auto get_prg_for = [&](name const & I_name) -> optional<unsigned> {
|
|
for (unsigned i = 0; i < prgs.size(); i++) {
|
|
expr const & t = get_rec_type(i);
|
|
if (const_name(get_app_fn(t)) == I_name)
|
|
return optional<unsigned>(i);
|
|
}
|
|
return optional<unsigned>();
|
|
};
|
|
|
|
expr const & a0_type = get_rec_type(0);
|
|
lean_assert(is_rec_inductive(prgs[0].m_context, a0_type));
|
|
buffer<expr> a0_type_args;
|
|
expr const & I0 = get_app_args(a0_type, a0_type_args);
|
|
inductive::inductive_decls t = *inductive::is_inductive_decl(env(), const_name(I0));
|
|
unsigned nparams = std::get<1>(t);
|
|
list<inductive::inductive_decl> decls = std::get<2>(t);
|
|
buffer<expr> params;
|
|
params.append(nparams, a0_type_args.data());
|
|
|
|
// Return true if the local constant l is in the buffer b.
|
|
// This is similar to contains_local, but b may contain arbitrary terms
|
|
auto contains_local_at = [&](expr const & l, buffer<expr> const & b) {
|
|
lean_assert(is_mlocal(l));
|
|
for (expr const & e : b) {
|
|
if (is_local(e) && mlocal_name(l) == mlocal_name(e))
|
|
return true;
|
|
}
|
|
return false;
|
|
};
|
|
|
|
// Distribute parameters of the ith program intro three groups:
|
|
// indices, major premise (arg), and remaining arguments (rest)
|
|
// We store the position of the rest arguments in the buffer rest_pos.
|
|
// The buffer rest_pos is used to replace the recursive applications with below applications.
|
|
auto distribute_context_core = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest,
|
|
buffer<unsigned> & indices_pos, buffer<unsigned> & rest_pos) {
|
|
program const & p = prgs[i];
|
|
arg = get_rec_arg(i);
|
|
list<expr> const & ctx = p.m_context;
|
|
buffer<expr> arg_args;
|
|
get_app_args(mlocal_type(arg), arg_args);
|
|
lean_assert(nparams <= arg_args.size());
|
|
indices.append(arg_args.size() - nparams, arg_args.data() + nparams);
|
|
unsigned j = 0;
|
|
for (expr const & l : ctx) {
|
|
if (mlocal_name(l) == mlocal_name(arg) || contains_local_at(l, params)) {
|
|
// do nothing
|
|
} else if (contains_local_at(l, indices)) {
|
|
indices_pos.push_back(j);
|
|
} else {
|
|
rest.push_back(l);
|
|
rest_pos.push_back(j);
|
|
}
|
|
j++;
|
|
}
|
|
};
|
|
|
|
auto distribute_context = [&](unsigned i, buffer<expr> & indices, expr & arg, buffer<expr> & rest) {
|
|
buffer<unsigned> indices_pos, rest_pos;
|
|
distribute_context_core(i, indices, arg, rest, indices_pos, rest_pos);
|
|
};
|
|
|
|
// Compute the resulting universe level for brec_on
|
|
auto get_brec_on_result_level = [&]() -> level {
|
|
buffer<expr> indices, rest; expr arg;
|
|
distribute_context(0, indices, arg, rest);
|
|
expr r_type = Pi(rest, prgs[0].m_type);
|
|
return sort_level(m_tc.ensure_type(r_type).first);
|
|
};
|
|
|
|
level rlvl = get_brec_on_result_level();
|
|
bool reflexive = env().prop_proof_irrel() && is_reflexive_datatype(m_tc.get_type_context(), const_name(I0));
|
|
bool use_ibelow = reflexive && is_zero(rlvl);
|
|
if (reflexive) {
|
|
if (!is_zero(rlvl) && !is_not_zero(rlvl))
|
|
throw_error(sstream() << "invalid equations, "
|
|
<< "when trying to recurse over reflexive inductive datatype, "
|
|
<< "the universe level of the resultant universe must be zero OR "
|
|
<< "not zero for every level assignment");
|
|
if (!is_zero(rlvl)) {
|
|
// For reflexive type, the type of brec_on and ibelow perform a +1 on the motive universe.
|
|
// Example: for a reflexive formula type, we have:
|
|
// formula.below.{l_1} : Π {C : formula → Type.{l_1+1}}, formula → Type.{max (l_1+1) 1}
|
|
if (auto dlvl = dec_level(rlvl)) {
|
|
rlvl = *dlvl;
|
|
} else {
|
|
throw_error(sstream() << "invalid equations, "
|
|
<< "when trying to recurse over reflexive inductive datatype, "
|
|
<< "the universe level of the resultant universe must be zero OR "
|
|
<< "not zero for every level assignment, "
|
|
<< "the compiler managed to establish that the resultant "
|
|
<< "universe level L is never zero, but fail to comput L-1");
|
|
}
|
|
}
|
|
}
|
|
levels brec_on_lvls;
|
|
expr brec_on;
|
|
if (use_ibelow) {
|
|
brec_on_lvls = const_levels(I0);
|
|
brec_on = mk_constant(name{const_name(I0), "binduction_on"}, brec_on_lvls);
|
|
} else {
|
|
brec_on_lvls = cons(rlvl, const_levels(I0));
|
|
brec_on = mk_constant(name{const_name(I0), "brec_on"}, brec_on_lvls);
|
|
}
|
|
// add parameters
|
|
brec_on = mk_app(brec_on, params);
|
|
|
|
buffer<expr> Cs; // brec_on "motives"
|
|
// The following loop fills Cs_locals with auxiliary local constants that will be used to
|
|
// convert recursive applications into "below applications".
|
|
// These constants are essentially abstracting Cs.
|
|
buffer<expr> Cs_locals;
|
|
buffer<buffer<expr>> C_args_buffer;
|
|
for (inductive::inductive_decl const & decl : decls) {
|
|
name const & I_name = inductive::inductive_decl_name(decl);
|
|
expr C;
|
|
C_args_buffer.push_back(buffer<expr>());
|
|
buffer<expr> & C_args = C_args_buffer.back();
|
|
expr C_type = whnf(infer_type(brec_on));
|
|
expr C_local = mk_local(mk_fresh_name(), "C", C_type, binder_info());
|
|
Cs_locals.push_back(C_local);
|
|
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
|
buffer<expr> indices, rest; expr arg;
|
|
distribute_context(*p_idx, indices, arg, rest);
|
|
expr type = Pi(rest, prgs[*p_idx].m_type);
|
|
C_args.append(indices);
|
|
C_args.push_back(arg);
|
|
C = Fun(C_args, type);
|
|
} else {
|
|
expr d = binding_domain(C_type);
|
|
expr unit = mk_constant(get_poly_unit_name(), to_list(rlvl));
|
|
to_telescope_ext(d, C_args);
|
|
C = Fun(C_args, unit);
|
|
}
|
|
brec_on = mk_app(brec_on, C);
|
|
Cs.push_back(C);
|
|
}
|
|
|
|
// add indices and major
|
|
buffer<expr> indices0, rest0; expr arg0;
|
|
distribute_context(0, indices0, arg0, rest0);
|
|
brec_on = mk_app(mk_app(brec_on, indices0), arg0);
|
|
|
|
// add functionals
|
|
unsigned i = 0;
|
|
buffer<expr> below_cnsts;
|
|
buffer<buffer<unsigned>> rest_arg_pos;
|
|
for (inductive::inductive_decl const & decl : decls) {
|
|
name const & I_name = inductive::inductive_decl_name(decl);
|
|
expr below_cnst;
|
|
if (use_ibelow)
|
|
below_cnst = mk_constant(name{I_name, "ibelow"}, brec_on_lvls);
|
|
else
|
|
below_cnst = mk_constant(name{I_name, "below"}, brec_on_lvls);
|
|
below_cnsts.push_back(below_cnst);
|
|
expr below = mk_app(mk_app(below_cnst, params), Cs);
|
|
expr F;
|
|
buffer<expr> & C_args = C_args_buffer[i];
|
|
rest_arg_pos.push_back(buffer<unsigned>());
|
|
if (optional<unsigned> p_idx = get_prg_for(I_name)) {
|
|
program const & prg_i = prgs[*p_idx];
|
|
buffer<expr> indices, rest; expr arg; buffer<unsigned> indices_pos;
|
|
buffer<unsigned> & rest_pos = rest_arg_pos.back();
|
|
distribute_context_core(*p_idx, indices, arg, rest, indices_pos, rest_pos);
|
|
below = mk_app(mk_app(below, indices), arg);
|
|
expr b = mk_local(mk_fresh_name(), "b", below, binder_info());
|
|
buffer<expr> new_ctx;
|
|
new_ctx.append(indices);
|
|
new_ctx.push_back(arg);
|
|
new_ctx.push_back(b);
|
|
new_ctx.append(rest);
|
|
F = compile_pat_match(program(prg_i, to_list(new_ctx)), *p_idx);
|
|
} else {
|
|
expr star = mk_constant(get_poly_unit_star_name(), to_list(rlvl));
|
|
buffer<expr> F_args;
|
|
F_args.append(C_args);
|
|
below = mk_app(below, F_args);
|
|
F_args.push_back(mk_local(mk_fresh_name(), "b", below, binder_info()));
|
|
F = Fun(F_args, star);
|
|
}
|
|
brec_on = mk_app(brec_on, F);
|
|
i++;
|
|
}
|
|
expr r = elim_rec_apps_fn(*this, prgs, nparams, below_cnsts, Cs_locals, arg_pos, rest_arg_pos)(brec_on);
|
|
// add remaining arguments
|
|
r = mk_app(r, rest0);
|
|
|
|
buffer<expr> ctx0_buffer;
|
|
to_buffer(prgs[0].m_context, ctx0_buffer);
|
|
r = Fun(m_additional_context, Fun(ctx0_buffer, r));
|
|
return r;
|
|
}
|
|
|
|
expr compile_brec_on(buffer<program> & prgs) {
|
|
lean_assert(!prgs.empty());
|
|
buffer<unsigned> arg_pos;
|
|
if (!find_rec_args(prgs, arg_pos)) {
|
|
throw_error(sstream() << "invalid equations, "
|
|
<< "failed to find recursive arguments that are structurally smaller "
|
|
<< "(possible solution: use well-founded recursion)");
|
|
}
|
|
move_params(prgs, arg_pos);
|
|
buffer<expr> rs;
|
|
for (unsigned i = 0; i < prgs.size(); i++) {
|
|
if (i > 0)
|
|
rs.push_back(mlocal_type(prgs[i].m_fn));
|
|
// Remark: this loop is very hackish.
|
|
// We are "compiling" the code prgs.size() times!
|
|
// This is wasteful. We should rewrite this.
|
|
std::swap(prgs[0], prgs[i]);
|
|
std::swap(arg_pos[0], arg_pos[i]);
|
|
rs.push_back(compile_brec_on_core(prgs, arg_pos));
|
|
std::swap(prgs[0], prgs[i]);
|
|
std::swap(arg_pos[0], arg_pos[i]);
|
|
}
|
|
|
|
if (rs.size() > 1)
|
|
return mk_equations_result(rs.size(), rs.data());
|
|
else
|
|
return rs[0];
|
|
}
|
|
|
|
expr compile_wf(buffer<program> & /* prgs */) {
|
|
// TODO(Leo)
|
|
return expr();
|
|
}
|
|
|
|
|
|
public:
|
|
equation_compiler_fn(old_type_checker & tc, io_state const & ios, expr const & meta, expr const & meta_type):
|
|
m_tc(tc), m_ios(ios), m_meta(meta), m_meta_type(meta_type) {
|
|
get_app_args(m_meta, m_global_context);
|
|
}
|
|
|
|
expr operator()(expr eqns) {
|
|
check_limitations(eqns);
|
|
buffer<program> prgs;
|
|
initialize(eqns, prgs);
|
|
m_init_prgs.append(prgs);
|
|
if (is_recursive(prgs)) {
|
|
if (is_wf_equations(eqns)) {
|
|
return compile_wf(prgs);
|
|
} else {
|
|
return compile_brec_on(prgs);
|
|
}
|
|
} else {
|
|
lean_assert(prgs.size() == 1);
|
|
return compile_pat_match(prgs[0], 0);
|
|
}
|
|
}
|
|
};
|
|
|
|
expr compile_equations(old_type_checker & tc, io_state const & ios, expr const & eqns,
|
|
expr const & meta, expr const & meta_type) {
|
|
return equation_compiler_fn(tc, ios, meta, meta_type)(eqns);
|
|
}
|
|
}
|