diff --git a/src/library/inductive_compiler/ginductive.cpp b/src/library/inductive_compiler/ginductive.cpp index 1aee91f0a8..ce77abdf8c 100644 --- a/src/library/inductive_compiler/ginductive.cpp +++ b/src/library/inductive_compiler/ginductive.cpp @@ -7,18 +7,42 @@ Author: Daniel Selsam #include #include #include "util/serializer.h" +#include "util/list_fn.h" #include "kernel/environment.h" #include "library/inductive_compiler/ginductive.h" #include "library/module.h" +#include "library/constants.h" #include "library/kernel_serializer.h" namespace lean { +static unsigned compute_idx_number(expr const & e) { + buffer args; + unsigned idx = 0; + expr it = e; + while (true) { + args.clear(); + expr fn = get_app_args(it, args); + if (is_constant(fn) && const_name(fn) == get_psum_inl_name()) { + return idx; + } else if (is_constant(fn) && const_name(fn) == get_psum_inr_name()) { + idx++; + it = args[2]; + } else { + return idx; + } + } + lean_unreachable(); +} + struct ginductive_entry { ginductive_kind m_kind; + bool m_from_mutual; unsigned m_num_params; list m_inds; list > m_intro_rules; + list m_ir_offsets; + list > m_idx_to_ir_range; }; inline serializer & operator<<(serializer & s, ginductive_kind k) { @@ -45,16 +69,22 @@ inline deserializer & operator>>(deserializer & d, ginductive_entry & entry); serializer & operator<<(serializer & s, ginductive_entry const & entry) { s << entry.m_kind; + s << entry.m_from_mutual; s << entry.m_num_params; write_list(s, entry.m_inds); for (list const & irs : reverse(entry.m_intro_rules)) write_list(s, irs); + + write_list(s, entry.m_ir_offsets); + write_list >(s, entry.m_idx_to_ir_range); + return s; } ginductive_entry read_ginductive_entry(deserializer & d) { ginductive_entry entry; d >> entry.m_kind; + d >> entry.m_from_mutual; d >> entry.m_num_params; entry.m_inds = read_list(d, read_name); @@ -62,6 +92,9 @@ ginductive_entry read_ginductive_entry(deserializer & d) { for (unsigned i = 0; i < num_inds; ++i) { entry.m_intro_rules = list >(read_list(d, read_name), entry.m_intro_rules); } + + entry.m_ir_offsets = read_list(d); + entry.m_idx_to_ir_range = read_list >(d); return entry; } @@ -79,26 +112,43 @@ struct ginductive_env_ext : public environment_extension { name_map m_num_params; name_map m_ir_to_ind; + name_set m_from_mutual; + name_map m_ir_to_simulated_ir_offset; + name_map > > m_ind_to_ir_ranges; + ginductive_env_ext() {} void register_ginductive_entry(ginductive_entry const & entry) { buffer > intro_rules; to_buffer(entry.m_intro_rules, intro_rules); + buffer ir_offsets; + to_buffer(entry.m_ir_offsets, ir_offsets); + unsigned ind_idx = 0; + unsigned acc_ir_idx = 0; for (name const & ind : entry.m_inds) { switch (entry.m_kind) { case ginductive_kind::BASIC: break; case ginductive_kind::MUTUAL: m_all_mutual_inds = list(ind, m_all_mutual_inds); break; case ginductive_kind::NESTED: m_all_nested_inds = list(ind, m_all_nested_inds); break; } + + if (entry.m_from_mutual) + m_from_mutual.insert(ind); + m_ind_to_irs.insert(ind, intro_rules[ind_idx]); m_ind_to_mut_inds.insert(ind, entry.m_inds); m_ind_to_kind.insert(ind, entry.m_kind); m_num_params.insert(ind, entry.m_num_params); for (name const & ir : intro_rules[ind_idx]) { m_ir_to_ind.insert(ir, ind); + m_ir_to_simulated_ir_offset.insert(ir, ir_offsets[acc_ir_idx]); + acc_ir_idx++; } + + m_ind_to_ir_ranges.insert(ind, entry.m_idx_to_ir_range); + ind_idx++; } } @@ -137,6 +187,24 @@ struct ginductive_env_ext : public environment_extension { return *mut_ind_names; } + unsigned ir_to_simulated_ir_offset(name basic_ir_name) const { + unsigned const * offset = m_ir_to_simulated_ir_offset.find(basic_ir_name); + lean_assert(assert); + return *offset; + } + + pair ind_indices_to_ir_range(name const & basic_ind_name, buffer const & idxs) const { + if (!m_from_mutual.contains(basic_ind_name)) + return mk_pair(0, length(get_intro_rules(basic_ind_name))); + + lean_assert(idxs.size == 1); + unsigned idx_number = compute_idx_number(idxs[0]); + + list > const * ranges = m_ind_to_ir_ranges.find(basic_ind_name); + lean_assert(ranges); + return get_ith(*ranges, idx_number); + } + list get_all_nested_inds() const { return m_all_nested_inds; } @@ -229,6 +297,14 @@ list get_ginductive_mut_ind_names(environment const & env, name const & in return get_extension(env).get_mut_ind_names(ind_name); } +unsigned ir_to_simulated_ir_offset(environment const & env, name basic_ir_name) { + return get_extension(env).ir_to_simulated_ir_offset(basic_ir_name); +} + +pair ind_indices_to_ir_range(environment const & env, name const & basic_ind_name, buffer const & idxs) { + return get_extension(env).ind_indices_to_ir_range(basic_ind_name, idxs); +} + list get_ginductive_all_mutual_inds(environment const & env) { return get_extension(env).get_all_mutual_inds(); } diff --git a/src/library/inductive_compiler/ginductive.h b/src/library/inductive_compiler/ginductive.h index fde021eee0..61712cad91 100644 --- a/src/library/inductive_compiler/ginductive.h +++ b/src/library/inductive_compiler/ginductive.h @@ -29,6 +29,46 @@ unsigned get_ginductive_num_params(environment const & env, name const & ind_nam /* \brief Returns the names of all types that are mutually inductive with \e ind_name */ list get_ginductive_mut_ind_names(environment const & env, name const & ind_name); +/* \brief Returns the offset of a simulated introduction rule. + +Example: + +inductive foo +| mk1 : list foo -> foo +| mk2 : foo + +0. foo.basic.foo_mk1 +1. foo.basic.foo_mk2 +2. foo.basic.list_nil ==> list.nil +3. foo.basic.list_cons ==> list.cons + +ir_to_simulated_ir_offset("list.nil") = 0 +ir_to_simulated_ir_offset("list.cons") = 0 +ir_to_simulated_ir_offset("foo.basic.foo_mk1") = 0 +ir_to_simulated_ir_offset("foo.basic.foo_mk2") = 0 +ir_to_simulated_ir_offset("foo.basic.list_nil") = 2 +ir_to_simulated_ir_offset("foo.basic.list_cons") = 2 +*/ +unsigned ir_to_simulated_ir_offset(environment const & env, name basic_ir_name); + +/* \brief Returns the range, i.e. (start, number), of the simulated inductive name corresponding to the idxs. +Example: + +inductive foo +| mk1 : list foo -> foo +| mk2 : foo + +0. foo.basic.foo_mk1 +1. foo.basic.foo_mk2 +2. foo.basic.list_nil ==> list.nil +3. foo.basic.list_cons ==> list.cons + +ind_indices_to_ir_range("list", {}) = (0, 2) +ind_indices_to_ir_range("foo.basic", {sum.inl ()}) = (0, 2) +ind_indices_to_ir_range("foo.basic", {sum.inr ()}) = (2, 2) +*/ +pair ind_indices_to_ir_range(environment const & env, name const & basic_ind_name, buffer const & idxs); + /* \brief Returns the names of all mutual ginductive types */ list get_ginductive_all_mutual_inds(environment const & env); diff --git a/src/library/inductive_compiler/ginductive_decl.h b/src/library/inductive_compiler/ginductive_decl.h index ba515a5031..9b25b28cb2 100644 --- a/src/library/inductive_compiler/ginductive_decl.h +++ b/src/library/inductive_compiler/ginductive_decl.h @@ -13,28 +13,44 @@ namespace lean { class ginductive_decl { unsigned m_nest_depth{0}; + bool m_from_mutual; buffer m_lp_names; buffer m_params; buffer m_inds; buffer > m_intro_rules; + buffer m_ir_offsets; // # total intro rules @ basic + buffer > m_idx_to_ir_range; // # total inds @ mutual + optional m_sizeof_lemmas; public: ginductive_decl() {} - ginductive_decl(unsigned nest_depth, buffer const & lp_names, buffer const & params): - m_nest_depth(nest_depth), m_lp_names(lp_names), m_params(params) {} + + ginductive_decl(unsigned nest_depth, buffer const & lp_names, buffer const & params, buffer const & ir_offsets): + m_nest_depth(nest_depth), m_from_mutual(true), m_lp_names(lp_names), m_params(params), m_ir_offsets(ir_offsets) {} + ginductive_decl(unsigned nest_depth, buffer const & lp_names, buffer const & params, buffer const & inds, buffer > const & intro_rules): - m_nest_depth(nest_depth), m_lp_names(lp_names), m_params(params), m_inds(inds), m_intro_rules(intro_rules) {} + m_nest_depth(nest_depth), m_from_mutual(false), m_lp_names(lp_names), m_params(params), m_inds(inds), m_intro_rules(intro_rules) { + for (unsigned ind_idx = 0; ind_idx < inds.size(); ++ind_idx) { + for (unsigned ir_idx = 0; ir_idx < intro_rules[ind_idx].size(); ++ir_idx) { + m_ir_offsets.emplace_back(0); + } + } + } void set_sizeof_lemmas(simp_lemmas const & sizeof_lemmas) { m_sizeof_lemmas = optional(sizeof_lemmas); } + bool is_from_mutual() const { return m_from_mutual; } + bool has_sizeof_lemmas() const { return static_cast(m_sizeof_lemmas); } simp_lemmas get_sizeof_lemmas() const { return *m_sizeof_lemmas; } unsigned get_nest_depth() const { return m_nest_depth; } + bool from_mutual() const { return m_from_mutual; } + bool is_mutual() const { return m_inds.size() > 1; } unsigned get_num_params() const { return m_params.size(); } unsigned get_num_inds() const { return m_inds.size(); } @@ -56,6 +72,12 @@ public: buffer & get_inds() { return m_inds; } buffer > & get_intro_rules() { return m_intro_rules; } + buffer const & get_ir_offsets() const { return m_ir_offsets; } + buffer & get_ir_offsets() { return m_ir_offsets; } + + buffer > const & get_idx_to_ir_range() const { return m_idx_to_ir_range; } + buffer > & get_idx_to_ir_range() { return m_idx_to_ir_range; } + expr mk_const(name const & n) const { return mk_constant(n, get_levels()); } expr mk_const_params(name const & n) const { return mk_app(mk_const(n), m_params); } expr get_c_ind(unsigned ind_idx) const { return mk_const(mlocal_name(m_inds[ind_idx])); } diff --git a/src/library/inductive_compiler/mutual.cpp b/src/library/inductive_compiler/mutual.cpp index 7aaf3695f8..76d56a26ef 100644 --- a/src/library/inductive_compiler/mutual.cpp +++ b/src/library/inductive_compiler/mutual.cpp @@ -177,6 +177,16 @@ class add_mutual_inductive_decl_fn { m_basic_prefix = prefix; } + void compute_idx_to_ir_range() { + unsigned offset = 0; + for (unsigned ind_idx = 0; ind_idx < m_mut_decl.get_num_inds(); ++ind_idx) { + unsigned num_irs = m_mut_decl.get_num_intro_rules(ind_idx); + m_basic_decl.get_idx_to_ir_range().emplace_back(mk_pair(offset, num_irs)); + lean_trace(name({"inductive_compiler", "mutual", "range"}), tout() << ind_idx << " ==> (" << offset << ", " << num_irs << ")\n";); + offset += num_irs; + } + } + void compute_new_ind() { expr ind = mk_local(m_basic_ind_name, mk_arrow(m_full_index_type, get_ind_result_type(m_tctx, m_mut_decl.get_ind(0)))); lean_trace(name({"inductive_compiler", "mutual", "basic_ind"}), tout() << mlocal_name(ind) << " : " << mlocal_type(ind) << "\n";); @@ -753,7 +763,7 @@ public: bool is_trusted): m_env(env), m_opts(opts), m_implicit_infer_map(implicit_infer_map), m_mut_decl(mut_decl), m_is_trusted(is_trusted), - m_basic_decl(m_mut_decl.get_nest_depth() + 1, m_mut_decl.get_lp_names(), m_mut_decl.get_params()), + m_basic_decl(m_mut_decl.get_nest_depth() + 1, m_mut_decl.get_lp_names(), m_mut_decl.get_params(), m_mut_decl.get_ir_offsets()), m_tctx(env, opts) {} environment operator()() { @@ -766,6 +776,8 @@ public: compute_new_ind(); compute_new_intro_rules(); + compute_idx_to_ir_range(); + try { m_env = add_inner_inductive_declaration(m_env, m_opts, m_implicit_infer_map, m_basic_decl, m_is_trusted); } catch (exception & ex) { @@ -799,6 +811,7 @@ void initialize_inductive_compiler_mutual() { register_trace_class(name({"inductive_compiler", "mutual", "new_inds"})); register_trace_class(name({"inductive_compiler", "mutual", "rec"})); register_trace_class(name({"inductive_compiler", "mutual", "sizeof"})); + register_trace_class(name({"inductive_compiler", "mutual", "range"})); g_mutual_suffix = new name("_mut_"); } diff --git a/src/library/inductive_compiler/nested.cpp b/src/library/inductive_compiler/nested.cpp index a84e0c0e83..4d9f31cf2d 100644 --- a/src/library/inductive_compiler/nested.cpp +++ b/src/library/inductive_compiler/nested.cpp @@ -597,6 +597,7 @@ class add_nested_inductive_decl_fn { expr unpack_type(expr const & e) { return unpack_constants(unpack_nested_occs(e)); } void construct_inner_decl() { + unsigned offset = 0; // Construct inner inds for each of the nested inds for (unsigned ind_idx = 0; ind_idx < m_nested_decl.get_num_inds(); ++ind_idx) { expr const & ind = m_nested_decl.get_ind(ind_idx); @@ -608,11 +609,12 @@ class add_nested_inductive_decl_fn { m_inner_decl.get_intro_rules().emplace_back(); for (expr const & ir : m_nested_decl.get_intro_rules(ind_idx)) { + offset++; expr inner_ir = mk_local(mk_inner_name(mlocal_name(ir)), pack_type(mlocal_type(ir))); m_inner_decl.get_intro_rules().back().push_back(inner_ir); - lean_trace(name({"inductive_compiler", "nested", "inner", "ir"}), - tout() << mlocal_name(inner_ir) << " : " << mlocal_type(inner_ir) << "\n";); + lean_trace(name({"inductive_compiler", "nested", "inner", "ir"}), + tout() << mlocal_name(inner_ir) << " : " << mlocal_type(inner_ir) << "\n";); } } @@ -636,8 +638,11 @@ class add_nested_inductive_decl_fn { expr c_mimic_ir = mk_app(mk_constant(ir, const_levels(nested_occ_fn)), nested_occ_params); expr mimic_ir = mk_local(mk_inner_name(ir), pack_type(m_tctx.infer(c_mimic_ir))); m_inner_decl.get_intro_rules().back().push_back(mimic_ir); + m_inner_decl.get_ir_offsets().emplace_back(offset); lean_trace(name({"inductive_compiler", "nested", "mimic", "ir"}), tout() << mlocal_name(mimic_ir) << " : " << mlocal_type(mimic_ir) << "\n";); + lean_trace(name({"inductive_compiler", "nested", "mimic", "ir", "offset"}), + tout() << mlocal_name(mimic_ir) << " ==> " << offset << "\n";); } } @@ -1799,7 +1804,7 @@ public: ginductive_decl & nested_decl, bool is_trusted): m_env(env), m_opts(opts), m_implicit_infer_map(implicit_infer_map), m_nested_decl(nested_decl), m_is_trusted(is_trusted), - m_inner_decl(m_nested_decl.get_nest_depth() + 1, m_nested_decl.get_lp_names(), m_nested_decl.get_params()), + m_inner_decl(m_nested_decl.get_nest_depth() + 1, m_nested_decl.get_lp_names(), m_nested_decl.get_params(), m_nested_decl.get_ir_offsets()), m_tctx(env, opts, transparency_mode::Semireducible) { } optional operator()() { @@ -1841,6 +1846,7 @@ void initialize_inductive_compiler_nested() { register_trace_class(name({"inductive_compiler", "nested", "mimic"})); register_trace_class(name({"inductive_compiler", "nested", "mimic", "ind"})); register_trace_class(name({"inductive_compiler", "nested", "mimic", "ir"})); + register_trace_class(name({"inductive_compiler", "nested", "mimic", "ir", "offset"})); register_trace_class(name({"inductive_compiler", "nested", "inner"})); register_trace_class(name({"inductive_compiler", "nested", "inner", "ind"}));