fix(library/equations_compiler): structural recursion and partial equations

The equational compiler was failing to generate equational lemmas for
equations such as:

   def f : nat → nat → nat
   | (x+1) (y+1) := f (x+10) y
   | _     _     := 1

It would fail when trying to prove the following equation:

   forall x, f 0 x = 1

using a "refl" proof. This equation does not hold definitionally.
It is not blocked by the internal pattern matching based on the
cases_on recursor, but it is blocked by the outer most brec_on
used to implement structural recursion. The solution is to
"complete" the set of equations. So, the structural_rec
module will replace the equation above with

   def f : nat → nat → nat
   | (x+1) (y+1) := f (x+10) y
   | _     0     := 1
   | _     (y+1) := 1

and then (as before)

   def f : Pi (x y : nat), below y → nat
   | (x+1) (y+1) F := F^.fst^.fst (x+10)
   | _     0     F := 1
   | _     (y+1) F := 1
This commit is contained in:
Leonardo de Moura 2017-02-16 14:27:48 -08:00
parent 3428e9bd59
commit 769220fa4e
5 changed files with 153 additions and 29 deletions

View file

@ -873,27 +873,11 @@ struct elim_match_fn {
[&](expr const & c, buffer<expr> const & new_c_vars) {
expr var = pattern;
/* We are replacing `var` with `c` */
buffer<expr> vars; to_buffer(eqn.m_vars, vars);
buffer<expr> new_vars;
buffer<expr> from;
buffer<expr> to;
buffer<expr> new_vars;
for (expr const & curr : eqn.m_vars) {
if (curr == var) {
from.push_back(var);
to.push_back(c);
new_vars.append(new_c_vars);
} else {
expr curr_type = ctx.infer(curr);
expr new_curr_type = replace_locals(curr_type, from, to);
if (curr_type == new_curr_type) {
new_vars.push_back(curr);
} else {
expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type);
from.push_back(curr);
to.push_back(new_curr);
new_vars.push_back(new_curr);
}
}
}
update_telescope(ctx, vars, var, c, new_c_vars, new_vars, from, to);
equation new_eqn = eqn;
new_eqn.m_lctx = ctx.lctx();
new_eqn.m_vars = to_list(new_vars);

View file

@ -503,26 +503,100 @@ struct structural_rec_fn {
}
};
/* Return true if we need to complete equations by expanding the recursive argument.
For example, suppose we have, where the recursive argument is the second
def f : nat nat nat
| (x+1) (y+1) := f (x+10) y
| _ _ := 1
this function returns true because
1) We need to perform case analysis in the first argument (first equation),
(flag has_case_analysis_before in the followin procedure); and
2) W have an equation (second) where the recursive argument is a variable
(flag incomplete).
*/
bool must_complete_rec_arg(type_context & ctx, unpack_eqns const & ues) {
if (m_arg_pos == 0) return false;
buffer<expr> const & eqns = ues.get_eqns_of(0);
bool has_case_analysis_before = false;
bool incomplete = false;
for (expr const & eqn : eqns) {
unpack_eqn ue(ctx, eqn);
buffer<expr> lhs_args;
get_app_args(ue.lhs(), lhs_args);
if (!has_case_analysis_before) {
for (unsigned i = 0; i < m_arg_pos; i++) {
if (!is_local(lhs_args[i]) && !is_inaccessible(lhs_args[i])) {
has_case_analysis_before = true;
break;
}
}
}
if (is_local(lhs_args[m_arg_pos]))
incomplete = true;
if (has_case_analysis_before && incomplete)
return true;
}
return false;
}
void update_eqs(type_context & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) {
/* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary".
The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */
expr C = mk_local(mk_fresh_name(), "_C", m_motive_type, binder_info());
buffer<expr> & eqns = ues.get_eqns_of(0);
for (expr & eqn : eqns) {
buffer<expr> new_eqns;
bool complete = must_complete_rec_arg(ctx, ues);
for (expr const & eqn : eqns) {
unpack_eqn ue(ctx, eqn);
expr lhs = ue.lhs();
expr rhs = ue.rhs();
buffer<expr> lhs_args;
get_app_args(lhs, lhs_args);
expr new_lhs = mk_app(new_fn, lhs_args);
expr type = ctx.whnf(ctx.infer(new_lhs));
lean_assert(is_pi(type));
expr F = ue.add_var(binding_name(type), binding_domain(type));
new_lhs = mk_app(new_lhs, F);
ue.lhs() = new_lhs;
ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs);
eqn = ue.repack();
if (complete && is_local(lhs_args[m_arg_pos])) {
expr var = lhs_args[m_arg_pos];
for_each_compatible_constructor(ctx, var,
[&](expr const & c, buffer<expr> const & new_c_vars) {
buffer<expr> new_vars;
buffer<expr> from;
buffer<expr> to;
update_telescope(ctx, ue.get_vars(), var, c, new_c_vars,
new_vars, from, to);
buffer<expr> new_lhs_args(lhs_args);
new_lhs_args[m_arg_pos] = c;
for (unsigned i = m_arg_pos + 1; i < new_lhs_args.size(); i++)
new_lhs_args[i] = replace_locals(new_lhs_args[i], from, to);
expr new_lhs = mk_app(new_fn, new_lhs_args);
expr type = ctx.whnf(ctx.infer(new_lhs));
lean_assert(is_pi(type));
type_context::tmp_locals extra(ctx);
expr F = extra.push_local(binding_name(type), binding_domain(type));
new_vars.push_back(F);
new_lhs = mk_app(new_lhs, F);
/* The lhs was a variable, so we don't need to update the rhs using elim_rec_apps_fn.
Reason: the rhs should not contain recursive equations.
But, we need to update the locals. */
expr new_rhs = replace_locals(ue.rhs(), from, to);
expr new_eqn = copy_tag(ue.get_nested_src(), mk_equation(new_lhs, new_rhs));
new_eqns.push_back(copy_tag(eqn, ctx.mk_lambda(new_vars, new_eqn)));
});
} else {
expr new_lhs = mk_app(new_fn, lhs_args);
expr type = ctx.whnf(ctx.infer(new_lhs));
lean_assert(is_pi(type));
expr F = ue.add_var(binding_name(type), binding_domain(type));
new_lhs = mk_app(new_lhs, F);
ue.lhs() = new_lhs;
ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs);
new_eqns.push_back(ue.repack());
}
}
eqns = new_eqns;
}
optional<expr> elim_recursion(expr const & e) {

View file

@ -13,6 +13,7 @@ Author: Leonardo de Moura
#include "library/trace.h"
#include "library/app_builder.h"
#include "library/private.h"
#include "library/locals.h"
#include "library/idx_metavar.h"
#include "library/constants.h"
#include "library/annotation.h"
@ -676,6 +677,37 @@ void for_each_compatible_constructor(type_context & ctx, expr const & var,
}
}
/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i,
and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n},
Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)},
where T(x_j) updates the type of x_j (j > i) by replacing x_i with t.
\remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k}
*/
void update_telescope(type_context & ctx, buffer<expr> const & vars, expr const & var,
expr const & t, buffer<expr> const & t_vars, buffer<expr> & new_vars,
buffer<expr> & from, buffer<expr> & to) {
/* We are replacing `var` with `c` */
for (expr const & curr : vars) {
if (curr == var) {
from.push_back(var);
to.push_back(t);
new_vars.append(t_vars);
} else {
expr curr_type = ctx.infer(curr);
expr new_curr_type = replace_locals(curr_type, from, to);
if (curr_type == new_curr_type) {
new_vars.push_back(curr);
} else {
expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type);
from.push_back(curr);
to.push_back(new_curr);
new_vars.push_back(new_curr);
}
}
}
}
void initialize_eqn_compiler_util() {
register_trace_class("eqn_compiler");
register_trace_class(name{"debug", "eqn_compiler"});

View file

@ -62,9 +62,10 @@ class unpack_eqn {
public:
unpack_eqn(type_context & ctx, expr const & eqn);
expr add_var(name const & n, expr const & type);
buffer<expr> const & get_vars() { return m_vars; }
buffer<expr> & get_vars() { return m_vars; }
expr & lhs() { return m_lhs; }
expr & rhs() { return m_rhs; }
expr const & get_nested_src() const { return m_nested_src; }
expr repack();
};
@ -110,6 +111,22 @@ bool is_nat_int_char_string_value(type_context & ctx, expr const & e);
void for_each_compatible_constructor(type_context & ctx, expr const & var,
std::function<void(expr const &, buffer<expr> &)> const & fn);
/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i,
and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n},
Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)},
where T(x_j) updates the type of x_j (j > i) by replacing x_i with t.
\remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k}
The output parameters from/to contain the replacement
[x_i, ... x_n] => [t, T(x_{i+1}), ..., T(x_n)]
The replacement will suppress entries x_j => T(x_j) if T(x_j) is equal to x_j.
*/
void update_telescope(type_context & ctx, buffer<expr> const & vars, expr const & var,
expr const & t, buffer<expr> const & t_vars, buffer<expr> & new_vars,
buffer<expr> & from, buffer<expr> & to);
void initialize_eqn_compiler_util();
void finalize_eqn_compiler_util();
}

View file

@ -0,0 +1,17 @@
def f : nat → nat → nat
| (x+1) (y+1) := f (x+10) y
| _ _ := 1
vm_eval f 1 1000
example (x y) : f (x+1) (y+1) = f (x+10) y :=
rfl
example (y) : f 0 (y+1) = 1 :=
rfl
example (x) : f (x+1) 0 = 1 :=
rfl
example : f 0 0 = 1 :=
rfl