From 5ef892bb4591ad983bf1bcc515681c92ef56bd95 Mon Sep 17 00:00:00 2001 From: Daniel Selsam Date: Mon, 27 Feb 2017 16:53:42 -0800 Subject: [PATCH] feat(inductive_compiler): cases_on for mutual and nested --- src/library/inductive_compiler/mutual.cpp | 82 +++++++++++++++- src/library/inductive_compiler/nested.cpp | 112 +++++++++++++++------- tests/lean/1334b.lean | 6 ++ 3 files changed, 161 insertions(+), 39 deletions(-) diff --git a/src/library/inductive_compiler/mutual.cpp b/src/library/inductive_compiler/mutual.cpp index b1a056cbb6..7aaf3695f8 100644 --- a/src/library/inductive_compiler/mutual.cpp +++ b/src/library/inductive_compiler/mutual.cpp @@ -512,7 +512,7 @@ class add_mutual_inductive_decl_fn { return Fun(index, construct_inner_C_core(C, index, 0, ind_idx)); } - expr introduce_locals_for_rec_args(unsigned ind_idx, expr & C, buffer & minor_premises, buffer & indices, expr & major_premise) { + expr introduce_locals_for_rec_args(unsigned ind_idx, expr & C, buffer & minor_premises, buffer & indices, expr & major_premise, bool cases_on) { expr const & ind = m_mut_decl.get_ind(ind_idx); { buffer C_args; @@ -545,7 +545,7 @@ class add_mutual_inductive_decl_fn { } buffer inner_indices; - if (m_mut_decl.is_ind_app(ir_arg, ind_idx, inner_indices)) { + if (!cases_on && m_mut_decl.is_ind_app(ir_arg, ind_idx, inner_indices)) { expr rec_arg_type = Pi(ir_arg_args, mk_app(mk_app(C, inner_indices), mk_app(minor_premise_arg, ir_arg_args))); expr rec_arg = mk_local_pp("x", rec_arg_type); rec_args.push_back(rec_arg); @@ -584,7 +584,7 @@ class add_mutual_inductive_decl_fn { expr C; buffer minor_premises, indices; expr major_premise; - expr rec_type = introduce_locals_for_rec_args(ind_idx, C, minor_premises, indices, major_premise); + expr rec_type = introduce_locals_for_rec_args(ind_idx, C, minor_premises, indices, major_premise, false); expr inner_C = construct_inner_C(C, ind_idx); lean_trace(name({"inductive_compiler", "mutual", "rec"}), tout() << "inner C: " << inner_C << "\n";); @@ -656,6 +656,80 @@ class add_mutual_inductive_decl_fn { m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, get_dep_recursor(m_env, mlocal_name(ind)), rec_lp_names, rec_type, rec_val, true))); } + void define_cases_on(name const & rec_name, level_param_names const & rec_lp_names, unsigned ind_idx) { + expr const & ind = m_mut_decl.get_ind(ind_idx); + + expr C; + buffer minor_premises, indices; + expr major_premise; + expr cases_on_type = introduce_locals_for_rec_args(ind_idx, C, minor_premises, indices, major_premise, true); + + expr inner_C = construct_inner_C(C, ind_idx); + lean_trace(name({"inductive_compiler", "mutual", "cases_on"}), tout() << "inner C: " << inner_C << "\n";); + + buffer inner_minor_premises; + for (unsigned i = 0; i < m_mut_decl.get_inds().size(); ++i) { + buffer const & irs = m_mut_decl.get_intro_rules(i); + for (unsigned ir_idx = 0; ir_idx < irs.size(); ++ir_idx) { + expr const & ir = irs[ir_idx]; + buffer locals; + buffer rec_args; + buffer return_args; + expr ir_type = mlocal_type(ir); + while (is_pi(ir_type)) { + expr l = mk_local_for(ir_type); + locals.push_back(l); + + buffer ir_arg_args; + expr ir_arg = binding_domain(ir_type); + + while (is_pi(ir_arg)) { + expr ir_arg_arg = mk_local_for(ir_arg); + ir_arg_args.push_back(ir_arg_arg); + ir_arg = instantiate(binding_body(ir_arg), ir_arg_arg); + } + + buffer inner_indices; + if (m_mut_decl.is_ind_app(ir_arg, inner_indices)) { + bool this_ind_app = m_mut_decl.is_ind_app(ir_arg, ind_idx); + expr C_term = mk_app(mk_app(C, inner_indices), mk_app(l, ir_arg_args)); + expr rec_arg_type = Pi(ir_arg_args, this_ind_app ? C_term : punit()); + expr l2 = mk_local_pp("x", rec_arg_type); + rec_args.push_back(l2); + } + ir_type = m_tctx.whnf(instantiate(binding_body(ir_type), l)); + return_args.push_back(l); + } + locals.append(rec_args); + expr return_value; + if (i == ind_idx) { + return_value = mk_app(minor_premises[ir_idx], return_args); + } else { + return_value = punit_star(); + } + expr inner_minor_premise = Fun(locals, return_value); + lean_trace(name({"inductive_compiler", "mutual", "cases_on"}), tout() << "inner minor premise: " << inner_minor_premise << "\n";); + inner_minor_premises.push_back(inner_minor_premise); + } + } + + expr inner_index = mk_app(m_putters[ind_idx], mk_app(m_makers[ind_idx], indices)); + lean_trace(name({"inductive_compiler", "mutual", "cases_on"}), tout() << "inner index: " << inner_index << "\n";); + expr inner_major_premise = major_premise; + expr cases_on_val = mk_app(mk_app(mk_app(mk_app(mk_app(mk_constant(rec_name, param_names_to_levels(rec_lp_names)), m_mut_decl.get_params()), inner_C), + inner_minor_premises), inner_index), inner_major_premise); + + cases_on_type = Pi(m_mut_decl.get_params(), Pi(C, Pi(indices, Pi(major_premise, Pi(minor_premises, cases_on_type))))); + cases_on_val = Fun(m_mut_decl.get_params(), Fun(C, Fun(indices, Fun(major_premise, Fun(minor_premises, cases_on_val))))); + + lean_trace(name({"inductive_compiler", "mutual", "cases_on"}), tout() << "cases_on type: " << cases_on_type << "\n";); + lean_trace(name({"inductive_compiler", "mutual", "cases_on"}), tout() << "cases_on val: " << cases_on_val << "\n";); + + lean_assert(!has_local(cases_on_type)); + lean_assert(!has_local(cases_on_val)); + m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, name(mlocal_name(ind), "cases_on"), rec_lp_names, cases_on_type, cases_on_val, true))); + } + void define_recursors() { name rec_name = get_dep_recursor(m_env, mlocal_name(m_basic_decl.get_ind(0))); declaration rec_decl = m_env.get(rec_name); @@ -669,8 +743,10 @@ class add_mutual_inductive_decl_fn { for (unsigned i = 0; i < m_mut_decl.get_inds().size(); ++i) { define_recursor(rec_name, rec_lp_names, i); + define_cases_on(rec_name, rec_lp_names, i); } } + public: add_mutual_inductive_decl_fn(environment const & env, options const & opts, name_map const & implicit_infer_map, ginductive_decl const & mut_decl, diff --git a/src/library/inductive_compiler/nested.cpp b/src/library/inductive_compiler/nested.cpp index 567de32792..a84e0c0e83 100644 --- a/src/library/inductive_compiler/nested.cpp +++ b/src/library/inductive_compiler/nested.cpp @@ -1711,49 +1711,88 @@ class add_nested_inductive_decl_fn { } }; - ///////////////////////////////////////////////// - ///// Stage 8: sizeof lemmas for nested ind ///// - ///////////////////////////////////////////////// - - void define_nested_sizeof_lemmas() { + void define_nested_cases_on() { for (unsigned ind_idx = 0; ind_idx < m_nested_decl.get_num_inds(); ++ind_idx) { - name sizeof_name = mk_sizeof_name(mlocal_name(m_nested_decl.get_ind(ind_idx))); - for (unsigned ir_idx = 0; ir_idx < m_nested_decl.get_num_intro_rules(ind_idx); ++ir_idx) { - type_context tctx_synth(m_env, m_tctx.get_options(), m_synth_lctx); + expr const & nested_ind = m_nested_decl.get_ind(ind_idx); + expr const & inner_ind = m_inner_decl.get_ind(ind_idx); - expr c_nested_ir = m_nested_decl.get_c_ir_params(ind_idx, ir_idx); - expr ir_ty = tctx_synth.infer(c_nested_ir); + declaration d = m_env.get(inductive::get_elim_name(mlocal_name(inner_ind))); + level_param_names lp_names = d.get_univ_params(); + levels lvls = param_names_to_levels(lp_names); - buffer locals; - while (is_pi(ir_ty)) { - expr local = mk_local_for(ir_ty); - locals.push_back(local); - ir_ty = safe_whnf(tctx_synth, instantiate(binding_body(ir_ty), local)); - } - buffer fully_packed_args; - get_app_args(tctx_synth.relaxed_whnf(mk_app(c_nested_ir, locals)), fully_packed_args); + expr inner_cases_on = mk_app(mk_constant(name(mlocal_name(inner_ind), "cases_on"), lvls), m_nested_decl.get_params()); + expr inner_cases_on_type = m_tctx.infer(inner_cases_on); - expr rhs = mk_nat_one(); - for (expr const & rhs_arg : fully_packed_args) { - rhs = mk_nat_add(rhs, mk_app(tctx_synth, get_sizeof_name(), rhs_arg)); - } - - expr lhs = mk_app(tctx_synth, sizeof_name, {mk_app(c_nested_ir, locals)}); - expr dsimp_rule_type = Pi(m_nested_decl.get_params(), tctx_synth.mk_pi(m_param_insts, Pi(locals, mk_eq(tctx_synth, lhs, rhs)))); - expr dsimp_rule_val = Fun(m_nested_decl.get_params(), tctx_synth.mk_lambda(m_param_insts, Fun(locals, mk_eq_refl(tctx_synth, lhs)))); - name dsimp_rule_name = mk_sizeof_spec_name(mlocal_name(m_nested_decl.get_intro_rule(ind_idx, ir_idx))); - - lean_trace(name({"inductive_compiler", "nested", "sizeof"}), tout() << "[rfl]: " << dsimp_rule_type << "\n";); - - define_theorem(dsimp_rule_name, dsimp_rule_type, dsimp_rule_val); - m_env = mark_rfl_lemma(m_env, dsimp_rule_name); - m_env = add_eqn_lemma(m_env, dsimp_rule_name); - m_env = add_protected(m_env, dsimp_rule_name); - m_tctx.set_env(m_env); - } + expr outer_cases_on_type = Pi(m_nested_decl.get_params(), unpack_type(inner_cases_on_type)); + expr outer_cases_on_val = Fun(m_nested_decl.get_params(), build_nested_cases_on(ind_idx, inner_cases_on, unpack_type(inner_cases_on_type))); + define(name(mlocal_name(nested_ind), "cases_on"), outer_cases_on_type, outer_cases_on_val, lp_names); } } + expr build_nested_cases_on(unsigned ind_idx, expr const & inner_cases_on, expr const & outer_cases_on_type) { + expr C; + buffer indices; + expr major_premise; + buffer minor_premises; + expr goal = introduce_locals_for_nested_cases_on(ind_idx, outer_cases_on_type, C, indices, major_premise, minor_premises); + + // Only the minor premises need to change + lean_assert(m_nested_decl.get_num_intro_rules(ind_idx) == minor_premises.size()); + buffer inner_minor_premises; + for (unsigned ir_idx = 0; ir_idx < minor_premises.size(); ++ir_idx) { + expr const & minor_premise = minor_premises[ir_idx]; + expr ty = safe_whnf(m_tctx, pack_type(mlocal_type(minor_premise))); + + buffer inner_minor_premise_args; + buffer inner_minor_premise_rec_args; + while (is_pi(ty)) { + expr arg = mk_local_for(ty); + if (get_app_fn(mlocal_type(arg)) != pack_type(C)) { + lean_assert(inner_minor_premise_rec_args.empty()); + inner_minor_premise_args.push_back(arg); + } else { + inner_minor_premise_rec_args.push_back(arg); + } + ty = safe_whnf(m_tctx, instantiate(binding_body(ty), arg)); + } + inner_minor_premises.push_back(build_nested_minor_premise_fn(*this, ind_idx, ir_idx, minor_premise, inner_minor_premise_args, + inner_minor_premise_rec_args, ty)()); + } + + return Fun(C, + Fun(indices, + Fun(major_premise, + Fun(minor_premises, + mk_app(mk_app(mk_app(mk_app(inner_cases_on, C), indices), major_premise), inner_minor_premises))))); + } + + expr introduce_locals_for_nested_cases_on(unsigned ind_idx, expr const & outer_cases_on_type, + expr & C, buffer & indices, expr & major_premise, + buffer & minor_premises) { + expr ty = safe_whnf(m_tctx, outer_cases_on_type); + + C = mk_local_for(ty, "C"); + ty = safe_whnf(m_tctx, instantiate(binding_body(ty), C)); + + while (true) { + expr l = mk_local_for(ty); + ty = safe_whnf(m_tctx, instantiate(binding_body(ty), l)); + if (m_nested_decl.is_ind_app(mlocal_type(l), ind_idx)) { + major_premise = l; + break; + } else { + indices.push_back(l); + } + } + + for (unsigned ir_idx = 0; ir_idx < m_nested_decl.get_num_intro_rules(ind_idx); ++ir_idx) { + expr minor_premise = mk_local_for(ty); + minor_premises.push_back(minor_premise); + ty = safe_whnf(m_tctx, instantiate(binding_body(ty), minor_premise)); + } + + return ty; + } public: add_nested_inductive_decl_fn(environment const & env, options const & opts, name_map const & implicit_infer_map, @@ -1783,6 +1822,7 @@ public: build_primitive_pack_unpack(); define_nested_irs(); define_nested_recursors(); + define_nested_cases_on(); return optional(m_env); } diff --git a/tests/lean/1334b.lean b/tests/lean/1334b.lean index 630b97154c..d0aebbb993 100644 --- a/tests/lean/1334b.lean +++ b/tests/lean/1334b.lean @@ -18,3 +18,9 @@ def list_of : term → list term check list_of.equations._eqn_1 check list_of.equations._eqn_2 check list_of.equations._eqn_3 + +example (a : nat) (ls : list term) : term.var a = term.app ls → false := +by contradiction + +example (a : nat) (s : string) : ¬ term.var a = term.cnst s := +by contradiction