From 7e4b79b2214aeefddbaa056b5153d6537fd17ff4 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Fri, 6 Jan 2017 00:24:25 -0800 Subject: [PATCH] feat(library/tactic/smt/smt_state): add ematch_using tactic --- library/init/meta/smt/ematch.lean | 3 + library/init/meta/smt/interactive.lean | 28 ++++++ library/init/meta/smt/smt_tactic.lean | 18 +++- src/library/tactic/smt/hinst_lemmas.h | 4 +- src/library/tactic/smt/smt_state.cpp | 126 ++++++++++++++----------- src/library/tactic/smt/smt_state.h | 1 + tests/lean/run/smt_ematch1.lean | 16 ++++ 7 files changed, 136 insertions(+), 60 deletions(-) diff --git a/library/init/meta/smt/ematch.lean b/library/init/meta/smt/ematch.lean index e2f3b7bbc9..e8ecc8b544 100644 --- a/library/init/meta/smt/ematch.lean +++ b/library/init/meta/smt/ematch.lean @@ -35,6 +35,9 @@ meta constant hinst_lemmas.add : hinst_lemmas → hinst_lemma → hi meta constant hinst_lemmas.fold {α : Type} : hinst_lemmas → α → (hinst_lemma → α → α) → α meta constant hinst_lemmas.merge : hinst_lemmas → hinst_lemmas → hinst_lemmas +meta def mk_hinst_singleton : hinst_lemma → hinst_lemmas := +hinst_lemmas.add hinst_lemmas.mk + meta def hinst_lemmas.pp (s : hinst_lemmas) : tactic format := let tac := s^.fold (return format.nil) (λ h tac, do diff --git a/library/init/meta/smt/interactive.lean b/library/init/meta/smt/interactive.lean index f98ed3e9e9..4dfb4355f0 100644 --- a/library/init/meta/smt/interactive.lean +++ b/library/init/meta/smt/interactive.lean @@ -147,6 +147,34 @@ add_eqn_lemmas_for_core reducible ids meta def add_eqn_lemmas (ids : raw_ident_list) : smt_tactic unit := add_eqn_lemmas_for ids +private meta def add_hinst_lemma_from_name (md : transparency) (lhs_lemma : bool) (n : name) (hs : hinst_lemmas) : smt_tactic hinst_lemmas := +do { + e ← resolve_name n, + match e with + | expr.const n _ := do h ← hinst_lemma.mk_from_decl_core md n lhs_lemma, return $ hs^.add h + | expr.local_const _ _ _ _ := do h ← hinst_lemma.mk_core md e lhs_lemma, return $ hs^.add h + | _ := fail "failed" + end +} +<|> +fail ("invalid ematch lemma '" ++ to_string n ++ "'") + +private meta def add_hinst_lemma_from_pexpr (md : transparency) (lhs_lemma : bool) (p : pexpr) (hs : hinst_lemmas) : smt_tactic hinst_lemmas := +let e := pexpr.to_raw_expr p in +match e with +| (expr.const c []) := add_hinst_lemma_from_name md lhs_lemma c hs +| (expr.local_const c _ _ _) := add_hinst_lemma_from_name md lhs_lemma c hs +| _ := do new_e ← to_expr p, h ← hinst_lemma.mk_core md new_e lhs_lemma, return $ hs^.add h +end + +private meta def add_hinst_lemmas_from_pexprs (md : transparency) (lhs_lemma : bool) : list pexpr → hinst_lemmas → smt_tactic hinst_lemmas +| [] hs := return hs +| (p::ps) hs := do hs₁ ← add_hinst_lemma_from_pexpr md lhs_lemma p hs, add_hinst_lemmas_from_pexprs ps hs₁ + +meta def ematch_using (l : qexpr_list_or_qexpr0) : smt_tactic unit := +do hs ← add_hinst_lemmas_from_pexprs reducible ff l hinst_lemmas.mk, + smt_tactic.ematch_using hs + meta def try (t : itactic) : smt_tactic unit := smt_tactic.try t diff --git a/library/init/meta/smt/smt_tactic.lean b/library/init/meta/smt/smt_tactic.lean index 2694dab9d0..1f636337e1 100644 --- a/library/init/meta/smt/smt_tactic.lean +++ b/library/init/meta/smt/smt_tactic.lean @@ -101,14 +101,26 @@ open tactic (transparency) meta constant intros_core : bool → smt_tactic unit meta constant close : smt_tactic unit meta constant ematch_core : (expr → bool) → smt_tactic unit -meta constant add_ematch_lemma_core : transparency → bool → expr → smt_tactic unit -meta constant add_ematch_lemma_from_decl_core : transparency → bool → name → smt_tactic unit -meta constant add_ematch_eqn_lemmas_for_core : transparency → name → smt_tactic unit +meta constant ematch_using : hinst_lemmas → smt_tactic unit +meta constant mk_ematch_eqn_lemmas_for_core : transparency → name → smt_tactic hinst_lemmas meta constant to_cc_state : smt_tactic cc_state meta constant to_em_state : smt_tactic ematch_state meta constant preprocess : expr → smt_tactic (expr × expr) meta constant get_lemmas : smt_tactic hinst_lemmas meta constant set_lemmas : hinst_lemmas → smt_tactic unit +meta constant add_lemmas : hinst_lemmas → smt_tactic unit + +meta def add_ematch_lemma_core (md : transparency) (as_simp : bool) (e : expr) : smt_tactic unit := +do h ← hinst_lemma.mk_core md e as_simp, + add_lemmas (mk_hinst_singleton h) + +meta def add_ematch_lemma_from_decl_core (md : transparency) (as_simp : bool) (n : name) : smt_tactic unit := +do h ← hinst_lemma.mk_from_decl_core md n as_simp, + add_lemmas (mk_hinst_singleton h) + +meta def add_ematch_eqn_lemmas_for_core (md : transparency) (n : name) : smt_tactic unit := +do hs ← mk_ematch_eqn_lemmas_for_core md n, + add_lemmas hs meta def intros : smt_tactic unit := intros_core tt diff --git a/src/library/tactic/smt/hinst_lemmas.h b/src/library/tactic/smt/hinst_lemmas.h index b3c0e7f370..d24e0fd965 100644 --- a/src/library/tactic/smt/hinst_lemmas.h +++ b/src/library/tactic/smt/hinst_lemmas.h @@ -35,8 +35,8 @@ typedef list multi_pattern; /** Heuristic instantiation lemma */ struct hinst_lemma { name m_id; - unsigned m_num_uvars; - unsigned m_num_mvars; + unsigned m_num_uvars{0}; + unsigned m_num_mvars{0}; list m_multi_patterns; list m_is_inst_implicit; list m_mvars; diff --git a/src/library/tactic/smt/smt_state.cpp b/src/library/tactic/smt/smt_state.cpp index 3a7f55783d..bb41ab6286 100644 --- a/src/library/tactic/smt/smt_state.cpp +++ b/src/library/tactic/smt/smt_state.cpp @@ -95,6 +95,12 @@ void smt::ematch(buffer & result) { ::lean::ematch(m_ctx, m_goal.m_em_state, m_cc, result); } +void smt::ematch_using(hinst_lemmas const & lemmas, buffer & result) { + lemmas.for_each([&](hinst_lemma const & lemma) { + ::lean::ematch(m_ctx, m_goal.m_em_state, m_cc, lemma, false, result); + }); +} + struct vm_smt_goal : public vm_external { smt_goal m_val; vm_smt_goal(smt_goal const & v):m_val(v) {} @@ -680,74 +686,33 @@ vm_obj smt_tactic_ematch_core(vm_obj const & pred, vm_obj const & ss, vm_obj con LEAN_TACTIC_CATCH(ts); } -vm_obj smt_tactic_add_ematch_lemma_core(vm_obj const & md, vm_obj const & as_simp, vm_obj const & _h, vm_obj const & ss, vm_obj const & _ts) { +vm_obj smt_tactic_mk_ematch_eqn_lemmas_for_core(vm_obj const & md, vm_obj const & decl_name, vm_obj const & ss, vm_obj const & _ts) { tactic_state ts = to_tactic_state(_ts); if (is_nil(ss)) return mk_smt_state_empty_exception(ts); lean_assert(ts.goals()); LEAN_TACTIC_TRY; type_context ctx = mk_type_context_for(ts); - smt_goal g = to_smt_goal(head(ss)); - expr h = to_expr(_h); - expr type = ctx.infer(h); - std::tie(type, h) = preprocess_forward(ctx, g, type, h); - hinst_lemma lemma = mk_hinst_lemma(ctx, to_transparency_mode(md), h, to_bool(as_simp)); - g.add_lemma(lemma); - lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); - tout() << "new lemma " << lemma << "\n" << lemma.m_prop << "\n";); - vm_obj new_ss = mk_vm_cons(to_obj(g), tail(ss)); - tactic_state new_ts = set_env_mctx(ts, ctx.env(), ctx.mctx()); - return mk_smt_tactic_success(new_ss, new_ts); - LEAN_TACTIC_CATCH(ts); -} - -vm_obj smt_tactic_add_ematch_lemma_from_decl_core(vm_obj const & md, vm_obj const & as_simp, vm_obj const & decl_name, vm_obj const & ss, vm_obj const & _ts) { - tactic_state ts = to_tactic_state(_ts); - if (is_nil(ss)) return mk_smt_state_empty_exception(ts); - lean_assert(ts.goals()); - LEAN_TACTIC_TRY; - type_context ctx = mk_type_context_for(ts); - smt_goal g = to_smt_goal(head(ss)); - hinst_lemma lemma = mk_hinst_lemma(ctx, to_transparency_mode(md), to_name(decl_name), to_bool(as_simp)); - g.add_lemma(lemma); - lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); - tout() << "new lemma " << lemma << "\n" << lemma.m_prop << "\n";); - vm_obj new_ss = mk_vm_cons(to_obj(g), tail(ss)); - tactic_state new_ts = set_env_mctx(ts, ctx.env(), ctx.mctx()); - return mk_smt_tactic_success(new_ss, new_ts); - LEAN_TACTIC_CATCH(ts); -} - -vm_obj smt_tactic_add_ematch_eqn_lemmas_for_core(vm_obj const & md, vm_obj const & decl_name, vm_obj const & ss, vm_obj const & _ts) { - tactic_state ts = to_tactic_state(_ts); - if (is_nil(ss)) return mk_smt_state_empty_exception(ts); - lean_assert(ts.goals()); - LEAN_TACTIC_TRY; - type_context ctx = mk_type_context_for(ts); - smt_goal g = to_smt_goal(head(ss)); - smt S(ctx, g); buffer eqns; get_ext_eqn_lemmas_for(ts.env(), to_name(decl_name), eqns); if (eqns.empty()) return mk_tactic_exception(sstream() << "tactic failed, '" << to_name(decl_name) << "' does not have equation lemmas", ts); + hinst_lemmas hs; for (name const & eqn : eqns) { declaration eqn_decl = ctx.env().get(eqn); if (eqn_decl.get_num_univ_params() == 0 && !is_pi(ctx.relaxed_whnf(ctx.env().get(eqn).get_type()))) { - expr h = mk_constant(eqn); - expr type = ctx.infer(h); - std::tie(type, h) = preprocess_forward(ctx, g, type, h); - lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); - tout() << "new ground fact: " << type << "\n";); - S.add(type, h); + hinst_lemma h; + h.m_id = eqn; + h.m_proof = mk_constant(eqn); + h.m_prop = ctx.infer(h.m_proof); + h.m_expr = h.m_proof; + hs.insert(h); } else { - hinst_lemma lemma = mk_hinst_lemma(ctx, to_transparency_mode(md), eqn, true); - g.add_lemma(lemma); - lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); - tout() << "new equation lemma " << lemma << "\n" << lemma.m_prop << "\n";); + hinst_lemma h = mk_hinst_lemma(ctx, to_transparency_mode(md), eqn, true); + hs.insert(h); } } - vm_obj new_ss = mk_vm_cons(to_obj(g), tail(ss)); tactic_state new_ts = set_env_mctx(ts, ctx.env(), ctx.mctx()); - return mk_smt_tactic_success(new_ss, new_ts); + return mk_smt_tactic_success(to_obj(hs), ss, to_obj(new_ts)); LEAN_TACTIC_CATCH(ts); } @@ -794,6 +759,57 @@ vm_obj smt_tactic_set_lemmas(vm_obj const & lemmas, vm_obj const & ss, vm_obj co return mk_smt_tactic_success(new_ss, _ts); } +vm_obj smt_tactic_add_lemmas(vm_obj const & lemmas, vm_obj const & ss, vm_obj const & _ts) { + tactic_state ts = to_tactic_state(_ts); + if (is_nil(ss)) return mk_smt_state_empty_exception(ts); + type_context ctx = mk_type_context_for(ts); + smt_goal g = to_smt_goal(head(ss)); + smt S(ctx, g); + to_hinst_lemmas(lemmas).for_each([&](hinst_lemma const & lemma) { + if (lemma.m_num_mvars == 0 && lemma.m_num_uvars == 0) { + expr type = lemma.m_prop; + expr h = lemma.m_proof; + std::tie(type, h) = preprocess_forward(ctx, g, type, h); + lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); + tout() << "new ground fact: " << type << "\n";); + S.add(type, h); + } else { + lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); + tout() << "new equation lemma " << lemma << "\n" << lemma.m_prop << "\n";); + g.add_lemma(lemma); + } + }); + vm_obj new_ss = mk_vm_cons(to_obj(g), tail(ss)); + tactic_state new_ts = set_env_mctx(ts, ctx.env(), ctx.mctx()); + return mk_smt_tactic_success(new_ss, new_ts); +} + +vm_obj smt_tactic_ematch_using(vm_obj const & hs, vm_obj const & ss, vm_obj const & _ts) { + tactic_state ts = to_tactic_state(_ts); + if (is_nil(ss)) return mk_smt_state_empty_exception(ts); + lean_assert(ts.goals()); + LEAN_TACTIC_TRY; + expr target = ts.get_main_goal_decl()->get_type(); + type_context ctx = mk_type_context_for(ts); + smt_goal g = to_smt_goal(head(ss)); + smt S(ctx, g); + S.internalize(target); + buffer new_instances; + S.ematch_using(to_hinst_lemmas(hs), new_instances); + for (expr_pair const & p : new_instances) { + expr type = p.first; + expr proof = p.second; + std::tie(type, proof) = preprocess_forward(ctx, g, type, proof); + lean_trace(name({"smt", "ematch"}), scope_trace_env _(ctx.env(), ctx); + tout() << "new instance\n" << type << "\n";); + S.add(type, proof); + } + vm_obj new_ss = mk_vm_cons(to_obj(g), tail(ss)); + tactic_state new_ts = set_env_mctx(ts, ctx.env(), ctx.mctx()); + return mk_smt_tactic_success(new_ss, new_ts); + LEAN_TACTIC_CATCH(ts); +} + void initialize_smt_state() { register_trace_class("smt"); register_trace_class(name({"smt", "fact"})); @@ -807,14 +823,14 @@ void initialize_smt_state() { DECLARE_VM_BUILTIN(name({"smt_tactic", "close"}), smt_tactic_close); DECLARE_VM_BUILTIN(name({"smt_tactic", "intros_core"}), smt_tactic_intros_core); DECLARE_VM_BUILTIN(name({"smt_tactic", "ematch_core"}), smt_tactic_ematch_core); + DECLARE_VM_BUILTIN(name({"smt_tactic", "ematch_using"}), smt_tactic_ematch_using); DECLARE_VM_BUILTIN(name({"smt_tactic", "to_cc_state"}), smt_tactic_to_cc_state); DECLARE_VM_BUILTIN(name({"smt_tactic", "to_em_state"}), smt_tactic_to_em_state); DECLARE_VM_BUILTIN(name({"smt_tactic", "preprocess"}), smt_tactic_preprocess); DECLARE_VM_BUILTIN(name({"smt_tactic", "get_lemmas"}), smt_tactic_get_lemmas); DECLARE_VM_BUILTIN(name({"smt_tactic", "set_lemmas"}), smt_tactic_set_lemmas); - DECLARE_VM_BUILTIN(name({"smt_tactic", "add_ematch_lemma_core"}), smt_tactic_add_ematch_lemma_core); - DECLARE_VM_BUILTIN(name({"smt_tactic", "add_ematch_lemma_from_decl_core"}), smt_tactic_add_ematch_lemma_from_decl_core); - DECLARE_VM_BUILTIN(name({"smt_tactic", "add_ematch_eqn_lemmas_for_core"}), smt_tactic_add_ematch_eqn_lemmas_for_core); + DECLARE_VM_BUILTIN(name({"smt_tactic", "add_lemmas"}), smt_tactic_add_lemmas); + DECLARE_VM_BUILTIN(name({"smt_tactic", "mk_ematch_eqn_lemmas_for_core"}), smt_tactic_mk_ematch_eqn_lemmas_for_core); } void finalize_smt_state() { diff --git a/src/library/tactic/smt/smt_state.h b/src/library/tactic/smt/smt_state.h index a0951ec2ef..1500168356 100644 --- a/src/library/tactic/smt/smt_state.h +++ b/src/library/tactic/smt/smt_state.h @@ -59,6 +59,7 @@ public: void internalize(expr const & e); void add(expr const & type, expr const & proof); void ematch(buffer & result); + void ematch_using(hinst_lemmas const & lemmas, buffer & result); optional get_proof(expr const & e); bool inconsistent() const { return m_cc.inconsistent(); } diff --git a/tests/lean/run/smt_ematch1.lean b/tests/lean/run/smt_ematch1.lean index d85c4d9aa8..aa8486a152 100644 --- a/tests/lean/run/smt_ematch1.lean +++ b/tests/lean/run/smt_ematch1.lean @@ -18,6 +18,11 @@ begin [smt] ematch end +example (a b c d e : nat) : d = a → c = e → g a d = b → b = g e c → f a = c := +begin [smt] + ematch_using [fax, gax] +end + local attribute [-ematch] fax example (a b c d e : nat) : d = a → c = e → g a d = b → b = g e c → f a = c := @@ -26,6 +31,11 @@ begin [smt] ematch end +example (a b c d e : nat) : d = a → c = e → g a d = b → b = g e c → f a = c := +begin [smt] + ematch_using [fax, gax] +end + example (a b c d e : nat) : (∀ x, g x (f x) = 0) → a = f b → g b a + 0 = f 0 := begin [smt] assert h : ∀ x, g x (f x) = 0, @@ -33,6 +43,12 @@ begin [smt] ematch end +example (a b c d e : nat) : (∀ x, g x (f x) = 0) → a = f b → g b a + 0 = f 0 := +begin [smt] + assert h : ∀ x, g x (f x) = 0, + ematch_using [h, fax, add_zero] +end + local attribute [ematch] fax add_zero open smt_tactic