1337 lines
54 KiB
C++
1337 lines
54 KiB
C++
/*
|
||
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
|
||
Author: Leonardo de Moura
|
||
*/
|
||
#include <string>
|
||
#include "runtime/flet.h"
|
||
#include "util/sexpr/option_declarations.h"
|
||
#include "kernel/instantiate.h"
|
||
#include "kernel/for_each_fn.h"
|
||
#include "kernel/replace_fn.h"
|
||
#include "kernel/abstract.h"
|
||
#include "library/placeholder.h"
|
||
#include "library/max_sharing.h"
|
||
#include "library/trace.h"
|
||
#include "library/num.h"
|
||
#include "library/constants.h"
|
||
#include "library/idx_metavar.h"
|
||
#include "library/string.h"
|
||
#include "library/pp_options.h"
|
||
#include "library/exception.h"
|
||
#include "library/util.h"
|
||
#include "library/locals.h"
|
||
#include "library/annotation.h"
|
||
#include "library/private.h"
|
||
#include "library/aux_definition.h"
|
||
#include "library/app_builder.h"
|
||
#include "library/sorry.h"
|
||
#include "library/tactic/tactic_state.h"
|
||
#include "library/tactic/revert_tactic.h"
|
||
#include "library/tactic/clear_tactic.h"
|
||
#include "library/tactic/cases_tactic.h"
|
||
#include "library/tactic/intro_tactic.h"
|
||
#include "library/equations_compiler/equations.h"
|
||
#include "library/equations_compiler/util.h"
|
||
#include "library/equations_compiler/elim_match.h"
|
||
#include "frontends/lean/elaborator.h"
|
||
|
||
#ifndef LEAN_DEFAULT_EQN_COMPILER_ITE
|
||
#define LEAN_DEFAULT_EQN_COMPILER_ITE true
|
||
#endif
|
||
|
||
#ifndef LEAN_DEFAULT_EQN_COMPILER_MAX_STEPS
|
||
#define LEAN_DEFAULT_EQN_COMPILER_MAX_STEPS 2048
|
||
#endif
|
||
|
||
namespace lean {
|
||
static name * g_eqn_compiler_ite = nullptr;
|
||
static name * g_eqn_compiler_max_steps = nullptr;
|
||
|
||
static bool get_eqn_compiler_ite(options const & o) {
|
||
return o.get_bool(*g_eqn_compiler_ite, LEAN_DEFAULT_EQN_COMPILER_ITE);
|
||
}
|
||
|
||
static unsigned get_eqn_compiler_max_steps(options const & o) {
|
||
return o.get_unsigned(*g_eqn_compiler_max_steps, LEAN_DEFAULT_EQN_COMPILER_MAX_STEPS);
|
||
}
|
||
|
||
#define trace_match(Code) lean_trace(name({"eqn_compiler", "elim_match"}), Code)
|
||
#define trace_match_debug(Code) lean_trace(name({"debug", "eqn_compiler", "elim_match"}), Code)
|
||
|
||
struct elim_match_fn {
|
||
environment m_env;
|
||
elaborator & m_elab;
|
||
metavar_context m_mctx;
|
||
|
||
expr m_ref;
|
||
unsigned m_depth{0};
|
||
buffer<bool> m_used_eqns;
|
||
bool m_aux_lemmas;
|
||
unsigned m_num_steps{0};
|
||
|
||
bool m_error_during_process = false;
|
||
|
||
/* configuration options */
|
||
bool m_use_ite;
|
||
unsigned m_max_steps;
|
||
|
||
/* m_enum is a mapping from inductive type name to flag indicating whether it is
|
||
an enumeration type or not. */
|
||
name_map<bool> m_enum;
|
||
|
||
elim_match_fn(environment const & env, elaborator & elab,
|
||
metavar_context const & mctx):
|
||
m_env(env), m_elab(elab), m_mctx(mctx) {
|
||
m_use_ite = get_eqn_compiler_ite(elab.get_options());
|
||
m_max_steps = get_eqn_compiler_max_steps(elab.get_options());
|
||
}
|
||
|
||
struct equation {
|
||
local_context m_lctx;
|
||
list<expr> m_patterns;
|
||
expr m_rhs;
|
||
/* m_renames map variables in this->m_lctx to problem local context */
|
||
hsubstitution m_subst;
|
||
expr m_ref; /* for producing error messages */
|
||
unsigned m_eqn_idx; /* for producing error message */
|
||
/* The following fields are only used for lemma generation */
|
||
list<expr> m_hs;
|
||
list<expr> m_vars;
|
||
list<expr> m_lhs_args;
|
||
};
|
||
|
||
struct problem {
|
||
name m_fn_name;
|
||
expr m_goal;
|
||
list<expr> m_var_stack;
|
||
list<equation> m_equations;
|
||
list<expr> m_example;
|
||
};
|
||
|
||
buffer<problem> m_unsolved;
|
||
|
||
struct lemma {
|
||
local_context m_lctx;
|
||
list<expr> m_vars;
|
||
list<expr> m_hs;
|
||
list<expr> m_lhs_args;
|
||
expr m_rhs;
|
||
unsigned m_eqn_idx;
|
||
};
|
||
|
||
[[ noreturn ]] void throw_error(char const * msg) {
|
||
throw generic_exception(m_ref, msg);
|
||
}
|
||
|
||
[[ noreturn ]] void throw_error(sstream const & strm) {
|
||
throw generic_exception(m_ref, strm);
|
||
}
|
||
|
||
local_context get_local_context(expr const & mvar) {
|
||
return m_mctx.get_metavar_decl(mvar).get_context();
|
||
}
|
||
|
||
local_context get_local_context(problem const & P) {
|
||
return get_local_context(P.m_goal);
|
||
}
|
||
|
||
/* (for debugging, i.e., writing assertions) make sure all locals
|
||
in `e` are defined in `lctx` */
|
||
bool check_locals_decl_at(expr const & e, local_context const & lctx) {
|
||
for_each(e, [&](expr const & e, unsigned) {
|
||
if (is_local(e) && !lctx.find_local_decl(e)) {
|
||
lean_unreachable();
|
||
}
|
||
return true;
|
||
});
|
||
return true;
|
||
}
|
||
|
||
/* For debugging */
|
||
bool check_equation(problem const & P, equation const & eqn) {
|
||
lean_assert(length(eqn.m_patterns) == length(P.m_var_stack));
|
||
local_context const & lctx = get_local_context(P);
|
||
eqn.m_subst.for_each([&](name const & n, expr const & e) {
|
||
if (!eqn.m_lctx.find_local_decl(n)) {
|
||
lean_unreachable();
|
||
}
|
||
if (!check_locals_decl_at(e, lctx)) {
|
||
lean_unreachable();
|
||
}
|
||
});
|
||
lean_assert(check_locals_decl_at(eqn.m_rhs, eqn.m_lctx));
|
||
for (expr const & p : eqn.m_patterns) {
|
||
if (!check_locals_decl_at(p, eqn.m_lctx)) {
|
||
lean_unreachable();
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
/* For debugging */
|
||
bool check_problem(problem const & P) {
|
||
local_context const & lctx = get_local_context(P);
|
||
for (expr const & x : P.m_var_stack) {
|
||
if (!check_locals_decl_at(x, lctx)) {
|
||
lean_unreachable();
|
||
}
|
||
}
|
||
for (equation const & eqn : P.m_equations) {
|
||
if (!check_equation(P, eqn)) {
|
||
lean_unreachable();
|
||
}
|
||
}
|
||
auto tc = mk_type_context(P);
|
||
return true;
|
||
}
|
||
|
||
type_context_old mk_type_context(local_context const & lctx) {
|
||
return type_context_old(m_env, m_mctx, lctx, m_elab.get_cache(), transparency_mode::Semireducible);
|
||
}
|
||
|
||
type_context_old mk_type_context(expr const & mvar) {
|
||
return mk_type_context(get_local_context(mvar));
|
||
}
|
||
|
||
type_context_old mk_type_context(problem const & P) {
|
||
return mk_type_context(get_local_context(P));
|
||
}
|
||
|
||
options const & get_options() const { return m_elab.get_options(); }
|
||
|
||
std::function<format(expr const &)> mk_pp_ctx(local_context const & lctx) {
|
||
options opts = get_options();
|
||
opts = opts.update(get_pp_beta_name(), false);
|
||
type_context_old ctx = mk_type_context(lctx);
|
||
return ::lean::mk_pp_ctx(ctx);
|
||
}
|
||
|
||
std::function<format(expr const &)> mk_pp_ctx(problem const & P) {
|
||
return mk_pp_ctx(get_local_context(P));
|
||
}
|
||
|
||
format nest(format const & fmt) const {
|
||
return ::lean::nest(get_pp_indent(get_options()), fmt);
|
||
}
|
||
|
||
format pp_equation(equation const & eqn) {
|
||
format r;
|
||
auto pp = mk_pp_ctx(eqn.m_lctx);
|
||
bool first = true;
|
||
for (expr const & p : eqn.m_patterns) {
|
||
if (first) first = false; else r += format(" ");
|
||
r += paren(pp(p));
|
||
}
|
||
r += space() + format(":=") + nest(line() + pp(eqn.m_rhs));
|
||
return group(r);
|
||
}
|
||
|
||
format pp_problem(problem const & P) {
|
||
format r;
|
||
auto pp = mk_pp_ctx(P);
|
||
type_context_old ctx = mk_type_context(P);
|
||
r += format("match") + space() + format(P.m_fn_name) + space() + format(":") + space() + pp(ctx.infer(P.m_goal));
|
||
format v;
|
||
bool first = true;
|
||
for (expr const & x : P.m_var_stack) {
|
||
if (first) first = false; else v += comma() + space();
|
||
v += pp(x);
|
||
}
|
||
r += bracket("[", v, "]");
|
||
for (equation const & eqn : P.m_equations) {
|
||
r += nest(line() + pp_equation(eqn));
|
||
}
|
||
|
||
auto example = format("example:");
|
||
for (auto & ex : P.m_example) {
|
||
example += space() + paren(pp(ex));
|
||
}
|
||
r += line() + nest(example);
|
||
|
||
return r;
|
||
}
|
||
|
||
optional<name> is_constructor(name const & n) const {
|
||
constant_info info = m_env.get(n);
|
||
if (info.is_constructor())
|
||
return optional<name>(info.to_constructor_val().get_induct());
|
||
else
|
||
return optional<name>();
|
||
}
|
||
optional<name> is_constructor(expr const & e) const {
|
||
if (!is_constant(e)) return optional<name>();
|
||
return is_constructor(const_name(e));
|
||
}
|
||
optional<name> is_constructor_app(type_context_old & ctx, expr const & e) const {
|
||
if (auto ind_type = is_constructor(get_app_fn(e))) {
|
||
// Check that e is not a partially applied constructor.
|
||
auto e_type = ctx.relaxed_whnf(ctx.infer(e));
|
||
if (is_app_of(e_type, *ind_type)){
|
||
return ind_type;
|
||
}
|
||
}
|
||
return optional<name>();
|
||
}
|
||
|
||
bool is_inductive(name const & n) const { return ::lean::is_inductive(m_env, n); }
|
||
bool is_inductive(expr const & e) const { return is_constant(e) && is_inductive(const_name(e)); }
|
||
bool is_inductive_app(expr const & e) const { return is_inductive(get_app_fn(e)); }
|
||
|
||
void get_constructors_of(name const & n, buffer<name> & c_names) const {
|
||
lean_assert(is_inductive(n));
|
||
get_constructor_names(m_env, n, c_names);
|
||
}
|
||
|
||
/* Return true iff `e` is of the form (I.below ...) or (I.ibelow ...) where `I` is an inductive datatype.
|
||
Move to a different module? */
|
||
bool is_below_type(expr const & e) const {
|
||
expr const & fn = get_app_fn(e);
|
||
if (!is_constant(fn)) return false;
|
||
name const & fn_name = const_name(fn);
|
||
if (fn_name.is_atomic() || !fn_name.is_string()) return false;
|
||
std::string s = fn_name.get_string().to_std_string();
|
||
return is_inductive(fn_name.get_prefix()) && (s == "below" || s == "ibelow");
|
||
}
|
||
|
||
/* Return true iff I_name is an enumeration type with more than 2 elements.
|
||
|
||
\remark It is not worth to apply the if-then-else compilation step if the enumeration
|
||
types has only 2 (or 1) element. */
|
||
bool is_nontrivial_enum(name const & I_name) {
|
||
if (auto r = m_enum.find(I_name))
|
||
return *r;
|
||
buffer<name> c_names;
|
||
get_constructors_of(I_name, c_names);
|
||
bool result = true;
|
||
if (c_names.size() <= 2) {
|
||
result = false;
|
||
} else {
|
||
for (name const & c_name : c_names) {
|
||
constant_info d = m_env.get(c_name);
|
||
expr type = d.get_type();
|
||
if (!is_constant(type) || const_name(type) != I_name) {
|
||
result = false;
|
||
break;
|
||
}
|
||
}
|
||
}
|
||
m_enum.insert(I_name, result);
|
||
return result;
|
||
}
|
||
|
||
bool is_value(type_context_old & ctx, expr const & e) {
|
||
try {
|
||
if (!m_use_ite) return false;
|
||
if (is_nat_int_char_string_name_value(ctx, e)) return true;
|
||
// TODO(Leo, Sebastian): decide whether we ever want to have this behavior back
|
||
// if (optional<name> I_name = is_constructor(e)) return is_nontrivial_enum(*I_name);
|
||
return false;
|
||
} catch (exception &) {
|
||
lean_unreachable();
|
||
}
|
||
}
|
||
|
||
bool is_finite_value(type_context_old & ctx, expr const & e) {
|
||
lean_assert(is_value(ctx, e));
|
||
return is_char_value(ctx, e);
|
||
}
|
||
|
||
unsigned get_inductive_num_params(name const & n) const { return m_env.get(n).to_inductive_val().get_nparams(); }
|
||
unsigned get_inductive_num_params(expr const & I) const { return get_inductive_num_params(const_name(I)); }
|
||
|
||
/* Normalize until head is constructor or value */
|
||
expr whnf_pattern(type_context_old & ctx, expr const & e) {
|
||
if (is_inaccessible(e)) {
|
||
return e;
|
||
} else if (is_value(ctx, e)) {
|
||
return e;
|
||
} else {
|
||
/* The case is_value(ctx, e) above is needed because whnf_head_pred does not check the given predicate
|
||
before unfolding projections. Moreover, some values are projections applications.
|
||
Example:
|
||
(@has_zero.zero nat nat.has_zero)
|
||
(@has_one.one nat nat.has_one)
|
||
*/
|
||
return ctx.whnf_head_pred(e, [&](expr const & e) {
|
||
return !is_constructor_app(ctx, e) && !is_value(ctx, e);
|
||
});
|
||
}
|
||
}
|
||
|
||
/* Normalize until head is constructor */
|
||
expr whnf_constructor(type_context_old & ctx, expr const & e) {
|
||
return ctx.whnf_head_pred(e, [&](expr const & e) {
|
||
return !is_constructor_app(ctx, e);
|
||
});
|
||
}
|
||
|
||
/* Normalize until head is an inductive datatype */
|
||
expr whnf_inductive(type_context_old & ctx, expr const & e) {
|
||
return ctx.whnf_head_pred(e, [&](expr const & e) {
|
||
return !is_inductive_app(e);
|
||
});
|
||
}
|
||
|
||
optional<equation> mk_equation(local_context const & lctx, expr const & eqn, unsigned idx) {
|
||
expr it = eqn;
|
||
it = binding_body(it); /* consume fn header */
|
||
if (is_no_equation(it)) return optional<equation>();
|
||
type_context_old ctx = mk_type_context(lctx);
|
||
buffer<expr> locals;
|
||
while (is_lambda(it)) {
|
||
expr type = instantiate_rev(binding_domain(it), locals);
|
||
expr local = ctx.push_local(binding_name(it), type);
|
||
locals.push_back(local);
|
||
it = binding_body(it);
|
||
}
|
||
lean_assert(is_equation(it));
|
||
equation E;
|
||
bool ignore_if_unused = ignore_equation_if_unused(it);
|
||
m_used_eqns.push_back(ignore_if_unused);
|
||
E.m_vars = to_list(locals);
|
||
E.m_lctx = ctx.lctx();
|
||
E.m_rhs = instantiate_rev(equation_rhs(it), locals);
|
||
/* The function being defined is not recursive. So, E.m_rhs
|
||
must be closed even if we "consumed" the fn header in
|
||
the beginning of the method. */
|
||
lean_assert(!has_loose_bvars(E.m_rhs));
|
||
buffer<expr> lhs_args;
|
||
get_app_args(equation_lhs(it), lhs_args);
|
||
buffer<expr> patterns;
|
||
for (expr & arg : lhs_args) {
|
||
arg = instantiate_rev(arg, locals);
|
||
patterns.push_back(whnf_pattern(ctx, arg));
|
||
}
|
||
E.m_lhs_args = to_list(lhs_args);
|
||
E.m_patterns = to_list(patterns);
|
||
E.m_ref = eqn;
|
||
E.m_eqn_idx = idx;
|
||
return optional<equation>(E);
|
||
}
|
||
|
||
list<equation> mk_equations(local_context const & lctx, buffer<expr> const & eqns) {
|
||
buffer<equation> R;
|
||
unsigned idx = 0;
|
||
for (expr const & eqn : eqns) {
|
||
if (auto r = mk_equation(lctx, eqn, idx)) {
|
||
R.push_back(*r);
|
||
lean_assert(length(R[0].m_patterns) == length(r->m_patterns));
|
||
} else {
|
||
lean_assert(eqns.size() == 1);
|
||
return list<equation>();
|
||
}
|
||
idx++;
|
||
}
|
||
return to_list(R);
|
||
}
|
||
|
||
unsigned get_eqns_arity(local_context const & lctx, expr const & eqns) {
|
||
/* Naive way to retrieve the arity of the function being defined */
|
||
lean_assert(is_equations(eqns));
|
||
type_context_old ctx = mk_type_context(lctx);
|
||
unpack_eqns ues(ctx, eqns);
|
||
return ues.get_arity_of(0);
|
||
}
|
||
|
||
pair<problem, expr> mk_problem(local_context const & lctx, expr const & e) {
|
||
lean_assert(is_equations(e));
|
||
buffer<expr> eqns;
|
||
to_equations(e, eqns);
|
||
problem P;
|
||
P.m_fn_name = binding_name(eqns[0]);
|
||
expr fn_type = binding_domain(eqns[0]);
|
||
expr mvar = m_mctx.mk_metavar_decl(lctx, fn_type);
|
||
unsigned arity = get_eqns_arity(lctx, e);
|
||
buffer<name> var_names;
|
||
bool use_unused_names = false;
|
||
optional<expr> goal = intron(m_env, get_options(), m_mctx, mvar,
|
||
arity, var_names, use_unused_names);
|
||
if (!goal) throw_ill_formed_eqns();
|
||
P.m_goal = *goal;
|
||
local_context goal_lctx = get_local_context(*goal);
|
||
buffer<expr> vars;
|
||
for (name const & n : var_names)
|
||
vars.push_back(goal_lctx.get_local_decl(n).mk_ref());
|
||
P.m_var_stack = to_list(vars);
|
||
P.m_example = P.m_var_stack;
|
||
P.m_equations = mk_equations(lctx, eqns);
|
||
return mk_pair(P, mvar);
|
||
}
|
||
|
||
/* Return true iff the next element in the variable stack is a variable.
|
||
\remark It may not be because of dependent pattern matching. */
|
||
bool is_next_var(problem const & P) {
|
||
lean_assert(P.m_var_stack);
|
||
expr const & x = head(P.m_var_stack);
|
||
return is_local(x);
|
||
}
|
||
|
||
template<typename Pred>
|
||
bool all_equations(problem const & P, Pred && p) const {
|
||
for (equation const & eqn : P.m_equations) {
|
||
if (!p(eqn))
|
||
return false;
|
||
}
|
||
return true;
|
||
}
|
||
|
||
template<typename Pred>
|
||
bool all_next_pattern(problem const & P, Pred && p) const {
|
||
return all_equations(P, [&](equation const & eqn) {
|
||
lean_assert(eqn.m_patterns);
|
||
return p(head(eqn.m_patterns));
|
||
});
|
||
}
|
||
|
||
/* Return true iff the next pattern in all equations is a variable. */
|
||
bool is_variable_transition(problem const & P) const {
|
||
return all_next_pattern(P, is_local);
|
||
}
|
||
|
||
/* Return true iff the next pattern in all equations is a constructor. */
|
||
bool is_constructor_transition(problem const & P) {
|
||
return all_equations(P, [&](equation const & eqn) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
if (is_constructor_app(ctx, p))
|
||
return true;
|
||
return is_value(ctx, p);
|
||
});
|
||
}
|
||
|
||
/* Return true iff the next pattern of every equation is a constructor or variable,
|
||
and there is at least one equation where it is a variable and another where it is a
|
||
constructor. */
|
||
bool is_complete_transition(problem const & P) {
|
||
bool has_variable = false;
|
||
bool has_constructor = false;
|
||
bool r = all_equations(P, [&](equation const & eqn) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (is_local(p)) {
|
||
has_variable = true; return true;
|
||
}
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
if (is_constructor_app(ctx, p)) {
|
||
has_constructor = true; return true;
|
||
}
|
||
if (is_value(ctx, p)) {
|
||
has_constructor = true; return true;
|
||
}
|
||
return false;
|
||
});
|
||
return r && has_variable && has_constructor;
|
||
}
|
||
|
||
/* Return true if the next pattern of every equation is a value or variable,
|
||
and there are at least one equation where it is a variable and another where it is a
|
||
value.
|
||
|
||
We also perform a value transition if one of the next patterns is a
|
||
string literal.
|
||
*/
|
||
bool is_value_transition(problem const & P) {
|
||
bool has_value = false;
|
||
bool has_variable = false;
|
||
bool has_finite_value = false;
|
||
bool r = all_equations(P, [&](equation const & eqn) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (is_local(p)) {
|
||
has_variable = true; return true;
|
||
} else {
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
if (is_value(ctx, p)) {
|
||
has_value = true;
|
||
if (is_finite_value(ctx, p))
|
||
has_finite_value = true;
|
||
return true;
|
||
} else {
|
||
return false;
|
||
}
|
||
}
|
||
});
|
||
if (!r || !has_value)
|
||
return false;
|
||
if (!has_variable && has_finite_value)
|
||
return false;
|
||
type_context_old ctx = mk_type_context(P);
|
||
/* Check whether other variables on the variable stack depend on the head. */
|
||
expr const & v = head(P.m_var_stack);
|
||
if (depends_on(ctx.infer(P.m_goal), v)) {
|
||
trace_match(tout() << "variable transition is not used because the target depends on '" << v << "'\n";);
|
||
return false;
|
||
}
|
||
for (expr const & w : tail(P.m_var_stack)) {
|
||
expr w_type = ctx.instantiate_mvars(ctx.infer(w));
|
||
if (depends_on(w_type, v)) {
|
||
trace_match(tout() << "variable transition is not used because type of '" << w << "' depends on '" << v << "'\n";);
|
||
return false;
|
||
}
|
||
}
|
||
return true;
|
||
}
|
||
|
||
/** Return true iff the next pattern of some equations is an inaccessible term */
|
||
bool some_inaccessible(problem const & P) const {
|
||
for (equation const & eqn : P.m_equations) {
|
||
lean_assert(eqn.m_patterns);
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (is_inaccessible(p))
|
||
return true;
|
||
}
|
||
return false;
|
||
}
|
||
|
||
/** Return true iff the next pattern of all equations is an inaccessible term */
|
||
bool all_inaccessible(problem const & P) const {
|
||
for (equation const & eqn : P.m_equations) {
|
||
lean_assert(eqn.m_patterns);
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (!is_inaccessible(p))
|
||
return false;
|
||
}
|
||
return true;
|
||
}
|
||
|
||
/* Return true iff the next pattern in some of the equations is an inaccessible term. */
|
||
bool is_inaccessible_transition(problem const & P) const {
|
||
return some_inaccessible(P);
|
||
}
|
||
|
||
/** Replace local `x` in `e` with `renaming.find(x)` */
|
||
expr apply_renaming(expr const & e, name_map<expr> const & renaming) {
|
||
return replace(e, [&](expr const & e, unsigned) {
|
||
if (is_local(e)) {
|
||
if (auto new_e = renaming.find(local_name(e)))
|
||
return some_expr(*new_e);
|
||
}
|
||
return none_expr();
|
||
});
|
||
}
|
||
|
||
expr get_next_pattern_of(list<equation> const & eqns, unsigned eqn_idx) {
|
||
for (equation const & eqn : eqns) {
|
||
lean_assert(eqn.m_patterns);
|
||
if (eqn.m_eqn_idx == eqn_idx)
|
||
return head(eqn.m_patterns);
|
||
}
|
||
lean_unreachable();
|
||
}
|
||
|
||
hsubstitution add_subst(hsubstitution subst, expr const & src, expr const & target) {
|
||
lean_assert(is_local(src));
|
||
if (!subst.contains(local_name(src)))
|
||
subst.insert(local_name(src), target);
|
||
return subst;
|
||
}
|
||
|
||
/* Variable and Inaccessible transition are very similar, this method implements both. */
|
||
list<lemma> process_variable_inaccessible(problem const & P, bool is_var_transition) {
|
||
lean_assert(is_variable_transition(P) || is_inaccessible_transition(P));
|
||
lean_assert(is_var_transition == is_variable_transition(P));
|
||
problem new_P;
|
||
new_P.m_fn_name = P.m_fn_name;
|
||
new_P.m_goal = P.m_goal;
|
||
new_P.m_var_stack = tail(P.m_var_stack);
|
||
new_P.m_example = P.m_example;
|
||
buffer<equation> new_eqns;
|
||
for (equation const & eqn : P.m_equations) {
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_patterns = tail(eqn.m_patterns);
|
||
if (is_var_transition || is_local(head(eqn.m_patterns))) {
|
||
new_eqn.m_subst = add_subst(eqn.m_subst, head(eqn.m_patterns), head(P.m_var_stack));
|
||
}
|
||
new_eqns.push_back(new_eqn);
|
||
}
|
||
new_P.m_equations = to_list(new_eqns);
|
||
return process(new_P);
|
||
}
|
||
|
||
list<lemma> process_variable(problem const & P) {
|
||
trace_match(tout() << "step: variables only\n";);
|
||
return process_variable_inaccessible(P, true);
|
||
}
|
||
|
||
list<lemma> process_inaccessible(problem const & P) {
|
||
trace_match(tout() << "step: inaccessible terms only\n";);
|
||
return process_variable_inaccessible(P, false);
|
||
}
|
||
|
||
/* Make sure then next pattern of each equation is a constructor application. */
|
||
list<equation> normalize_next_pattern(list<equation> const & eqns) {
|
||
buffer<equation> R;
|
||
for (equation const & eqn : eqns) {
|
||
lean_assert(eqn.m_patterns);
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
/* Remark: reverted bcf44f7020, see issue #1739 */
|
||
/* expr pattern = whnf_constructor(ctx, head(eqn.m_patterns)); */
|
||
/* We use ctx.relaxed_whnf to make sure we expose the kernel constructor */
|
||
expr pattern = ctx.relaxed_whnf(head(eqn.m_patterns));
|
||
if (!is_constructor_app(ctx, pattern)) {
|
||
throw_error("equation compiler failed, pattern is not a constructor "
|
||
"(use 'set_option trace.eqn_compiler.elim_match true' for additional details)");
|
||
}
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_patterns = cons(pattern, tail(eqn.m_patterns));
|
||
R.push_back(new_eqn);
|
||
}
|
||
return to_list(R);
|
||
}
|
||
|
||
/* Append `ilist` and `var_stack`. Due to dependent pattern matching, ilist may contain terms that
|
||
are not local constants. */
|
||
list<expr> update_var_stack(list<expr> const & ilist, list<expr> const & var_stack, hsubstitution const & subst) {
|
||
buffer<expr> new_var_stack;
|
||
for (expr const & e : ilist) {
|
||
new_var_stack.push_back(e);
|
||
}
|
||
for (expr const & v : var_stack) {
|
||
new_var_stack.push_back(apply(v, subst));
|
||
}
|
||
return to_list(new_var_stack);
|
||
}
|
||
|
||
/* eqns is the data-structured returned by distribute_constructor_equations.
|
||
This method selects the ones such that eqns[i].first == C.
|
||
It also updates eqns[i].second.m_subst using \c new_subst.
|
||
It also "replaces" the next pattern (a constructor) with its fields.
|
||
|
||
The map \c new_subst is produced by the `cases` tactic.
|
||
It is needed because the `cases` tactic may revert and reintroduce hypothesis that
|
||
depend on the hypothesis being destructed. */
|
||
list<equation> get_equations_for(name const & C, unsigned nparams, hsubstitution const & new_subst,
|
||
list<equation> const & eqns) {
|
||
buffer<equation> R;
|
||
for (equation const & eqn : eqns) {
|
||
expr pattern = head(eqn.m_patterns);
|
||
buffer<expr> pattern_args;
|
||
expr const & C2 = get_app_args(pattern, pattern_args);
|
||
if (!is_constant(C2, C)) continue;
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_subst = apply(eqn.m_subst, new_subst);
|
||
/* Update patterns */
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
for (unsigned i = nparams; i < pattern_args.size(); i++)
|
||
pattern_args[i] = whnf_pattern(ctx, pattern_args[i]);
|
||
new_eqn.m_patterns = to_list(pattern_args.begin() + nparams, pattern_args.end(), tail(eqn.m_patterns));
|
||
R.push_back(new_eqn);
|
||
}
|
||
return to_list(R);
|
||
}
|
||
|
||
optional<list<lemma>> process_constructor_core(problem const & P, bool fail_if_subgoals) {
|
||
trace_match(tout() << "step: constructors only\n";);
|
||
lean_assert(is_constructor_transition(P));
|
||
type_context_old ctx = mk_type_context(P);
|
||
expr x = head(P.m_var_stack);
|
||
/* Remark: reverted bcf44f7020, see issue #1739 */
|
||
/* expr x_type = whnf_inductive(ctx, ctx.infer(x)); */
|
||
expr x_type = ctx.relaxed_whnf(whnf_inductive(ctx, ctx.infer(x)));
|
||
lean_assert(is_inductive_app(x_type));
|
||
buffer<expr> x_type_args;
|
||
auto x_type_const = get_app_args(x_type, x_type_args);
|
||
name I_name = const_name(x_type_const);
|
||
unsigned I_nparams = get_inductive_num_params(I_name);
|
||
lean_assert(x_type_args.size() >= I_nparams);
|
||
intros_list ilist;
|
||
hsubstitution_list slist;
|
||
list<expr> new_goals;
|
||
names new_goal_cnames;
|
||
try {
|
||
names ids;
|
||
std::tie(new_goals, new_goal_cnames) =
|
||
cases(m_env, get_options(), transparency_mode::Semireducible, m_mctx,
|
||
P.m_goal, x, ids, &ilist, &slist);
|
||
lean_assert(length(new_goals) == length(new_goal_cnames));
|
||
lean_assert(length(new_goals) == length(ilist));
|
||
lean_assert(length(new_goals) == length(slist));
|
||
} catch (exception & ex) {
|
||
throw nested_exception("equation compiler failed (use 'set_option trace.eqn_compiler.elim_match true' "
|
||
"for additional details)", std::current_exception());
|
||
}
|
||
if (empty(new_goals)) {
|
||
return some(list<lemma>());
|
||
} else if (fail_if_subgoals) {
|
||
return optional<list<lemma>>();
|
||
}
|
||
list<equation> eqns = normalize_next_pattern(P.m_equations);
|
||
buffer<lemma> new_Ls;
|
||
while (new_goals) {
|
||
lean_assert(new_goal_cnames && ilist && slist);
|
||
problem new_P;
|
||
new_P.m_fn_name = name(P.m_fn_name, head(new_goal_cnames).get_string());
|
||
expr new_goal = head(new_goals);
|
||
new_P.m_var_stack = update_var_stack(head(ilist), tail(P.m_var_stack), head(slist));
|
||
new_P.m_goal = new_goal;
|
||
name const & C = head(new_goal_cnames);
|
||
new_P.m_example = map(P.m_example, [&] (expr ex) {
|
||
ex = instantiate(abstract(ex, head(P.m_var_stack)),
|
||
mk_app(mk_app(mk_constant(C, const_levels(x_type_const)), I_nparams, x_type_args.begin()), head(ilist)));
|
||
ex = apply(ex, head(slist));
|
||
return ex;
|
||
});
|
||
new_P.m_equations = get_equations_for(C, I_nparams, head(slist), eqns);
|
||
to_buffer(process(new_P), new_Ls);
|
||
new_goals = tail(new_goals);
|
||
new_goal_cnames = tail(new_goal_cnames);
|
||
ilist = tail(ilist);
|
||
slist = tail(slist);
|
||
}
|
||
return some(to_list(new_Ls));
|
||
}
|
||
|
||
list<lemma> process_constructor(problem const & P) {
|
||
bool fail_if_subgoals = false;
|
||
return *process_constructor_core(P, fail_if_subgoals);
|
||
}
|
||
|
||
list<lemma> process_value(problem const & P) {
|
||
trace_match(tout() << "step: if-then-else\n";);
|
||
bool is_last = !tail(P.m_var_stack);
|
||
expr x = head(P.m_var_stack);
|
||
local_context lctx = get_local_context(P.m_goal);
|
||
type_context_old ctx = mk_type_context(P);
|
||
expr goal_type = ctx.infer(P.m_goal);
|
||
expr else_goal = ctx.mk_metavar_decl(lctx, goal_type);
|
||
buffer<expr> values;
|
||
buffer<expr> value_goals;
|
||
buffer<expr> eqs;
|
||
for (equation const & eqn : P.m_equations) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (is_last && is_local(p))
|
||
break;
|
||
if (!is_local(p) &&
|
||
std::find(values.begin(), values.end(), p) == values.end()) {
|
||
values.push_back(p);
|
||
value_goals.push_back(ctx.mk_metavar_decl(lctx, goal_type));
|
||
expr const & eq = mk_eq(ctx, x, p);
|
||
eqs.push_back(eq);
|
||
}
|
||
}
|
||
expr goal_val = else_goal;
|
||
unsigned i = value_goals.size();
|
||
while (i > 0) {
|
||
--i;
|
||
goal_val = mk_ite(ctx, eqs[i], value_goals[i], goal_val);
|
||
}
|
||
m_mctx = ctx.mctx();
|
||
m_mctx.assign(P.m_goal, goal_val);
|
||
buffer<lemma> new_Ls;
|
||
for (unsigned i = 0; i < values.size(); i++) {
|
||
/* Process equations for values[i] */
|
||
problem new_P;
|
||
expr val = values[i];
|
||
new_P.m_fn_name = name(P.m_fn_name, "_ite_val");
|
||
new_P.m_goal = value_goals[i];
|
||
new_P.m_var_stack = tail(P.m_var_stack);
|
||
new_P.m_example = cons(val, tail(P.m_example));
|
||
buffer<equation> new_eqns;
|
||
for (equation const & eqn : P.m_equations) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (p == val) {
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_patterns = tail(new_eqn.m_patterns);
|
||
new_eqns.push_back(new_eqn);
|
||
if (is_last) break;
|
||
} else if (is_local(p)) {
|
||
/* Replace variable `p` with `val` in this equation */
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
buffer<expr> from;
|
||
buffer<expr> to;
|
||
buffer<expr> new_vars;
|
||
for (expr const & curr : eqn.m_vars) {
|
||
if (curr == p) {
|
||
from.push_back(p);
|
||
to.push_back(val);
|
||
} else {
|
||
expr curr_type = ctx.infer(curr);
|
||
expr new_curr_type = replace_locals(curr_type, from, to);
|
||
if (curr_type == new_curr_type) {
|
||
new_vars.push_back(curr);
|
||
} else {
|
||
expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type);
|
||
from.push_back(curr);
|
||
to.push_back(new_curr);
|
||
new_vars.push_back(new_curr);
|
||
}
|
||
}
|
||
}
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_vars = to_list(new_vars);
|
||
new_eqn.m_lctx = ctx.lctx();
|
||
new_eqn.m_lhs_args = map(eqn.m_lhs_args, [&](expr const & arg) {
|
||
return replace_locals(arg, from, to); });
|
||
new_eqn.m_rhs = replace_locals(eqn.m_rhs, from, to);
|
||
new_eqn.m_patterns = map(tail(eqn.m_patterns), [&](expr const & p) {
|
||
return replace_locals(p, from, to); });
|
||
new_eqns.push_back(new_eqn);
|
||
if (is_last) break;
|
||
}
|
||
}
|
||
new_P.m_equations = to_list(new_eqns);
|
||
to_buffer(process(new_P), new_Ls);
|
||
}
|
||
{
|
||
/* Else-case */
|
||
problem new_P;
|
||
new_P.m_fn_name = name(P.m_fn_name, "_ite_else");
|
||
new_P.m_goal = else_goal;
|
||
new_P.m_var_stack = tail(P.m_var_stack);
|
||
new_P.m_example = P.m_example;
|
||
buffer<equation> new_eqns;
|
||
for (equation const & eqn : P.m_equations) {
|
||
expr const & p = head(eqn.m_patterns);
|
||
if (is_local(p)) {
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_patterns = tail(new_eqn.m_patterns);
|
||
new_eqn.m_subst = add_subst(eqn.m_subst, p, x);
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
new_eqn.m_hs = eqn.m_hs;
|
||
unsigned idx = length(eqn.m_hs) + 1;
|
||
for (unsigned i = 0; i < values.size(); i++) {
|
||
expr eq = mk_eq(ctx, p, values[i]);
|
||
expr ne = mk_not(ctx, eq);
|
||
expr H = ctx.push_local(name("_h").append_after(idx), ne);
|
||
idx++;
|
||
new_eqn.m_hs = cons(H, new_eqn.m_hs);
|
||
}
|
||
new_eqn.m_lctx = ctx.lctx();
|
||
new_eqns.push_back(new_eqn);
|
||
if (is_last) break;
|
||
}
|
||
}
|
||
new_P.m_equations = to_list(new_eqns);
|
||
to_buffer(process(new_P), new_Ls);
|
||
}
|
||
return to_list(new_Ls);
|
||
}
|
||
|
||
list<lemma> process_complete(problem const & P) {
|
||
lean_assert(is_complete_transition(P));
|
||
trace_match(tout() << "step: variables and constructors\n";);
|
||
/* The next pattern of every equation is a constructor or variable.
|
||
We split the equations where the next pattern is a variable into cases.
|
||
That is, we are reducing this case to the compile_constructor case. */
|
||
buffer<equation> new_eqns;
|
||
for (equation const & eqn : P.m_equations) {
|
||
expr const & pattern = head(eqn.m_patterns);
|
||
if (is_local(pattern)) {
|
||
type_context_old ctx = mk_type_context(eqn.m_lctx);
|
||
for_each_compatible_constructor(ctx, pattern,
|
||
[&](expr const & c, buffer<expr> const & new_c_vars) {
|
||
expr var = pattern;
|
||
/* We are replacing `var` with `c` */
|
||
buffer<expr> vars; to_buffer(eqn.m_vars, vars);
|
||
buffer<expr> new_vars;
|
||
buffer<expr> from;
|
||
buffer<expr> to;
|
||
update_telescope(ctx, vars, var, c, new_c_vars, new_vars, from, to);
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_lctx = ctx.lctx();
|
||
new_eqn.m_vars = to_list(new_vars);
|
||
new_eqn.m_lhs_args = map(eqn.m_lhs_args, [&](expr const & arg) {
|
||
return replace_locals(arg, from, to); });
|
||
new_eqn.m_rhs = replace_locals(eqn.m_rhs, from, to);
|
||
new_eqn.m_patterns =
|
||
cons(c, map(tail(eqn.m_patterns), [&](expr const & p) {
|
||
return replace_locals(p, from, to); }));
|
||
new_eqns.push_back(new_eqn);
|
||
});
|
||
} else {
|
||
new_eqns.push_back(eqn);
|
||
}
|
||
}
|
||
problem new_P = P;
|
||
new_P.m_equations = to_list(new_eqns);
|
||
return process(new_P);
|
||
}
|
||
|
||
list<lemma> process_no_equation(problem const & P) {
|
||
if (!is_next_var(P)) {
|
||
return process_variable(P);
|
||
} else {
|
||
type_context_old ctx = mk_type_context(P);
|
||
expr x = head(P.m_var_stack);
|
||
expr arg_type = ctx.infer(x);
|
||
if (is_below_type(arg_type)) {
|
||
return process_variable(P);
|
||
} else {
|
||
expr I = whnf_inductive(ctx, arg_type);
|
||
if (is_inductive_app(I)) {
|
||
metavar_context saved_mctx = m_mctx;
|
||
bool fail_if_subgoals = is_recursive_datatype(m_env, const_name(get_app_fn(I)));
|
||
optional<list<lemma>> r;
|
||
try {
|
||
r = process_constructor_core(P, fail_if_subgoals);
|
||
} catch (exception &) {}
|
||
if (r) {
|
||
return list<lemma>();
|
||
} else {
|
||
/* Process_constructor_core produced subgoals for recursive datatype,
|
||
this may produce non-termination. So, if fail and handle it as a variable case. */
|
||
m_mctx = saved_mctx;
|
||
return process_variable(P);
|
||
}
|
||
} else {
|
||
return process_variable(P);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
list<lemma> process_non_variable(problem const & P) {
|
||
expr p = head(P.m_var_stack);
|
||
lean_assert(!is_local(p));
|
||
type_context_old ctx = mk_type_context(P);
|
||
if (all_inaccessible(P)) {
|
||
trace_match(tout() << "step: skip inaccessible patterns\n";);
|
||
problem new_P;
|
||
new_P.m_fn_name = P.m_fn_name;
|
||
new_P.m_goal = P.m_goal;
|
||
new_P.m_example = P.m_example;
|
||
new_P.m_var_stack = tail(P.m_var_stack);
|
||
buffer<equation> new_eqns;
|
||
for (equation const & eqn : P.m_equations) {
|
||
equation new_eqn = eqn;
|
||
new_eqn.m_patterns = tail(eqn.m_patterns);
|
||
new_eqns.push_back(new_eqn);
|
||
}
|
||
new_P.m_equations = to_list(new_eqns);
|
||
return process(new_P);
|
||
} else {
|
||
trace_match(tout() << "step: filter equations using constructor\n";);
|
||
p = whnf_constructor(ctx, p);
|
||
if (!is_constructor_app(ctx, p)) {
|
||
throw_error("dependent pattern matching result is not a constructor application "
|
||
"(use 'set_option trace.eqn_compiler.elim_match true' "
|
||
"for additional details)");
|
||
}
|
||
expr p_type = whnf_inductive(ctx, ctx.infer(p));
|
||
lean_assert(is_inductive_app(p_type));
|
||
name I_name = const_name(get_app_fn(p_type));
|
||
unsigned I_nparams = get_inductive_num_params(I_name);
|
||
buffer<expr> C_args;
|
||
expr const & C = get_app_args(p, C_args);
|
||
list<equation> eqns = normalize_next_pattern(P.m_equations);
|
||
problem new_P;
|
||
new_P.m_fn_name = P.m_fn_name;
|
||
new_P.m_goal = P.m_goal;
|
||
new_P.m_example = P.m_example;
|
||
buffer<expr> new_var_stack;
|
||
for (unsigned i = I_nparams; i < C_args.size(); i++) {
|
||
new_var_stack.push_back(whnf_constructor(ctx, C_args[i]));
|
||
}
|
||
to_buffer(tail(P.m_var_stack), new_var_stack);
|
||
new_P.m_var_stack = to_list(new_var_stack);
|
||
new_P.m_equations = get_equations_for(const_name(C), I_nparams, hsubstitution(), eqns);
|
||
return process(new_P);
|
||
}
|
||
}
|
||
|
||
/* Create (f ... x) with the given arity, where the other arguments are inferred using
|
||
type inference */
|
||
expr mk_app_with_arity(type_context_old & ctx, name const & f, unsigned arity, expr const & x) {
|
||
buffer<bool> mask;
|
||
mask.resize(arity - 1, false);
|
||
mask.push_back(true);
|
||
try {
|
||
return mk_app(ctx, f, mask.size(), mask.data(), &x);
|
||
} catch (app_builder_exception &) {
|
||
throw_error(sstream() << "equation compiler failed, when trying to build "
|
||
<< "'" << f << "' application (use 'set_option trace.eqn_compiler.elim_match true' "
|
||
<< "for additional details)");
|
||
}
|
||
}
|
||
|
||
list<lemma> process_leaf(problem const & P) {
|
||
if (!P.m_equations) {
|
||
m_unsolved.push_back(P);
|
||
type_context_old ctx = mk_type_context(P);
|
||
m_mctx.assign(P.m_goal, mk_sorry(ctx, m_mctx.get_metavar_decl(P.m_goal).get_type(), true));
|
||
return list<lemma>();
|
||
}
|
||
equation const & eqn = head(P.m_equations);
|
||
m_used_eqns[eqn.m_eqn_idx] = true;
|
||
expr rhs = apply(eqn.m_rhs, eqn.m_subst);
|
||
if (m_env.find(get_id_rhs_name())) {
|
||
/* We wrap the rhs with `id_rhs` to solve a performance problem related to whnf_ite when proving
|
||
the equational lemmas.
|
||
|
||
We use `id_rhs` as a marker at whnf_ite. The goal is to stop whnf computation as soon as we find
|
||
an `id_rhs` application at whnf_ite.
|
||
|
||
Remark: `id_rhs` is defined using `abbrev` hint. So, the is_def_eq procedure in the kernel
|
||
is not affected by it. That is, a problem
|
||
t =?= id_rhs s
|
||
is reduced to
|
||
t =?= s
|
||
|
||
|
||
Remark: we also use `id_rhs` to implement "smart reduction" at type_context_old.
|
||
*/
|
||
type_context_old ctx = mk_type_context(P);
|
||
rhs = mk_id_rhs(ctx, rhs);
|
||
}
|
||
m_mctx.assign(P.m_goal, rhs);
|
||
if (m_aux_lemmas) {
|
||
lemma L;
|
||
L.m_lctx = eqn.m_lctx;
|
||
L.m_vars = eqn.m_vars;
|
||
L.m_hs = eqn.m_hs;
|
||
L.m_lhs_args = erase_inaccessible_annotations(eqn.m_lhs_args);
|
||
L.m_rhs = eqn.m_rhs;
|
||
L.m_eqn_idx = eqn.m_eqn_idx;
|
||
return to_list(L);
|
||
} else {
|
||
return list<lemma>();
|
||
}
|
||
}
|
||
|
||
bool is_var_only_lhs(equation const & eqn) {
|
||
for (expr const & arg : eqn.m_patterns) {
|
||
if (!is_fvar(arg))
|
||
return false;
|
||
}
|
||
return true;
|
||
}
|
||
|
||
problem truncate_after_var_only_lhs(problem P) {
|
||
buffer<equation> new_eqs;
|
||
bool found = false;
|
||
for (equation const & eqn : P.m_equations) {
|
||
new_eqs.push_back(eqn);
|
||
if (is_var_only_lhs(eqn)) {
|
||
found = true;
|
||
break;
|
||
}
|
||
}
|
||
if (!found) return P;
|
||
P.m_equations = to_list(new_eqs);
|
||
return P;
|
||
}
|
||
|
||
|
||
list<lemma> process(problem P) {
|
||
flet<unsigned> inc_depth(m_depth, m_depth+1);
|
||
P = truncate_after_var_only_lhs(P);
|
||
trace_match(tout() << "depth [" << m_depth << "]\n" << pp_problem(P) << "\n";);
|
||
lean_assert(check_problem(P));
|
||
m_num_steps++;
|
||
try {
|
||
if (m_num_steps > m_max_steps) {
|
||
throw_error(sstream() << "equation compiler failed, maximum number of steps (" << m_max_steps << ") exceeded"
|
||
<< " (possible solution: use 'set_option eqn_compiler.max_steps <new-threshold>')"
|
||
<< " (use 'set_option trace.eqn_compiler.elim_match true' for additional details)");
|
||
}
|
||
if (P.m_var_stack) {
|
||
if (!P.m_equations) {
|
||
return process_no_equation(P);
|
||
} else if (!is_next_var(P)) {
|
||
return process_non_variable(P);
|
||
} else if (is_variable_transition(P)) {
|
||
return process_variable(P);
|
||
} else if (is_value_transition(P)) {
|
||
return process_value(P);
|
||
} else if (is_complete_transition(P)) {
|
||
return process_complete(P);
|
||
} else if (is_constructor_transition(P)) {
|
||
return process_constructor(P);
|
||
} else if (is_inaccessible_transition(P)) {
|
||
return process_inaccessible(P);
|
||
} else {
|
||
trace_match(tout() << "compilation failed at\n" << pp_problem(P) << "\n";);
|
||
throw_error("equation compiler failed (use 'set_option trace.eqn_compiler.elim_match true' "
|
||
"for additional details)");
|
||
}
|
||
} else {
|
||
return process_leaf(P);
|
||
}
|
||
} catch (exception & ex) {
|
||
if (!m_elab.try_report(ex, some_expr(m_ref))) throw;
|
||
m_error_during_process = true;
|
||
m_mctx.assign(P.m_goal, m_elab.mk_sorry(m_mctx.get_metavar_decl(P.m_goal).get_type(), true));
|
||
}
|
||
return list<lemma>();
|
||
}
|
||
|
||
expr finalize_lemma(expr const & fn, lemma const & L) {
|
||
buffer<expr> args;
|
||
to_buffer(L.m_lhs_args, args);
|
||
type_context_old ctx = mk_type_context(L.m_lctx);
|
||
expr lhs = mk_app(fn, args);
|
||
expr eq = mk_eq(ctx, lhs, L.m_rhs);
|
||
buffer<expr> locals;
|
||
to_buffer(L.m_vars, locals);
|
||
to_buffer(L.m_hs, locals);
|
||
return ctx.mk_pi(locals, eq);
|
||
}
|
||
|
||
list<expr> finalize_lemmas(expr const & fn, list<lemma> const & Ls) {
|
||
return map2<expr>(Ls, [&](lemma const & L) { return finalize_lemma(fn, L); });
|
||
}
|
||
|
||
void check_no_unused_eqns(expr const & eqns) {
|
||
for (unsigned i = 0; i < m_used_eqns.size(); i++) {
|
||
if (!m_used_eqns[i]) {
|
||
buffer<expr> eqns_buffer;
|
||
to_equations(eqns, eqns_buffer);
|
||
/* Check if there is an equation occurring before #i s.t. the lhs is of the form
|
||
(f x)
|
||
where x is a variable */
|
||
unsigned j = 0;
|
||
for (; j < i; j++) {
|
||
expr eqn_j = eqns_buffer[j];
|
||
while (is_lambda(eqn_j))
|
||
eqn_j = binding_body(eqn_j);
|
||
if (is_equation(eqn_j)) {
|
||
buffer<expr> lhs_args;
|
||
get_app_args(equation_lhs(eqn_j), lhs_args);
|
||
if (lhs_args.size() == 1 && is_var(lhs_args[0]))
|
||
break; // found it
|
||
}
|
||
}
|
||
expr ref = eqns_buffer[i];
|
||
while (is_lambda(ref)) ref = binding_body(ref);
|
||
if (j != i) {
|
||
m_elab.report_or_throw(elaborator_exception(ref,
|
||
sstream() << "equation compiler error, equation #" << (i+1)
|
||
<< " has not been used in the compilation, note that the left-hand-side of equation #" << (j+1)
|
||
<< " is a variable"));
|
||
} else {
|
||
m_elab.report_or_throw(elaborator_exception(ref,
|
||
sstream() << "equation compiler error, equation #" << (i+1)
|
||
<< " has not been used in the compilation (possible solution: delete equation)"));
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
list<list<expr>> get_counter_examples() {
|
||
buffer<list<expr>> counter_examples;
|
||
|
||
auto underscore = mk_expr_placeholder();
|
||
for (auto & P : m_unsolved) {
|
||
counter_examples.push_back(map(P.m_example, [&] (expr const & e) {
|
||
return replace(e, [&] (expr const & e, unsigned) {
|
||
if (!has_local(e)) return some_expr(e);
|
||
if (is_local(e)) return some_expr(underscore);
|
||
return none_expr();
|
||
});
|
||
}));
|
||
}
|
||
|
||
return to_list(counter_examples);
|
||
}
|
||
|
||
elim_match_result operator()(local_context const & lctx, expr const & eqns) {
|
||
lean_assert(equations_num_fns(eqns) == 1);
|
||
DEBUG_CODE({
|
||
type_context_old ctx = mk_type_context(lctx);
|
||
lean_assert(!is_recursive_eqns(ctx, eqns));
|
||
});
|
||
m_aux_lemmas = get_equations_header(eqns).m_aux_lemmas;
|
||
m_ref = eqns;
|
||
problem P; expr fn;
|
||
std::tie(P, fn) = mk_problem(lctx, eqns);
|
||
lean_assert(check_problem(P));
|
||
list<lemma> pre_Ls = process(P);
|
||
auto counter_examples = get_counter_examples();
|
||
if (!counter_examples && !m_error_during_process)
|
||
check_no_unused_eqns(eqns);
|
||
/* The method `process` may create many common subexpressions because of wildcards occurring in patterns.
|
||
We reduce this redundancy and improve the performance with the function max_sharing.
|
||
The performace improvement can be observed in the following example:
|
||
```
|
||
universes u
|
||
|
||
inductive node (α : Type u)
|
||
| leaf : node
|
||
| red_node : node → α → node → node
|
||
| black_node : node → α → node → node
|
||
|
||
namespace node
|
||
variable {α : Type u}
|
||
|
||
def balance : node α → α → node α → node α
|
||
| (red_node (red_node a x b) y c) k d := red_node (black_node a x b) y (black_node c k d)
|
||
| (red_node a x (red_node b y c)) k d := red_node (black_node a x b) y (black_node c k d)
|
||
| l k r := black_node l k r
|
||
|
||
end node
|
||
```
|
||
It produces 121 equations.
|
||
At commit 47994fe14ec7982d5b727c4f8a4f29ae9abce95c, `balance` takes 781 ms to be elaborated
|
||
on Leo's office desktop. Most of the time is spent proving equation lemmas.
|
||
The runtime is reduced to 479 ms after we added max_sharing.
|
||
*/
|
||
fn = max_sharing(m_mctx.instantiate_mvars(fn));
|
||
trace_match_debug(tout() << "code:\n" << fn << "\n";);
|
||
list<expr> Ls = finalize_lemmas(fn, pre_Ls);
|
||
return { fn, Ls, counter_examples };
|
||
}
|
||
};
|
||
|
||
elim_match_result elim_match(environment & env, elaborator & elab, metavar_context & mctx,
|
||
local_context const & lctx, expr const & eqns) {
|
||
elim_match_fn elim(env, elab, mctx);
|
||
auto r = elim(lctx, eqns);
|
||
env = elim.m_env;
|
||
return r;
|
||
}
|
||
|
||
static expr get_fn_type_from_eqns(expr const & eqns) {
|
||
/* TODO(Leo): implement more efficient version if needed */
|
||
buffer<expr> eqn_buffer;
|
||
to_equations(eqns, eqn_buffer);
|
||
return binding_domain(eqn_buffer[0]);
|
||
}
|
||
|
||
eqn_compiler_result mk_nonrec(environment & env, elaborator & elab, metavar_context & mctx,
|
||
local_context const & lctx, expr const & eqns) {
|
||
equations_header header = get_equations_header(eqns);
|
||
auto R = elim_match(env, elab, mctx, lctx, eqns);
|
||
if (header.m_is_unsafe || header.m_is_lemma) {
|
||
/* Do not generate auxiliary equation or equational lemmas */
|
||
auto fn = mk_constant(head(header.m_fn_names));
|
||
auto counter_examples = map2<expr>(R.m_counter_examples, [&] (list<expr> const & e) { return mk_app(fn, e); });
|
||
return { {R.m_fn}, counter_examples };
|
||
}
|
||
type_context_old ctx1(env, mctx, lctx, elab.get_cache(), transparency_mode::Semireducible);
|
||
/*
|
||
We should use the type specified at eqns instead of m_ctx.infer(R.m_fn).
|
||
These two types must be definitionally equal, but the shape of
|
||
m_ctx.infer(R.m_fn) may confuse automaton. For example,
|
||
it might be of the form (Pi (_a : nat), (fun x, nat) _a) which is
|
||
definitionally equal to (nat -> nat), but may confuse simplifier and
|
||
congruence closure modules make them "believe" that this is
|
||
a dependent function.
|
||
*/
|
||
expr fn_type = get_fn_type_from_eqns(eqns);
|
||
name fn_name = head(header.m_fn_names);
|
||
name fn_actual_name = head(header.m_fn_actual_names);
|
||
expr fn;
|
||
std::tie(env, fn) = mk_aux_definition(env, elab.get_options(), mctx, lctx, header,
|
||
fn_name, fn_actual_name, fn_type, R.m_fn);
|
||
type_context_old ctx2(env, mctx, lctx, elab.get_cache(), transparency_mode::Semireducible);
|
||
auto counter_examples = map2<expr>(R.m_counter_examples, [&] (list<expr> const & e) { return mk_app(fn, e); });
|
||
return { {fn}, counter_examples };
|
||
}
|
||
|
||
void initialize_elim_match() {
|
||
register_trace_class({"eqn_compiler", "elim_match"});
|
||
register_trace_class({"debug", "eqn_compiler", "elim_match"});
|
||
|
||
g_eqn_compiler_ite = new name{"eqn_compiler", "ite"};
|
||
g_eqn_compiler_max_steps = new name{"eqn_compiler", "max_steps"};
|
||
register_bool_option(*g_eqn_compiler_ite, LEAN_DEFAULT_EQN_COMPILER_ITE,
|
||
"(equation compiler) use if-then-else terms when pattern matching on simple values "
|
||
"(e.g., strings and characters)");
|
||
register_unsigned_option(*g_eqn_compiler_max_steps, LEAN_DEFAULT_EQN_COMPILER_MAX_STEPS,
|
||
"(equation compiler) maximum number of pattern matching compilation steps");
|
||
}
|
||
void finalize_elim_match() {
|
||
delete g_eqn_compiler_ite;
|
||
delete g_eqn_compiler_max_steps;
|
||
}
|
||
}
|