From 7ebf16ca26da82b3d0e458dbcf32cda374ec785d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 15 Feb 2017 20:45:57 -0800 Subject: [PATCH] fix(library/equations_compiler): performance issues at structural_rec module and equational lemma generator There were two performance bottlenecks in the recursive equation compiler. Both bottlenecks were due to conversion checking. 1- We allow patterns such as (x+1) in the left-hand-side of a recursive equation. This is kind of pattern has to be reduced since it is not a constructor. Moreover, when we are trying to compile using structural recursion, we need to find an element that is structurally smaller in recursive applications. Again, we need to use reduction since the pattern may be (x+2), and in the recursive application we have (x+1). Now, consider the following equation f (x+1) (y+1) := f complex_term y It will first check whether complex_term is structurally smaller than (x+1), and the compiler will timeout trying to reduce complex_term. This commit adds the following workaround. The structural recursion module from now on will only unfold reducible constants and constants marked as patterns. This is not a complete solution. It will timeout in the following equation: f (x+1) (y+1) := f (x+1000000000000) y For this one, we need to add a whnf "fuel" option to type_context 2- Equational lemma generation was producing lemmas that are too expensive to check. Suppose we the following two definitions | f x 0 := 1 | f x (y+1) := f complex_term y and | g 0 y := 1 | g (x+1) y := g x complex_term Before this commit, we would generate the following proofs for the second equation of each definition: eq.refl (f complex_term y) eq.refl (g x complex_term) This proof triggers the following definitionally equality test: f x (y+1) =?= f complex_term y g (x+1) y =?= g x complex_term Since, we have f/g on both sides, the type checker will try first to unify the arguments, and may timeout trying to solve x =?= complex_term y =?= complex_term since it may take a long time to reduce `complex_term`. We workaround this problem by creating a slightly different proof. eq.refl (unfold_of(f x (y+1))) eq.refl (unfold_of(g (x+1) y)) where unfold_of(t) is the result of applying one delta reduction step. --- library/init/core.lean | 4 ++++ .../equations_compiler/structural_rec.cpp | 14 ++++++++++-- src/library/equations_compiler/util.cpp | 22 ++++++++++++------- src/library/type_context.cpp | 2 +- src/library/type_context.h | 2 +- src/tests/shell/shell_test.produced.out | 2 -- tests/lean/671.lean.expected.out | 1 + tests/lean/run/cc_ac3.lean | 2 +- tests/lean/run/eqn_compiler_perf_issue.lean | 15 +++++++++++++ 9 files changed, 49 insertions(+), 15 deletions(-) create mode 100644 tests/lean/run/eqn_compiler_perf_issue.lean diff --git a/library/init/core.lean b/library/init/core.lean index 94e7d9064e..387d9ceb4b 100644 --- a/library/init/core.lean +++ b/library/init/core.lean @@ -382,6 +382,10 @@ namespace nat | a zero := a | a (succ b) := succ (add a b) + /- We mark the following definitions as pattern to make sure they can be used in recursive equations, + and reduced by the equation compiler. -/ + attribute [pattern] nat.add nat.add._main + def of_pos_num : pos_num → nat | pos_num.one := succ zero | (pos_num.bit0 a) := let r := of_pos_num a in nat.add r r diff --git a/src/library/equations_compiler/structural_rec.cpp b/src/library/equations_compiler/structural_rec.cpp index a881e9834f..bcf8dbf76e 100644 --- a/src/library/equations_compiler/structural_rec.cpp +++ b/src/library/equations_compiler/structural_rec.cpp @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "library/constants.h" #include "library/locals.h" #include "library/util.h" +#include "library/pattern_attribute.h" #include "library/app_builder.h" #include "library/replace_visitor_with_tc.h" #include "library/equations_compiler/equations.h" @@ -76,6 +77,15 @@ struct structural_rec_fn { return is_constant(e) && inductive::is_intro_rule(m_ctx.env(), const_name(e)); } + expr whnf(expr const & e) { + /* We only unfold patterns and reducible definitions */ + return m_ctx.whnf_transparency_pred(e, [&](name const & n) { + return + has_pattern_attribute(m_ctx.env(), n) || + is_reducible(m_ctx.env(), n); + }); + } + /** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */ bool is_le(expr const & s, expr const & t) { return m_ctx.is_def_eq(s, t) || is_lt(s, t); @@ -83,8 +93,8 @@ struct structural_rec_fn { /** Return true iff \c s is structurally smaller than \c t */ bool is_lt(expr s, expr t) { - s = m_ctx.whnf(s); - t = m_ctx.whnf(t); + s = whnf(s); + t = whnf(t); if (is_app(s)) { expr const & s_fn = get_app_fn(s); if (!is_constructor(s_fn)) diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 8f569f7477..fe1b3ae62b 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -480,7 +480,7 @@ static optional prove_eq_rec_invertible(type_context & ctx, expr cons return optional(mk_pair(h_a, pr)); } -static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs) { +static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs, bool root) { buffer ite_args; expr new_lhs = whnf_ite(ctx, lhs); if (is_ite_eq(new_lhs, ite_args)) { @@ -492,7 +492,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false); return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2}); } else if (quick_is_def_eq_when_values(ctx, c_lhs, c_rhs)) { expr H = mk_eq_refl(ctx, c_lhs); @@ -500,7 +500,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_pos_name(), {A_lvl}), {c, ite_args[1], H, A, lhs_then, ite_args[4]}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs, false); expr eq_trans = mk_constant(get_eq_trans_name(), {A_lvl}); return mk_app(eq_trans, {A, lhs, lhs_then, rhs, H1, H2}); } else if (compare_values(c_lhs, c_rhs) == l_false) { @@ -509,7 +509,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex expr A = ite_args[2]; level A_lvl = get_level(ctx, A); expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else}); - expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false); return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2}); } } @@ -518,12 +518,18 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex if (optional p = prove_eq_rec_invertible(ctx, new_lhs)) { expr new_new_lhs = p->first; expr H1 = p->second; - expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs); + expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs, false); return mk_eq_trans(ctx, H1, H2); } - if (ctx.is_def_eq(lhs, rhs)) { - return mk_eq_refl(ctx, rhs); + expr lhs_body = lhs; + if (root) { + if (auto b = unfold_term(ctx.env(), lhs)) + lhs_body = *b; + } + + if (ctx.is_def_eq(lhs_body, rhs)) { + return mk_eq_refl(ctx, lhs_body); } throw exception("equation compiler failed to prove equation lemma (workaround: " @@ -531,7 +537,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex } static expr prove_eqn_lemma(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs) { - expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs); + expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs, true); return ctx.mk_lambda(Hs, body); } diff --git a/src/library/type_context.cpp b/src/library/type_context.cpp index 51016e3bcf..5343164345 100644 --- a/src/library/type_context.cpp +++ b/src/library/type_context.cpp @@ -800,7 +800,7 @@ expr type_context::whnf_head_pred(expr const & e, std::function const & pred) { +expr type_context::whnf_transparency_pred(expr const & e, std::function const & pred) { // NOLINT flet const *>set_trans_pred(m_transparency_pred, &pred); // NOLINT return whnf(e); } diff --git a/src/library/type_context.h b/src/library/type_context.h index b97fc8a509..5bc39d4659 100644 --- a/src/library/type_context.h +++ b/src/library/type_context.h @@ -379,7 +379,7 @@ public: the given predicate to decide whether a constant should be unfolded or not. Remark: cache is not used. */ - expr whnf_transparency_pred(expr const & e, std::function const & pred); + expr whnf_transparency_pred(expr const & e, std::function const & pred); // NOLINT /** \brief Put \c e in whnf, it is a Pi, then return whnf, otherwise return e */ expr try_to_pi(expr const & e); diff --git a/src/tests/shell/shell_test.produced.out b/src/tests/shell/shell_test.produced.out index 7cd72b2568..e69de29bb2 100644 --- a/src/tests/shell/shell_test.produced.out +++ b/src/tests/shell/shell_test.produced.out @@ -1,2 +0,0 @@ -{"message":"file invalidated","response":"ok","seq_num":0} -{"record":{"full-id":"f","source":{"column":10,"file":"f","line":1},"type":"Type"},"response":"ok","seq_num":1} diff --git a/tests/lean/671.lean.expected.out b/tests/lean/671.lean.expected.out index 72952ec500..d7f2428f20 100644 --- a/tests/lean/671.lean.expected.out +++ b/tests/lean/671.lean.expected.out @@ -1,2 +1,3 @@ +@[pattern] protected def nat.add : ℕ → ℕ → ℕ := nat.add._main diff --git a/tests/lean/run/cc_ac3.lean b/tests/lean/run/cc_ac3.lean index 92ca32f0e4..de3bd032fd 100644 --- a/tests/lean/run/cc_ac3.lean +++ b/tests/lean/run/cc_ac3.lean @@ -31,6 +31,6 @@ section variables [is_associative α op] variables [is_commutative α op] - def ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) := + lemma ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) := by cc end diff --git a/tests/lean/run/eqn_compiler_perf_issue.lean b/tests/lean/run/eqn_compiler_perf_issue.lean new file mode 100644 index 0000000000..87c50af86a --- /dev/null +++ b/tests/lean/run/eqn_compiler_perf_issue.lean @@ -0,0 +1,15 @@ +def mk (a : nat) : nat → list nat +| 0 := [] +| (nat.succ n) := a :: mk n + +def Sum : list nat → nat → nat +| [] r := r +| (n::ns) r := Sum ns (r + n) + +def loop1 : nat → nat → nat +| s 0 := s +| s (nat.succ n) := loop1 (s + (Sum (mk 1 1000000) 0)) n + +def loop2 : nat → nat → nat +| 0 s := s +| (nat.succ n) s := loop2 n (s + (Sum (mk 1 1000000) 0))