From 729e798d6f25dc8a43a564b3fc843133ca8b223e Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 22 May 2017 14:51:06 -0700 Subject: [PATCH] feat(frontends/lean/definition_cmds): copy equational lemmas to top level definition --- src/frontends/lean/definition_cmds.cpp | 335 ++++++++++++++----------- tmp/even_odd.lean | 16 +- 2 files changed, 194 insertions(+), 157 deletions(-) diff --git a/src/frontends/lean/definition_cmds.cpp b/src/frontends/lean/definition_cmds.cpp index e23daf70ee..34745210e2 100644 --- a/src/frontends/lean/definition_cmds.cpp +++ b/src/frontends/lean/definition_cmds.cpp @@ -284,6 +284,189 @@ declare_definition(parser & p, environment const & env, def_cmd_kind kind, buffe return mk_pair(new_env, c_real_name); } +static void throw_unexpected_error_at_copy_lemmas() { + throw exception("unexpected error, failed to generate equational lemmas in the front-end"); +} + +/* Given e of the form Pi (a_1 : A_1) ... (a_n : A_n), lhs = rhs, + return the pair (lhs, n) */ +static pair get_lemma_lhs(expr e) { + unsigned nparams = 0; + while (is_pi(e)) { + nparams++; + e = binding_body(e); + } + expr lhs, rhs; + if (!is_eq(e, lhs, rhs)) + throw_unexpected_error_at_copy_lemmas(); + return mk_pair(lhs, nparams); +} + +/* + Given a lemma with parameters lp_names: + [lp_1 ... lp_n] + and the levels in the function application on the lemma left-hand-side lhs_fn_levels: + [u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)], + and levels fn_levels + [v_1 ... v_n] + Then, store + p([v_1 ... v_n]) + in result. +*/ +static void get_levels_for_instantiating_lemma(level_param_names const & lp_names, + levels const & lhs_fn_levels, + levels const & fn_levels, + buffer & result) { + buffer fn_levels_buffer; + buffer lhs_fn_levels_buffer; + to_buffer(fn_levels, fn_levels_buffer); + to_buffer(lhs_fn_levels, lhs_fn_levels_buffer); + lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size()); + for (name const & lp_name : lp_names) { + unsigned j = 0; + for (; j < lhs_fn_levels_buffer.size(); j++) { + if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas(); + if (param_id(lhs_fn_levels_buffer[j]) == lp_name) { + result.push_back(fn_levels_buffer[j]); + break; + } + } + lean_assert(j < lhs_fn_levels_buffer.size()); + } +} + +/** + Given a lemma with the given arity (i.e., number of nested Pi-terms), + n = args.size() <= lhs_args.size(), and the first n + arguments in lhs_args.size() are a permutation p of + (var #0) ... (var #n-1) + Then, store in result p(args) +*/ +static void get_args_for_instantiating_lemma(unsigned arity, + buffer const & lhs_args, + buffer const & args, + buffer & result) { + for (unsigned i = 0; i < args.size(); i++) { + if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity) + throw_unexpected_error_at_copy_lemmas(); + result.push_back(args[arity - var_idx(lhs_args[i]) - 1]); + } +} + +/** + Given declarations d_1, ..., d_n defined as + (fun (a_1 : A_1) ... (a_n : A_n), d_1._main a_1' ... a_n') + where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n. + Moreover, the parameters A_1 ... A_n are the same in all d_i's. + Then, copy the equation lemmas from d._main to d. +*/ +static environment copy_equation_lemmas(environment const & env, buffer const & d_names) { + type_context ctx(env, transparency_mode::All); + type_context::tmp_locals locals(ctx); + level_param_names lps; + levels ls; + buffer vals; + buffer new_vals; + for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) { + declaration const & d = env.get(d_names[d_idx]); + expr val; + if (d_idx == 0) { + lps = d.get_univ_params(); + ls = param_names_to_levels(lps); + val = instantiate_value_univ_params(d, ls); + while (is_lambda(val)) { + expr local = locals.push_local_from_binding(val); + val = instantiate(binding_body(val), local); + } + } else { + val = instantiate_value_univ_params(d, ls); + for (expr const & local : locals.as_buffer()) { + lean_assert(is_lambda(val)); + lean_assert(ctx.is_def_eq(ctx.infer(local), binding_domain(val))); + val = instantiate(binding_body(val), local); + } + } + buffer args; + expr const & fn = get_app_args(val, args); + if (!is_constant(fn) || + !std::all_of(args.begin(), args.end(), is_local) || + length(const_levels(fn)) != length(ls)) { + throw_unexpected_error_at_copy_lemmas(); + } + vals.push_back(val); + /* We want to create new equations where we replace val with new_val in the equations + associated with fn. */ + expr new_val = mk_app(mk_constant(d_names[d_idx], ls), locals.as_buffer()); + new_vals.push_back(new_val); + } + /* Copy equations */ + environment new_env = env; + for (unsigned d_idx = 0; d_idx < d_names.size(); d_idx++) { + buffer args; + expr const & fn = get_app_args(vals[d_idx], args); + unsigned i = 1; + while (true) { + name eqn_name = mk_equation_name(const_name(fn), i); + optional eqn_decl = env.find(eqn_name); + if (!eqn_decl) break; + unsigned num_eqn_levels = eqn_decl->get_num_univ_params(); + if (num_eqn_levels != length(ls)) + throw_unexpected_error_at_copy_lemmas(); + expr lhs; unsigned num_eqn_params; + std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type()); + buffer lhs_args; + expr const & lhs_fn = get_app_args(lhs, lhs_args); + if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size()) + throw_unexpected_error_at_copy_lemmas(); + /* Get levels for instantiating the lemma */ + buffer eqn_level_buffer; + get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(), + const_levels(lhs_fn), + const_levels(fn), + eqn_level_buffer); + levels eqn_levels = to_list(eqn_level_buffer); + /* Get arguments for instantiating the lemma */ + buffer eqn_args; + get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args); + /* Convert type */ + expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels); + for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type); + eqn_type = instantiate_rev(eqn_type, eqn_args); + expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) { + for (unsigned i = 0; i < vals.size(); i++) { + if (e == vals[i]) + return some_expr(new_vals[i]); + } + return none_expr(); + }); + new_eqn_type = locals.mk_pi(new_eqn_type); + name new_eqn_name = mk_equation_name(d_names[d_idx], i); + expr new_eqn_value; + new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args); + new_eqn_value = locals.mk_lambda(new_eqn_value); + declaration new_decl = mk_theorem(new_eqn_name, lps, new_eqn_type, new_eqn_value); + new_env = module::add(new_env, check(new_env, new_decl, true)); + if (is_rfl_lemma(env, eqn_name)) + new_env = mark_rfl_lemma(new_env, new_eqn_name); + new_env = add_eqn_lemma(new_env, new_eqn_name); + i++; + } + } + return new_env; +} + +/** + Given a declaration d defined as + (fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n') + where a_1' ... a_n' is a permutation of a subset of a_1 ... a_n. + Then, copy the equation lemmas from d._main to d. +*/ +static environment copy_equation_lemmas(environment const & env, name const & d_name) { + buffer d_names; + d_names.push_back(d_name); + return copy_equation_lemmas(env, d_names); +} + environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modifiers const & modifiers, decl_attributes attrs) { buffer lp_names; buffer fns, params; @@ -314,6 +497,7 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif } unsigned num_defs = get_equations_result_size(val); lean_assert(fns.size() == num_defs); + buffer new_d_names; /* Define functions */ for (unsigned i = 0; i < num_defs; i++) { expr curr = get_equations_result(val, i); @@ -325,12 +509,11 @@ environment mutual_definition_cmd_core(parser & p, def_cmd_kind kind, decl_modif std::tie(env, c_real_name) = declare_definition(p, env, kind, lp_names, c_name, curr_type, some_expr(curr), {}, modifiers, attrs, doc_string, header_pos); + new_d_names.push_back(c_real_name); elab.set_env(env); } /* Add lemmas */ - // TODO(Leo): - - return elab.env(); + return copy_equation_lemmas(elab.env(), new_d_names); } static expr_pair parse_definition(parser & p, buffer & lp_names, buffer & params, @@ -465,152 +648,6 @@ static expr fix_rec_fn_macro_args(elaborator & elab, name const & fn, buffer get_lemma_lhs(expr e) { - unsigned nparams = 0; - while (is_pi(e)) { - nparams++; - e = binding_body(e); - } - expr lhs, rhs; - if (!is_eq(e, lhs, rhs)) - throw_unexpected_error_at_copy_lemmas(); - return mk_pair(lhs, nparams); -} - -/* - Given a lemma with parameters lp_names: - [lp_1 ... lp_n] - and the levels in the function application on the lemma left-hand-side lhs_fn_levels: - [u_1 ... u_n] s.t. there is a permutation p s.t. p([u_1 ... u_n] = [(mk_univ_param lp_1) ... (mk_univ_param lp_n)], - and levels fn_levels - [v_1 ... v_n] - Then, store - p([v_1 ... v_n]) - in result. -*/ -static void get_levels_for_instantiating_lemma(level_param_names const & lp_names, - levels const & lhs_fn_levels, - levels const & fn_levels, - buffer & result) { - buffer fn_levels_buffer; - buffer lhs_fn_levels_buffer; - to_buffer(fn_levels, fn_levels_buffer); - to_buffer(lhs_fn_levels, lhs_fn_levels_buffer); - lean_assert(fn_levels_buffer.size() == lhs_fn_levels_buffer.size()); - for (name const & lp_name : lp_names) { - unsigned j = 0; - for (; j < lhs_fn_levels_buffer.size(); j++) { - if (!is_param(lhs_fn_levels_buffer[j])) throw_unexpected_error_at_copy_lemmas(); - if (param_id(lhs_fn_levels_buffer[j]) == lp_name) { - result.push_back(fn_levels_buffer[j]); - break; - } - } - lean_assert(j < lhs_fn_levels_buffer.size()); - } -} - -/** - Given a lemma with the given arity (i.e., number of nested Pi-terms), - n = args.size() <= lhs_args.size(), and the first n - arguments in lhs_args.size() are a permutation p of - (var #0) ... (var #n-1) - Then, store in result p(args) -*/ -static void get_args_for_instantiating_lemma(unsigned arity, - buffer const & lhs_args, - buffer const & args, - buffer & result) { - for (unsigned i = 0; i < args.size(); i++) { - if (!is_var(lhs_args[i]) || var_idx(lhs_args[i]) >= arity) - throw_unexpected_error_at_copy_lemmas(); - result.push_back(args[arity - var_idx(lhs_args[i]) - 1]); - } -} - -/** - Given a declaration d defined as - (fun (a_1 : A_1) ... (a_n : A_n), d._main a_1' ... a_n') - where a_1' ... a_n' is a permutation of a_1 ... a_n. - Then, copy the equation lemmas from d._main to d. -*/ -static environment copy_equation_lemmas(environment const & env, name const & d_name) { - declaration const & d = env.get(d_name); - levels lps = param_names_to_levels(d.get_univ_params()); - expr val = instantiate_value_univ_params(d, lps); - type_context ctx(env, transparency_mode::All); - type_context::tmp_locals locals(ctx); - while (is_lambda(val)) { - expr local = locals.push_local_from_binding(val); - val = instantiate(binding_body(val), local); - } - buffer args; - expr const & fn = get_app_args(val, args); - if (!is_constant(fn) || - !std::all_of(args.begin(), args.end(), is_local) || - length(const_levels(fn)) != length(lps)) { - throw_unexpected_error_at_copy_lemmas(); - } - /* We want to create new equations where we replace val with new_val in the equations - associated with fn. */ - expr new_val = mk_app(mk_constant(d_name, lps), locals.as_buffer()); - /* Copy equations */ - environment new_env = env; - unsigned i = 1; - while (true) { - name eqn_name = mk_equation_name(const_name(fn), i); - optional eqn_decl = env.find(eqn_name); - if (!eqn_decl) break; - unsigned num_eqn_levels = eqn_decl->get_num_univ_params(); - if (num_eqn_levels != length(lps)) - throw_unexpected_error_at_copy_lemmas(); - expr lhs; unsigned num_eqn_params; - std::tie(lhs, num_eqn_params) = get_lemma_lhs(eqn_decl->get_type()); - buffer lhs_args; - expr const & lhs_fn = get_app_args(lhs, lhs_args); - if (!is_constant(lhs_fn) || const_name(lhs_fn) != const_name(fn) || lhs_args.size() < args.size()) - throw_unexpected_error_at_copy_lemmas(); - /* Get levels for instantiating the lemma */ - buffer eqn_level_buffer; - get_levels_for_instantiating_lemma(eqn_decl->get_univ_params(), - const_levels(lhs_fn), - const_levels(fn), - eqn_level_buffer); - levels eqn_levels = to_list(eqn_level_buffer); - /* Get arguments for instantiating the lemma */ - buffer eqn_args; - get_args_for_instantiating_lemma(num_eqn_params, lhs_args, args, eqn_args); - /* Convert type */ - expr eqn_type = instantiate_type_univ_params(*eqn_decl, eqn_levels); - for (unsigned j = 0; j < eqn_args.size(); j++) eqn_type = binding_body(eqn_type); - eqn_type = instantiate_rev(eqn_type, eqn_args); - expr new_eqn_type = replace(eqn_type, [&](expr const & e, unsigned) { - if (e == val) - return some_expr(new_val); - else - return none_expr(); - }); - new_eqn_type = locals.mk_pi(new_eqn_type); - name new_eqn_name = mk_equation_name(d_name, i); - expr new_eqn_value; - new_eqn_value = mk_app(mk_constant(eqn_name, eqn_levels), args); - new_eqn_value = locals.mk_lambda(new_eqn_value); - declaration new_decl = mk_theorem(new_eqn_name, d.get_univ_params(), new_eqn_type, new_eqn_value); - new_env = module::add(new_env, check(new_env, new_decl, true)); - if (is_rfl_lemma(env, eqn_name)) - new_env = mark_rfl_lemma(new_env, new_eqn_name); - new_env = add_eqn_lemma(new_env, new_eqn_name); - i++; - } - return new_env; -} - static expr inline_new_defs(environment const & old_env, environment const & new_env, name const & n, expr const & e) { return replace(e, [=] (expr const & e, unsigned) -> optional { if (is_sorry(e)) { diff --git a/tmp/even_odd.lean b/tmp/even_odd.lean index 50a4545d1e..0c96dc17d1 100644 --- a/tmp/even_odd.lean +++ b/tmp/even_odd.lean @@ -18,10 +18,10 @@ with odd : nat → bool #eval even 4 #eval odd 3 #eval odd 4 -#check even._main.equations._eqn_1 -#check even._main.equations._eqn_2 -#check odd._main.equations._eqn_1 -#check odd._main.equations._eqn_2 +#check even.equations._eqn_1 +#check even.equations._eqn_2 +#check odd.equations._eqn_1 +#check odd.equations._eqn_2 mutual def f, g {α β : Type u} (p : α × β) with f : Π n : nat, vector (α × β) n @@ -31,7 +31,7 @@ with g : Π n : nat, α → vector β n | 0 a := vector.nil | (succ n) a := vector.cons p.2 $ (f n).map (λ p, p.2) -#check @f._main.equations._eqn_1 -#check @f._main.equations._eqn_2 -#check @g._main.equations._eqn_1 -#check @g._main.equations._eqn_2 +#check @f.equations._eqn_1 +#check @f.equations._eqn_2 +#check @g.equations._eqn_1 +#check @g.equations._eqn_2