fix(library/equations_compiler): structural recursion and partial equations
The equational compiler was failing to generate equational lemmas for equations such as: def f : nat → nat → nat | (x+1) (y+1) := f (x+10) y | _ _ := 1 It would fail when trying to prove the following equation: forall x, f 0 x = 1 using a "refl" proof. This equation does not hold definitionally. It is not blocked by the internal pattern matching based on the cases_on recursor, but it is blocked by the outer most brec_on used to implement structural recursion. The solution is to "complete" the set of equations. So, the structural_rec module will replace the equation above with def f : nat → nat → nat | (x+1) (y+1) := f (x+10) y | _ 0 := 1 | _ (y+1) := 1 and then (as before) def f : Pi (x y : nat), below y → nat | (x+1) (y+1) F := F^.fst^.fst (x+10) | _ 0 F := 1 | _ (y+1) F := 1
This commit is contained in:
parent
3428e9bd59
commit
769220fa4e
5 changed files with 153 additions and 29 deletions
|
|
@ -873,27 +873,11 @@ struct elim_match_fn {
|
|||
[&](expr const & c, buffer<expr> const & new_c_vars) {
|
||||
expr var = pattern;
|
||||
/* We are replacing `var` with `c` */
|
||||
buffer<expr> vars; to_buffer(eqn.m_vars, vars);
|
||||
buffer<expr> new_vars;
|
||||
buffer<expr> from;
|
||||
buffer<expr> to;
|
||||
buffer<expr> new_vars;
|
||||
for (expr const & curr : eqn.m_vars) {
|
||||
if (curr == var) {
|
||||
from.push_back(var);
|
||||
to.push_back(c);
|
||||
new_vars.append(new_c_vars);
|
||||
} else {
|
||||
expr curr_type = ctx.infer(curr);
|
||||
expr new_curr_type = replace_locals(curr_type, from, to);
|
||||
if (curr_type == new_curr_type) {
|
||||
new_vars.push_back(curr);
|
||||
} else {
|
||||
expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type);
|
||||
from.push_back(curr);
|
||||
to.push_back(new_curr);
|
||||
new_vars.push_back(new_curr);
|
||||
}
|
||||
}
|
||||
}
|
||||
update_telescope(ctx, vars, var, c, new_c_vars, new_vars, from, to);
|
||||
equation new_eqn = eqn;
|
||||
new_eqn.m_lctx = ctx.lctx();
|
||||
new_eqn.m_vars = to_list(new_vars);
|
||||
|
|
|
|||
|
|
@ -503,26 +503,100 @@ struct structural_rec_fn {
|
|||
}
|
||||
};
|
||||
|
||||
/* Return true if we need to complete equations by expanding the recursive argument.
|
||||
|
||||
For example, suppose we have, where the recursive argument is the second
|
||||
|
||||
def f : nat → nat → nat
|
||||
| (x+1) (y+1) := f (x+10) y
|
||||
| _ _ := 1
|
||||
|
||||
this function returns true because
|
||||
1) We need to perform case analysis in the first argument (first equation),
|
||||
(flag has_case_analysis_before in the followin procedure); and
|
||||
2) W have an equation (second) where the recursive argument is a variable
|
||||
(flag incomplete).
|
||||
*/
|
||||
bool must_complete_rec_arg(type_context & ctx, unpack_eqns const & ues) {
|
||||
if (m_arg_pos == 0) return false;
|
||||
buffer<expr> const & eqns = ues.get_eqns_of(0);
|
||||
bool has_case_analysis_before = false;
|
||||
bool incomplete = false;
|
||||
for (expr const & eqn : eqns) {
|
||||
unpack_eqn ue(ctx, eqn);
|
||||
buffer<expr> lhs_args;
|
||||
get_app_args(ue.lhs(), lhs_args);
|
||||
|
||||
if (!has_case_analysis_before) {
|
||||
for (unsigned i = 0; i < m_arg_pos; i++) {
|
||||
if (!is_local(lhs_args[i]) && !is_inaccessible(lhs_args[i])) {
|
||||
has_case_analysis_before = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (is_local(lhs_args[m_arg_pos]))
|
||||
incomplete = true;
|
||||
|
||||
if (has_case_analysis_before && incomplete)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void update_eqs(type_context & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) {
|
||||
/* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary".
|
||||
The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */
|
||||
expr C = mk_local(mk_fresh_name(), "_C", m_motive_type, binder_info());
|
||||
buffer<expr> & eqns = ues.get_eqns_of(0);
|
||||
for (expr & eqn : eqns) {
|
||||
buffer<expr> new_eqns;
|
||||
bool complete = must_complete_rec_arg(ctx, ues);
|
||||
for (expr const & eqn : eqns) {
|
||||
unpack_eqn ue(ctx, eqn);
|
||||
expr lhs = ue.lhs();
|
||||
expr rhs = ue.rhs();
|
||||
buffer<expr> lhs_args;
|
||||
get_app_args(lhs, lhs_args);
|
||||
expr new_lhs = mk_app(new_fn, lhs_args);
|
||||
expr type = ctx.whnf(ctx.infer(new_lhs));
|
||||
lean_assert(is_pi(type));
|
||||
expr F = ue.add_var(binding_name(type), binding_domain(type));
|
||||
new_lhs = mk_app(new_lhs, F);
|
||||
ue.lhs() = new_lhs;
|
||||
ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs);
|
||||
eqn = ue.repack();
|
||||
if (complete && is_local(lhs_args[m_arg_pos])) {
|
||||
expr var = lhs_args[m_arg_pos];
|
||||
for_each_compatible_constructor(ctx, var,
|
||||
[&](expr const & c, buffer<expr> const & new_c_vars) {
|
||||
buffer<expr> new_vars;
|
||||
buffer<expr> from;
|
||||
buffer<expr> to;
|
||||
update_telescope(ctx, ue.get_vars(), var, c, new_c_vars,
|
||||
new_vars, from, to);
|
||||
buffer<expr> new_lhs_args(lhs_args);
|
||||
new_lhs_args[m_arg_pos] = c;
|
||||
for (unsigned i = m_arg_pos + 1; i < new_lhs_args.size(); i++)
|
||||
new_lhs_args[i] = replace_locals(new_lhs_args[i], from, to);
|
||||
expr new_lhs = mk_app(new_fn, new_lhs_args);
|
||||
expr type = ctx.whnf(ctx.infer(new_lhs));
|
||||
lean_assert(is_pi(type));
|
||||
type_context::tmp_locals extra(ctx);
|
||||
expr F = extra.push_local(binding_name(type), binding_domain(type));
|
||||
new_vars.push_back(F);
|
||||
new_lhs = mk_app(new_lhs, F);
|
||||
/* The lhs was a variable, so we don't need to update the rhs using elim_rec_apps_fn.
|
||||
Reason: the rhs should not contain recursive equations.
|
||||
But, we need to update the locals. */
|
||||
expr new_rhs = replace_locals(ue.rhs(), from, to);
|
||||
expr new_eqn = copy_tag(ue.get_nested_src(), mk_equation(new_lhs, new_rhs));
|
||||
new_eqns.push_back(copy_tag(eqn, ctx.mk_lambda(new_vars, new_eqn)));
|
||||
});
|
||||
} else {
|
||||
expr new_lhs = mk_app(new_fn, lhs_args);
|
||||
expr type = ctx.whnf(ctx.infer(new_lhs));
|
||||
lean_assert(is_pi(type));
|
||||
expr F = ue.add_var(binding_name(type), binding_domain(type));
|
||||
new_lhs = mk_app(new_lhs, F);
|
||||
ue.lhs() = new_lhs;
|
||||
ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs);
|
||||
new_eqns.push_back(ue.repack());
|
||||
}
|
||||
}
|
||||
eqns = new_eqns;
|
||||
}
|
||||
|
||||
optional<expr> elim_recursion(expr const & e) {
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ Author: Leonardo de Moura
|
|||
#include "library/trace.h"
|
||||
#include "library/app_builder.h"
|
||||
#include "library/private.h"
|
||||
#include "library/locals.h"
|
||||
#include "library/idx_metavar.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/annotation.h"
|
||||
|
|
@ -676,6 +677,37 @@ void for_each_compatible_constructor(type_context & ctx, expr const & var,
|
|||
}
|
||||
}
|
||||
|
||||
/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i,
|
||||
and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n},
|
||||
Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)},
|
||||
where T(x_j) updates the type of x_j (j > i) by replacing x_i with t.
|
||||
|
||||
\remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k}
|
||||
*/
|
||||
void update_telescope(type_context & ctx, buffer<expr> const & vars, expr const & var,
|
||||
expr const & t, buffer<expr> const & t_vars, buffer<expr> & new_vars,
|
||||
buffer<expr> & from, buffer<expr> & to) {
|
||||
/* We are replacing `var` with `c` */
|
||||
for (expr const & curr : vars) {
|
||||
if (curr == var) {
|
||||
from.push_back(var);
|
||||
to.push_back(t);
|
||||
new_vars.append(t_vars);
|
||||
} else {
|
||||
expr curr_type = ctx.infer(curr);
|
||||
expr new_curr_type = replace_locals(curr_type, from, to);
|
||||
if (curr_type == new_curr_type) {
|
||||
new_vars.push_back(curr);
|
||||
} else {
|
||||
expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type);
|
||||
from.push_back(curr);
|
||||
to.push_back(new_curr);
|
||||
new_vars.push_back(new_curr);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void initialize_eqn_compiler_util() {
|
||||
register_trace_class("eqn_compiler");
|
||||
register_trace_class(name{"debug", "eqn_compiler"});
|
||||
|
|
|
|||
|
|
@ -62,9 +62,10 @@ class unpack_eqn {
|
|||
public:
|
||||
unpack_eqn(type_context & ctx, expr const & eqn);
|
||||
expr add_var(name const & n, expr const & type);
|
||||
buffer<expr> const & get_vars() { return m_vars; }
|
||||
buffer<expr> & get_vars() { return m_vars; }
|
||||
expr & lhs() { return m_lhs; }
|
||||
expr & rhs() { return m_rhs; }
|
||||
expr const & get_nested_src() const { return m_nested_src; }
|
||||
expr repack();
|
||||
};
|
||||
|
||||
|
|
@ -110,6 +111,22 @@ bool is_nat_int_char_string_value(type_context & ctx, expr const & e);
|
|||
void for_each_compatible_constructor(type_context & ctx, expr const & var,
|
||||
std::function<void(expr const &, buffer<expr> &)> const & fn);
|
||||
|
||||
/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i,
|
||||
and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n},
|
||||
Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)},
|
||||
where T(x_j) updates the type of x_j (j > i) by replacing x_i with t.
|
||||
|
||||
\remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k}
|
||||
|
||||
The output parameters from/to contain the replacement
|
||||
[x_i, ... x_n] => [t, T(x_{i+1}), ..., T(x_n)]
|
||||
|
||||
The replacement will suppress entries x_j => T(x_j) if T(x_j) is equal to x_j.
|
||||
*/
|
||||
void update_telescope(type_context & ctx, buffer<expr> const & vars, expr const & var,
|
||||
expr const & t, buffer<expr> const & t_vars, buffer<expr> & new_vars,
|
||||
buffer<expr> & from, buffer<expr> & to);
|
||||
|
||||
void initialize_eqn_compiler_util();
|
||||
void finalize_eqn_compiler_util();
|
||||
}
|
||||
|
|
|
|||
17
tests/lean/run/complete_rec_var.lean
Normal file
17
tests/lean/run/complete_rec_var.lean
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
def f : nat → nat → nat
|
||||
| (x+1) (y+1) := f (x+10) y
|
||||
| _ _ := 1
|
||||
|
||||
vm_eval f 1 1000
|
||||
|
||||
example (x y) : f (x+1) (y+1) = f (x+10) y :=
|
||||
rfl
|
||||
|
||||
example (y) : f 0 (y+1) = 1 :=
|
||||
rfl
|
||||
|
||||
example (x) : f (x+1) 0 = 1 :=
|
||||
rfl
|
||||
|
||||
example : f 0 0 = 1 :=
|
||||
rfl
|
||||
Loading…
Add table
Reference in a new issue