feat(inductive_compiler): APIs for simulated constructor offsets
This commit is contained in:
parent
f460cbdf2e
commit
e9c05f727c
5 changed files with 164 additions and 7 deletions
|
|
@ -7,18 +7,42 @@ Author: Daniel Selsam
|
|||
#include <utility>
|
||||
#include <string>
|
||||
#include "util/serializer.h"
|
||||
#include "util/list_fn.h"
|
||||
#include "kernel/environment.h"
|
||||
#include "library/inductive_compiler/ginductive.h"
|
||||
#include "library/module.h"
|
||||
#include "library/constants.h"
|
||||
#include "library/kernel_serializer.h"
|
||||
|
||||
namespace lean {
|
||||
|
||||
static unsigned compute_idx_number(expr const & e) {
|
||||
buffer<expr> args;
|
||||
unsigned idx = 0;
|
||||
expr it = e;
|
||||
while (true) {
|
||||
args.clear();
|
||||
expr fn = get_app_args(it, args);
|
||||
if (is_constant(fn) && const_name(fn) == get_psum_inl_name()) {
|
||||
return idx;
|
||||
} else if (is_constant(fn) && const_name(fn) == get_psum_inr_name()) {
|
||||
idx++;
|
||||
it = args[2];
|
||||
} else {
|
||||
return idx;
|
||||
}
|
||||
}
|
||||
lean_unreachable();
|
||||
}
|
||||
|
||||
struct ginductive_entry {
|
||||
ginductive_kind m_kind;
|
||||
bool m_from_mutual;
|
||||
unsigned m_num_params;
|
||||
list<name> m_inds;
|
||||
list<list<name> > m_intro_rules;
|
||||
list<unsigned> m_ir_offsets;
|
||||
list<pair<unsigned, unsigned> > m_idx_to_ir_range;
|
||||
};
|
||||
|
||||
inline serializer & operator<<(serializer & s, ginductive_kind k) {
|
||||
|
|
@ -45,16 +69,22 @@ inline deserializer & operator>>(deserializer & d, ginductive_entry & entry);
|
|||
|
||||
serializer & operator<<(serializer & s, ginductive_entry const & entry) {
|
||||
s << entry.m_kind;
|
||||
s << entry.m_from_mutual;
|
||||
s << entry.m_num_params;
|
||||
write_list<name>(s, entry.m_inds);
|
||||
for (list<name> const & irs : reverse(entry.m_intro_rules))
|
||||
write_list<name>(s, irs);
|
||||
|
||||
write_list<unsigned>(s, entry.m_ir_offsets);
|
||||
write_list<pair<unsigned, unsigned> >(s, entry.m_idx_to_ir_range);
|
||||
|
||||
return s;
|
||||
}
|
||||
|
||||
ginductive_entry read_ginductive_entry(deserializer & d) {
|
||||
ginductive_entry entry;
|
||||
d >> entry.m_kind;
|
||||
d >> entry.m_from_mutual;
|
||||
d >> entry.m_num_params;
|
||||
entry.m_inds = read_list<name>(d, read_name);
|
||||
|
||||
|
|
@ -62,6 +92,9 @@ ginductive_entry read_ginductive_entry(deserializer & d) {
|
|||
for (unsigned i = 0; i < num_inds; ++i) {
|
||||
entry.m_intro_rules = list<list<name> >(read_list<name>(d, read_name), entry.m_intro_rules);
|
||||
}
|
||||
|
||||
entry.m_ir_offsets = read_list<unsigned>(d);
|
||||
entry.m_idx_to_ir_range = read_list<pair<unsigned, unsigned> >(d);
|
||||
return entry;
|
||||
}
|
||||
|
||||
|
|
@ -79,26 +112,43 @@ struct ginductive_env_ext : public environment_extension {
|
|||
name_map<unsigned> m_num_params;
|
||||
name_map<name> m_ir_to_ind;
|
||||
|
||||
name_set m_from_mutual;
|
||||
name_map<unsigned> m_ir_to_simulated_ir_offset;
|
||||
name_map<list<pair<unsigned, unsigned> > > m_ind_to_ir_ranges;
|
||||
|
||||
ginductive_env_ext() {}
|
||||
|
||||
void register_ginductive_entry(ginductive_entry const & entry) {
|
||||
buffer<list<name> > intro_rules;
|
||||
to_buffer(entry.m_intro_rules, intro_rules);
|
||||
|
||||
buffer<unsigned> ir_offsets;
|
||||
to_buffer(entry.m_ir_offsets, ir_offsets);
|
||||
|
||||
unsigned ind_idx = 0;
|
||||
unsigned acc_ir_idx = 0;
|
||||
for (name const & ind : entry.m_inds) {
|
||||
switch (entry.m_kind) {
|
||||
case ginductive_kind::BASIC: break;
|
||||
case ginductive_kind::MUTUAL: m_all_mutual_inds = list<name>(ind, m_all_mutual_inds); break;
|
||||
case ginductive_kind::NESTED: m_all_nested_inds = list<name>(ind, m_all_nested_inds); break;
|
||||
}
|
||||
|
||||
if (entry.m_from_mutual)
|
||||
m_from_mutual.insert(ind);
|
||||
|
||||
m_ind_to_irs.insert(ind, intro_rules[ind_idx]);
|
||||
m_ind_to_mut_inds.insert(ind, entry.m_inds);
|
||||
m_ind_to_kind.insert(ind, entry.m_kind);
|
||||
m_num_params.insert(ind, entry.m_num_params);
|
||||
for (name const & ir : intro_rules[ind_idx]) {
|
||||
m_ir_to_ind.insert(ir, ind);
|
||||
m_ir_to_simulated_ir_offset.insert(ir, ir_offsets[acc_ir_idx]);
|
||||
acc_ir_idx++;
|
||||
}
|
||||
|
||||
m_ind_to_ir_ranges.insert(ind, entry.m_idx_to_ir_range);
|
||||
|
||||
ind_idx++;
|
||||
}
|
||||
}
|
||||
|
|
@ -137,6 +187,24 @@ struct ginductive_env_ext : public environment_extension {
|
|||
return *mut_ind_names;
|
||||
}
|
||||
|
||||
unsigned ir_to_simulated_ir_offset(name basic_ir_name) const {
|
||||
unsigned const * offset = m_ir_to_simulated_ir_offset.find(basic_ir_name);
|
||||
lean_assert(assert);
|
||||
return *offset;
|
||||
}
|
||||
|
||||
pair<unsigned, unsigned> ind_indices_to_ir_range(name const & basic_ind_name, buffer<expr> const & idxs) const {
|
||||
if (!m_from_mutual.contains(basic_ind_name))
|
||||
return mk_pair(0, length(get_intro_rules(basic_ind_name)));
|
||||
|
||||
lean_assert(idxs.size == 1);
|
||||
unsigned idx_number = compute_idx_number(idxs[0]);
|
||||
|
||||
list<pair<unsigned, unsigned> > const * ranges = m_ind_to_ir_ranges.find(basic_ind_name);
|
||||
lean_assert(ranges);
|
||||
return get_ith(*ranges, idx_number);
|
||||
}
|
||||
|
||||
list<name> get_all_nested_inds() const {
|
||||
return m_all_nested_inds;
|
||||
}
|
||||
|
|
@ -229,6 +297,14 @@ list<name> get_ginductive_mut_ind_names(environment const & env, name const & in
|
|||
return get_extension(env).get_mut_ind_names(ind_name);
|
||||
}
|
||||
|
||||
unsigned ir_to_simulated_ir_offset(environment const & env, name basic_ir_name) {
|
||||
return get_extension(env).ir_to_simulated_ir_offset(basic_ir_name);
|
||||
}
|
||||
|
||||
pair<unsigned, unsigned> ind_indices_to_ir_range(environment const & env, name const & basic_ind_name, buffer<expr> const & idxs) {
|
||||
return get_extension(env).ind_indices_to_ir_range(basic_ind_name, idxs);
|
||||
}
|
||||
|
||||
list<name> get_ginductive_all_mutual_inds(environment const & env) {
|
||||
return get_extension(env).get_all_mutual_inds();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -29,6 +29,46 @@ unsigned get_ginductive_num_params(environment const & env, name const & ind_nam
|
|||
/* \brief Returns the names of all types that are mutually inductive with \e ind_name */
|
||||
list<name> get_ginductive_mut_ind_names(environment const & env, name const & ind_name);
|
||||
|
||||
/* \brief Returns the offset of a simulated introduction rule.
|
||||
|
||||
Example:
|
||||
|
||||
inductive foo
|
||||
| mk1 : list foo -> foo
|
||||
| mk2 : foo
|
||||
|
||||
0. foo.basic.foo_mk1
|
||||
1. foo.basic.foo_mk2
|
||||
2. foo.basic.list_nil ==> list.nil
|
||||
3. foo.basic.list_cons ==> list.cons
|
||||
|
||||
ir_to_simulated_ir_offset("list.nil") = 0
|
||||
ir_to_simulated_ir_offset("list.cons") = 0
|
||||
ir_to_simulated_ir_offset("foo.basic.foo_mk1") = 0
|
||||
ir_to_simulated_ir_offset("foo.basic.foo_mk2") = 0
|
||||
ir_to_simulated_ir_offset("foo.basic.list_nil") = 2
|
||||
ir_to_simulated_ir_offset("foo.basic.list_cons") = 2
|
||||
*/
|
||||
unsigned ir_to_simulated_ir_offset(environment const & env, name basic_ir_name);
|
||||
|
||||
/* \brief Returns the range, i.e. (start, number), of the simulated inductive name corresponding to the idxs.
|
||||
Example:
|
||||
|
||||
inductive foo
|
||||
| mk1 : list foo -> foo
|
||||
| mk2 : foo
|
||||
|
||||
0. foo.basic.foo_mk1
|
||||
1. foo.basic.foo_mk2
|
||||
2. foo.basic.list_nil ==> list.nil
|
||||
3. foo.basic.list_cons ==> list.cons
|
||||
|
||||
ind_indices_to_ir_range("list", {}) = (0, 2)
|
||||
ind_indices_to_ir_range("foo.basic", {sum.inl ()}) = (0, 2)
|
||||
ind_indices_to_ir_range("foo.basic", {sum.inr ()}) = (2, 2)
|
||||
*/
|
||||
pair<unsigned, unsigned> ind_indices_to_ir_range(environment const & env, name const & basic_ind_name, buffer<expr> const & idxs);
|
||||
|
||||
/* \brief Returns the names of all mutual ginductive types */
|
||||
list<name> get_ginductive_all_mutual_inds(environment const & env);
|
||||
|
||||
|
|
|
|||
|
|
@ -13,28 +13,44 @@ namespace lean {
|
|||
|
||||
class ginductive_decl {
|
||||
unsigned m_nest_depth{0};
|
||||
bool m_from_mutual;
|
||||
buffer<name> m_lp_names;
|
||||
buffer<expr> m_params;
|
||||
buffer<expr> m_inds;
|
||||
buffer<buffer<expr> > m_intro_rules;
|
||||
|
||||
buffer<unsigned> m_ir_offsets; // # total intro rules @ basic
|
||||
buffer<pair<unsigned, unsigned> > m_idx_to_ir_range; // # total inds @ mutual
|
||||
|
||||
optional<simp_lemmas> m_sizeof_lemmas;
|
||||
public:
|
||||
ginductive_decl() {}
|
||||
ginductive_decl(unsigned nest_depth, buffer<name> const & lp_names, buffer<expr> const & params):
|
||||
m_nest_depth(nest_depth), m_lp_names(lp_names), m_params(params) {}
|
||||
|
||||
ginductive_decl(unsigned nest_depth, buffer<name> const & lp_names, buffer<expr> const & params, buffer<unsigned> const & ir_offsets):
|
||||
m_nest_depth(nest_depth), m_from_mutual(true), m_lp_names(lp_names), m_params(params), m_ir_offsets(ir_offsets) {}
|
||||
|
||||
ginductive_decl(unsigned nest_depth, buffer<name> const & lp_names, buffer<expr> const & params,
|
||||
buffer<expr> const & inds, buffer<buffer<expr> > const & intro_rules):
|
||||
m_nest_depth(nest_depth), m_lp_names(lp_names), m_params(params), m_inds(inds), m_intro_rules(intro_rules) {}
|
||||
m_nest_depth(nest_depth), m_from_mutual(false), m_lp_names(lp_names), m_params(params), m_inds(inds), m_intro_rules(intro_rules) {
|
||||
for (unsigned ind_idx = 0; ind_idx < inds.size(); ++ind_idx) {
|
||||
for (unsigned ir_idx = 0; ir_idx < intro_rules[ind_idx].size(); ++ir_idx) {
|
||||
m_ir_offsets.emplace_back(0);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void set_sizeof_lemmas(simp_lemmas const & sizeof_lemmas) {
|
||||
m_sizeof_lemmas = optional<simp_lemmas>(sizeof_lemmas);
|
||||
}
|
||||
|
||||
bool is_from_mutual() const { return m_from_mutual; }
|
||||
|
||||
bool has_sizeof_lemmas() const { return static_cast<bool>(m_sizeof_lemmas); }
|
||||
simp_lemmas get_sizeof_lemmas() const { return *m_sizeof_lemmas; }
|
||||
|
||||
unsigned get_nest_depth() const { return m_nest_depth; }
|
||||
bool from_mutual() const { return m_from_mutual; }
|
||||
|
||||
bool is_mutual() const { return m_inds.size() > 1; }
|
||||
unsigned get_num_params() const { return m_params.size(); }
|
||||
unsigned get_num_inds() const { return m_inds.size(); }
|
||||
|
|
@ -56,6 +72,12 @@ public:
|
|||
buffer<expr> & get_inds() { return m_inds; }
|
||||
buffer<buffer<expr> > & get_intro_rules() { return m_intro_rules; }
|
||||
|
||||
buffer<unsigned> const & get_ir_offsets() const { return m_ir_offsets; }
|
||||
buffer<unsigned> & get_ir_offsets() { return m_ir_offsets; }
|
||||
|
||||
buffer<pair<unsigned, unsigned> > const & get_idx_to_ir_range() const { return m_idx_to_ir_range; }
|
||||
buffer<pair<unsigned, unsigned> > & get_idx_to_ir_range() { return m_idx_to_ir_range; }
|
||||
|
||||
expr mk_const(name const & n) const { return mk_constant(n, get_levels()); }
|
||||
expr mk_const_params(name const & n) const { return mk_app(mk_const(n), m_params); }
|
||||
expr get_c_ind(unsigned ind_idx) const { return mk_const(mlocal_name(m_inds[ind_idx])); }
|
||||
|
|
|
|||
|
|
@ -177,6 +177,16 @@ class add_mutual_inductive_decl_fn {
|
|||
m_basic_prefix = prefix;
|
||||
}
|
||||
|
||||
void compute_idx_to_ir_range() {
|
||||
unsigned offset = 0;
|
||||
for (unsigned ind_idx = 0; ind_idx < m_mut_decl.get_num_inds(); ++ind_idx) {
|
||||
unsigned num_irs = m_mut_decl.get_num_intro_rules(ind_idx);
|
||||
m_basic_decl.get_idx_to_ir_range().emplace_back(mk_pair(offset, num_irs));
|
||||
lean_trace(name({"inductive_compiler", "mutual", "range"}), tout() << ind_idx << " ==> (" << offset << ", " << num_irs << ")\n";);
|
||||
offset += num_irs;
|
||||
}
|
||||
}
|
||||
|
||||
void compute_new_ind() {
|
||||
expr ind = mk_local(m_basic_ind_name, mk_arrow(m_full_index_type, get_ind_result_type(m_tctx, m_mut_decl.get_ind(0))));
|
||||
lean_trace(name({"inductive_compiler", "mutual", "basic_ind"}), tout() << mlocal_name(ind) << " : " << mlocal_type(ind) << "\n";);
|
||||
|
|
@ -753,7 +763,7 @@ public:
|
|||
bool is_trusted):
|
||||
m_env(env), m_opts(opts), m_implicit_infer_map(implicit_infer_map),
|
||||
m_mut_decl(mut_decl), m_is_trusted(is_trusted),
|
||||
m_basic_decl(m_mut_decl.get_nest_depth() + 1, m_mut_decl.get_lp_names(), m_mut_decl.get_params()),
|
||||
m_basic_decl(m_mut_decl.get_nest_depth() + 1, m_mut_decl.get_lp_names(), m_mut_decl.get_params(), m_mut_decl.get_ir_offsets()),
|
||||
m_tctx(env, opts) {}
|
||||
|
||||
environment operator()() {
|
||||
|
|
@ -766,6 +776,8 @@ public:
|
|||
compute_new_ind();
|
||||
compute_new_intro_rules();
|
||||
|
||||
compute_idx_to_ir_range();
|
||||
|
||||
try {
|
||||
m_env = add_inner_inductive_declaration(m_env, m_opts, m_implicit_infer_map, m_basic_decl, m_is_trusted);
|
||||
} catch (exception & ex) {
|
||||
|
|
@ -799,6 +811,7 @@ void initialize_inductive_compiler_mutual() {
|
|||
register_trace_class(name({"inductive_compiler", "mutual", "new_inds"}));
|
||||
register_trace_class(name({"inductive_compiler", "mutual", "rec"}));
|
||||
register_trace_class(name({"inductive_compiler", "mutual", "sizeof"}));
|
||||
register_trace_class(name({"inductive_compiler", "mutual", "range"}));
|
||||
|
||||
g_mutual_suffix = new name("_mut_");
|
||||
}
|
||||
|
|
|
|||
|
|
@ -597,6 +597,7 @@ class add_nested_inductive_decl_fn {
|
|||
expr unpack_type(expr const & e) { return unpack_constants(unpack_nested_occs(e)); }
|
||||
|
||||
void construct_inner_decl() {
|
||||
unsigned offset = 0;
|
||||
// Construct inner inds for each of the nested inds
|
||||
for (unsigned ind_idx = 0; ind_idx < m_nested_decl.get_num_inds(); ++ind_idx) {
|
||||
expr const & ind = m_nested_decl.get_ind(ind_idx);
|
||||
|
|
@ -608,11 +609,12 @@ class add_nested_inductive_decl_fn {
|
|||
|
||||
m_inner_decl.get_intro_rules().emplace_back();
|
||||
for (expr const & ir : m_nested_decl.get_intro_rules(ind_idx)) {
|
||||
offset++;
|
||||
expr inner_ir = mk_local(mk_inner_name(mlocal_name(ir)), pack_type(mlocal_type(ir)));
|
||||
m_inner_decl.get_intro_rules().back().push_back(inner_ir);
|
||||
|
||||
lean_trace(name({"inductive_compiler", "nested", "inner", "ir"}),
|
||||
tout() << mlocal_name(inner_ir) << " : " << mlocal_type(inner_ir) << "\n";);
|
||||
lean_trace(name({"inductive_compiler", "nested", "inner", "ir"}),
|
||||
tout() << mlocal_name(inner_ir) << " : " << mlocal_type(inner_ir) << "\n";);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -636,8 +638,11 @@ class add_nested_inductive_decl_fn {
|
|||
expr c_mimic_ir = mk_app(mk_constant(ir, const_levels(nested_occ_fn)), nested_occ_params);
|
||||
expr mimic_ir = mk_local(mk_inner_name(ir), pack_type(m_tctx.infer(c_mimic_ir)));
|
||||
m_inner_decl.get_intro_rules().back().push_back(mimic_ir);
|
||||
m_inner_decl.get_ir_offsets().emplace_back(offset);
|
||||
lean_trace(name({"inductive_compiler", "nested", "mimic", "ir"}),
|
||||
tout() << mlocal_name(mimic_ir) << " : " << mlocal_type(mimic_ir) << "\n";);
|
||||
lean_trace(name({"inductive_compiler", "nested", "mimic", "ir", "offset"}),
|
||||
tout() << mlocal_name(mimic_ir) << " ==> " << offset << "\n";);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -1799,7 +1804,7 @@ public:
|
|||
ginductive_decl & nested_decl, bool is_trusted):
|
||||
m_env(env), m_opts(opts), m_implicit_infer_map(implicit_infer_map),
|
||||
m_nested_decl(nested_decl), m_is_trusted(is_trusted),
|
||||
m_inner_decl(m_nested_decl.get_nest_depth() + 1, m_nested_decl.get_lp_names(), m_nested_decl.get_params()),
|
||||
m_inner_decl(m_nested_decl.get_nest_depth() + 1, m_nested_decl.get_lp_names(), m_nested_decl.get_params(), m_nested_decl.get_ir_offsets()),
|
||||
m_tctx(env, opts, transparency_mode::Semireducible) { }
|
||||
|
||||
optional<environment> operator()() {
|
||||
|
|
@ -1841,6 +1846,7 @@ void initialize_inductive_compiler_nested() {
|
|||
register_trace_class(name({"inductive_compiler", "nested", "mimic"}));
|
||||
register_trace_class(name({"inductive_compiler", "nested", "mimic", "ind"}));
|
||||
register_trace_class(name({"inductive_compiler", "nested", "mimic", "ir"}));
|
||||
register_trace_class(name({"inductive_compiler", "nested", "mimic", "ir", "offset"}));
|
||||
|
||||
register_trace_class(name({"inductive_compiler", "nested", "inner"}));
|
||||
register_trace_class(name({"inductive_compiler", "nested", "inner", "ind"}));
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue