/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Daniel Selsam */ #include #include "util/sstream.h" #include "kernel/expr.h" #include "kernel/instantiate.h" #include "library/kernel_serializer.h" #include "library/util.h" #include "library/constants.h" #include "library/app_builder.h" #include "library/tactic/ac_tactics.h" #include "library/tactic/simplifier/util.h" namespace lean { static void get_app_nary_args_core(type_context * tctx_ptr, expr const & op, expr const & e, buffer & nary_args, bool unsafe) { lean_assert(unsafe || tctx_ptr); auto next_op = get_binary_op(e); if (next_op && (unsafe ? (get_app_fn(*next_op) == get_app_fn(op)) : tctx_ptr->is_def_eq(*next_op, op))) { get_app_nary_args_core(tctx_ptr, op, app_arg(app_fn(e)), nary_args, unsafe); get_app_nary_args_core(tctx_ptr, op, app_arg(e), nary_args, unsafe); } else { nary_args.push_back(e); } } void unsafe_get_app_nary_args(expr const & op, expr const & e, buffer & nary_args) { bool unsafe = true; get_app_nary_args_core(nullptr, op, app_arg(app_fn(e)), nary_args, unsafe); get_app_nary_args_core(nullptr, op, app_arg(e), nary_args, unsafe); } void get_app_nary_args(type_context & tctx, expr const & op, expr const & e, buffer & nary_args) { bool unsafe = false; get_app_nary_args_core(&tctx, op, app_arg(app_fn(e)), nary_args, unsafe); get_app_nary_args_core(&tctx, op, app_arg(e), nary_args, unsafe); } optional > is_assoc(type_context & tctx, expr const & e) { auto op = get_binary_op(e); if (!op) return optional >(); try { expr assoc_class = mk_app(tctx, get_is_associative_name(), *op); if (auto assoc_inst = tctx.mk_class_instance(assoc_class)) return optional >(mk_pair(mk_app(tctx, get_is_associative_op_assoc_name(), 3, *op, *assoc_inst), *op)); else return optional >(); } catch (app_builder_exception ex) { return optional >(); } } expr mk_congr_bin_op(abstract_type_context & tctx, expr const & H, expr const & arg1, expr const & arg2) { expr eq = tctx.relaxed_whnf(tctx.infer(H)); expr A_op, op1, op2; lean_verify(is_eq(eq, A_op, op1, op2)); lean_assert(is_pi(A_op)); expr A = binding_domain(A_op); level lvl = get_level(tctx, A); return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_op_name(), {lvl}), A, op1, op2, H, arg1, arg2}); } expr mk_congr_bin_arg1(abstract_type_context & tctx, expr const & op, expr const & H1, expr const & arg2) { expr eq = tctx.relaxed_whnf(tctx.infer(H1)); expr A, arg11, arg12; lean_verify(is_eq(eq, A, arg11, arg12)); level lvl = get_level(tctx, A); return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_arg1_name(), {lvl}), A, op, arg11, arg12, arg2, H1}); } expr mk_congr_bin_arg2(abstract_type_context & tctx, expr const & op, expr const & arg1, expr const & H2) { expr eq = tctx.relaxed_whnf(tctx.infer(H2)); expr A, arg21, arg22; lean_verify(is_eq(eq, A, arg21, arg22)); level lvl = get_level(tctx, A); return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_arg2_name(), {lvl}), A, op, arg1, arg21, arg22, H2}); } expr mk_congr_bin_args(abstract_type_context & tctx, expr const & op, expr const & H1, expr const & H2) { expr eq1 = tctx.relaxed_whnf(tctx.infer(H1)); expr eq2 = tctx.relaxed_whnf(tctx.infer(H2)); expr A, A0, arg11, arg12, arg21, arg22; lean_verify(is_eq(eq1, A, arg11, arg12)); lean_verify(is_eq(eq2, A0, arg21, arg22)); lean_assert(tctx.is_def_eq(A, A0)); level lvl = get_level(tctx, A); return ::lean::mk_app({mk_constant(get_simplifier_congr_bin_args_name(), {lvl}), A, op, arg11, arg12, arg21, arg22, H1, H2}); } expr mk_assoc_subst(abstract_type_context & tctx, expr const & old_op, expr const & new_op, expr const & pf_op, expr const & assoc) { expr A_op = tctx.relaxed_whnf(tctx.infer(new_op)); lean_assert(is_pi(A_op)); expr A = binding_domain(A_op); level lvl = get_level(tctx, A); return ::lean::mk_app({mk_constant(get_simplifier_assoc_subst_name(), {lvl}), A, old_op, new_op, pf_op, assoc}); } // flat macro static name * g_flat_macro_name = nullptr; static std::string * g_flat_opcode = nullptr; class flat_macro_definition_cell : public macro_definition_cell { void check_macro(expr const & m) const { if (!is_macro(m) || macro_num_args(m) != 2) throw exception(sstream() << "invalid 'flat' macro, incorrect number of arguments"); } public: flat_macro_definition_cell() {} virtual name get_name() const override { return *g_flat_macro_name; } virtual expr check_type(expr const & m, abstract_type_context &, bool) const override { check_macro(m); return macro_arg(m, 1); } virtual optional expand(expr const & m, abstract_type_context & tctx) const override { check_macro(m); expr const & assoc = macro_arg(m, 0); expr const & thm = macro_arg(m, 1); expr old_e = app_arg(app_fn(thm)); expr new_e = app_arg(thm); optional op = get_binary_op(old_e); lean_assert(op); pair> r_assoc = flat_assoc(tctx, *op, assoc, old_e); optional const & pf_of_assoc = r_assoc.second; if (!pf_of_assoc) return some_expr(mk_eq_refl(tctx, old_e)); else return pf_of_assoc; } virtual void write(serializer & s) const override { s.write_string(*g_flat_opcode); } virtual bool operator==(macro_definition_cell const & other) const override { if (dynamic_cast(&other)) { return true; } else { return false; } } virtual unsigned hash() const override { return get_name().hash(); } }; // Rewrite-assoc macro static name * g_rewrite_assoc_macro_name = nullptr; static std::string * g_rewrite_assoc_opcode = nullptr; class rewrite_assoc_macro_definition_cell : public macro_definition_cell { unsigned m_arg_idx; unsigned m_num_patterns; unsigned m_num_args; void check_macro(expr const & m) const { if (!is_macro(m) || macro_num_args(m) != 4) throw exception(sstream() << "invalid 'rewrite_assoc' macro, incorrect number of arguments"); } public: rewrite_assoc_macro_definition_cell(unsigned arg_idx, unsigned num_patterns, unsigned num_args): m_arg_idx(arg_idx), m_num_patterns(num_patterns), m_num_args(num_args) {} virtual name get_name() const { return *g_rewrite_assoc_macro_name; } virtual expr check_type(expr const & m, abstract_type_context &, bool) const { check_macro(m); return macro_arg(m, 1); } pair > group_args(expr const & e, unsigned pat_idx) const { if (pat_idx + 1 < m_num_patterns) { pair > p = group_args(app_arg(e), pat_idx + 1); return mk_pair(mk_app(app_fn(e), p.first), p.second); } else { lean_assert(pat_idx + 1 == m_num_patterns); if (m_arg_idx + m_num_patterns == m_num_args) { return mk_pair(e, none_expr()); } else { return mk_pair(app_arg(app_fn(e)), some_expr(app_arg(e))); } } } expr compute_pre_motive(abstract_type_context & tctx, expr const & e, expr & l, expr & lhs, unsigned i) const { if (i == m_arg_idx) { // (lhs, rest) l = tctx.push_local(name("lhs"), tctx.infer(e)); pair > lhs_rest = group_args(e, 0); lhs = lhs_rest.first; if (lhs_rest.second) return mk_app(app_fn(app_fn(e)), l, *lhs_rest.second); else return l; } else { return mk_app(app_fn(e), compute_pre_motive(tctx, app_arg(e), l, lhs, i+1)); } } virtual optional expand(expr const & m, abstract_type_context & tctx) const { check_macro(m); expr const & assoc = macro_arg(m, 0); expr const & thm = macro_arg(m, 1); /* expr const & step_rhs = macro_arg(m, 2); */ expr pf_of_step = macro_arg(m, 3); expr const & old_e = app_arg(app_fn(thm)); /* expr const & new_e = app_arg(thm); */ expr op = app_fn(app_fn(old_e)); // Step 1: Re-arrange to group the args being rewritten expr local; expr step_lhs; expr pre_motive = tctx.abstract_locals(compute_pre_motive(tctx, old_e, local, step_lhs, 0), 1, &local); expr middle_e = instantiate(pre_motive, step_lhs); expr motive = mk_lambda(mlocal_name(local), tctx.infer(local), mk_app(app_fn(app_fn(thm)), middle_e, pre_motive)); optional pf_flat = flat_assoc(tctx, op, assoc, middle_e).second; simp_result r_middle; if (pf_flat) r_middle = simp_result(middle_e, mk_eq_symm(tctx, *pf_flat)); else r_middle = simp_result(middle_e); expr thm_of_step = tctx.infer(pf_of_step); if (auto pf_step_lhs_not_flat = flat_assoc(tctx, op, assoc, app_arg(app_fn(thm_of_step))).second) { // lemma needs to be flattened expr l = tctx.push_local(name("pf_lhs"), tctx.infer(old_e)); expr motive = mk_lambda(mlocal_name(l), tctx.infer(l), tctx.abstract_locals(mk_app(app_fn(app_fn(thm)), l, app_arg(thm_of_step)), 1, &l)); pf_of_step = mk_eq_subst(tctx, motive, *pf_step_lhs_not_flat, pf_of_step); } if (r_middle.has_proof()) return some_expr(mk_eq_trans(tctx, r_middle.get_proof(), mk_eq_subst(tctx, motive, pf_of_step, mk_eq_refl(tctx, r_middle.get_new())))); else return some_expr(mk_eq_subst(tctx, motive, pf_of_step, mk_eq_refl(tctx, r_middle.get_new()))); } virtual void write(serializer & s) const { s.write_string(*g_rewrite_assoc_opcode); s << m_arg_idx; s << m_num_patterns; s << m_num_args; } virtual bool operator==(macro_definition_cell const & other) const { if (auto other_ptr = dynamic_cast(&other)) { return m_arg_idx == other_ptr->m_arg_idx && m_num_patterns == other_ptr->m_num_patterns && m_num_args == other_ptr->m_num_args; } else { return false; } } virtual unsigned hash() const { return ::lean::hash(m_arg_idx, ::lean::hash(m_num_patterns, ::lean::hash(m_num_args, get_name().hash()))); } }; expr mk_flat_proof(expr const & assoc, expr const & thm) { expr margs[3]; margs[0] = assoc; margs[1] = thm; macro_definition m(new flat_macro_definition_cell()); return mk_macro(m, 2, margs); } expr mk_flat_macro(unsigned num_args, expr const * args) { lean_assert(num_args == 2); macro_definition m(new flat_macro_definition_cell()); return mk_macro(m, num_args, args); } expr mk_rewrite_assoc_proof(expr const & assoc, expr const & thm, unsigned arg_idx, unsigned num_patterns, unsigned num_args, expr const & step_rhs, expr const & pf_of_step) { expr margs[4]; margs[0] = assoc; margs[1] = thm; margs[2] = step_rhs; margs[3] = pf_of_step; macro_definition m(new rewrite_assoc_macro_definition_cell(arg_idx, num_patterns, num_args)); return mk_macro(m, 4, margs); } expr mk_rewrite_assoc_macro(unsigned num_args, expr const * args, unsigned arg_idx, unsigned num_patterns, unsigned num_e_args) { lean_assert(num_args == 4); macro_definition m(new rewrite_assoc_macro_definition_cell(arg_idx, num_patterns, num_e_args)); return mk_macro(m, num_args, args); } // Setup and teardown void initialize_simp_util() { // flat macro g_flat_macro_name = new name("flat"); g_flat_opcode = new std::string("FLAT"); register_macro_deserializer(*g_flat_opcode, [](deserializer & /* d */, unsigned num, expr const * args) { if (num != 2) throw corrupted_stream_exception(); return mk_flat_macro(num, args); }); // rewrite_assoc macro g_rewrite_assoc_macro_name = new name("rewrite_assoc"); g_rewrite_assoc_opcode = new std::string("REWRITE_ASSOC"); register_macro_deserializer(*g_rewrite_assoc_opcode, [](deserializer & d, unsigned num, expr const * args) { if (num != 4) throw corrupted_stream_exception(); unsigned arg_idx, num_patterns, num_e_args; d >> arg_idx; d >> num_patterns; d >> num_e_args; return mk_rewrite_assoc_macro(num, args, arg_idx, num_patterns, num_e_args); }); } void finalize_simp_util() { // rewrite_assoc macro delete g_rewrite_assoc_macro_name; delete g_rewrite_assoc_opcode; // flat macro delete g_flat_macro_name; delete g_flat_opcode; } }