fix(library/equations_compiler): performance issues at structural_rec module and equational lemma generator

There were two performance bottlenecks in the recursive equation
compiler. Both bottlenecks were due to conversion checking.

1- We allow patterns such as (x+1) in the left-hand-side of a
   recursive equation. This is kind of pattern has to be reduced
   since it is not a constructor. Moreover, when we are trying to
   compile using structural recursion, we need to find an element
   that is structurally smaller in recursive applications.
   Again, we need to use reduction since the pattern may be (x+2),
   and in the recursive application we have (x+1). Now, consider
   the following equation

         f (x+1) (y+1) := f complex_term y

   It will first check whether complex_term is structurally smaller
   than (x+1), and the compiler will timeout trying to reduce
   complex_term.

   This commit adds the following workaround. The structural
   recursion module from now on will only unfold reducible constants
   and constants marked as patterns. This is not a complete
   solution. It will timeout in the following equation:

         f (x+1) (y+1) := f (x+1000000000000) y

   For this one, we need to add a whnf "fuel" option to type_context

2- Equational lemma generation was producing lemmas that are too
   expensive to check. Suppose we the following two definitions

       | f x 0     := 1
       | f x (y+1) := f complex_term y

    and

       | g 0     y    := 1
       | g (x+1) y    := g x complex_term

    Before this commit, we would generate the following proofs for
    the second equation of each definition:

         eq.refl (f complex_term y)
         eq.refl (g x complex_term)

    This proof triggers the following definitionally equality test:

             f x     (y+1)  =?= f complex_term y
             g (x+1) y      =?= g x complex_term

    Since, we have f/g on both sides, the type checker will try
    first to unify the arguments, and may timeout trying to solve

               x  =?= complex_term
               y  =?= complex_term

    since it may take a long time to reduce `complex_term`.

    We workaround this problem by creating a slightly different
    proof.

          eq.refl (unfold_of(f x (y+1)))
          eq.refl (unfold_of(g (x+1) y))

    where unfold_of(t) is the result of applying one delta reduction
    step.
This commit is contained in:
Leonardo de Moura 2017-02-15 20:45:57 -08:00
parent d4d5ac115c
commit 7ebf16ca26
9 changed files with 49 additions and 15 deletions

View file

@ -382,6 +382,10 @@ namespace nat
| a zero := a
| a (succ b) := succ (add a b)
/- We mark the following definitions as pattern to make sure they can be used in recursive equations,
and reduced by the equation compiler. -/
attribute [pattern] nat.add nat.add._main
def of_pos_num : pos_num → nat
| pos_num.one := succ zero
| (pos_num.bit0 a) := let r := of_pos_num a in nat.add r r

View file

