refactor(library/vm): use global indices for declarations and cases

This commit is contained in:
Gabriel Ebner 2016-12-04 08:49:28 -05:00 committed by Leonardo de Moura
parent 4c01cb503c
commit 41e8a1712e
2 changed files with 137 additions and 124 deletions

View file

@ -16,6 +16,7 @@ Author: Leonardo de Moura
#include "util/sstream.h"
#include "util/small_object_allocator.h"
#include "util/sexpr/option_declarations.h"
#include "util/shared_mutex.h"
#include "library/constants.h"
#include "library/kernel_serializer.h"
#include "library/trace.h"
@ -272,14 +273,14 @@ void vm_obj_cell::dealloc() {
}
}
void display(std::ostream & out, vm_obj const & o, std::function<optional<name>(unsigned)> const & idx2name) {
void display(std::ostream & out, vm_obj const & o) {
if (is_simple(o)) {
out << cidx(o);
} else if (is_constructor(o)) {
out << "(#" << cidx(o);
for (unsigned i = 0; i < csize(o); i++) {
out << " ";
display(out, cfield(o, i), idx2name);
display(out, cfield(o, i));
}
out << ")";
} else if (is_mpz(o)) {
@ -287,14 +288,14 @@ void display(std::ostream & out, vm_obj const & o, std::function<optional<name>(
} else if (is_external(o)) {
out << "[external]";
} else if (is_closure(o)) {
if (auto n = idx2name(cfn_idx(o))) {
if (auto n = find_vm_name(cfn_idx(o))) {
out << "(" << *n;
} else {
out << "(fn#" << cfn_idx(o);
}
for (unsigned i = 0; i < csize(o); i++) {
out << " ";
display(out, cfield(o, i), idx2name);
display(out, cfield(o, i));
}
out << ")";
} else if (is_native_closure(o)) {
@ -302,7 +303,7 @@ void display(std::ostream & out, vm_obj const & o, std::function<optional<name>(
vm_obj const * args = to_native_closure(o)->get_args();
for (unsigned i = 0; i < to_native_closure(o)->get_num_args(); i++) {
out << " ";
display(out, args[i], idx2name);
display(out, args[i]);
}
out << ")";
} else {
@ -310,24 +311,18 @@ void display(std::ostream & out, vm_obj const & o, std::function<optional<name>(
}
}
void display(std::ostream & out, vm_obj const & o) {
display(out, o, [](unsigned) { return optional<name>(); });
}
static void display_fn(std::ostream & out, std::function<optional<name>(unsigned)> const & idx2name, unsigned fn_idx) {
if (auto r = idx2name(fn_idx))
static void display_fn(std::ostream & out, unsigned fn_idx) {
if (auto r = find_vm_name(fn_idx))
out << *r;
else
out << fn_idx;
}
static void display_builtin_cases(std::ostream & out, std::function<optional<name>(unsigned)> const & idx2name, unsigned cases_idx) {
display_fn(out, idx2name, cases_idx);
static void display_builtin_cases(std::ostream & out, unsigned cases_idx) {
display_fn(out, cases_idx);
}
void vm_instr::display(std::ostream & out,
std::function<optional<name>(unsigned)> const & idx2name,
std::function<optional<name>(unsigned)> const & cases_idx2name) const {
void vm_instr::display(std::ostream & out) const {
switch (m_op) {
case opcode::Push: out << "push " << m_idx; break;
case opcode::Ret: out << "ret"; break;
@ -346,7 +341,7 @@ void vm_instr::display(std::ostream & out,
break;
case opcode::BuiltinCases:
out << "builtin_cases ";
display_builtin_cases(out, cases_idx2name, get_cases_idx());
display_builtin_cases(out, get_cases_idx());
out << ",";
for (unsigned i = 0; i < get_casesn_size(); i++)
out << " " << get_casesn_pc(i);
@ -356,19 +351,19 @@ void vm_instr::display(std::ostream & out,
case opcode::Apply: out << "apply"; break;
case opcode::InvokeGlobal:
out << "ginvoke ";
display_fn(out, idx2name, m_fn_idx);
display_fn(out, m_fn_idx);
break;
case opcode::InvokeBuiltin:
out << "builtin ";
display_fn(out, idx2name, m_fn_idx);
display_fn(out, m_fn_idx);
break;
case opcode::InvokeCFun:
out << "cfun ";
display_fn(out, idx2name, m_fn_idx);
display_fn(out, m_fn_idx);
break;
case opcode::Closure:
out << "closure ";
display_fn(out, idx2name, m_fn_idx);
display_fn(out, m_fn_idx);
out << " " << m_nargs;
break;
case opcode::Pexpr:
@ -378,10 +373,6 @@ void vm_instr::display(std::ostream & out,
}
}
void vm_instr::display(std::ostream & out) const {
display(out, [](unsigned) { return optional<name>(); }, [](unsigned) { return optional<name>(); });
}
unsigned vm_instr::get_num_pcs() const {
switch (m_op) {
case opcode::Goto:
@ -723,13 +714,10 @@ void vm_instr::serialize(serializer & s, std::function<name(unsigned)> const & i
}
}
static unsigned read_fn_idx(deserializer & d, name_map<unsigned> const & name2idx) {
static unsigned read_fn_idx(deserializer & d) {
name n;
d >> n;
if (auto r = name2idx.find(n))
return *r;
else
throw corrupted_stream_exception();
return get_vm_index(n);
}
static void read_cases_pcs(deserializer & d, buffer<unsigned> & pcs) {
@ -738,18 +726,18 @@ static void read_cases_pcs(deserializer & d, buffer<unsigned> & pcs) {
pcs.push_back(d.read_unsigned());
}
static vm_instr read_vm_instr(deserializer & d, name_map<unsigned> const & name2idx) {
static vm_instr read_vm_instr(deserializer & d) {
opcode op = static_cast<opcode>(d.read_char());
unsigned pc, idx;
switch (op) {
case opcode::InvokeGlobal:
return mk_invoke_global_instr(read_fn_idx(d, name2idx));
return mk_invoke_global_instr(read_fn_idx(d));
case opcode::InvokeBuiltin:
return mk_invoke_builtin_instr(read_fn_idx(d, name2idx));
return mk_invoke_builtin_instr(read_fn_idx(d));
case opcode::InvokeCFun:
return mk_invoke_cfun_instr(read_fn_idx(d, name2idx));
return mk_invoke_cfun_instr(read_fn_idx(d));
case opcode::Closure:
idx = read_fn_idx(d, name2idx);
idx = read_fn_idx(d);
return mk_closure_instr(idx, d.read_unsigned());
case opcode::Push:
return mk_push_instr(d.read_unsigned());
@ -906,58 +894,40 @@ void declare_vm_cases_builtin(name const & n, char const * i, vm_cases_function
/** \brief VM function/constant declarations are stored in an environment extension. */
struct vm_decls : public environment_extension {
name_map<unsigned> m_name2idx;
unsigned_map<vm_decl> m_decls;
unsigned m_next_decl_idx{0};
name_map<unsigned> m_cases2idx;
unsigned_map<vm_decl> m_decls;
unsigned_map<vm_cases_function> m_cases;
unsigned_map<name> m_cases_names;
unsigned m_next_cases_idx{0};
name m_monitor;
vm_decls() {
g_vm_builtins->for_each([&](name const & n, std::tuple<unsigned, char const *, vm_function> const & p) {
add_core(vm_decl(n, m_next_decl_idx, std::get<0>(p), std::get<2>(p)));
m_next_decl_idx++;
add_core(vm_decl(n, get_vm_index(n), std::get<0>(p), std::get<2>(p)));
});
g_vm_cbuiltins->for_each([&](name const & n, std::tuple<unsigned, char const *, vm_cfunction> const & p) {
add_core(vm_decl(n, m_next_decl_idx, std::get<0>(p), std::get<2>(p)));
m_next_decl_idx++;
add_core(vm_decl(n, get_vm_index(n), std::get<0>(p), std::get<2>(p)));
});
g_vm_cases_builtins->for_each([&](name const & n, std::tuple<char const *, vm_cases_function> const & p) {
unsigned idx = m_next_cases_idx;
m_cases2idx.insert(n, idx);
unsigned idx = get_vm_index(n);
m_cases.insert(idx, std::get<1>(p));
m_cases_names.insert(idx, n);
m_next_cases_idx++;
});
}
void add_core(vm_decl const & d) {
if (m_name2idx.contains(d.get_name()))
if (m_decls.contains(d.get_idx()))
throw exception(sstream() << "VM already contains code for '" << d.get_name() << "'");
m_name2idx.insert(d.get_name(), d.get_idx());
m_decls.insert(d.get_idx(), d);
}
void add_native(name const & n, unsigned arity, vm_cfunction fn) {
if (auto idx = m_name2idx.find(n)) {
lean_assert(m_decls.find(*idx)->get_arity() == arity);
m_decls.insert(*idx, vm_decl(n, *idx, arity, fn));
} else {
add_core(vm_decl(n, m_next_decl_idx, arity, fn));
m_next_decl_idx++;
}
auto idx = get_vm_index(n);
DEBUG_CODE(if (auto decl = m_decls.find(idx)) lean_assert(decl->get_arity() == arity);)
m_decls.insert(idx, vm_decl(n, idx, arity, fn));
}
unsigned reserve(name const & n, expr const & e) {
if (m_name2idx.contains(n))
unsigned idx = get_vm_index(n);
if (m_decls.contains(idx))
throw exception(sstream() << "VM already contains code for '" << n << "'");
unsigned idx = m_next_decl_idx;
m_next_decl_idx++;
m_name2idx.insert(n, idx);
m_decls.insert(idx, vm_decl(n, idx, e, 0, nullptr, list<vm_local_info>(), optional<pos_info>()));
return idx;
}
@ -965,9 +935,9 @@ struct vm_decls : public environment_extension {
void update(name const & n, unsigned code_sz, vm_instr const * code,
list<vm_local_info> const & args_info, optional<pos_info> const & pos,
optional<std::string> const & olean = optional<std::string>()) {
lean_assert(m_name2idx.contains(n));
unsigned idx = *m_name2idx.find(n);
unsigned idx = get_vm_index(n);
vm_decl const * d = m_decls.find(idx);
lean_assert(d);
m_decls.insert(idx, vm_decl(n, idx, d->get_expr(), code_sz, code, args_info, pos, olean));
}
};
@ -1037,21 +1007,23 @@ environment add_native(environment const & env, name const & n, unsigned arity,
bool is_vm_function(environment const & env, name const & fn) {
auto const & ext = get_extension(env);
return ext.m_name2idx.contains(fn) || g_vm_builtins->contains(fn);
return ext.m_decls.contains(get_vm_index(fn)) || g_vm_builtins->contains(fn);
}
optional<unsigned> get_vm_constant_idx(environment const & env, name const & n) {
auto const & ext = get_extension(env);
if (auto r = ext.m_name2idx.find(n))
return optional<unsigned>(*r);
auto idx = get_vm_index(n);
if (ext.m_decls.contains(idx))
return optional<unsigned>(idx);
else
return optional<unsigned>();
}
optional<unsigned> get_vm_builtin_idx(name const & n) {
lean_assert(g_ext);
if (auto r = g_ext->m_init_decls->m_name2idx.find(n))
return optional<unsigned>(*r);
auto idx = get_vm_index(n);
if (g_ext->m_init_decls->m_decls.contains(idx))
return optional<unsigned>(idx);
else
return optional<unsigned>();
}
@ -1094,7 +1066,7 @@ static void code_reader(deserializer & d, environment & env) {
vm_decls ext = get_extension(env);
buffer<vm_instr> code;
for (unsigned i = 0; i < code_sz; i++) {
code.push_back(read_vm_instr(d, ext.m_name2idx));
code.push_back(read_vm_instr(d));
}
ext.update(fn, code_sz, code.data(), args_info, pos, d.get_fname());
env = update(env, ext);
@ -1105,7 +1077,7 @@ environment update_vm_code(environment const & env, name const & fn, unsigned co
vm_decls ext = get_extension(env);
ext.update(fn, code_sz, code, args_info, pos);
environment new_env = update(env, ext);
unsigned fidx = *ext.m_name2idx.find(fn);
unsigned fidx = get_vm_index(fn);
return module::add(new_env, *g_vm_code_key, [=](environment const & env, serializer & s) {
serialize_code(s, fidx, get_extension(env).m_decls);
});
@ -1119,16 +1091,17 @@ environment add_vm_code(environment const & env, name const & fn, expr const & e
optional<vm_decl> get_vm_decl(environment const & env, name const & n) {
vm_decls const & ext = get_extension(env);
if (auto idx = ext.m_name2idx.find(n))
return optional<vm_decl>(*ext.m_decls.find(*idx));
if (auto decl = ext.m_decls.find(get_vm_index(n)))
return optional<vm_decl>(*decl);
else
return optional<vm_decl>();
}
optional<unsigned> get_vm_builtin_cases_idx(environment const & env, name const & n) {
vm_decls const & ext = get_extension(env);
if (auto idx = ext.m_cases2idx.find(n))
return optional<unsigned>(*idx);
auto idx = get_vm_index(n);
if (ext.m_cases.contains(idx))
return optional<unsigned>(idx);
else
return optional<unsigned>();
}
@ -1149,11 +1122,9 @@ vm_state::vm_state(environment const & env, options const & opts):
m_env(env),
m_options(opts),
m_decl_map(get_extension(m_env).m_decls),
m_decl_vector(get_extension(m_env).m_next_decl_idx),
m_decl_vector(get_vm_index_bound()),
m_builtin_cases_map(get_extension(m_env).m_cases),
m_builtin_cases_vector(get_extension(m_env).m_next_cases_idx),
m_builtin_cases_names(get_extension(m_env).m_cases_names),
m_fn_name2idx(get_extension(m_env).m_name2idx),
m_builtin_cases_vector(get_vm_index_bound()),
m_code(nullptr),
m_fn_idx(g_null_fn_idx),
m_bp(0) {
@ -1220,9 +1191,7 @@ void vm_state::update_env(environment const & env) {
m_env = env;
auto ext = get_extension(env);
m_decl_map = ext.m_decls;
lean_assert(ext.m_next_decl_idx >= m_decl_vector.size());
m_decl_vector.resize(ext.m_next_decl_idx);
m_fn_name2idx = ext.m_name2idx;
m_decl_vector.resize(get_vm_index_bound());
lean_assert(is_eqp(m_builtin_cases_map, ext.m_cases));
}
@ -1404,8 +1373,9 @@ vm_obj vm_state::invoke(unsigned fn_idx, unsigned nargs, vm_obj const * as) {
}
vm_obj vm_state::invoke(name const & fn, unsigned nargs, vm_obj const * as) {
if (auto r = m_fn_name2idx.find(fn)) {
return invoke(*r, nargs, as);
auto idx = get_vm_index(fn);
if (m_decl_map.contains(idx)) {
return invoke(idx, nargs, as);
} else {
throw exception(sstream() << "VM does not have code for '" << fn << "'");
}
@ -2524,13 +2494,7 @@ void vm_state::run() {
/* We only trace VM in debug mode */
lean_trace(name({"vm", "run"}),
tout() << m_pc << ": ";
instr.display(tout().get_stream(),
[&](unsigned idx) {
return optional<name>(get_decl(idx).get_name());
},
[&](unsigned idx) {
return optional<name>(*m_builtin_cases_names.find(idx));
});
instr.display(tout().get_stream());
tout() << "\n";
display_stack(tout().get_stream());
tout() << "\n";)
@ -2965,8 +2929,9 @@ void vm_state::run() {
}
void vm_state::invoke_fn(name const & fn) {
if (auto r = m_fn_name2idx.find(fn)) {
invoke_fn(*r);
auto idx = get_vm_index(fn);
if (m_decl_map.contains(idx)) {
invoke_fn(idx);
} else {
throw exception(sstream() << "VM does not have code for '" << fn << "'");
}
@ -2982,8 +2947,9 @@ void vm_state::invoke_fn(unsigned fn_idx) {
}
vm_obj vm_state::get_constant(name const & cname) {
if (auto fn_idx = m_fn_name2idx.find(cname)) {
vm_decl d = get_decl(*fn_idx);
auto fn_idx = get_vm_index(cname);
if (m_decl_map.contains(fn_idx)) {
vm_decl d = get_decl(fn_idx);
if (d.get_arity() == 0) {
DEBUG_CODE(unsigned stack_sz = m_stack.size(););
unsigned saved_pc = m_pc;
@ -2995,7 +2961,7 @@ vm_obj vm_state::get_constant(name const & cname) {
lean_assert(m_stack.size() == stack_sz);
return r;
} else {
return mk_vm_closure(*fn_idx, 0, nullptr);
return mk_vm_closure(fn_idx, 0, nullptr);
}
} else {
throw exception(sstream() << "VM does not have code for '" << cname << "'");
@ -3019,13 +2985,13 @@ void vm_state::apply(unsigned n) {
}
void vm_state::display(std::ostream & out, vm_obj const & o) const {
::lean::display(out, o,
[&](unsigned idx) { return optional<name>(get_decl(idx).get_name()); });
::lean::display(out, o);
}
optional<vm_decl> vm_state::get_decl(name const & n) const {
if (auto idx = m_fn_name2idx.find(n))
return optional<vm_decl>(get_decl(*idx));
auto idx = get_vm_index(n);
if (m_decl_map.contains(idx))
return optional<vm_decl>(get_decl(idx));
else
return optional<vm_decl>();
}
@ -3177,25 +3143,9 @@ void vm_state::profiler::snapshots::display(std::ostream & out) const {
}
void display_vm_code(std::ostream & out, environment const & env, unsigned code_sz, vm_instr const * code) {
vm_decls const & ext = get_extension(env);
auto idx2name = [&](unsigned idx) {
if (idx < ext.m_decls.size()) {
return optional<name>(ext.m_decls.find(idx)->get_name());
} else {
return optional<name>();
}
};
auto cases2name = [&](unsigned idx) {
if (idx < ext.m_cases_names.size()) {
return optional<name>(*ext.m_cases_names.find(idx));
} else {
return optional<name>();
}
};
for (unsigned i = 0; i < code_sz; i++) {
out << i << ": ";
code[i].display(out, idx2name, cases2name);
code[i].display(out);
out << "\n";
}
}
@ -3295,7 +3245,70 @@ static void vm_monitor_reader(deserializer & d, environment & env) {
env = update(env, ext);
}
class vm_index_manager {
shared_mutex m_mutex;
std::unordered_map<name, unsigned, name_hash> m_name2idx;
std::vector<name> m_idx2name;
public:
unsigned get_index(name const & n) {
{
shared_lock lock(m_mutex);
auto it = m_name2idx.find(n);
if (it != m_name2idx.end())
return it->second;
}
{
exclusive_lock lock(m_mutex);
auto it = m_name2idx.find(n);
if (it != m_name2idx.end()) {
return it->second;
} else {
auto i = static_cast<unsigned>(m_idx2name.size());
m_idx2name.push_back(n);
m_name2idx[n] = i;
return i;
}
}
}
unsigned get_index_bound() {
shared_lock _(m_mutex);
return static_cast<unsigned>(m_idx2name.size());
}
name const & get_name(unsigned idx) {
shared_lock lock(m_mutex);
lean_assert(idx < m_idx2name.size());
return m_idx2name.at(idx);
}
optional<name> find_name(unsigned idx) {
shared_lock lock(m_mutex);
if (idx < m_idx2name.size()) {
return optional<name>(m_idx2name.at(idx));
} else {
return optional<name>();
}
}
};
static vm_index_manager * g_vm_index_manager = nullptr;
unsigned get_vm_index(name const & n) {
return g_vm_index_manager->get_index(n);
}
unsigned get_vm_index_bound() {
return g_vm_index_manager->get_index_bound();
}
name const & get_vm_name(unsigned idx) {
return g_vm_index_manager->get_name(idx);
}
optional<name> find_vm_name(unsigned idx) {
return g_vm_index_manager->find_name(idx);
}
void initialize_vm_core() {
g_vm_index_manager = new vm_index_manager;
g_vm_builtins = new name_map<std::tuple<unsigned, char const *, vm_function>>();
g_vm_cbuiltins = new name_map<std::tuple<unsigned, char const *, vm_cfunction>>();
g_vm_cases_builtins = new name_map<std::tuple<char const *, vm_cases_function>>();
@ -3311,6 +3324,7 @@ void finalize_vm_core() {
delete g_vm_builtins;
delete g_vm_cbuiltins;
delete g_vm_cases_builtins;
delete g_vm_index_manager;
}
void initialize_vm() {

View file

@ -49,7 +49,6 @@ public:
#define LEAN_VM_BOX(num) (reinterpret_cast<vm_obj_cell*>((num << 1) | 1))
#define LEAN_VM_UNBOX(obj) (reinterpret_cast<size_t>(obj) >> 1)
void display(std::ostream & out, vm_obj const & o, std::function<optional<name>(unsigned)> const & idx2name);
void display(std::ostream & out, vm_obj const & o);
/** \brief VM object */
@ -454,9 +453,6 @@ public:
unsigned get_pc(unsigned i) const;
void set_pc(unsigned i, unsigned pc);
void display(std::ostream & out,
std::function<optional<name>(unsigned)> const & idx2name,
std::function<optional<name>(unsigned)> const & cases_idx2name) const;
void display(std::ostream & out) const;
void serialize(serializer & s, std::function<name(unsigned)> const & idx2name) const;
@ -573,8 +569,6 @@ class vm_state {
cache_vector m_cache_vector; /* for 0-ary declarations */
builtin_cases_map m_builtin_cases_map;
builtin_cases_vector m_builtin_cases_vector;
unsigned_map<name> m_builtin_cases_names;
name_map<unsigned> m_fn_name2idx;
vm_instr const * m_code; /* code of the current function being executed */
unsigned m_fn_idx; /* function idx being executed */
unsigned m_pc; /* program counter */
@ -832,6 +826,11 @@ environment add_native(environment const & env, name const & n, vm_cfunction_7 f
environment add_native(environment const & env, name const & n, vm_cfunction_8 fn);
environment add_native(environment const & env, name const & n, unsigned arity, vm_cfunction_N fn);
unsigned get_vm_index(name const & n);
unsigned get_vm_index_bound();
name const & get_vm_name(unsigned idx);
optional<name> find_vm_name(unsigned idx);
/** \brief Reserve an index for the given function in the VM, the expression
\c e is the value of \c fn after preprocessing.
See library/compiler/pre_proprocess_rec.cpp for details. */