fix(library/equations_compiler): performance issues at structural_rec module and equational lemma generator
There were two performance bottlenecks in the recursive equation
compiler. Both bottlenecks were due to conversion checking.
1- We allow patterns such as (x+1) in the left-hand-side of a
recursive equation. This is kind of pattern has to be reduced
since it is not a constructor. Moreover, when we are trying to
compile using structural recursion, we need to find an element
that is structurally smaller in recursive applications.
Again, we need to use reduction since the pattern may be (x+2),
and in the recursive application we have (x+1). Now, consider
the following equation
f (x+1) (y+1) := f complex_term y
It will first check whether complex_term is structurally smaller
than (x+1), and the compiler will timeout trying to reduce
complex_term.
This commit adds the following workaround. The structural
recursion module from now on will only unfold reducible constants
and constants marked as patterns. This is not a complete
solution. It will timeout in the following equation:
f (x+1) (y+1) := f (x+1000000000000) y
For this one, we need to add a whnf "fuel" option to type_context
2- Equational lemma generation was producing lemmas that are too
expensive to check. Suppose we the following two definitions
| f x 0 := 1
| f x (y+1) := f complex_term y
and
| g 0 y := 1
| g (x+1) y := g x complex_term
Before this commit, we would generate the following proofs for
the second equation of each definition:
eq.refl (f complex_term y)
eq.refl (g x complex_term)
This proof triggers the following definitionally equality test:
f x (y+1) =?= f complex_term y
g (x+1) y =?= g x complex_term
Since, we have f/g on both sides, the type checker will try
first to unify the arguments, and may timeout trying to solve
x =?= complex_term
y =?= complex_term
since it may take a long time to reduce `complex_term`.
We workaround this problem by creating a slightly different
proof.
eq.refl (unfold_of(f x (y+1)))
eq.refl (unfold_of(g (x+1) y))
where unfold_of(t) is the result of applying one delta reduction
step.
This commit is contained in:
parent
d4d5ac115c
commit
7ebf16ca26
9 changed files with 49 additions and 15 deletions
|
|
@ -382,6 +382,10 @@ namespace nat
|
|||
| a zero := a
|
||||
| a (succ b) := succ (add a b)
|
||||
|
||||
/- We mark the following definitions as pattern to make sure they can be used in recursive equations,
|
||||
and reduced by the equation compiler. -/
|
||||
attribute [pattern] nat.add nat.add._main
|
||||
|
||||
def of_pos_num : pos_num → nat
|
||||
| pos_num.one := succ zero
|
||||
| (pos_num.bit0 a) := let r := of_pos_num a in nat.add r r
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ Author: Leonardo de Moura
|
|||
#include "library/constants.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/util.h"
|
||||
#include "library/pattern_attribute.h"
|
||||
#include "library/app_builder.h"
|
||||
#include "library/replace_visitor_with_tc.h"
|
||||
#include "library/equations_compiler/equations.h"
|
||||
|
|
@ -76,6 +77,15 @@ struct structural_rec_fn {
|
|||
return is_constant(e) && inductive::is_intro_rule(m_ctx.env(), const_name(e));
|
||||
}
|
||||
|
||||
expr whnf(expr const & e) {
|
||||
/* We only unfold patterns and reducible definitions */
|
||||
return m_ctx.whnf_transparency_pred(e, [&](name const & n) {
|
||||
return
|
||||
has_pattern_attribute(m_ctx.env(), n) ||
|
||||
is_reducible(m_ctx.env(), n);
|
||||
});
|
||||
}
|
||||
|
||||
/** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */
|
||||
bool is_le(expr const & s, expr const & t) {
|
||||
return m_ctx.is_def_eq(s, t) || is_lt(s, t);
|
||||
|
|
@ -83,8 +93,8 @@ struct structural_rec_fn {
|
|||
|
||||
/** Return true iff \c s is structurally smaller than \c t */
|
||||
bool is_lt(expr s, expr t) {
|
||||
s = m_ctx.whnf(s);
|
||||
t = m_ctx.whnf(t);
|
||||
s = whnf(s);
|
||||
t = whnf(t);
|
||||
if (is_app(s)) {
|
||||
expr const & s_fn = get_app_fn(s);
|
||||
if (!is_constructor(s_fn))
|
||||
|
|
|
|||
|
|
@ -480,7 +480,7 @@ static optional<expr_pair> prove_eq_rec_invertible(type_context & ctx, expr cons
|
|||
return optional<expr_pair>(mk_pair(h_a, pr));
|
||||
}
|
||||
|
||||
static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs) {
|
||||
static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs, bool root) {
|
||||
buffer<expr> ite_args;
|
||||
expr new_lhs = whnf_ite(ctx, lhs);
|
||||
if (is_ite_eq(new_lhs, ite_args)) {
|
||||
|
|
@ -492,7 +492,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
expr A = ite_args[2];
|
||||
level A_lvl = get_level(ctx, A);
|
||||
expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else});
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs);
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false);
|
||||
return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2});
|
||||
} else if (quick_is_def_eq_when_values(ctx, c_lhs, c_rhs)) {
|
||||
expr H = mk_eq_refl(ctx, c_lhs);
|
||||
|
|
@ -500,7 +500,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
expr A = ite_args[2];
|
||||
level A_lvl = get_level(ctx, A);
|
||||
expr H1 = mk_app(mk_constant(get_if_pos_name(), {A_lvl}), {c, ite_args[1], H, A, lhs_then, ite_args[4]});
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs);
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs, false);
|
||||
expr eq_trans = mk_constant(get_eq_trans_name(), {A_lvl});
|
||||
return mk_app(eq_trans, {A, lhs, lhs_then, rhs, H1, H2});
|
||||
} else if (compare_values(c_lhs, c_rhs) == l_false) {
|
||||
|
|
@ -509,7 +509,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
expr A = ite_args[2];
|
||||
level A_lvl = get_level(ctx, A);
|
||||
expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else});
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs);
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false);
|
||||
return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2});
|
||||
}
|
||||
}
|
||||
|
|
@ -518,12 +518,18 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
if (optional<expr_pair> p = prove_eq_rec_invertible(ctx, new_lhs)) {
|
||||
expr new_new_lhs = p->first;
|
||||
expr H1 = p->second;
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs);
|
||||
expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs, false);
|
||||
return mk_eq_trans(ctx, H1, H2);
|
||||
}
|
||||
|
||||
if (ctx.is_def_eq(lhs, rhs)) {
|
||||
return mk_eq_refl(ctx, rhs);
|
||||
expr lhs_body = lhs;
|
||||
if (root) {
|
||||
if (auto b = unfold_term(ctx.env(), lhs))
|
||||
lhs_body = *b;
|
||||
}
|
||||
|
||||
if (ctx.is_def_eq(lhs_body, rhs)) {
|
||||
return mk_eq_refl(ctx, lhs_body);
|
||||
}
|
||||
|
||||
throw exception("equation compiler failed to prove equation lemma (workaround: "
|
||||
|
|
@ -531,7 +537,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
}
|
||||
|
||||
static expr prove_eqn_lemma(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs) {
|
||||
expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs);
|
||||
expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs, true);
|
||||
return ctx.mk_lambda(Hs, body);
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -800,7 +800,7 @@ expr type_context::whnf_head_pred(expr const & e, std::function<bool(expr const
|
|||
}
|
||||
}
|
||||
|
||||
expr type_context::whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred) {
|
||||
expr type_context::whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred) { // NOLINT
|
||||
flet<std::function<bool(name const &)> const *>set_trans_pred(m_transparency_pred, &pred); // NOLINT
|
||||
return whnf(e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -379,7 +379,7 @@ public:
|
|||
the given predicate to decide whether a constant should be unfolded or not.
|
||||
|
||||
Remark: cache is not used. */
|
||||
expr whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred);
|
||||
expr whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred); // NOLINT
|
||||
|
||||
/** \brief Put \c e in whnf, it is a Pi, then return whnf, otherwise return e */
|
||||
expr try_to_pi(expr const & e);
|
||||
|
|
|
|||
|
|
@ -1,2 +0,0 @@
|
|||
{"message":"file invalidated","response":"ok","seq_num":0}
|
||||
{"record":{"full-id":"f","source":{"column":10,"file":"f","line":1},"type":"Type"},"response":"ok","seq_num":1}
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
@[pattern]
|
||||
protected def nat.add : ℕ → ℕ → ℕ :=
|
||||
nat.add._main
|
||||
|
|
|
|||
|
|
@ -31,6 +31,6 @@ section
|
|||
variables [is_associative α op]
|
||||
variables [is_commutative α op]
|
||||
|
||||
def ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) :=
|
||||
lemma ex (a b c d e : α) (f : α → α → α) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) :=
|
||||
by cc
|
||||
end
|
||||
|
|
|
|||
15
tests/lean/run/eqn_compiler_perf_issue.lean
Normal file
15
tests/lean/run/eqn_compiler_perf_issue.lean
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
def mk (a : nat) : nat → list nat
|
||||
| 0 := []
|
||||
| (nat.succ n) := a :: mk n
|
||||
|
||||
def Sum : list nat → nat → nat
|
||||
| [] r := r
|
||||
| (n::ns) r := Sum ns (r + n)
|
||||
|
||||
def loop1 : nat → nat → nat
|
||||
| s 0 := s
|
||||
| s (nat.succ n) := loop1 (s + (Sum (mk 1 1000000) 0)) n
|
||||
|
||||
def loop2 : nat → nat → nat
|
||||
| 0 s := s
|
||||
| (nat.succ n) s := loop2 n (s + (Sum (mk 1 1000000) 0))
|
||||
Loading…
Add table
Reference in a new issue