diff --git a/src/library/tactic/smt/congruence_closure.cpp b/src/library/tactic/smt/congruence_closure.cpp index eba0afa376..77f3d76eff 100644 --- a/src/library/tactic/smt/congruence_closure.cpp +++ b/src/library/tactic/smt/congruence_closure.cpp @@ -500,6 +500,8 @@ void congruence_closure::apply_simple_eqvs(expr const & e) { expr const & fn = get_app_fn(e); if (is_lambda(fn)) { expr reduced_e = head_beta_reduce(e); + if (m_phandler) + m_phandler->new_aux_cc_term(reduced_e); internalize_core(reduced_e, none_expr()); push_refl_eq(e, reduced_e); } diff --git a/src/library/tactic/smt/congruence_closure.h b/src/library/tactic/smt/congruence_closure.h index 91490bf98f..a146ea52b2 100644 --- a/src/library/tactic/smt/congruence_closure.h +++ b/src/library/tactic/smt/congruence_closure.h @@ -25,6 +25,9 @@ public: virtual ~cc_propagation_handler() {} virtual void propagated(unsigned n, expr const * data) = 0; void propagated(buffer const & p) { propagated(p.size(), p.data()); } + /* Congruence closure module invokes the following method when + a new auxiliary term is created during propagation. */ + virtual void new_aux_cc_term(expr const & e) = 0; }; /* The congruence_closure module (optionally) uses a normalizer. diff --git a/src/library/tactic/smt/ematch.cpp b/src/library/tactic/smt/ematch.cpp index 8f82236223..06a2159e08 100644 --- a/src/library/tactic/smt/ematch.cpp +++ b/src/library/tactic/smt/ematch.cpp @@ -688,6 +688,16 @@ struct ematch_fn { return true; } + bool match_args_prefix(state & s, buffer const & p_args, expr const & t) { + unsigned t_nargs = get_app_num_args(t); + expr it = t; + while (t_nargs > p_args.size()) { + t_nargs--; + it = app_fn(it); + } + return match_args(s, p_args, it); + } + bool process_continue(expr const & p) { lean_trace(name({"debug", "ematch"}), tout() << "process_continue: " << p << "\n";); buffer p_args; @@ -697,7 +707,7 @@ struct ematch_fn { s->for_each([&](expr const & t) { if (m_cc.is_congr_root(t) || m_cc.in_heterogeneous_eqc(t)) { state new_state = m_state; - if (match_args(new_state, p_args, t)) + if (match_args_prefix(new_state, p_args, t)) new_states.push_back(new_state); } }); @@ -919,7 +929,7 @@ struct ematch_fn { expr const & fn = get_app_args(p, p_args); if (!m_ctx.is_def_eq(fn, get_app_fn(t))) return; - if (!match_args(m_state, p_args, t)) + if (!match_args_prefix(m_state, p_args, t)) return; search(lemma); } diff --git a/src/library/tactic/smt/smt_state.cpp b/src/library/tactic/smt/smt_state.cpp index 3c2ed72d3d..938fadb6ee 100644 --- a/src/library/tactic/smt/smt_state.cpp +++ b/src/library/tactic/smt/smt_state.cpp @@ -88,6 +88,10 @@ void smt::internalize(expr const & e) { m_goal.m_em_state.internalize(m_ctx, e); } +void smt::new_aux_cc_term(expr const & e) { + m_goal.m_em_state.internalize(m_ctx, e); +} + void smt::add(expr const & type, expr const & proof) { m_goal.m_em_state.internalize(m_ctx, type); m_cc.add(type, proof); diff --git a/src/library/tactic/smt/smt_state.h b/src/library/tactic/smt/smt_state.h index 6d17f10a34..c95505d4b7 100644 --- a/src/library/tactic/smt/smt_state.h +++ b/src/library/tactic/smt/smt_state.h @@ -52,6 +52,7 @@ private: lbool get_value_core(expr const & e); lbool get_value(expr const & e); virtual void propagated(unsigned n, expr const * p) override; + virtual void new_aux_cc_term(expr const & e) override; virtual expr normalize(expr const & e) override; public: smt(type_context & ctx, defeq_can_state & dcs, smt_goal & g); diff --git a/tests/lean/run/ematch_partial_apps.lean b/tests/lean/run/ematch_partial_apps.lean new file mode 100644 index 0000000000..20ba2cb7ea --- /dev/null +++ b/tests/lean/run/ematch_partial_apps.lean @@ -0,0 +1,42 @@ +open tactic + +meta def check_expr (p : pexpr) (t : expr) : tactic unit := +do e ← to_expr p, guard (expr.alpha_eqv t e) + +meta def check_target (p : pexpr) : tactic unit := +do t ← target, check_expr p t + +set_option trace.smt.ematch true + +example (a : list nat) (f : nat → nat) : a = [1, 2] → a^.for f = [f 1, f 2] := +begin [smt] + intros, + ematch_using [list.for], + ematch_using [flip], + ematch_using [list.map], + ematch_using [list.map], + ematch_using [list.map] +end + +example (a : list nat) (f : nat → nat) : a = [1, 2] → a^.for f = [f 1, f 2] := +begin [smt] + intros, + repeat {ematch_using [list.for, flip, list.map], try { close }}, +end + +attribute [ematch] list.map flip list.for + +example (a : list nat) (f : nat → nat) : a = [1, 2] → a^.for f = [f 1, f 2] := +begin [smt] + intros, eblast +end + +constant f : nat → nat → nat +constant g : nat → nat → nat +axiom fgx : ∀ x y, (: f x :) = (λ y, y) ∧ (: g y :) = λ x, 0 +attribute [ematch] fgx + +example (a b c : nat) : f a b = b ∧ g b c = 0 := +begin [smt] + ematch +end