From 7fa2b7cace836a907e476a1f607e060451f4f5de Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 29 Nov 2015 06:40:19 -0700 Subject: [PATCH] feat(library/blast/forward/ematch): ematching skeleton --- src/library/blast/congruence_closure.cpp | 8 + src/library/blast/congruence_closure.h | 3 + src/library/blast/forward/CMakeLists.txt | 3 +- src/library/blast/forward/ematch.cpp | 300 ++++++++++++++++++++++ src/library/blast/forward/ematch.h | 15 ++ src/library/blast/forward/init_module.cpp | 3 + src/library/blast/forward/pattern.cpp | 4 + src/library/blast/forward/pattern.h | 4 + 8 files changed, 339 insertions(+), 1 deletion(-) create mode 100644 src/library/blast/forward/ematch.cpp create mode 100644 src/library/blast/forward/ematch.h diff --git a/src/library/blast/congruence_closure.cpp b/src/library/blast/congruence_closure.cpp index a1912d9f9a..06f0a93c9d 100644 --- a/src/library/blast/congruence_closure.cpp +++ b/src/library/blast/congruence_closure.cpp @@ -1256,6 +1256,14 @@ expr congruence_closure::get_next(name const & R, expr const & e) const { } } +unsigned congruence_closure::get_mt(name const & R, expr const & e) const { + if (auto n = m_entries.find(eqc_key(R, e))) { + return n->m_mt; + } else { + return m_gmt; + } +} + void congruence_closure::freeze_partitions() { m_froze_partitions = true; entries new_entries; diff --git a/src/library/blast/congruence_closure.h b/src/library/blast/congruence_closure.h index e62107f591..f221963a0a 100644 --- a/src/library/blast/congruence_closure.h +++ b/src/library/blast/congruence_closure.h @@ -216,6 +216,9 @@ public: void inc_gmt() { m_gmt++; } + unsigned get_gmt() const { return m_gmt; } + unsigned get_mt(name const & R, expr const & e) const; + /** \brief dump for debugging purposes. */ void display() const; void display_eqcs() const; diff --git a/src/library/blast/forward/CMakeLists.txt b/src/library/blast/forward/CMakeLists.txt index 0ad1d170f7..d13aaeeb5c 100644 --- a/src/library/blast/forward/CMakeLists.txt +++ b/src/library/blast/forward/CMakeLists.txt @@ -1 +1,2 @@ -add_library(forward OBJECT init_module.cpp forward_extension.cpp qcf.cpp pattern.cpp) +add_library(forward OBJECT init_module.cpp forward_extension.cpp qcf.cpp pattern.cpp + ematch.cpp) diff --git a/src/library/blast/forward/ematch.cpp b/src/library/blast/forward/ematch.cpp new file mode 100644 index 0000000000..2b8077fb06 --- /dev/null +++ b/src/library/blast/forward/ematch.cpp @@ -0,0 +1,300 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "library/constants.h" +#include "library/blast/blast.h" +#include "library/blast/trace.h" +#include "library/blast/congruence_closure.h" +#include "library/blast/forward/pattern.h" + +namespace lean { +namespace blast { +/* +When a hypothesis hidx is activated: +1- Traverse its type and for each f-application. + If it is the first f-application found, and f is a constant then + retrieve lemmas which contain a multi-pattern starting with f. + +2- If hypothesis is a proposition and a quantifier, +try to create a hi-lemma for it, and add it to +set of recently activated hi_lemmas + +E-match round action + +1- For each active hi-lemma L, and mulit-pattern P, + If L has been recently activated, then we ematch ignoring + gmt. + + If L has been processed before, we try to ematch starting + at each each element of the multi-pattern. + We only consider the head f-applications that have a mt + equal to gmt + +*/ +typedef rb_tree expr_set; +typedef rb_tree hi_lemma_set; +static unsigned g_ext_id = 0; +struct ematch_branch_extension : public branch_extension { + hi_lemma_set m_lemmas; + hi_lemma_set m_new_lemmas; + rb_map m_apps; + name_set m_initialized; + + ematch_branch_extension() {} + ematch_branch_extension(ematch_branch_extension const &) {} + + void collect_apps(expr const & e) { + switch (e.kind()) { + case expr_kind::Var: case expr_kind::Sort: + case expr_kind::Constant: case expr_kind::Meta: + case expr_kind::Local: case expr_kind::Lambda: + break; + case expr_kind::Pi: + if (is_arrow(e) && is_prop(e)) { + collect_apps(binding_domain(e)); + collect_apps(binding_body(e)); + } + break; + case expr_kind::Macro: + for (unsigned i = 0; i < macro_num_args(e); i++) + collect_apps(macro_arg(e, i)); + break; + case expr_kind::App: { + buffer args; + expr const & f = get_app_args(e, args); + if (is_constant(f) && !m_initialized.contains(const_name(f))) { + m_initialized.insert(const_name(f)); + if (auto lemmas = get_hi_lemma_index(env()).find(const_name(f))) { + for (hi_lemma const & lemma : *lemmas) { + m_new_lemmas.insert(lemma); + } + } + } + if ((is_constant(f) && !is_no_pattern(env(), const_name(f))) || + (is_local(f))) { + expr_set s; + if (auto old_s = m_apps.find(f)) + s = *old_s; + s.insert(e); + m_apps.insert(f, s); + } + for (expr const & arg : args) { + collect_apps(arg); + } + break; + }} + } + + void register_lemma(hypothesis const & h) { + if (is_pi(h.get_type()) && !is_arrow(h.get_type())) { + blast_tmp_type_context ctx; + try { + m_new_lemmas.insert(mk_hi_lemma(*ctx, h.get_self())); + } catch (exception &) {} + } + } + + virtual ~ematch_branch_extension() {} + virtual branch_extension * clone() override { return new ematch_branch_extension(*this); } + virtual void initialized() override {} + virtual void hypothesis_activated(hypothesis const & h, hypothesis_idx) override { + collect_apps(h.get_type()); + register_lemma(h); + } + virtual void hypothesis_deleted(hypothesis const &, hypothesis_idx) override {} + virtual void target_updated() override { collect_apps(curr_state().get_target()); } +}; + +void initialize_ematch() { + g_ext_id = register_branch_extension(new ematch_branch_extension()); +} + +void finalize_ematch() {} + +struct ematch_fn { + ematch_branch_extension & m_ext; + blast_tmp_type_context m_ctx; + congruence_closure & m_cc; + + enum frame_kind { DefEqOnly, Match, Continue }; + + typedef std::tuple entry; + typedef list state; + typedef list choice; + + state m_state; + buffer m_choice_stack; + + bool m_new_instances; + + ematch_fn(): + m_ext(static_cast(curr_state().get_extension(g_ext_id))), + m_cc(get_cc()), + m_new_instances(false) { + } + + bool is_done() const { + return !m_state; + } + + bool is_eqv(name const & R, expr const & p, expr const & t) { + if (!has_expr_metavar(p)) + return m_cc.is_eqv(R, p, t) || m_ctx->is_def_eq(p, t); + else + return m_ctx->is_def_eq(p, t); + } + + bool process_match(name const & R, expr const & p, expr const & t) { + if (!is_app(p)) + return is_eqv(R, p, t); + buffer p_args; + expr const & fn = get_app_args(p, p_args); + if (m_ctx->is_mvar(fn)) + return is_eqv(R, p, t); + buffer candidates; + expr it = t; + do { + if (m_cc.is_congr_root(R, t) && m_ctx->is_def_eq(get_app_fn(it), fn) && + get_app_num_args(it) == p_args.size()) { + candidates.push_back(it); + } + it = m_cc.get_next(R, it); + } while (it != t); + if (candidates.empty()) + return false; + optional lemma = mk_ext_congr_lemma(R, fn, p_args.size()); + if (!lemma) + return false; + buffer new_states; + for (expr const & c : candidates) { + buffer c_args; + get_app_args(c, c_args); + lean_assert(c_args.size() == p_args.size()); + state new_state = m_state; + auto const * r_names = &lemma->m_rel_names; + for (unsigned i = 0; i < p_args.size(); i++) { + lean_assert(*r_names); + if (auto Rc = head(*r_names)) { + new_state = cons(entry(*Rc, Match, p_args[i], c_args[i]), new_state); + + } else { + new_state = cons(entry(get_eq_name(), DefEqOnly, p_args[i], c_args[i]), new_state); + } + r_names = &tail(*r_names); + } + new_states.push_back(new_state); + } + lean_assert(candidates.size() == new_states.size()); + if (candidates.size() == 1) { + m_state = new_states[0]; + return true; + } else { + m_state = new_states.back(); + new_states.pop_back(); + choice c = to_list(new_states); + m_choice_stack.push_back(c); + m_ctx->push(); + return true; + } + } + + bool process_continue(expr const &) { + // TODO(Leo): + return false; + } + + bool process_next() { + lean_assert(!is_done()); + name R; frame_kind kind; expr p, t; + std::tie(R, kind, p, t) = head(m_state); + m_state = tail(m_state); + switch (kind) { + case DefEqOnly: + return m_ctx->is_def_eq(p, t); + case Match: + return process_match(R, p, t); + case Continue: + return process_continue(p); + } + lean_unreachable(); + } + + bool match() { + // TODO(Leo) + return false; + } + + void instantiate_lemma_using(hi_lemma const & lemma, buffer const & ps, bool filter) { + expr const & p0 = ps[0]; + expr const & f = get_app_fn(p0); + name const & R = is_prop(p0) ? get_iff_name() : get_eq_name(); + unsigned gmt = m_cc.get_gmt(); + if (auto s = m_ext.m_apps.find(f)) { + s->for_each([&](expr const & t) { + if (m_cc.is_congr_root(R, t) && (!filter || m_cc.get_mt(R, t) == gmt)) { + m_ctx->set_next_uvar_idx(lemma.m_num_uvars); + m_ctx->set_next_mvar_idx(lemma.m_num_mvars); + state s; + unsigned i = ps.size(); + while (i > 1) { + --i; + s = cons(entry(name(), Continue, ps[i], expr()), s); + } + s = cons(entry(R, Match, p0, t), s); + diagnostic(env(), ios()) << "ematch " << ppb(p0) << " =?= " << ppb(t) << "\n"; + if (match()) { + // TODO(Leo): add instance + } + } + }); + } + } + + void instantiate_lemma_using(hi_lemma const & lemma, multi_pattern const & mp, bool filter) { + buffer ps; + to_buffer(mp, ps); + if (filter) { + for (unsigned i = 0; i < ps.size(); i++) { + std::swap(ps[0], ps[i]); + instantiate_lemma_using(lemma, ps, filter); + std::swap(ps[0], ps[i]); + } + } else { + instantiate_lemma_using(lemma, ps, filter); + } + } + + void instantiate_lemma(hi_lemma const & lemma, bool filter) { + for (multi_pattern const & mp : lemma.m_multi_patterns) { + instantiate_lemma_using(lemma, mp, filter); + } + } + + /* (Try to) instantiate lemmas in \c s. If \c filter is true, then use gmt optimization. */ + void instantiate_lemmas(hi_lemma_set const & s, bool filter) { + s.for_each([&](hi_lemma const & l) { + instantiate_lemma(l, filter); + }); + } + + action_result operator()() { + instantiate_lemmas(m_ext.m_new_lemmas, false); + instantiate_lemmas(m_ext.m_lemmas, true); + m_ext.m_lemmas.merge(m_ext.m_new_lemmas); + m_ext.m_new_lemmas = hi_lemma_set(); + m_cc.inc_gmt(); + if (m_new_instances) { + return action_result::new_branch(); + } else { + return action_result::failed(); + } + } +}; + +action_result ematch_action() { + return ematch_fn()(); +} +}} diff --git a/src/library/blast/forward/ematch.h b/src/library/blast/forward/ematch.h new file mode 100644 index 0000000000..5f1d206c8c --- /dev/null +++ b/src/library/blast/forward/ematch.h @@ -0,0 +1,15 @@ +/* +Copyright (c) 2015 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "library/blast/action_result.h" + +namespace lean { +namespace blast { +action_result ematch_action(); +void initialize_ematch(); +void finalize_ematch(); +}} diff --git a/src/library/blast/forward/init_module.cpp b/src/library/blast/forward/init_module.cpp index 5b5ba77d7f..d446108001 100644 --- a/src/library/blast/forward/init_module.cpp +++ b/src/library/blast/forward/init_module.cpp @@ -6,6 +6,7 @@ Author: Daniel Selsam #include "library/blast/forward/init_module.h" #include "library/blast/forward/forward_extension.h" #include "library/blast/forward/pattern.h" +#include "library/blast/forward/ematch.h" namespace lean { namespace blast { @@ -13,9 +14,11 @@ namespace blast { void initialize_forward_module() { initialize_forward_extension(); initialize_pattern(); + initialize_ematch(); } void finalize_forward_module() { + finalize_ematch(); finalize_pattern(); finalize_forward_extension(); } diff --git a/src/library/blast/forward/pattern.cpp b/src/library/blast/forward/pattern.cpp index 638776e72b..544352de8e 100644 --- a/src/library/blast/forward/pattern.cpp +++ b/src/library/blast/forward/pattern.cpp @@ -656,6 +656,10 @@ hi_lemma const * get_hi_lemma(environment const & env, name const & c) { return hi_ext::get_state(env).m_name_to_lemma.find(c); } +hi_lemmas get_hi_lemma_index(environment const & env) { + return hi_ext::get_state(env).m_lemmas; +} + void initialize_pattern() { g_hi_name = new name("hi"); g_key = new std::string("HI"); diff --git a/src/library/blast/forward/pattern.h b/src/library/blast/forward/pattern.h index 53e09e9ab2..68f0ec70b0 100644 --- a/src/library/blast/forward/pattern.h +++ b/src/library/blast/forward/pattern.h @@ -7,6 +7,7 @@ Author: Leonardo de Moura #pragma once #include "util/rb_multi_map.h" #include "kernel/environment.h" +#include "library/expr_lt.h" #include "library/tmp_type_context.h" #ifndef LEAN_HI_LEMMA_DEFAULT_PRIORITY @@ -45,6 +46,9 @@ struct hi_lemma { inline bool operator==(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop == l2.m_prop; } inline bool operator!=(hi_lemma const & l1, hi_lemma const & l2) { return l1.m_prop != l2.m_prop; } +struct hi_lemma_cmp { + int operator()(hi_lemma const & l1, hi_lemma const & l2) const { return expr_quick_cmp()(l1.m_prop, l2.m_prop); } +}; /** \brief Mapping c -> S, where c is a constant name and S is a set of hi_lemmas that contain a pattern where the head symbol is c. */