feat(frontends/lean/definition_cmds): copy equational lemmas to top level definition
This commit is contained in:
parent
67190f565d
commit
729e798d6f
2 changed files with 194 additions and 157 deletions
|
|
@ -284,6 +284,189 @@ declare_definition(parser & p, environment const & env, def_cmd_kind kind, buffe
|
|||
return mk_pair(new_env, c_real_name);
|
||||
}
|
||||
|
||||
static void throw_unexpected_error_at_copy_lemmas() {
|
||||
throw exception("unexpected error, failed to generate equational lemmas in the front-end");
|
||||
}
|
||||
|
||||
/* Given e of the form Pi (a_1 : A_1) ... (a_n : A_n), lhs = rhs,
|
||||
return the pair (lhs, n) */
|
||||
static pair<expr, unsigned> get_lemma_lhs(expr e) {
|
||||
unsigned nparams = 0;
|
||||
while (is_pi(e)) {
|
||||
nparams++;
|
||||
e = binding_body(e);
|
||||
}
|
||||
expr lhs, rhs;
|
||||
if (!is_eq(e, lhs, rhs))
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
return mk_pair(lhs, nparams);
|
||||
}
|
||||
|
||||
/*
|
||||
Given a lemma with parameters lp_names:
|
||||
[lp_1 ... lp_n]
|
||||
and the levels in the function application on the lemma left-hand-side lhs_fn_levels:
|
||||
[u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)],
|
||||
and levels fn_levels
|
||||
[v_1 ... v_n]
|
||||
Then, store
|
||||
p([v_1 ... v_n])
|
||||
in result.
|
||||
*/
|
||||
static void get_levels_for_instantiating_lemma(level_param_names const & lp_names,
|
||||
levels const & lhs_fn_levels,
|
||||
levels const & fn_levels,
|
||||
buffer<level> & result) {
|
||||
buffer<level> fn_levels_buffer;
|
||||
buffer<level> lhs_fn_levels_buffer;
|
||||
to_buffer(fn_levels, fn_levels_buffer);
|
||||
to_buffer(lhs_fn_levels, lhs_fn_levels_buffer);
|
||||
lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size());
|
||||
for (name const & lp_name : lp_names) {
|
||||
unsigned j = 0;
|
||||
for (; j < lhs_fn_levels_buffer.size(); j++) {
|
||||
if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas();
|
||||
if (param_id(lhs_fn_levels_buffer[j]) == lp_name) {
|
||||
result.push_back(fn_levels_buffer[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
lean_assert(j < lhs_fn_levels_buffer.size());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Given a lemma with the given arity (i.e., number of nested Pi-terms),
|
||||
n = args.size() <= lhs_args.size(), and the first n
|
||||
arguments in lhs_args.size() are a permutation p of
|
||||
(var #0) ... (var #n-1)
|
||||
Then, store in result p(args)
|
||||
*/
|
||||
static void get_args_for_instantiating_lemma(unsigned arity,
|
||||
buffer<expr> const & lhs_args,
|
||||
buffer<expr> const & args,
|
||||
buffer<expr> & result) {
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity)
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
result.push_back(args[arity - var_idx(lhs_args[i]) - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Given declarations d_1, ..., d_n defined as
|
||||
(fun (a_1 : A_1) ... (a_n : A_n), d_1._main a_1' ... a_n')
|
||||
where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n.
|
||||
Moreover, the parameters A_1 ... A_n are the same in all d_i's.
|
||||
Then, copy the equation lemmas from d._main to d.
|
||||
*/
|
||||
static environment copy_equation_lemmas(environment const & env, buffer<name> const & d_names) {
|
||||
type_context ctx(env, transparency_mode::All);
|
||||
type_context::tmp_locals locals(ctx);
|
||||
level_param_names lps;
|
||||
levels ls;
|
||||
buffer<expr> vals;
|
||||
buffer<expr> new_vals;
|
||||
for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) {
|
||||
declaration const & d = env.get(d_names[d_idx]);
|
||||
expr val;
|
||||
if (d_idx == 0) {
|
||||
lps = d.get_univ_params();
|
||||
ls = param_names_to_levels(lps);
|
||||
val = instantiate_value_univ_params(d, ls);
|
||||
while (is_lambda(val)) {
|
||||
expr local = locals.push_local_from_binding(val);
|
||||
val = instantiate(binding_body(val), local);
|
||||
}
|
||||
} else {
|
||||
val = instantiate_value_univ_params(d, ls);
|
||||
for (expr const & local : locals.as_buffer()) {
|
||||
lean_assert(is_lambda(val));
|
||||
lean_assert(ctx.is_def_eq(ctx.infer(local), binding_domain(val)));
|
||||
val = instantiate(binding_body(val), local);
|
||||
}
|
||||
}
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(val, args);
|
||||
if (!is_constant(fn) ||
|
||||
!std::all_of(args.begin(), args.end(), is_local) ||
|
||||
length(const_levels(fn)) != length(ls)) {
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
}
|
||||
vals.push_back(val);
|
||||
/* We want to create new equations where we replace val with new_val in the equations
|
||||
associated with fn. */
|
||||
expr new_val = mk_app(mk_constant(d_names[d_idx], ls), locals.as_buffer());
|
||||
new_vals.push_back(new_val);
|
||||
}
|
||||
/* Copy equations */
|
||||
environment new_env = env;
|
||||
for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(vals[d_idx], args);
|
||||
unsigned i = 1;
|
||||
while (true) {
|
||||
name eqn_name = mk_equation_name(const_name(fn), i);
|
||||
optional<declaration> eqn_decl = env.find(eqn_name);
|
||||
if (!eqn_decl) break;
|
||||
unsigned num_eqn_levels = eqn_decl->get_num_univ_params();
|
||||
if (num_eqn_levels != length(ls))
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
expr lhs; unsigned num_eqn_params;
|
||||
std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type());
|
||||
buffer<expr> lhs_args;
|
||||
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
||||
if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size())
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
/* Get levels for instantiating the lemma */
|
||||
buffer<level> eqn_level_buffer;
|
||||
get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(),
|
||||
const_levels(lhs_fn),
|
||||
const_levels(fn),
|
||||
eqn_level_buffer);
|
||||
levels eqn_levels = to_list(eqn_level_buffer);
|
||||
/* Get arguments for instantiating the lemma */
|
||||
buffer<expr> eqn_args;
|
||||
get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args);
|
||||
/* Convert type */
|
||||
expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels);
|
||||
for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type);
|
||||
eqn_type = instantiate_rev(eqn_type, eqn_args);
|
||||
expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) {
|
||||
for (unsigned i = 0; i < vals.size(); i++) {
|
||||
if (e == vals[i])
|
||||
return some_expr(new_vals[i]);
|
||||
}
|
||||
return none_expr();
|
||||
});
|
||||
new_eqn_type = locals.mk_pi(new_eqn_type);
|
||||
name new_eqn_name = mk_equation_name(d_names[d_idx], i);
|
||||
expr new_eqn_value;
|
||||
new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args);
|
||||
new_eqn_value = locals.mk_lambda(new_eqn_value);
|
||||
declaration new_decl = mk_theorem(new_eqn_name, lps, new_eqn_type, new_eqn_value);
|
||||
new_env = module::add(new_env, check(new_env, new_decl, true));
|
||||
if (is_rfl_lemma(env, eqn_name))
|
||||
new_env = mark_rfl_lemma(new_env, new_eqn_name);
|
||||
new_env = add_eqn_lemma(new_env, new_eqn_name);
|
||||
i++;
|
||||
}
|
||||
}
|
||||
return new_env;
|
||||
}
|
||||
|
||||
/**
|
||||
Given a declaration d defined as
|
||||
(fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n')
|
||||
where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n.
|
||||
Then, copy the equation lemmas from d._main to d.
|
||||
*/
|
||||
static environment copy_equation_lemmas(environment const & env, name const & d_name) {
|
||||
buffer<name> d_names;
|
||||
d_names.push_back(d_name);
|
||||
return copy_equation_lemmas(env, d_names);
|
||||
}
|
||||
|
||||
environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modifiers const & modifiers, decl_attributes attrs) {
|
||||
buffer<name> lp_names;
|
||||
buffer<expr> fns, params;
|
||||
|
|
@ -314,6 +497,7 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
|
|||
}
|
||||
unsigned num_defs = get_equations_result_size(val);
|
||||
lean_assert(fns.size() == num_defs);
|
||||
buffer<name> new_d_names;
|
||||
/* Define functions */
|
||||
for (unsigned i = 0; i < num_defs; i++) {
|
||||
expr curr = get_equations_result(val, i);
|
||||
|
|
@ -325,12 +509,11 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif
|
|||
std::tie(env, c_real_name) = declare_definition(p, env, kind, lp_names, c_name,
|
||||
curr_type, some_expr(curr), {}, modifiers, attrs,
|
||||
doc_string, header_pos);
|
||||
new_d_names.push_back(c_real_name);
|
||||
elab.set_env(env);
|
||||
}
|
||||
/* Add lemmas */
|
||||
// TODO(Leo):
|
||||
|
||||
return elab.env();
|
||||
return copy_equation_lemmas(elab.env(), new_d_names);
|
||||
}
|
||||
|
||||
static expr_pair parse_definition(parser & p, buffer<name> & lp_names, buffer<expr> & params,
|
||||
|
|
@ -465,152 +648,6 @@ static expr fix_rec_fn_macro_args(elaborator & elab, name const & fn, buffer<exp
|
|||
return fix_rec_fn_macro_args_fn(params, fns)(val);
|
||||
}
|
||||
|
||||
static void throw_unexpected_error_at_copy_lemmas() {
|
||||
throw exception("unexpected error, failed to generate equational lemmas in the front-end");
|
||||
}
|
||||
|
||||
/* Given e of the form Pi (a_1 : A_1) ... (a_n : A_n), lhs = rhs,
|
||||
return the pair (lhs, n) */
|
||||
static pair<expr, unsigned> get_lemma_lhs(expr e) {
|
||||
unsigned nparams = 0;
|
||||
while (is_pi(e)) {
|
||||
nparams++;
|
||||
e = binding_body(e);
|
||||
}
|
||||
expr lhs, rhs;
|
||||
if (!is_eq(e, lhs, rhs))
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
return mk_pair(lhs, nparams);
|
||||
}
|
||||
|
||||
/*
|
||||
Given a lemma with parameters lp_names:
|
||||
[lp_1 ... lp_n]
|
||||
and the levels in the function application on the lemma left-hand-side lhs_fn_levels:
|
||||
[u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)],
|
||||
and levels fn_levels
|
||||
[v_1 ... v_n]
|
||||
Then, store
|
||||
p([v_1 ... v_n])
|
||||
in result.
|
||||
*/
|
||||
static void get_levels_for_instantiating_lemma(level_param_names const & lp_names,
|
||||
levels const & lhs_fn_levels,
|
||||
levels const & fn_levels,
|
||||
buffer<level> & result) {
|
||||
buffer<level> fn_levels_buffer;
|
||||
buffer<level> lhs_fn_levels_buffer;
|
||||
to_buffer(fn_levels, fn_levels_buffer);
|
||||
to_buffer(lhs_fn_levels, lhs_fn_levels_buffer);
|
||||
lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size());
|
||||
for (name const & lp_name : lp_names) {
|
||||
unsigned j = 0;
|
||||
for (; j < lhs_fn_levels_buffer.size(); j++) {
|
||||
if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas();
|
||||
if (param_id(lhs_fn_levels_buffer[j]) == lp_name) {
|
||||
result.push_back(fn_levels_buffer[j]);
|
||||
break;
|
||||
}
|
||||
}
|
||||
lean_assert(j < lhs_fn_levels_buffer.size());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Given a lemma with the given arity (i.e., number of nested Pi-terms),
|
||||
n = args.size() <= lhs_args.size(), and the first n
|
||||
arguments in lhs_args.size() are a permutation p of
|
||||
(var #0) ... (var #n-1)
|
||||
Then, store in result p(args)
|
||||
*/
|
||||
static void get_args_for_instantiating_lemma(unsigned arity,
|
||||
buffer<expr> const & lhs_args,
|
||||
buffer<expr> const & args,
|
||||
buffer<expr> & result) {
|
||||
for (unsigned i = 0; i < args.size(); i++) {
|
||||
if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity)
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
result.push_back(args[arity - var_idx(lhs_args[i]) - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
Given a declaration d defined as
|
||||
(fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n')
|
||||
where a_1' ... a_n' is a permutation of a_1 ... a_n.
|
||||
Then, copy the equation lemmas from d._main to d.
|
||||
*/
|
||||
static environment copy_equation_lemmas(environment const & env, name const & d_name) {
|
||||
declaration const & d = env.get(d_name);
|
||||
levels lps = param_names_to_levels(d.get_univ_params());
|
||||
expr val = instantiate_value_univ_params(d, lps);
|
||||
type_context ctx(env, transparency_mode::All);
|
||||
type_context::tmp_locals locals(ctx);
|
||||
while (is_lambda(val)) {
|
||||
expr local = locals.push_local_from_binding(val);
|
||||
val = instantiate(binding_body(val), local);
|
||||
}
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(val, args);
|
||||
if (!is_constant(fn) ||
|
||||
!std::all_of(args.begin(), args.end(), is_local) ||
|
||||
length(const_levels(fn)) != length(lps)) {
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
}
|
||||
/* We want to create new equations where we replace val with new_val in the equations
|
||||
associated with fn. */
|
||||
expr new_val = mk_app(mk_constant(d_name, lps), locals.as_buffer());
|
||||
/* Copy equations */
|
||||
environment new_env = env;
|
||||
unsigned i = 1;
|
||||
while (true) {
|
||||
name eqn_name = mk_equation_name(const_name(fn), i);
|
||||
optional<declaration> eqn_decl = env.find(eqn_name);
|
||||
if (!eqn_decl) break;
|
||||
unsigned num_eqn_levels = eqn_decl->get_num_univ_params();
|
||||
if (num_eqn_levels != length(lps))
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
expr lhs; unsigned num_eqn_params;
|
||||
std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type());
|
||||
buffer<expr> lhs_args;
|
||||
expr const & lhs_fn = get_app_args(lhs, lhs_args);
|
||||
if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size())
|
||||
throw_unexpected_error_at_copy_lemmas();
|
||||
/* Get levels for instantiating the lemma */
|
||||
buffer<level> eqn_level_buffer;
|
||||
get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(),
|
||||
const_levels(lhs_fn),
|
||||
const_levels(fn),
|
||||
eqn_level_buffer);
|
||||
levels eqn_levels = to_list(eqn_level_buffer);
|
||||
/* Get arguments for instantiating the lemma */
|
||||
buffer<expr> eqn_args;
|
||||
get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args);
|
||||
/* Convert type */
|
||||
expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels);
|
||||
for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type);
|
||||
eqn_type = instantiate_rev(eqn_type, eqn_args);
|
||||
expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) {
|
||||
if (e == val)
|
||||
return some_expr(new_val);
|
||||
else
|
||||
return none_expr();
|
||||
});
|
||||
new_eqn_type = locals.mk_pi(new_eqn_type);
|
||||
name new_eqn_name = mk_equation_name(d_name, i);
|
||||
expr new_eqn_value;
|
||||
new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args);
|
||||
new_eqn_value = locals.mk_lambda(new_eqn_value);
|
||||
declaration new_decl = mk_theorem(new_eqn_name, d.get_univ_params(), new_eqn_type, new_eqn_value);
|
||||
new_env = module::add(new_env, check(new_env, new_decl, true));
|
||||
if (is_rfl_lemma(env, eqn_name))
|
||||
new_env = mark_rfl_lemma(new_env, new_eqn_name);
|
||||
new_env = add_eqn_lemma(new_env, new_eqn_name);
|
||||
i++;
|
||||
}
|
||||
return new_env;
|
||||
}
|
||||
|
||||
static expr inline_new_defs(environment const & old_env, environment const & new_env, name const & n, expr const & e) {
|
||||
return replace(e, [=] (expr const & e, unsigned) -> optional<expr> {
|
||||
if (is_sorry(e)) {
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ with odd : nat → bool
|
|||
#eval even 4
|
||||
#eval odd 3
|
||||
#eval odd 4
|
||||
#check even._main.equations._eqn_1
|
||||
#check even._main.equations._eqn_2
|
||||
#check odd._main.equations._eqn_1
|
||||
#check odd._main.equations._eqn_2
|
||||
#check even.equations._eqn_1
|
||||
#check even.equations._eqn_2
|
||||
#check odd.equations._eqn_1
|
||||
#check odd.equations._eqn_2
|
||||
|
||||
mutual def f, g {α β : Type u} (p : α × β)
|
||||
with f : Π n : nat, vector (α × β) n
|
||||
|
|
@ -31,7 +31,7 @@ with g : Π n : nat, α → vector β n
|
|||
| 0 a := vector.nil
|
||||
| (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2)
|
||||
|
||||
#check @f._main.equations._eqn_1
|
||||
#check @f._main.equations._eqn_2
|
||||
#check @g._main.equations._eqn_1
|
||||
#check @g._main.equations._eqn_2
|
||||
#check @f.equations._eqn_1
|
||||
#check @f.equations._eqn_2
|
||||
#check @g.equations._eqn_1
|
||||
#check @g.equations._eqn_2
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue