From 7a99d87cbd8ee45ff96abe4fe79ff3cf8eeb15ec Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 8 Mar 2017 13:46:49 -0800 Subject: [PATCH] fix(library/tactic/ac_tactics): allow nested ac_app macros in perm_ac macro fixes #1442 --- src/library/tactic/ac_tactics.cpp | 525 +++++++++++++++--------------- tests/lean/run/1442.lean | 11 + 2 files changed, 276 insertions(+), 260 deletions(-) create mode 100644 tests/lean/run/1442.lean diff --git a/src/library/tactic/ac_tactics.cpp b/src/library/tactic/ac_tactics.cpp index 0e9e6d84d1..cdb6cf967f 100644 --- a/src/library/tactic/ac_tactics.cpp +++ b/src/library/tactic/ac_tactics.cpp @@ -99,266 +99,6 @@ optional ac_manager::is_comm(expr const & e) { return r; } -struct flat_assoc_fn { - abstract_type_context & m_ctx; - expr m_op; - expr m_assoc; - - flat_assoc_fn(abstract_type_context & ctx, expr const & op, expr const & assoc): - m_ctx(ctx), m_op(op), m_assoc(assoc) {} - - bool is_op_app(expr const & e, expr & lhs, expr & rhs) { - if (!is_app(e)) return false; - expr const & fn1 = app_fn(e); - if (!is_app(fn1)) return false; - if (app_fn(fn1) != m_op) return false; - lhs = app_arg(fn1); - rhs = app_arg(e); - return true; - } - - bool is_op_app(expr const & e) { - if (!is_app(e)) return false; - expr const & fn1 = app_fn(e); - if (!is_app(fn1)) return false; - return app_fn(fn1) == m_op; - } - - expr mk_op(expr const & a, expr const & b) { - return mk_app(m_op, a, b); - } - - expr mk_assoc(expr const & a, expr const & b, expr const & c) { - return mk_app(m_assoc, a, b, c); - } - - expr mk_eq_refl(expr const & a) { - return ::lean::mk_eq_refl(m_ctx, a); - } - - expr mk_eq_trans(expr const & H1, expr const & H2) { - return ::lean::mk_eq_trans(m_ctx, H1, H2); - } - - expr mk_eq_trans(expr const & H1, optional const & H2) { - if (!H2) return H1; - return mk_eq_trans(H1, *H2); - } - - optional mk_eq_trans(optional const & H1, optional const & H2) { - if (!H1) return H2; - if (!H2) return H1; - return some_expr(mk_eq_trans(*H1, *H2)); - } - - expr mk_eq_symm(expr const & H) { - return ::lean::mk_eq_symm(m_ctx, H); - } - - optional mk_eq_symm(optional const & H) { - if (!H) return none_expr(); - return some_expr(mk_eq_symm(*H)); - } - - expr mk_congr_arg(expr const & fn, expr const & H) { - return ::lean::mk_congr_arg(m_ctx, fn, H); - } - - pair> flat_with(expr const & e, expr const & rest) { - expr lhs, rhs; - if (is_op_app(e, lhs, rhs)) { - auto p1 = flat_with(rhs, rest); - if (p1.second) { - auto p2 = flat_with(lhs, p1.first); - // H3 is a proof for (lhs `op` rhs) `op` rest = lhs `op` (rhs `op` rest) - expr H3 = mk_assoc(lhs, rhs, rest); - // H4 is a proof for lhs `op` (rhs `op` rest) = lhs `op` p1.first - expr H4 = mk_congr_arg(mk_app(m_op, lhs), *p1.second); - expr H = mk_eq_trans(mk_eq_trans(H3, H4), p2.second); - return mk_pair(p2.first, some_expr(H)); - } else { - if (is_op_app(lhs)) { - auto p2 = flat_with(lhs, p1.first); - // H3 is a proof for (lhs `op` rhs) `op` rest = lhs `op` (rhs `op` rest) - expr H3 = mk_assoc(lhs, rhs, rest); - expr H = mk_eq_trans(H3, p2.second); - return mk_pair(p2.first, some_expr(H)); - } else { - return mk_pair(mk_op(lhs, p1.first), some_expr(mk_assoc(lhs, rhs, rest))); - } - } - } else { - return mk_pair(mk_op(e, rest), none_expr()); - } - } - - pair> flat_core(expr const & e) { - expr lhs, rhs; - if (is_op_app(e, lhs, rhs)) { - auto p1 = flat_core(rhs); - if (p1.second) { - if (is_op_app(lhs)) { - auto p2 = flat_with(lhs, p1.first); - expr H3 = mk_congr_arg(mk_app(m_op, lhs), *p1.second); - expr H = mk_eq_trans(H3, p2.second); - return mk_pair(p2.first, some_expr(H)); - } else { - expr r = mk_op(lhs, p1.first); - expr H = mk_congr_arg(mk_app(m_op, lhs), *p1.second); - return mk_pair(r, some_expr(H)); - } - } else { - if (is_op_app(lhs)) { - return flat_with(lhs, rhs); - } else { - return mk_pair(e, none_expr()); - } - } - } else { - return mk_pair(e, none_expr()); - } - } - - pair flat(expr const & e) { - auto p = flat_core(e); - if (p.second) { - return mk_pair(p.first, *p.second); - } else { - return mk_pair(e, mk_eq_refl(e)); - } - } -}; - -#define lean_perm_ac_trace(code) lean_trace(name({"tactic", "perm_ac"}), scope_trace_env _scope1(m_ctx.env(), m_ctx); code) - -struct perm_ac_fn : public flat_assoc_fn { - expr m_comm; - optional m_left_comm; - - perm_ac_fn(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & comm): - flat_assoc_fn(ctx, op, assoc), m_comm(comm) { - } - - [[ noreturn ]] void throw_failed() { - throw exception("perm_ac failed, arguments are not equal modulo AC"); - } - - expr mk_comm(expr const & a, expr const & b) { - return mk_app(m_comm, a, b); - } - - level dec_level(level const & l) { - if (auto r = ::lean::dec_level(l)) - return *r; - throw_failed(); - } - - expr mk_left_comm(expr const & a, expr const & b, expr const & c) { - if (!m_left_comm) { - expr A = m_ctx.infer(a); - level lvl = dec_level(get_level(m_ctx, A)); - m_left_comm = mk_app(mk_constant(get_left_comm_name(), {lvl}), A, m_op, m_comm, m_assoc); - } - return mk_app(*m_left_comm, a, b, c); - } - - /* Given a term \c e of the form (op t_1 (op t_2 ... (op t_{n-1} t_n))), if - for some i, t_i == t, then produce the term - - (op t_i (op t_2 ... (op t_{n-1} t_n))) - - and a proof they are equal AC. - Throw exception if t is not found. */ - pair pull_term(expr const & t, expr const & e) { - expr lhs1, rhs1; - if (!is_op_app(e, lhs1, rhs1)) { - lean_perm_ac_trace(tout() << "right-hand-side does not contain:\n" << t << "\n";); - throw_failed(); - } - if (t == rhs1) { - return mk_pair(mk_op(rhs1, lhs1), mk_comm(lhs1, rhs1)); - } - expr lhs2, rhs2; - if (!is_op_app(rhs1, lhs2, rhs2)) { - lean_perm_ac_trace(tout() << "right-hand-side does not contain:\n" << t << "\n";); - throw_failed(); - } - if (t == lhs2) { - return mk_pair(mk_op(lhs2, mk_op(lhs1, rhs2)), mk_left_comm(lhs1, lhs2, rhs2)); - } - /* We have e := lhs1 `op` lhs2 `op` rhs2 */ - auto p = pull_term(t, rhs1); - expr lhs3, rhs3; - lean_verify(is_op_app(p.first, lhs3, rhs3)); - lean_assert(t == lhs3); - /* p.second : rhs1 = t `op` rhs3 */ - expr H1 = mk_congr_arg(mk_app(m_op, lhs1), p.second); - /* H1 : lhs1 `op` rhs1 = lhs1 `op` t `op` rhs3 */ - expr H2 = mk_left_comm(lhs1, t, rhs3); - /* H2 : lhs1 `op` t `op` rhs3 = t `op` lhs1 `op` rhs3 */ - return mk_pair(mk_op(t, mk_op(lhs1, rhs3)), mk_eq_trans(H1, H2)); - } - - /* Return a proof that e1 == e2 modulo AC. Return none if reflexivity. - Throw exception if failure */ - optional perm_flat(expr const & e1, expr const & e2) { - expr lhs1, rhs1; - expr lhs2, rhs2; - bool b1 = is_op_app(e1, lhs1, rhs1); - bool b2 = is_op_app(e2, lhs2, rhs2); - if (b1 != b2) { - lean_perm_ac_trace(tout() << "left and right-hand-sides have different number of terms\n";); - throw_failed(); - } - if (!b1 && !b2) { - if (e1 == e2) { - return none_expr(); // reflexivity - } else { - lean_perm_ac_trace(tout() << "the left and right hand sides contain the terms:\n" << e1 << "\n" << e2 << "\n";); - throw_failed(); - } - } - lean_assert(b1 && b2); - if (lhs1 == lhs2) { - optional H = perm_flat(rhs1, rhs2); - if (!H) return none_expr(); - return some_expr(mk_congr_arg(mk_app(m_op, lhs1), *H)); - } else { - auto p = pull_term(lhs2, e1); - is_op_app(p.first, lhs1, rhs1); - lean_assert(lhs1 == lhs2); - optional H1 = perm_flat(rhs1, rhs2); - if (!H1) return some_expr(p.second); - expr H2 = mk_congr_arg(mk_app(m_op, lhs1), *H1); - return some_expr(mk_eq_trans(p.second, H2)); - } - } - - /* Return a proof that lhs == rhs modulo AC. Return none if reflexivity. - Throw exception if failure */ - optional perm_core(expr const & lhs, expr const & rhs) { - auto p1 = flat_core(lhs); - auto p2 = flat_core(rhs); - auto H = perm_flat(p1.first, p2.first); - return mk_eq_trans(p1.second, mk_eq_trans(H, mk_eq_symm(p2.second))); - } - - expr perm(expr const & lhs, expr const & rhs) { - if (auto H = perm_core(lhs, rhs)) - return *H; - else - return mk_eq_refl(lhs); - } -}; - -pair> flat_assoc(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & e) { - return flat_assoc_fn(ctx, op, assoc).flat_core(e); -} - -expr perm_ac(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & comm, expr const & e1, expr const & e2) { - return perm_ac_fn(ctx, op, assoc, comm).perm(e1, e2); -} - static name * g_ac_app_name = nullptr; static macro_definition * g_ac_app_macro = nullptr; static std::string * g_ac_app_opcode = nullptr; @@ -612,6 +352,271 @@ static expr expand_if_ac_app(expr const & e) { return e; } +struct flat_assoc_fn { + abstract_type_context & m_ctx; + expr m_op; + expr m_assoc; + + flat_assoc_fn(abstract_type_context & ctx, expr const & op, expr const & assoc): + m_ctx(ctx), m_op(op), m_assoc(assoc) {} + + bool is_op_app(expr const & e, expr & lhs, expr & rhs) { + if (!is_app(e)) return false; + expr const & fn1 = app_fn(e); + if (!is_app(fn1)) return false; + if (app_fn(fn1) != m_op) return false; + lhs = app_arg(fn1); + rhs = app_arg(e); + return true; + } + + bool is_op_app(expr const & e) { + if (!is_app(e)) return false; + expr const & fn1 = app_fn(e); + if (!is_app(fn1)) return false; + return app_fn(fn1) == m_op; + } + + expr mk_op(expr const & a, expr const & b) { + return mk_app(m_op, a, b); + } + + expr mk_assoc(expr const & a, expr const & b, expr const & c) { + return mk_app(m_assoc, a, b, c); + } + + expr mk_eq_refl(expr const & a) { + return ::lean::mk_eq_refl(m_ctx, a); + } + + expr mk_eq_trans(expr const & H1, expr const & H2) { + return ::lean::mk_eq_trans(m_ctx, H1, H2); + } + + expr mk_eq_trans(expr const & H1, optional const & H2) { + if (!H2) return H1; + return mk_eq_trans(H1, *H2); + } + + optional mk_eq_trans(optional const & H1, optional const & H2) { + if (!H1) return H2; + if (!H2) return H1; + return some_expr(mk_eq_trans(*H1, *H2)); + } + + expr mk_eq_symm(expr const & H) { + return ::lean::mk_eq_symm(m_ctx, H); + } + + optional mk_eq_symm(optional const & H) { + if (!H) return none_expr(); + return some_expr(mk_eq_symm(*H)); + } + + expr mk_congr_arg(expr const & fn, expr const & H) { + return ::lean::mk_congr_arg(m_ctx, fn, H); + } + + pair> flat_with(expr const & e, expr const & rest) { + expr lhs, rhs; + if (is_op_app(e, lhs, rhs)) { + lhs = expand_if_ac_app(lhs); + rhs = expand_if_ac_app(rhs); + auto p1 = flat_with(rhs, rest); + if (p1.second) { + auto p2 = flat_with(lhs, p1.first); + // H3 is a proof for (lhs `op` rhs) `op` rest = lhs `op` (rhs `op` rest) + expr H3 = mk_assoc(lhs, rhs, rest); + // H4 is a proof for lhs `op` (rhs `op` rest) = lhs `op` p1.first + expr H4 = mk_congr_arg(mk_app(m_op, lhs), *p1.second); + expr H = mk_eq_trans(mk_eq_trans(H3, H4), p2.second); + return mk_pair(p2.first, some_expr(H)); + } else { + if (is_op_app(lhs)) { + auto p2 = flat_with(lhs, p1.first); + // H3 is a proof for (lhs `op` rhs) `op` rest = lhs `op` (rhs `op` rest) + expr H3 = mk_assoc(lhs, rhs, rest); + expr H = mk_eq_trans(H3, p2.second); + return mk_pair(p2.first, some_expr(H)); + } else { + return mk_pair(mk_op(lhs, p1.first), some_expr(mk_assoc(lhs, rhs, rest))); + } + } + } else { + return mk_pair(mk_op(e, rest), none_expr()); + } + } + + pair> flat_core(expr e) { + expr lhs, rhs; + e = expand_if_ac_app(e); + if (is_op_app(e, lhs, rhs)) { + lhs = expand_if_ac_app(lhs); + rhs = expand_if_ac_app(rhs); + auto p1 = flat_core(rhs); + if (p1.second) { + if (is_op_app(lhs)) { + auto p2 = flat_with(lhs, p1.first); + expr H3 = mk_congr_arg(mk_app(m_op, lhs), *p1.second); + expr H = mk_eq_trans(H3, p2.second); + return mk_pair(p2.first, some_expr(H)); + } else { + expr r = mk_op(lhs, p1.first); + expr H = mk_congr_arg(mk_app(m_op, lhs), *p1.second); + return mk_pair(r, some_expr(H)); + } + } else { + if (is_op_app(lhs)) { + return flat_with(lhs, rhs); + } else { + return mk_pair(e, none_expr()); + } + } + } else { + return mk_pair(e, none_expr()); + } + } + + pair flat(expr const & e) { + auto p = flat_core(e); + if (p.second) { + return mk_pair(p.first, *p.second); + } else { + return mk_pair(e, mk_eq_refl(e)); + } + } +}; + +#define lean_perm_ac_trace(code) lean_trace(name({"tactic", "perm_ac"}), scope_trace_env _scope1(m_ctx.env(), m_ctx); code) + +struct perm_ac_fn : public flat_assoc_fn { + expr m_comm; + optional m_left_comm; + + perm_ac_fn(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & comm): + flat_assoc_fn(ctx, op, assoc), m_comm(comm) { + } + + [[ noreturn ]] void throw_failed() { + throw exception("perm_ac failed, arguments are not equal modulo AC"); + } + + expr mk_comm(expr const & a, expr const & b) { + return mk_app(m_comm, a, b); + } + + level dec_level(level const & l) { + if (auto r = ::lean::dec_level(l)) + return *r; + throw_failed(); + } + + expr mk_left_comm(expr const & a, expr const & b, expr const & c) { + if (!m_left_comm) { + expr A = m_ctx.infer(a); + level lvl = dec_level(get_level(m_ctx, A)); + m_left_comm = mk_app(mk_constant(get_left_comm_name(), {lvl}), A, m_op, m_comm, m_assoc); + } + return mk_app(*m_left_comm, a, b, c); + } + + /* Given a term \c e of the form (op t_1 (op t_2 ... (op t_{n-1} t_n))), if + for some i, t_i == t, then produce the term + + (op t_i (op t_2 ... (op t_{n-1} t_n))) + + and a proof they are equal AC. + Throw exception if t is not found. */ + pair pull_term(expr const & t, expr const & e) { + expr lhs1, rhs1; + if (!is_op_app(e, lhs1, rhs1)) { + lean_perm_ac_trace(tout() << "right-hand-side does not contain:\n" << t << "\n";); + throw_failed(); + } + if (t == rhs1) { + return mk_pair(mk_op(rhs1, lhs1), mk_comm(lhs1, rhs1)); + } + expr lhs2, rhs2; + if (!is_op_app(rhs1, lhs2, rhs2)) { + lean_perm_ac_trace(tout() << "right-hand-side does not contain:\n" << t << "\n";); + throw_failed(); + } + if (t == lhs2) { + return mk_pair(mk_op(lhs2, mk_op(lhs1, rhs2)), mk_left_comm(lhs1, lhs2, rhs2)); + } + /* We have e := lhs1 `op` lhs2 `op` rhs2 */ + auto p = pull_term(t, rhs1); + expr lhs3, rhs3; + lean_verify(is_op_app(p.first, lhs3, rhs3)); + lean_assert(t == lhs3); + /* p.second : rhs1 = t `op` rhs3 */ + expr H1 = mk_congr_arg(mk_app(m_op, lhs1), p.second); + /* H1 : lhs1 `op` rhs1 = lhs1 `op` t `op` rhs3 */ + expr H2 = mk_left_comm(lhs1, t, rhs3); + /* H2 : lhs1 `op` t `op` rhs3 = t `op` lhs1 `op` rhs3 */ + return mk_pair(mk_op(t, mk_op(lhs1, rhs3)), mk_eq_trans(H1, H2)); + } + + /* Return a proof that e1 == e2 modulo AC. Return none if reflexivity. + Throw exception if failure */ + optional perm_flat(expr const & e1, expr const & e2) { + expr lhs1, rhs1; + expr lhs2, rhs2; + bool b1 = is_op_app(e1, lhs1, rhs1); + bool b2 = is_op_app(e2, lhs2, rhs2); + if (b1 != b2) { + lean_perm_ac_trace(tout() << "left and right-hand-sides have different number of terms\n";); + throw_failed(); + } + if (!b1 && !b2) { + if (e1 == e2) { + return none_expr(); // reflexivity + } else { + lean_perm_ac_trace(tout() << "the left and right hand sides contain the terms:\n" << e1 << "\n" << e2 << "\n";); + throw_failed(); + } + } + lean_assert(b1 && b2); + if (lhs1 == lhs2) { + optional H = perm_flat(rhs1, rhs2); + if (!H) return none_expr(); + return some_expr(mk_congr_arg(mk_app(m_op, lhs1), *H)); + } else { + auto p = pull_term(lhs2, e1); + is_op_app(p.first, lhs1, rhs1); + lean_assert(lhs1 == lhs2); + optional H1 = perm_flat(rhs1, rhs2); + if (!H1) return some_expr(p.second); + expr H2 = mk_congr_arg(mk_app(m_op, lhs1), *H1); + return some_expr(mk_eq_trans(p.second, H2)); + } + } + + /* Return a proof that lhs == rhs modulo AC. Return none if reflexivity. + Throw exception if failure */ + optional perm_core(expr const & lhs, expr const & rhs) { + auto p1 = flat_core(lhs); + auto p2 = flat_core(rhs); + auto H = perm_flat(p1.first, p2.first); + return mk_eq_trans(p1.second, mk_eq_trans(H, mk_eq_symm(p2.second))); + } + + expr perm(expr const & lhs, expr const & rhs) { + if (auto H = perm_core(lhs, rhs)) + return *H; + else + return mk_eq_refl(lhs); + } +}; + +pair> flat_assoc(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & e) { + return flat_assoc_fn(ctx, op, assoc).flat_core(e); +} + +expr perm_ac(abstract_type_context & ctx, expr const & op, expr const & assoc, expr const & comm, expr const & e1, expr const & e2) { + return perm_ac_fn(ctx, op, assoc, comm).perm(e1, e2); +} + static name * g_perm_ac_name = nullptr; static macro_definition * g_perm_ac_macro = nullptr; static std::string * g_perm_ac_opcode = nullptr; diff --git a/tests/lean/run/1442.lean b/tests/lean/run/1442.lean new file mode 100644 index 0000000000..d22c036f19 --- /dev/null +++ b/tests/lean/run/1442.lean @@ -0,0 +1,11 @@ +protected def rel : ℤ × ℤ → ℤ × ℤ → Prop +| ⟨n₁, d₁⟩ ⟨n₂, d₂⟩ := n₁ * d₂ = n₂ * d₁ + +private def mul' : ℤ × ℤ → ℤ × ℤ → ℤ × ℤ +| ⟨n₁, d₁⟩ ⟨n₂, d₂⟩ := ⟨n₁ * n₂, d₁ * d₂⟩ + +example : ∀(a b c d : ℤ × ℤ), rel a c → rel b d → rel (mul' a b) (mul' c d) := +λ⟨n₁, d₁⟩ ⟨n₂, d₂⟩ ⟨n₃, d₃⟩ ⟨n₄, d₄⟩, + assume (h₁ : n₁ * d₃ = n₃ * d₁) (h₂ : n₂ * d₄ = n₄ * d₂), + show (n₁ * n₂) * (d₃ * d₄) = (n₃ * n₄) * (d₁ * d₂), + by cc