diff --git a/src/library/blast/unit/unit_propagate.cpp b/src/library/blast/unit/unit_propagate.cpp index fd8065dbab..befe4cdaa1 100644 --- a/src/library/blast/unit/unit_propagate.cpp +++ b/src/library/blast/unit/unit_propagate.cpp @@ -9,6 +9,7 @@ Author: Daniel Selsam #include "library/constants.h" #include "library/util.h" #include "library/blast/blast.h" +#include "library/blast/trace.h" #include "library/blast/action_result.h" #include "library/blast/unit/unit_actions.h" #include "library/blast/proof_expr.h" @@ -20,8 +21,7 @@ Author: Daniel Selsam namespace lean { namespace blast { - -bool is_lemma(expr const & _type) { +static bool is_lemma(expr const & _type) { expr type = _type; bool has_antecedent = false; if (!is_prop(type)) return false; @@ -34,11 +34,20 @@ bool is_lemma(expr const & _type) { else return false; } -bool is_fact(expr const & type) { - return !is_lemma(type); +/** \brief We say \c type is a dependent lemma iff + - \c type is a proposition + - \c type is of the form (Pi (x : A), B) + - 'A' is a proposition + - 'B' depends on 'x' */ +static bool is_dep_lemma(expr const & type) { + return is_pi(type) && is_prop(type) && is_prop(binding_domain(type)) && !closed(binding_body(type)); } -expr flip(expr const & e) { +static bool is_fact(expr const & type) { + return !is_lemma(type) && !is_dep_lemma(type); +} + +static expr flip(expr const & e) { expr not_e; if (!blast::is_not(e, not_e)) { // we use whnf to make sure we get a uniform representation @@ -53,12 +62,15 @@ struct unit_branch_extension : public branch_extension { /* We map each lemma to the two facts that it is watching. */ rb_multi_map m_lemmas_to_facts; - /* We map each fact back to the lemma hypotheses that are watching it. */ + /* We map each fact (i.e., the type of the hypothesis) back to the lemma hypotheses that are watching it. */ rb_multi_map m_facts_to_lemmas; - /* We map each fact expression to its hypothesis. */ + /* We map each fact expression (i.e., the type of the hypothesis) to its hypothesis. */ rb_map m_facts; + /* We map each fact (i.e., the type of the hypothesis) back to the dependent lemma hypotheses that are watching it. */ + rb_multi_map m_facts_to_dep_lemmas; + unit_branch_extension() {} unit_branch_extension(unit_branch_extension const & b): m_lemmas_to_facts(b.m_lemmas_to_facts), @@ -77,9 +89,16 @@ struct unit_branch_extension : public branch_extension { unwatch(hidx, fact); }); } + } else if (is_dep_lemma(h.get_type())) { + expr fact_type = binding_domain(h.get_type()); + m_facts_to_lemmas.erase(fact_type, hidx); } else if (is_fact(h.get_type())) { m_facts.erase(h.get_type()); + // TODO(Leo): it is not clear to me why the following assertion should hold. + // I think we need + // m_facts_to_lemmas.erase(h.get_type()) lean_assert(!find_lemmas_watching_fact(h.get_type())); + m_facts_to_dep_lemmas.erase(h.get_type()); } } @@ -103,6 +122,9 @@ public: m_lemmas_to_facts.insert(lemma_hidx, fact_type); m_facts_to_lemmas.insert(fact_type, lemma_hidx); } + list const * find_dep_lemmas_watching_fact(expr const & fact_type) { + return m_facts_to_dep_lemmas.find(fact_type); + } }; void initialize_unit_propagate() { @@ -136,7 +158,7 @@ static unit_branch_extension & get_extension() { */ -bool can_propagate(expr const & _type, buffer & to_watch) { +static bool can_propagate(expr const & _type, buffer & to_watch) { lean_assert(is_lemma(_type)); expr type = _type; unsigned num_watching = 0; @@ -178,7 +200,7 @@ bool can_propagate(expr const & _type, buffer & to_watch) { return !missing_non_Prop; } -action_result unit_lemma(hypothesis_idx hidx, expr const & _type, expr const & _proof) { +static action_result unit_lemma(hypothesis_idx hidx, expr const & _type, expr const & _proof) { lean_assert(is_lemma(_type)); unit_branch_extension & ext = get_extension(); @@ -270,16 +292,51 @@ action_result unit_lemma(hypothesis_idx hidx, expr const & _type, expr const & _ return action_result::new_branch(); } -action_result unit_fact(expr const & type) { +static action_result unit_dep_lemma(hypothesis_idx hidx, expr type, expr proof) { + lean_assert(is_dep_lemma(type)); + unit_branch_extension & ext = get_extension(); + bool propagated = false; + while (is_pi(type)) { + expr d = binding_domain(type); + if (auto hidx = ext.find_fact(d)) { + propagated = true; + expr h = mk_href(*hidx); + proof = mk_app(proof, h); + type = instantiate(binding_body(type), h); + } else { + break; + } + } + if (propagated) { + curr_state().del_hypothesis(hidx); + curr_state().mk_hypothesis(type, proof); + return action_result::new_branch(); + } + lean_assert(is_pi(type)); + ext.m_facts_to_dep_lemmas.insert(binding_domain(type), hidx); + return action_result::failed(); +} + +static action_result unit_fact(expr const & type) { unit_branch_extension & ext = get_extension(); - list const * lemmas = ext.find_lemmas_watching_fact(type); - if (!lemmas) return action_result::failed(); bool success = false; - for_each(*lemmas, [&](hypothesis_idx const & hidx) { - hypothesis const & h = curr_state().get_hypothesis_decl(hidx); - action_result r = unit_lemma(hidx, whnf(h.get_type()), h.get_self()); - success = success || (r.get_kind() == action_result::NewBranch); - }); + /* non dependent lemmas */ + if (list const * lemmas = ext.find_lemmas_watching_fact(type)) { + for_each(*lemmas, [&](hypothesis_idx const & hidx) { + hypothesis const & h = curr_state().get_hypothesis_decl(hidx); + // TODO(Leo): it is not clear to me why we need whnf in the following statement. + action_result r = unit_lemma(hidx, whnf(h.get_type()), h.get_self()); + success = success || (r.get_kind() == action_result::NewBranch); + }); + } + /* dependent lemmas */ + if (list const * lemmas = ext.find_dep_lemmas_watching_fact(type)) { + for_each(*lemmas, [&](hypothesis_idx const & hidx) { + hypothesis const & h = curr_state().get_hypothesis_decl(hidx); + action_result r = unit_dep_lemma(hidx, whnf(h.get_type()), h.get_self()); + success = success || (r.get_kind() == action_result::NewBranch); + }); + } if (success) return action_result::new_branch(); else return action_result::failed(); } @@ -288,6 +345,7 @@ action_result unit_propagate(unsigned hidx) { hypothesis const & h = curr_state().get_hypothesis_decl(hidx); expr type = whnf(h.get_type()); if (is_lemma(type)) return unit_lemma(hidx, type, h.get_self()); + else if (is_dep_lemma(type)) return unit_dep_lemma(hidx, type, h.get_self()); else if (is_fact(type)) return unit_fact(type); else return action_result::failed(); } diff --git a/tests/lean/run/blast_unit_dep.lean b/tests/lean/run/blast_unit_dep.lean new file mode 100644 index 0000000000..9a4f323015 --- /dev/null +++ b/tests/lean/run/blast_unit_dep.lean @@ -0,0 +1,22 @@ +constant p : nat → Prop +constant q : Π a, p a → Prop + +set_option blast.strategy "unit" + +example (a : nat) (h₁ : p a) (h₂ : ∀ h : p a, q a h) : q a h₁ := +by blast + +example (a : nat) (h₂ : ∀ h : p a, q a h) (h₁ : p a) : q a h₁ := +by blast + +example (a b : nat) (H : ∀ (p₁ : p a) (p₂ : p b), q b p₂ → q a p₁) (h₁ : p a) (h₂ : p b) : q b h₂ → q a h₁ := +by blast + +example (a b : nat) (h₂ : p b) (H : ∀ (p₁ : p a) (p₂ : p b), q b p₂ → q a p₁) (h₁ : p a) : q b h₂ → q a h₁ := +by blast + +example (a b : nat) (h₂ : p b) (h₁ : p a) (H : ∀ (p₁ : p a) (p₂ : p b), q b p₂ → q a p₁) : q b h₂ → q a h₁ := +by blast + +example (a b : nat) (h₁ : p a) (H : ∀ (p₁ : p a) (p₂ : p b), q b p₂ → q a p₁) (h₂ : p b) : q b h₂ → q a h₁ := +by blast