feat(frontends/lean/definition_cmds): copy equational lemmas to top level definition

This commit is contained in:
Leonardo de Moura 2017-05-22 14:51:06 -07:00
parent 67190f565d
commit 729e798d6f
2 changed files with 194 additions and 157 deletions

View file

@ -284,6 +284,189 @@ declare_definition(parser & p, environment const & env, def_cmd_kind kind, buffe
return mk_pair(new_env, c_real_name);
}
static void throw_unexpected_error_at_copy_lemmas() {
throw exception("unexpected error, failed to generate equational lemmas in the front-end");
}
/* Given e of the form Pi (a_1 : A_1) ... (a_n : A_n), lhs = rhs,
return the pair (lhs, n) */
static pair<expr, unsigned> get_lemma_lhs(expr e) {
unsigned nparams = 0;
while (is_pi(e)) {
nparams++;
e = binding_body(e);
}
expr lhs, rhs;
if (!is_eq(e, lhs, rhs))
throw_unexpected_error_at_copy_lemmas();
return mk_pair(lhs, nparams);
}
/*
Given a lemma with parameters lp_names:
[lp_1 ... lp_n]
and the levels in the function application on the lemma left-hand-side lhs_fn_levels:
[u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)],
and levels fn_levels
[v_1 ... v_n]
Then, store
p([v_1 ... v_n])
in result.
*/
static void get_levels_for_instantiating_lemma(level_param_names const & lp_names,
levels const & lhs_fn_levels,
levels const & fn_levels,
buffer<level> & result) {
buffer<level> fn_levels_buffer;
buffer<level> lhs_fn_levels_buffer;
to_buffer(fn_levels, fn_levels_buffer);
to_buffer(lhs_fn_levels, lhs_fn_levels_buffer);
lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size());
for (name const & lp_name : lp_names) {
unsigned j = 0;
for (; j < lhs_fn_levels_buffer.size(); j++) {
if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas();
if (param_id(lhs_fn_levels_buffer[j]) == lp_name) {
result.push_back(fn_levels_buffer[j]);
break;
}
}
lean_assert(j < lhs_fn_levels_buffer.size());
}
}
/**
Given a lemma with the given arity (i.e., number of nested Pi-terms),
n = args.size() <= lhs_args.size(), and the first n
arguments in lhs_args.size() are a permutation p of
(var #0) ... (var #n-1)
Then, store in result p(args)
*/
static void get_args_for_instantiating_lemma(unsigned arity,
buffer<expr> const & lhs_args,
buffer<expr> const & args,
buffer<expr> & result) {
for (unsigned i = 0; i < args.size(); i++) {
if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity)
throw_unexpected_error_at_copy_lemmas();
result.push_back(args[arity - var_idx(lhs_args[i]) - 1]);
}
}
/**
Given declarations d_1, ..., d_n defined as
(fun (a_1 : A_1) ... (a_n : A_n), d_1._main a_1' ... a_n')
where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n.
Moreover, the parameters A_1 ... A_n are the same in all d_i's.
Then, copy the equation lemmas from d._main to d.
*/
static environment copy_equation_lemmas(environment const & env, buffer<name> const & d_names) {
type_context ctx(env, transparency_mode::All);
type_context::tmp_locals locals(ctx);
level_param_names lps;
levels ls;
buffer<expr> vals;
buffer<expr> new_vals;
for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) {
declaration const & d = env.get(d_names[d_idx]);
expr val;
if (d_idx == 0) {
lps = d.get_univ_params();
ls = param_names_to_levels(lps);
val = instantiate_value_univ_params(d, ls);
while (is_lambda(val)) {
expr local = locals.push_local_from_binding(val);
val = instantiate(binding_body(val), local);
}
} else {
val = instantiate_value_univ_params(d, ls);
for (expr const & local : locals.as_buffer()) {
lean_assert(is_lambda(val));
lean_assert(ctx.is_def_eq(ctx.infer(local), binding_domain(val)));
val = instantiate(binding_body(val), local);
}
}
buffer<expr> args;
expr const & fn = get_app_args(val, args);
if (!is_constant(fn) ||
!std::all_of(args.begin(), args.end(), is_local) ||
length(const_levels(fn)) != length(ls)) {
throw_unexpected_error_at_copy_lemmas();
}
vals.push_back(val);
/* We want to create new equations where we replace val with new_val in the equations
associated with fn. */
expr new_val = mk_app(mk_constant(d_names[d_idx], ls), locals.as_buffer());
new_vals.push_back(new_val);
}
/* Copy equations */
environment new_env = env;
for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) {
buffer<expr> args;
expr const & fn = get_app_args(vals[d_idx], args);
unsigned i = 1;
while (true) {
name eqn_name = mk_equation_name(const_name(fn), i);
optional<declaration> eqn_decl = env.find(eqn_name);
if (!eqn_decl) break;
unsigned num_eqn_levels = eqn_decl->get_num_univ_params();
if (num_eqn_levels != length(ls))
throw_unexpected_error_at_copy_lemmas();
expr lhs; unsigned num_eqn_params;
std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type());
buffer<expr> lhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size())
throw_unexpected_error_at_copy_lemmas();
/* Get levels for instantiating the lemma */
buffer<level> eqn_level_buffer;
get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(),
const_levels(lhs_fn),
const_levels(fn),
eqn_level_buffer);
levels eqn_levels = to_list(eqn_level_buffer);
/* Get arguments for instantiating the lemma */
buffer<expr> eqn_args;
get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args);
/* Convert type */
expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels);
for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type);
eqn_type = instantiate_rev(eqn_type, eqn_args);
expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) {
for (unsigned i = 0; i < vals.size(); i++) {
if (e == vals[i])
return some_expr(new_vals[i]);
}
return none_expr();
});
new_eqn_type = locals.mk_pi(new_eqn_type);
name new_eqn_name = mk_equation_name(d_names[d_idx], i);
expr new_eqn_value;
new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args);
new_eqn_value = locals.mk_lambda(new_eqn_value);
declaration new_decl = mk_theorem(new_eqn_name, lps, new_eqn_type, new_eqn_value);
new_env = module::add(new_env, check(new_env, new_decl, true));
if (is_rfl_lemma(env, eqn_name))
new_env = mark_rfl_lemma(new_env, new_eqn_name);
new_env = add_eqn_lemma(new_env, new_eqn_name);
i++;
}
}
return new_env;
}
/**
Given a declaration d defined as
(fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n')
where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n.
Then, copy the equation lemmas from d._main to d.
*/
static environment copy_equation_lemmas(environment const & env, name const & d_name) {
buffer<name> d_names;
d_names.push_back(d_name);
return copy_equation_lemmas(env, d_names);
}
environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modifiers const & modifiers, decl_attributes attrs) {
buffer<name> lp_names;
buffer<expr> fns, params;
@ -314,6 +497,7 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
}
unsigned num_defs = get_equations_result_size(val);
lean_assert(fns.size() == num_defs);
buffer<name> new_d_names;
/* Define functions */
for (unsigned i = 0; i < num_defs; i++) {
expr curr = get_equations_result(val, i);
@ -325,12 +509,11 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
std::tie(env, c_real_name) = declare_definition(p, env, kind, lp_names, c_name,
curr_type, some_expr(curr), {}, modifiers, attrs,
doc_string, header_pos);
new_d_names.push_back(c_real_name);
elab.set_env(env);
}
/* Add lemmas */
// TODO(Leo):
return elab.env();
return copy_equation_lemmas(elab.env(), new_d_names);
}
static expr_pair parse_definition(parser & p, buffer<name> & lp_names, buffer<expr> & params,
@ -465,152 +648,6 @@ static expr fix_rec_fn_macro_args(elaborator & elab, name const & fn, buffer<exp
return fix_rec_fn_macro_args_fn(params, fns)(val);
}
static void throw_unexpected_error_at_copy_lemmas() {
throw exception("unexpected error, failed to generate equational lemmas in the front-end");
}
/* Given e of the form Pi (a_1 : A_1) ... (a_n : A_n), lhs = rhs,
return the pair (lhs, n) */
static pair<expr, unsigned> get_lemma_lhs(expr e) {
unsigned nparams = 0;
while (is_pi(e)) {
nparams++;
e = binding_body(e);
}
expr lhs, rhs;
if (!is_eq(e, lhs, rhs))
throw_unexpected_error_at_copy_lemmas();
return mk_pair(lhs, nparams);
}
/*
Given a lemma with parameters lp_names:
[lp_1 ... lp_n]
and the levels in the function application on the lemma left-hand-side lhs_fn_levels:
[u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)],
and levels fn_levels
[v_1 ... v_n]
Then, store
p([v_1 ... v_n])
in result.
*/
static void get_levels_for_instantiating_lemma(level_param_names const & lp_names,
levels const & lhs_fn_levels,
levels const & fn_levels,
buffer<level> & result) {
buffer<level> fn_levels_buffer;
buffer<level> lhs_fn_levels_buffer;
to_buffer(fn_levels, fn_levels_buffer);
to_buffer(lhs_fn_levels, lhs_fn_levels_buffer);
lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size());
for (name const & lp_name : lp_names) {
unsigned j = 0;
for (; j < lhs_fn_levels_buffer.size(); j++) {
if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas();
if (param_id(lhs_fn_levels_buffer[j]) == lp_name) {
result.push_back(fn_levels_buffer[j]);
break;
}
}
lean_assert(j < lhs_fn_levels_buffer.size());
}
}
/**
Given a lemma with the given arity (i.e., number of nested Pi-terms),
n = args.size() <= lhs_args.size(), and the first n
arguments in lhs_args.size() are a permutation p of
(var #0) ... (var #n-1)
Then, store in result p(args)
*/
static void get_args_for_instantiating_lemma(unsigned arity,
buffer<expr> const & lhs_args,
buffer<expr> const & args,
buffer<expr> & result) {
for (unsigned i = 0; i < args.size(); i++) {
if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity)
throw_unexpected_error_at_copy_lemmas();
result.push_back(args[arity - var_idx(lhs_args[i]) - 1]);
}
}
/**
Given a declaration d defined as
(fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n')
where a_1' ... a_n' is a permutation of a_1 ... a_n.
Then, copy the equation lemmas from d._main to d.
*/
static environment copy_equation_lemmas(environment const & env, name const & d_name) {
declaration const & d = env.get(d_name);
levels lps = param_names_to_levels(d.get_univ_params());
expr val = instantiate_value_univ_params(d, lps);
type_context ctx(env, transparency_mode::All);
type_context::tmp_locals locals(ctx);
while (is_lambda(val)) {
expr local = locals.push_local_from_binding(val);
val = instantiate(binding_body(val), local);
}
buffer<expr> args;
expr const & fn = get_app_args(val, args);
if (!is_constant(fn) ||
!std::all_of(args.begin(), args.end(), is_local) ||
length(const_levels(fn)) != length(lps)) {
throw_unexpected_error_at_copy_lemmas();
}
/* We want to create new equations where we replace val with new_val in the equations
associated with fn. */
expr new_val = mk_app(mk_constant(d_name, lps), locals.as_buffer());
/* Copy equations */
environment new_env = env;
unsigned i = 1;
while (true) {
name eqn_name = mk_equation_name(const_name(fn), i);
optional<declaration> eqn_decl = env.find(eqn_name);
if (!eqn_decl) break;
unsigned num_eqn_levels = eqn_decl->get_num_univ_params();
if (num_eqn_levels != length(lps))
throw_unexpected_error_at_copy_lemmas();
expr lhs; unsigned num_eqn_params;
std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type());
buffer<expr> lhs_args;
expr const & lhs_fn = get_app_args(lhs, lhs_args);
if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size())
throw_unexpected_error_at_copy_lemmas();
/* Get levels for instantiating the lemma */
buffer<level> eqn_level_buffer;
get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(),
const_levels(lhs_fn),
const_levels(fn),
eqn_level_buffer);
levels eqn_levels = to_list(eqn_level_buffer);
/* Get arguments for instantiating the lemma */
buffer<expr> eqn_args;
get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args);
/* Convert type */
expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels);
for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type);
eqn_type = instantiate_rev(eqn_type, eqn_args);
expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) {
if (e == val)
return some_expr(new_val);
else
return none_expr();
});
new_eqn_type = locals.mk_pi(new_eqn_type);
name new_eqn_name = mk_equation_name(d_name, i);
expr new_eqn_value;
new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args);
new_eqn_value = locals.mk_lambda(new_eqn_value);
declaration new_decl = mk_theorem(new_eqn_name, d.get_univ_params(), new_eqn_type, new_eqn_value);
new_env = module::add(new_env, check(new_env, new_decl, true));
if (is_rfl_lemma(env, eqn_name))
new_env = mark_rfl_lemma(new_env, new_eqn_name);
new_env = add_eqn_lemma(new_env, new_eqn_name);
i++;
}
return new_env;
}
static expr inline_new_defs(environment const & old_env, environment const & new_env, name const & n, expr const & e) {
return replace(e, [=] (expr const & e, unsigned) -> optional<expr> {
if (is_sorry(e)) {

View file

@ -18,10 +18,10 @@ with odd : nat → bool
#eval even 4
#eval odd 3
#eval odd 4
#check even._main.equations._eqn_1
#check even._main.equations._eqn_2
#check odd._main.equations._eqn_1
#check odd._main.equations._eqn_2
#check even.equations._eqn_1
#check even.equations._eqn_2
#check odd.equations._eqn_1
#check odd.equations._eqn_2
mutual def f, g {α β : Type u} (p : α × β)
with f : Π n : nat, vector (α × β) n
@ -31,7 +31,7 @@ with g : Π n : nat, α → vector β n
| 0 a := vector.nil
| (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2)
#check @f._main.equations._eqn_1
#check @f._main.equations._eqn_2
#check @g._main.equations._eqn_1
#check @g._main.equations._eqn_2
#check @f.equations._eqn_1
#check @f.equations._eqn_2
#check @g.equations._eqn_1
#check @g.equations._eqn_2