diff --git a/library/init/core.lean b/library/init/core.lean index 94e7d9064e..387d9ceb4b 100644 --- a/library/init/core.lean +++ b/library/init/core.lean @@ -382,6 +382,10 @@ namespace nat | a zero := a | a (succ b) := succ (add a b) + /- We mark the following definitions as pattern to make sure they can be used in recursive equations, + and reduced by the equation compiler. -/ + attribute [pattern] nat.add nat.add._main + def of_pos_num : pos_num → nat | pos_num.one := succ zero | (pos_num.bit0 a) := let r := of_pos_num a in nat.add r r diff --git a/src/library/equations_compiler/structural_rec.cpp b/src/library/equations_compiler/structural_rec.cpp index a881e9834f..bcf8dbf76e 100644 --- a/src/library/equations_compiler/structural_rec.cpp +++ b/src/library/equations_compiler/structural_rec.cpp @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "library/constants.h" #include "library/locals.h" #include "library/util.h" +#include "library/pattern_attribute.h" #include "library/app_builder.h" #include "library/replace_visitor_with_tc.h" #include "library/equations_compiler/equations.h" @@ -76,6 +77,15 @@ struct structural_rec_fn { return is_constant(e) && inductive::is_intro_rule(m_ctx.env(), const_name(e)); } + expr whnf(expr const & e) { + /* We only unfold patterns and reducible definitions */ + return m_ctx.whnf_transparency_pred(e, [&](name const & n) { + return + has_pattern_attribute(m_ctx.env(), n) || + is_reducible(m_ctx.env(), n); + }); + } + /** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */ bool is_le(expr const & s, expr const & t) { return m_ctx.is_def_eq(s, t) || is_lt(s, t); @@ -83,8 +93,8 @@ struct structural_rec_fn { /** Return true iff \c s is structurally smaller than \c t */ bool is_lt(expr s, expr t) { - s = m_ctx.whnf(s); - t = m_ctx.whnf(t); + s = whnf(s); + t = whnf(t); if (is_app(s)) { expr const & s_fn = get_app_fn(s); if (!is_constructor(s_fn)) diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 8f569f7477..fe1b3ae62b 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -480,7 +480,7 @@ static optional prove_eq_rec_invertible(type_context & ctx, expr cons return optional(mk_pair(h_a, pr)); } -static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs) { +static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs, bool root) { buffer ite_args; expr new_lhs = whnf_ite(ctx, lhs); if (is_ite_eq(new_lhs, ite_args)) { @@ -492,7 +492,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false); return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2}); } else if (quick_is_def_eq_when_values(ctx, c_lhs, c_rhs)) { expr H = mk_eq_refl(ctx, c_lhs); @@ -500,7 +500,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_pos_name(), {A_lvl}), {c, ite_args[1], H, A, lhs_then, ite_args[4]}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs, false); expr eq_trans = mk_constant(get_eq_trans_name(), {A_lvl}); return mk_app(eq_trans, {A, lhs, lhs_then, rhs, H1, H2}); } else if (compare_values(c_lhs, c_rhs) == l_false) { @@ -509,7 +509,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false); return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2}); } } @@ -518,12 +518,18 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex if (optional p = prove_eq_rec_invertible(ctx, new_lhs)) { expr new_new_lhs = p->first; expr H1 = p->second; - expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs, false); return mk_eq_trans(ctx, H1, H2); } - if (ctx.is_def_eq(lhs, rhs)) { - return mk_eq_refl(ctx, rhs); + expr lhs_body = lhs; + if (root) { + if (auto b = unfold_term(ctx.env(), lhs)) + lhs_body = *b; + } + + if (ctx.is_def_eq(lhs_body, rhs)) { + return mk_eq_refl(ctx, lhs_body); } throw exception("equation compiler failed to prove equation lemma (workaround: " @@ -531,7 +537,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex } static expr prove_eqn_lemma(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs) { - expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs); + expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs, true); return ctx.mk_lambda(Hs, body); } diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 51016e3bcf..5343164345 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -800,7 +800,7 @@ expr type_context::whnf_head_pred(expr const & e, std::function const & pred) { +expr type_context::whnf_transparency_pred(expr const & e, std::function const & pred) { // NOLINT flet const *>set_trans_pred(m_transparency_pred, &pred); // NOLINT return whnf(e); } diff --git a/src/library/type_context.h b/src/library/type_context.h index b97fc8a509..5bc39d4659 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -379,7 +379,7 @@ public: the given predicate to decide whether a constant should be unfolded or not. Remark: cache is not used. */ - expr whnf_transparency_pred(expr const & e, std::function const & pred); + expr whnf_transparency_pred(expr const & e, std::function const & pred); // NOLINT /** \brief Put \c e in whnf, it is a Pi, then return whnf, otherwise return e */ expr try_to_pi(expr const & e); diff --git a/src/tests/shell/shell_test.produced.out b/src/tests/shell/shell_test.produced.out index 7cd72b2568..e69de29bb2 100644 --- a/src/tests/shell/shell_test.produced.out +++ b/src/tests/shell/shell_test.produced.out @@ -1,2 +0,0 @@ -{"message":"file invalidated","response":"ok","seq_num":0} -{"record":{"full-id":"f","source":{"column":10,"file":"f","line":1},"type":"Type"},"response":"ok","seq_num":1} diff --git a/tests/lean/671.lean.expected.out b/tests/lean/671.lean.expected.out index 72952ec500..d7f2428f20 100644 --- a/tests/lean/671.lean.expected.out +++ b/tests/lean/671.lean.expected.out @@ -1,2 +1,3 @@ +@[pattern] protected def nat.add : ℕ → ℕ → ℕ := nat.add._main diff --git a/tests/lean/run/cc_ac3.lean b/tests/lean/run/cc_ac3.lean index 92ca32f0e4..de3bd032fd 100644 --- a/tests/lean/run/cc_ac3.lean +++ b/tests/lean/run/cc_ac3.lean @@ -31,6 +31,6 @@ section variables [is_associative α op] variables [is_commutative α op] - def ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) := + lemma ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) := by cc end diff --git a/tests/lean/run/eqn_compiler_perf_issue.lean b/tests/lean/run/eqn_compiler_perf_issue.lean new file mode 100644 index 0000000000..87c50af86a --- /dev/null +++ b/tests/lean/run/eqn_compiler_perf_issue.lean @@ -0,0 +1,15 @@ +def mk (a : nat) : nat → list nat +| 0 := [] +| (nat.succ n) := a :: mk n + +def Sum : list nat → nat → nat +| [] r := r +| (n::ns) r := Sum ns (r + n) + +def loop1 : nat → nat → nat +| s 0 := s +| s (nat.succ n) := loop1 (s + (Sum (mk 1 1000000) 0)) n + +def loop2 : nat → nat → nat +| 0 s := s +| (nat.succ n) s := loop2 n (s + (Sum (mk 1 1000000) 0))