perf(library/equations_compiler): performance problem for definitions that produce many equational lemmas
The new test and comment at src/library/equations_compiler/util.cpp explains the issue.
This commit is contained in:
parent
dd9d8e9552
commit
64f575a2d5
11 changed files with 170 additions and 6 deletions
|
|
@ -80,6 +80,23 @@ reserve infixl `; `:1
|
|||
|
||||
universes u v w
|
||||
|
||||
/-
|
||||
The kernel definitional equality test (t =?= s) has special support for id_delta applications.
|
||||
It implements the following rules
|
||||
|
||||
1) (id_delta t) =?= t
|
||||
2) t =?= (id_delta t)
|
||||
3) (id_delta t) =?= s IF (unfold_of t) =?= s
|
||||
4) t =?= id_delta s IF t =?= (unfold_of s)
|
||||
|
||||
This is mechanism for controlling the delta reduction (aka unfolding) used in the kernel.
|
||||
|
||||
We use id_delta applications to address performance problems when type checking
|
||||
lemmas generated by the equation compiler.
|
||||
-/
|
||||
@[inline] def id_delta {α : Sort u} (a : α) : α :=
|
||||
a
|
||||
|
||||
/-- Gadget for optional parameter support. -/
|
||||
@[reducible] def opt_param (α : Sort u) (default : α) : Sort u :=
|
||||
α
|
||||
|
|
|
|||
|
|
@ -578,6 +578,8 @@ void type_checker::cache_failure(expr const & t, expr const & s) {
|
|||
m_failure_cache.insert(mk_pair(s, t));
|
||||
}
|
||||
|
||||
static name * g_id_delta = nullptr;
|
||||
|
||||
/** \brief Perform one lazy delta-reduction step.
|
||||
Return
|
||||
- l_true if t_n and s_n are definitionally equal.
|
||||
|
|
@ -590,6 +592,20 @@ auto type_checker::lazy_delta_reduction_step(expr & t_n, expr & s_n) -> reductio
|
|||
auto d_s = is_delta(s_n);
|
||||
if (!d_t && !d_s) {
|
||||
return reduction_status::DefUnknown;
|
||||
} else if (d_t && d_t->get_name() == *g_id_delta) {
|
||||
t_n = whnf_core(*unfold_definition(t_n));
|
||||
if (t_n == s_n)
|
||||
return reduction_status::DefEqual; /* id_delta t =?= t */
|
||||
if (auto u = unfold_definition(t_n)) /* id_delta t =?= s ===> unfold(t) =?= s */
|
||||
t_n = whnf_core(*u);
|
||||
return reduction_status::Continue;
|
||||
} else if (d_s && d_s->get_name() == *g_id_delta) {
|
||||
s_n = whnf_core(*unfold_definition(s_n));
|
||||
if (t_n == s_n)
|
||||
return reduction_status::DefEqual; /* t =?= id_delta t */
|
||||
if (auto u = unfold_definition(s_n)) /* t =?= id_delta s ===> t =?= unfold(s) */
|
||||
s_n = whnf_core(*u);
|
||||
return reduction_status::Continue;
|
||||
} else if (d_t && !d_s) {
|
||||
t_n = whnf_core(*unfold_definition(t_n));
|
||||
} else if (!d_t && d_s) {
|
||||
|
|
@ -790,10 +806,12 @@ certified_declaration certify_unchecked::certify_or_check(environment const & en
|
|||
}
|
||||
|
||||
void initialize_type_checker() {
|
||||
g_id_delta = new name("id_delta");
|
||||
g_dont_care = new expr(Const("dontcare"));
|
||||
}
|
||||
|
||||
void finalize_type_checker() {
|
||||
delete g_dont_care;
|
||||
delete g_id_delta;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1136,6 +1136,12 @@ expr mk_id_rhs(type_context & ctx, expr const & h) {
|
|||
return mk_app(mk_constant(get_id_rhs_name(), {lvl}), type, h);
|
||||
}
|
||||
|
||||
expr mk_id_delta(type_context & ctx, expr const & h) {
|
||||
expr type = ctx.infer(h);
|
||||
level lvl = get_level(ctx, type);
|
||||
return mk_app(mk_constant(get_id_delta_name(), {lvl}), type, h);
|
||||
}
|
||||
|
||||
static bool is_eq_trans(expr const & h, expr & h1, expr & h2) {
|
||||
if (is_app_of(h, get_eq_trans_name(), 6)) {
|
||||
h1 = app_arg(app_fn(h));
|
||||
|
|
|
|||
|
|
@ -186,6 +186,9 @@ expr mk_id_locked(type_context & ctx, expr const & type, expr const & h);
|
|||
/* (id_rhs h) */
|
||||
expr mk_id_rhs(type_context & ctx, expr const & h);
|
||||
|
||||
/* (id_delta h) */
|
||||
expr mk_id_delta(type_context & ctx, expr const & h);
|
||||
|
||||
expr mk_iff_mp(type_context & ctx, expr const & h1, expr const & h2);
|
||||
expr mk_iff_mpr(type_context & ctx, expr const & h1, expr const & h2);
|
||||
expr mk_eq_mp(type_context & ctx, expr const & h1, expr const & h2);
|
||||
|
|
|
|||
|
|
@ -129,6 +129,7 @@ name const * g_heq_of_eq = nullptr;
|
|||
name const * g_hole_command = nullptr;
|
||||
name const * g_id_locked = nullptr;
|
||||
name const * g_id_rhs = nullptr;
|
||||
name const * g_id_delta = nullptr;
|
||||
name const * g_if_neg = nullptr;
|
||||
name const * g_if_pos = nullptr;
|
||||
name const * g_iff = nullptr;
|
||||
|
|
@ -510,6 +511,7 @@ void initialize_constants() {
|
|||
g_hole_command = new name{"hole_command"};
|
||||
g_id_locked = new name{"id_locked"};
|
||||
g_id_rhs = new name{"id_rhs"};
|
||||
g_id_delta = new name{"id_delta"};
|
||||
g_if_neg = new name{"if_neg"};
|
||||
g_if_pos = new name{"if_pos"};
|
||||
g_iff = new name{"iff"};
|
||||
|
|
@ -892,6 +894,7 @@ void finalize_constants() {
|
|||
delete g_hole_command;
|
||||
delete g_id_locked;
|
||||
delete g_id_rhs;
|
||||
delete g_id_delta;
|
||||
delete g_if_neg;
|
||||
delete g_if_pos;
|
||||
delete g_iff;
|
||||
|
|
@ -1273,6 +1276,7 @@ name const & get_heq_of_eq_name() { return *g_heq_of_eq; }
|
|||
name const & get_hole_command_name() { return *g_hole_command; }
|
||||
name const & get_id_locked_name() { return *g_id_locked; }
|
||||
name const & get_id_rhs_name() { return *g_id_rhs; }
|
||||
name const & get_id_delta_name() { return *g_id_delta; }
|
||||
name const & get_if_neg_name() { return *g_if_neg; }
|
||||
name const & get_if_pos_name() { return *g_if_pos; }
|
||||
name const & get_iff_name() { return *g_iff; }
|
||||
|
|
|
|||
|
|
@ -131,6 +131,7 @@ name const & get_heq_of_eq_name();
|
|||
name const & get_hole_command_name();
|
||||
name const & get_id_locked_name();
|
||||
name const & get_id_rhs_name();
|
||||
name const & get_id_delta_name();
|
||||
name const & get_if_neg_name();
|
||||
name const & get_if_pos_name();
|
||||
name const & get_iff_name();
|
||||
|
|
|
|||
|
|
@ -124,6 +124,7 @@ heq_of_eq
|
|||
hole_command
|
||||
id_locked
|
||||
id_rhs
|
||||
id_delta
|
||||
if_neg
|
||||
if_pos
|
||||
iff
|
||||
|
|
|
|||
|
|
@ -619,15 +619,107 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer<expr> const & Hs, ex
|
|||
return mk_eq_trans(ctx, H1, H2);
|
||||
}
|
||||
|
||||
expr lhs_body = lhs;
|
||||
/* Check if lhs =?= rhs, and create a reflexivity proof if this is the case.
|
||||
|
||||
We have to be careful to avoid performace problems when checking this proof in the kernel.
|
||||
We considered different options.
|
||||
|
||||
Option 1) (refl rhs) or (refl lhs)
|
||||
It will perform poorly in one of the following examples:
|
||||
|
||||
|
||||
| f x 0 := 1
|
||||
| f x (y+1) := f complex_term y
|
||||
|
||||
| g 0 y := 1
|
||||
| g (x+1) y := g x complex_term
|
||||
|
||||
If we use (refl rhs), we will generate the proofs
|
||||
|
||||
eq.refl (f complex_term y)
|
||||
eq.refl (g x complex_term)
|
||||
|
||||
These proofs trigger the following definitionally equality tests:
|
||||
|
||||
f x (y+1) =?= f complex_term y
|
||||
g (x+1) y =?= g x complex_term
|
||||
|
||||
Since, we have f/g on both sides, the type checker will try
|
||||
first to unify the arguments, and may timeout trying to solve
|
||||
|
||||
x =?= complex_term
|
||||
y =?= complex_term
|
||||
|
||||
since it may take a long time to reduce `complex_term`.
|
||||
|
||||
We have a similar problem if we use (refl lhs)
|
||||
|
||||
Commit 7ebf16ca26da82b3d0e458dbcf32cda374ec785d tried to address this issue
|
||||
by using Option 2).
|
||||
|
||||
Option 2) (refl (unfold_of lhs))
|
||||
This option fixes the performance problem above, but it is still
|
||||
inefficient for definitions that produce many equations.
|
||||
For example, the following definition produces 121 equations.
|
||||
|
||||
```
|
||||
universes u
|
||||
|
||||
inductive node (α : Type u)
|
||||
| leaf : node
|
||||
| red_node : node → α → node → node
|
||||
| black_node : node → α → node → node
|
||||
|
||||
namespace node
|
||||
variable {α : Type u}
|
||||
|
||||
def balance : node α → α → node α → node α
|
||||
| (red_node (red_node a x b) y c) k d := red_node (black_node a x b) y (black_node c k d)
|
||||
| (red_node a x (red_node b y c)) k d := red_node (black_node a x b) y (black_node c k d)
|
||||
| l k r := black_node l k r
|
||||
|
||||
end node
|
||||
```
|
||||
|
||||
In each equation we will have a big (unfold_of lhs) term. This increases the size of .olean
|
||||
files, and introduces an overhead in the mk_aux_lemma procedure.
|
||||
|
||||
Option 3) (refl (id_delta lhs))
|
||||
We are currently using this option.
|
||||
This approach relies on the fact that the kernel type checker has special support for id_delta.
|
||||
The kernel implements the following is_def_eq rules for id_delta.
|
||||
|
||||
1) (id_delta t) =?= t
|
||||
2) t =?= (id_delta t)
|
||||
3) (id_delta t) =?= s IF (unfold_of t) =?= s
|
||||
4) t =?= id_delta s IF t =?= (unfold_of s)
|
||||
|
||||
We can view it as a "lazy" version of Option 2. The .olean file contains `id_delta t`
|
||||
instead of the result of delta-reducing t. Similarly, no overhead is introduced to mk_aux_lemma
|
||||
since the proof is quite small in this case.
|
||||
|
||||
Finally, note that this optimization is only use when root = true.
|
||||
That is, it is not use if the equation compiler used if-then-else compilation trick for
|
||||
pattern matching scalar values, and/or the pack/unpack auxiliary definitions introduced
|
||||
for nested inductive datatype declarations.
|
||||
The problem described in Option 1 does not happen in this case since we unfold the left-hand-side
|
||||
while building the proof. However, the performance problem described in Option 2 may happen.
|
||||
*/
|
||||
if (root) {
|
||||
/* Remark: type_context currently does not have support for id_delta.
|
||||
So, we unfold lhs before invoking ctx.is_def_eq. */
|
||||
expr lhs_body = lhs;
|
||||
if (auto b = unfold_term(ctx.env(), lhs))
|
||||
lhs_body = *b;
|
||||
}
|
||||
|
||||
if (ctx.is_def_eq(lhs_body, rhs)) {
|
||||
// tout() << "DONE\n";
|
||||
return mk_eq_refl(ctx, lhs_body);
|
||||
if (ctx.is_def_eq(lhs_body, rhs)) {
|
||||
if (ctx.env().find(get_id_delta_name()))
|
||||
return mk_eq_refl(ctx, mk_id_delta(ctx, lhs));
|
||||
else
|
||||
return mk_eq_refl(ctx, lhs);
|
||||
}
|
||||
} else {
|
||||
if (ctx.is_def_eq(lhs, rhs))
|
||||
return mk_eq_refl(ctx, lhs);
|
||||
}
|
||||
|
||||
throw exception("equation compiler failed to prove equation lemma (workaround: "
|
||||
|
|
|
|||
18
tests/lean/eqn_proof.lean
Normal file
18
tests/lean/eqn_proof.lean
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
universes u
|
||||
|
||||
inductive node (α : Type u)
|
||||
| leaf : node
|
||||
| red_node : node → α → node → node
|
||||
| black_node : node → α → node → node
|
||||
|
||||
namespace node
|
||||
variable {α : Type u}
|
||||
|
||||
def balance : node α → α → node α → node α
|
||||
| (red_node (red_node a x b) y c) k d := red_node (black_node a x b) y (black_node c k d)
|
||||
| (red_node a x (red_node b y c)) k d := red_node (black_node a x b) y (black_node c k d)
|
||||
| l k r := black_node l k r
|
||||
|
||||
#print balance._main.equations._eqn_1
|
||||
|
||||
end node
|
||||
3
tests/lean/eqn_proof.lean.expected.out
Normal file
3
tests/lean/eqn_proof.lean.expected.out
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
@[_refl_lemma]
|
||||
theorem node.balance._main.equations._eqn_1 : ∀ {α : Type u} (k : α) (r : node α), balance._main (leaf α) k r = black_node (leaf α) k r :=
|
||||
λ {α : Type u} (k : α) (r : node α), eq.refl (id_delta (balance._main (leaf α) k r))
|
||||
|
|
@ -129,6 +129,7 @@ run_cmd script_check_id `heq_of_eq
|
|||
run_cmd script_check_id `hole_command
|
||||
run_cmd script_check_id `id_locked
|
||||
run_cmd script_check_id `id_rhs
|
||||
run_cmd script_check_id `id_delta
|
||||
run_cmd script_check_id `if_neg
|
||||
run_cmd script_check_id `if_pos
|
||||
run_cmd script_check_id `iff
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue