lean4-htt/src/library/tactic/simplifier/util.cpp

340 lines
13 KiB
C++

/*
Copyright (c) 2016 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include <string>
#include "util/sstream.h"
#include "kernel/expr.h"
#include "kernel/instantiate.h"
#include "library/kernel_serializer.h"
#include "library/util.h"
#include "library/constants.h"
#include "library/app_builder.h"
#include "library/tactic/ac_tactics.h"
#include "library/tactic/simplifier/util.h"
namespace lean {
static void get_app_nary_args_core(type_context * tctx_ptr, expr const & op, expr const & e, buffer<expr> & nary_args, bool unsafe) {
lean_assert(unsafe || tctx_ptr);
auto next_op = get_binary_op(e);
if (next_op && (unsafe ? (get_app_fn(*next_op) == get_app_fn(op)) : tctx_ptr->is_def_eq(*next_op, op))) {
get_app_nary_args_core(tctx_ptr, op, app_arg(app_fn(e)), nary_args, unsafe);
get_app_nary_args_core(tctx_ptr, op, app_arg(e), nary_args, unsafe);
} else {
nary_args.push_back(e);
}
}
void unsafe_get_app_nary_args(expr const & op, expr const & e, buffer<expr> & nary_args) {
bool unsafe = true;
get_app_nary_args_core(nullptr, op, app_arg(app_fn(e)), nary_args, unsafe);
get_app_nary_args_core(nullptr, op, app_arg(e), nary_args, unsafe);
}
void get_app_nary_args(type_context & tctx, expr const & op, expr const & e, buffer<expr> & nary_args) {
bool unsafe = false;
get_app_nary_args_core(&tctx, op, app_arg(app_fn(e)), nary_args, unsafe);
get_app_nary_args_core(&tctx, op, app_arg(e), nary_args, unsafe);
}
optional<pair<expr, expr> > is_assoc(type_context & tctx, expr const & e) {
auto op = get_binary_op(e);
if (!op)
return optional<pair<expr, expr> >();
try {
expr assoc_class = mk_app(tctx, get_is_associative_name(), *op);
if (auto assoc_inst = tctx.mk_class_instance(assoc_class))
return optional<pair<expr, expr> >(mk_pair(mk_app(tctx, get_is_associative_op_assoc_name(), 3, *op, *assoc_inst), *op));
else
return optional<pair<expr, expr> >();
} catch (app_builder_exception ex) {
return optional<pair<expr, expr> >();
}
}
expr mk_congr_bin_op(abstract_type_context & tctx, expr const & H, expr const & arg1, expr const & arg2) {
expr eq = tctx.relaxed_whnf(tctx.infer(H));
expr A_op, op1, op2;
lean_verify(is_eq(eq, A_op, op1, op2));
lean_assert(is_pi(A_op));
expr A = binding_domain(A_op);
level lvl = get_level(tctx, A);
return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_op_name(), {lvl}), A, op1, op2, H, arg1, arg2});
}
expr mk_congr_bin_arg1(abstract_type_context & tctx, expr const & op, expr const & H1, expr const & arg2) {
expr eq = tctx.relaxed_whnf(tctx.infer(H1));
expr A, arg11, arg12;
lean_verify(is_eq(eq, A, arg11, arg12));
level lvl = get_level(tctx, A);
return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_arg1_name(), {lvl}), A, op, arg11, arg12, arg2, H1});
}
expr mk_congr_bin_arg2(abstract_type_context & tctx, expr const & op, expr const & arg1, expr const & H2) {
expr eq = tctx.relaxed_whnf(tctx.infer(H2));
expr A, arg21, arg22;
lean_verify(is_eq(eq, A, arg21, arg22));
level lvl = get_level(tctx, A);
return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_arg2_name(), {lvl}), A, op, arg1, arg21, arg22, H2});
}
expr mk_congr_bin_args(abstract_type_context & tctx, expr const & op, expr const & H1, expr const & H2) {
expr eq1 = tctx.relaxed_whnf(tctx.infer(H1));
expr eq2 = tctx.relaxed_whnf(tctx.infer(H2));
expr A, A0, arg11, arg12, arg21, arg22;
lean_verify(is_eq(eq1, A, arg11, arg12));
lean_verify(is_eq(eq2, A0, arg21, arg22));
lean_assert(tctx.is_def_eq(A, A0));
level lvl = get_level(tctx, A);
return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_args_name(), {lvl}), A, op, arg11, arg12, arg21, arg22, H1, H2});
}
expr mk_assoc_subst(abstract_type_context & tctx, expr const & old_op, expr const & new_op, expr const & pf_op, expr const & assoc) {
expr A_op = tctx.relaxed_whnf(tctx.infer(new_op));
lean_assert(is_pi(A_op));
expr A = binding_domain(A_op);
level lvl = get_level(tctx, A);
return ::lean::mk_app({mk_constant(get_simplifier_assoc_subst_name(), {lvl}), A, old_op, new_op, pf_op, assoc});
}
// flat macro
static name * g_flat_macro_name = nullptr;
static std::string * g_flat_opcode = nullptr;
class flat_macro_definition_cell : public macro_definition_cell {
void check_macro(expr const & m) const {
if (!is_macro(m) || macro_num_args(m) != 2)
throw exception(sstream() << "invalid 'flat' macro, incorrect number of arguments");
}
public:
flat_macro_definition_cell() {}
virtual name get_name() const override { return *g_flat_macro_name; }
virtual expr check_type(expr const & m, abstract_type_context &, bool) const override {
check_macro(m);
return macro_arg(m, 1);
}
virtual optional<expr> expand(expr const & m, abstract_type_context & tctx) const override {
check_macro(m);
expr const & assoc = macro_arg(m, 0);
expr const & thm = macro_arg(m, 1);
expr old_e = app_arg(app_fn(thm));
expr new_e = app_arg(thm);
optional<expr> op = get_binary_op(old_e);
lean_assert(op);
pair<expr, optional<expr>> r_assoc = flat_assoc(tctx, *op, assoc, old_e);
optional<expr> const & pf_of_assoc = r_assoc.second;
if (!pf_of_assoc)
return some_expr(mk_eq_refl(tctx, old_e));
else
return pf_of_assoc;
}
virtual void write(serializer & s) const override {
s.write_string(*g_flat_opcode);
}
virtual bool operator==(macro_definition_cell const & other) const override {
if (dynamic_cast<flat_macro_definition_cell const *>(&other)) {
return true;
} else {
return false;
}
}
virtual unsigned hash() const override {
return get_name().hash();
}
};
// Rewrite-assoc macro
static name * g_rewrite_assoc_macro_name = nullptr;
static std::string * g_rewrite_assoc_opcode = nullptr;
class rewrite_assoc_macro_definition_cell : public macro_definition_cell {
unsigned m_arg_idx;
unsigned m_num_patterns;
unsigned m_num_args;
void check_macro(expr const & m) const {
if (!is_macro(m) || macro_num_args(m) != 4)
throw exception(sstream() << "invalid 'rewrite_assoc' macro, incorrect number of arguments");
}
public:
rewrite_assoc_macro_definition_cell(unsigned arg_idx, unsigned num_patterns, unsigned num_args):
m_arg_idx(arg_idx), m_num_patterns(num_patterns), m_num_args(num_args) {}
virtual name get_name() const { return *g_rewrite_assoc_macro_name; }
virtual expr check_type(expr const & m, abstract_type_context &, bool) const {
check_macro(m);
return macro_arg(m, 1);
}
pair<expr, optional<expr> > group_args(expr const & e, unsigned pat_idx) const {
if (pat_idx + 1 < m_num_patterns) {
pair<expr, optional<expr> > p = group_args(app_arg(e), pat_idx + 1);
return mk_pair(mk_app(app_fn(e), p.first), p.second);
} else {
lean_assert(pat_idx + 1 == m_num_patterns);
if (m_arg_idx + m_num_patterns == m_num_args) {
return mk_pair(e, none_expr());
} else {
return mk_pair(app_arg(app_fn(e)), some_expr(app_arg(e)));
}
}
}
expr compute_pre_motive(abstract_type_context & tctx, expr const & e, expr & l, expr & lhs, unsigned i) const {
if (i == m_arg_idx) {
// (lhs, rest)
l = tctx.push_local(name("lhs"), tctx.infer(e));
pair<expr, optional<expr> > lhs_rest = group_args(e, 0);
lhs = lhs_rest.first;
if (lhs_rest.second)
return mk_app(app_fn(app_fn(e)), l, *lhs_rest.second);
else
return l;
} else {
return mk_app(app_fn(e), compute_pre_motive(tctx, app_arg(e), l, lhs, i+1));
}
}
virtual optional<expr> expand(expr const & m, abstract_type_context & tctx) const {
check_macro(m);
expr const & assoc = macro_arg(m, 0);
expr const & thm = macro_arg(m, 1);
/* expr const & step_rhs = macro_arg(m, 2); */
expr pf_of_step = macro_arg(m, 3);
expr const & old_e = app_arg(app_fn(thm));
/* expr const & new_e = app_arg(thm); */
expr op = app_fn(app_fn(old_e));
// Step 1: Re-arrange to group the args being rewritten
expr local;
expr step_lhs;
expr pre_motive = tctx.abstract_locals(compute_pre_motive(tctx, old_e, local, step_lhs, 0), 1, &local);
expr middle_e = instantiate(pre_motive, step_lhs);
expr motive = mk_lambda(mlocal_name(local), tctx.infer(local), mk_app(app_fn(app_fn(thm)), middle_e, pre_motive));
optional<expr> pf_flat = flat_assoc(tctx, op, assoc, middle_e).second;
simp_result r_middle;
if (pf_flat)
r_middle = simp_result(middle_e, mk_eq_symm(tctx, *pf_flat));
else
r_middle = simp_result(middle_e);
expr thm_of_step = tctx.infer(pf_of_step);
if (auto pf_step_lhs_not_flat = flat_assoc(tctx, op, assoc, app_arg(app_fn(thm_of_step))).second) {
// lemma needs to be flattened
expr l = tctx.push_local(name("pf_lhs"), tctx.infer(old_e));
expr motive = mk_lambda(mlocal_name(l), tctx.infer(l), tctx.abstract_locals(mk_app(app_fn(app_fn(thm)), l, app_arg(thm_of_step)), 1, &l));
pf_of_step = mk_eq_subst(tctx, motive, *pf_step_lhs_not_flat, pf_of_step);
}
if (r_middle.has_proof())
return some_expr(mk_eq_trans(tctx, r_middle.get_proof(), mk_eq_subst(tctx, motive, pf_of_step, mk_eq_refl(tctx, r_middle.get_new()))));
else
return some_expr(mk_eq_subst(tctx, motive, pf_of_step, mk_eq_refl(tctx, r_middle.get_new())));
}
virtual void write(serializer & s) const {
s.write_string(*g_rewrite_assoc_opcode);
s << m_arg_idx;
s << m_num_patterns;
s << m_num_args;
}
virtual bool operator==(macro_definition_cell const & other) const {
if (auto other_ptr = dynamic_cast<rewrite_assoc_macro_definition_cell const *>(&other)) {
return m_arg_idx == other_ptr->m_arg_idx && m_num_patterns == other_ptr->m_num_patterns
&& m_num_args == other_ptr->m_num_args;
} else {
return false;
}
}
virtual unsigned hash() const {
return ::lean::hash(m_arg_idx, ::lean::hash(m_num_patterns,
::lean::hash(m_num_args, get_name().hash())));
}
};
expr mk_flat_proof(expr const & assoc, expr const & thm) {
expr margs[3];
margs[0] = assoc;
margs[1] = thm;
macro_definition m(new flat_macro_definition_cell());
return mk_macro(m, 2, margs);
}
expr mk_flat_macro(unsigned num_args, expr const * args) {
lean_assert(num_args == 2);
macro_definition m(new flat_macro_definition_cell());
return mk_macro(m, num_args, args);
}
expr mk_rewrite_assoc_proof(expr const & assoc, expr const & thm, unsigned arg_idx, unsigned num_patterns,
unsigned num_args, expr const & step_rhs, expr const & pf_of_step) {
expr margs[4];
margs[0] = assoc;
margs[1] = thm;
margs[2] = step_rhs;
margs[3] = pf_of_step;
macro_definition m(new rewrite_assoc_macro_definition_cell(arg_idx, num_patterns, num_args));
return mk_macro(m, 4, margs);
}
expr mk_rewrite_assoc_macro(unsigned num_args, expr const * args, unsigned arg_idx, unsigned num_patterns, unsigned num_e_args) {
lean_assert(num_args == 4);
macro_definition m(new rewrite_assoc_macro_definition_cell(arg_idx, num_patterns, num_e_args));
return mk_macro(m, num_args, args);
}
// Setup and teardown
void initialize_simp_util() {
// flat macro
g_flat_macro_name = new name("flat");
g_flat_opcode = new std::string("FLAT");
register_macro_deserializer(*g_flat_opcode,
[](deserializer & /* d */, unsigned num, expr const * args) {
if (num != 2)
throw corrupted_stream_exception();
return mk_flat_macro(num, args);
});
// rewrite_assoc macro
g_rewrite_assoc_macro_name = new name("rewrite_assoc");
g_rewrite_assoc_opcode = new std::string("REWRITE_ASSOC");
register_macro_deserializer(*g_rewrite_assoc_opcode,
[](deserializer & d, unsigned num, expr const * args) {
if (num != 4)
throw corrupted_stream_exception();
unsigned arg_idx, num_patterns, num_e_args;
d >> arg_idx;
d >> num_patterns;
d >> num_e_args;
return mk_rewrite_assoc_macro(num, args, arg_idx, num_patterns, num_e_args);
});
}
void finalize_simp_util() {
// rewrite_assoc macro
delete g_rewrite_assoc_macro_name;
delete g_rewrite_assoc_opcode;
// flat macro
delete g_flat_macro_name;
delete g_flat_opcode;
}
}