@ -11,6 +11,7 @@ Author: Leonardo de Moura
#include "library/constants.h"
#include "library/locals.h"
#include "library/util.h"
#include "library/pattern_attribute.h"
#include "library/app_builder.h"
#include "library/replace_visitor_with_tc.h"
#include "library/equations_compiler/equations.h"
@ -76,6 +77,15 @@ struct structural_rec_fn {
return is_constant(e) && inductive::is_intro_rule(m_ctx.env(), const_name(e));
}
expr whnf(expr const & e) {
/* We only unfold patterns and reducible definitions */
return m_ctx.whnf_transparency_pred(e, [&](name const & n) {
return
has_pattern_attribute(m_ctx.env(), n) ||
is_reducible(m_ctx.env(), n);
});
}
/** \brief Return true iff \c s is structurally smaller than \c t OR equal to \c t */
bool is_le(expr const & s, expr const & t) {
return m_ctx.is_def_eq(s, t) || is_lt(s, t);
@ -83,8 +93,8 @@ struct structural_rec_fn {
/** Return true iff \c s is structurally smaller than \c t */
bool is_lt(expr s, expr t) {
s = m_ctx.whnf(s);
t = m_ctx.whnf(t);
s = whnf(s);
t = whnf(t);
if (is_app(s)) {
expr const & s_fn = get_app_fn(s);
if (!is_constructor(s_fn))

View file

@ -480,7 +480,7 @@ static optional<expr_pair> prove_eq_rec_invertible(type_context & ctx, expr cons
return optional<expr_pair>(mk_pair(h_a, pr));
}
static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs) {
static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs, bool root) {
buffer<expr> ite_args;
expr new_lhs = whnf_ite(ctx, lhs);
if (is_ite_eq(new_lhs, ite_args)) {
@ -492,7 +492,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
expr A = ite_args[2];
level A_lvl = get_level(ctx, A);
expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else});
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs);
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false);
return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2});
} else if (quick_is_def_eq_when_values(ctx, c_lhs, c_rhs)) {
expr H = mk_eq_refl(ctx, c_lhs);
@ -500,7 +500,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
expr A = ite_args[2];
level A_lvl = get_level(ctx, A);
expr H1 = mk_app(mk_constant(get_if_pos_name(), {A_lvl}), {c, ite_args[1], H, A, lhs_then, ite_args[4]});
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs);
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_then, rhs, false);
expr eq_trans = mk_constant(get_eq_trans_name(), {A_lvl});
return mk_app(eq_trans, {A, lhs, lhs_then, rhs, H1, H2});
} else if (compare_values(c_lhs, c_rhs) == l_false) {
@ -509,7 +509,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
expr A = ite_args[2];
level A_lvl = get_level(ctx, A);
expr H1 = mk_app(mk_constant(get_if_neg_name(), {A_lvl}), {c, ite_args[1], *H, A, ite_args[3], lhs_else});
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs);
expr H2 = prove_eqn_lemma_core(ctx, Hs, lhs_else, rhs, false);
return mk_app(mk_constant(get_eq_trans_name(), {A_lvl}), {A, lhs, lhs_else, rhs, H1, H2});
}
}
@ -518,12 +518,18 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
if (optional<expr_pair> p = prove_eq_rec_invertible(ctx, new_lhs)) {
expr new_new_lhs = p->first;
expr H1 = p->second;
expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs);
expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs, false);
return mk_eq_trans(ctx, H1, H2);
}
if (ctx.is_def_eq(lhs, rhs)) {
return mk_eq_refl(ctx, rhs);
expr lhs_body = lhs;
if (root) {
if (auto b = unfold_term(ctx.env(), lhs))
lhs_body = *b;
}
if (ctx.is_def_eq(lhs_body, rhs)) {
return mk_eq_refl(ctx, lhs_body);
}
throw exception("equation compiler failed to prove equation lemma (workaround: "
@ -531,7 +537,7 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
}
static expr prove_eqn_lemma(type_context & ctx, buffer<expr> const & Hs, expr const & lhs, expr const & rhs) {
expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs);
expr body = prove_eqn_lemma_core(ctx, Hs, lhs, rhs, true);
return ctx.mk_lambda(Hs, body);
}

View file

@ -800,7 +800,7 @@ expr type_context::whnf_head_pred(expr const & e, std::function<bool(expr const
}
}
expr type_context::whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred) {
expr type_context::whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred) { // NOLINT
flet<std::function<bool(name const &)> const *>set_trans_pred(m_transparency_pred, &pred); // NOLINT
return whnf(e);
}

View file

@ -379,7 +379,7 @@ public:
the given predicate to decide whether a constant should be unfolded or not.
Remark: cache is not used. */
expr whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred);
expr whnf_transparency_pred(expr const & e, std::function<bool(name const &)> const & pred); // NOLINT
/** \brief Put \c e in whnf, it is a Pi, then return whnf, otherwise return e */
expr try_to_pi(expr const & e);

View file

@ -1,2 +0,0 @@
{"message":"file invalidated","response":"ok","seq_num":0}
{"record":{"full-id":"f","source":{"column":10,"file":"f","line":1},"type":"Type"},"response":"ok","seq_num":1}

View file

@ -1,2 +1,3 @@
@[pattern]
protected def nat.add : :=
nat.add._main

View file

@ -31,6 +31,6 @@ section
variables [is_associative α op]
variables [is_commutative α op]
def ex (a b c d e : α) (f : ααα) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) :=
lemma ex (a b c d e : α) (f : ααα) : op b a = op d d → op b c = op e e → f (op a (op b c)) (op (op a b) c) = f (op (op c d) d) (op e (op a e)) :=
by cc
end

View file

@ -0,0 +1,15 @@
def mk (a : nat) : nat → list nat
| 0 := []
| (nat.succ n) := a :: mk n
def Sum : list nat → nat → nat
| [] r := r
| (n::ns) r := Sum ns (r + n)
def loop1 : nat → nat → nat
| s 0 := s
| s (nat.succ n) := loop1 (s + (Sum (mk 1 1000000) 0)) n
def loop2 : nat → nat → nat
| 0 s := s
| (nat.succ n) s := loop2 n (s + (Sum (mk 1 1000000) 0))