From 41e8a1712e0a0ae69284a73767829263e92f3a10 Mon Sep 17 00:00:00 2001 From: Gabriel Ebner Date: Sun, 4 Dec 2016 08:49:28 -0500 Subject: [PATCH] refactor(library/vm): use global indices for declarations and cases --- src/library/vm/vm.cpp | 250 ++++++++++++++++++++++-------------------- src/library/vm/vm.h | 11 +- 2 files changed, 137 insertions(+), 124 deletions(-) diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index b4f3fbc9d1..e63ce414bb 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -16,6 +16,7 @@ Author: Leonardo de Moura #include "util/sstream.h" #include "util/small_object_allocator.h" #include "util/sexpr/option_declarations.h" +#include "util/shared_mutex.h" #include "library/constants.h" #include "library/kernel_serializer.h" #include "library/trace.h" @@ -272,14 +273,14 @@ void vm_obj_cell::dealloc() { } } -void display(std::ostream & out, vm_obj const & o, std::function(unsigned)> const & idx2name) { +void display(std::ostream & out, vm_obj const & o) { if (is_simple(o)) { out << cidx(o); } else if (is_constructor(o)) { out << "(#" << cidx(o); for (unsigned i = 0; i < csize(o); i++) { out << " "; - display(out, cfield(o, i), idx2name); + display(out, cfield(o, i)); } out << ")"; } else if (is_mpz(o)) { @@ -287,14 +288,14 @@ void display(std::ostream & out, vm_obj const & o, std::function( } else if (is_external(o)) { out << "[external]"; } else if (is_closure(o)) { - if (auto n = idx2name(cfn_idx(o))) { + if (auto n = find_vm_name(cfn_idx(o))) { out << "(" << *n; } else { out << "(fn#" << cfn_idx(o); } for (unsigned i = 0; i < csize(o); i++) { out << " "; - display(out, cfield(o, i), idx2name); + display(out, cfield(o, i)); } out << ")"; } else if (is_native_closure(o)) { @@ -302,7 +303,7 @@ void display(std::ostream & out, vm_obj const & o, std::function( vm_obj const * args = to_native_closure(o)->get_args(); for (unsigned i = 0; i < to_native_closure(o)->get_num_args(); i++) { out << " "; - display(out, args[i], idx2name); + display(out, args[i]); } out << ")"; } else { @@ -310,24 +311,18 @@ void display(std::ostream & out, vm_obj const & o, std::function( } } -void display(std::ostream & out, vm_obj const & o) { - display(out, o, [](unsigned) { return optional(); }); -} - -static void display_fn(std::ostream & out, std::function(unsigned)> const & idx2name, unsigned fn_idx) { - if (auto r = idx2name(fn_idx)) +static void display_fn(std::ostream & out, unsigned fn_idx) { + if (auto r = find_vm_name(fn_idx)) out << *r; else out << fn_idx; } -static void display_builtin_cases(std::ostream & out, std::function(unsigned)> const & idx2name, unsigned cases_idx) { - display_fn(out, idx2name, cases_idx); +static void display_builtin_cases(std::ostream & out, unsigned cases_idx) { + display_fn(out, cases_idx); } -void vm_instr::display(std::ostream & out, - std::function(unsigned)> const & idx2name, - std::function(unsigned)> const & cases_idx2name) const { +void vm_instr::display(std::ostream & out) const { switch (m_op) { case opcode::Push: out << "push " << m_idx; break; case opcode::Ret: out << "ret"; break; @@ -346,7 +341,7 @@ void vm_instr::display(std::ostream & out, break; case opcode::BuiltinCases: out << "builtin_cases "; - display_builtin_cases(out, cases_idx2name, get_cases_idx()); + display_builtin_cases(out, get_cases_idx()); out << ","; for (unsigned i = 0; i < get_casesn_size(); i++) out << " " << get_casesn_pc(i); @@ -356,19 +351,19 @@ void vm_instr::display(std::ostream & out, case opcode::Apply: out << "apply"; break; case opcode::InvokeGlobal: out << "ginvoke "; - display_fn(out, idx2name, m_fn_idx); + display_fn(out, m_fn_idx); break; case opcode::InvokeBuiltin: out << "builtin "; - display_fn(out, idx2name, m_fn_idx); + display_fn(out, m_fn_idx); break; case opcode::InvokeCFun: out << "cfun "; - display_fn(out, idx2name, m_fn_idx); + display_fn(out, m_fn_idx); break; case opcode::Closure: out << "closure "; - display_fn(out, idx2name, m_fn_idx); + display_fn(out, m_fn_idx); out << " " << m_nargs; break; case opcode::Pexpr: @@ -378,10 +373,6 @@ void vm_instr::display(std::ostream & out, } } -void vm_instr::display(std::ostream & out) const { - display(out, [](unsigned) { return optional(); }, [](unsigned) { return optional(); }); -} - unsigned vm_instr::get_num_pcs() const { switch (m_op) { case opcode::Goto: @@ -723,13 +714,10 @@ void vm_instr::serialize(serializer & s, std::function const & i } } -static unsigned read_fn_idx(deserializer & d, name_map const & name2idx) { +static unsigned read_fn_idx(deserializer & d) { name n; d >> n; - if (auto r = name2idx.find(n)) - return *r; - else - throw corrupted_stream_exception(); + return get_vm_index(n); } static void read_cases_pcs(deserializer & d, buffer & pcs) { @@ -738,18 +726,18 @@ static void read_cases_pcs(deserializer & d, buffer & pcs) { pcs.push_back(d.read_unsigned()); } -static vm_instr read_vm_instr(deserializer & d, name_map const & name2idx) { +static vm_instr read_vm_instr(deserializer & d) { opcode op = static_cast(d.read_char()); unsigned pc, idx; switch (op) { case opcode::InvokeGlobal: - return mk_invoke_global_instr(read_fn_idx(d, name2idx)); + return mk_invoke_global_instr(read_fn_idx(d)); case opcode::InvokeBuiltin: - return mk_invoke_builtin_instr(read_fn_idx(d, name2idx)); + return mk_invoke_builtin_instr(read_fn_idx(d)); case opcode::InvokeCFun: - return mk_invoke_cfun_instr(read_fn_idx(d, name2idx)); + return mk_invoke_cfun_instr(read_fn_idx(d)); case opcode::Closure: - idx = read_fn_idx(d, name2idx); + idx = read_fn_idx(d); return mk_closure_instr(idx, d.read_unsigned()); case opcode::Push: return mk_push_instr(d.read_unsigned()); @@ -906,58 +894,40 @@ void declare_vm_cases_builtin(name const & n, char const * i, vm_cases_function /** \brief VM function/constant declarations are stored in an environment extension. */ struct vm_decls : public environment_extension { - name_map m_name2idx; - unsigned_map m_decls; - unsigned m_next_decl_idx{0}; - - name_map m_cases2idx; + unsigned_map m_decls; unsigned_map m_cases; - unsigned_map m_cases_names; - unsigned m_next_cases_idx{0}; name m_monitor; vm_decls() { g_vm_builtins->for_each([&](name const & n, std::tuple const & p) { - add_core(vm_decl(n, m_next_decl_idx, std::get<0>(p), std::get<2>(p))); - m_next_decl_idx++; + add_core(vm_decl(n, get_vm_index(n), std::get<0>(p), std::get<2>(p))); }); g_vm_cbuiltins->for_each([&](name const & n, std::tuple const & p) { - add_core(vm_decl(n, m_next_decl_idx, std::get<0>(p), std::get<2>(p))); - m_next_decl_idx++; + add_core(vm_decl(n, get_vm_index(n), std::get<0>(p), std::get<2>(p))); }); g_vm_cases_builtins->for_each([&](name const & n, std::tuple const & p) { - unsigned idx = m_next_cases_idx; - m_cases2idx.insert(n, idx); + unsigned idx = get_vm_index(n); m_cases.insert(idx, std::get<1>(p)); - m_cases_names.insert(idx, n); - m_next_cases_idx++; }); } void add_core(vm_decl const & d) { - if (m_name2idx.contains(d.get_name())) + if (m_decls.contains(d.get_idx())) throw exception(sstream() << "VM already contains code for '" << d.get_name() << "'"); - m_name2idx.insert(d.get_name(), d.get_idx()); m_decls.insert(d.get_idx(), d); } void add_native(name const & n, unsigned arity, vm_cfunction fn) { - if (auto idx = m_name2idx.find(n)) { - lean_assert(m_decls.find(*idx)->get_arity() == arity); - m_decls.insert(*idx, vm_decl(n, *idx, arity, fn)); - } else { - add_core(vm_decl(n, m_next_decl_idx, arity, fn)); - m_next_decl_idx++; - } + auto idx = get_vm_index(n); + DEBUG_CODE(if (auto decl = m_decls.find(idx)) lean_assert(decl->get_arity() == arity);) + m_decls.insert(idx, vm_decl(n, idx, arity, fn)); } unsigned reserve(name const & n, expr const & e) { - if (m_name2idx.contains(n)) + unsigned idx = get_vm_index(n); + if (m_decls.contains(idx)) throw exception(sstream() << "VM already contains code for '" << n << "'"); - unsigned idx = m_next_decl_idx; - m_next_decl_idx++; - m_name2idx.insert(n, idx); m_decls.insert(idx, vm_decl(n, idx, e, 0, nullptr, list(), optional())); return idx; } @@ -965,9 +935,9 @@ struct vm_decls : public environment_extension { void update(name const & n, unsigned code_sz, vm_instr const * code, list const & args_info, optional const & pos, optional const & olean = optional()) { - lean_assert(m_name2idx.contains(n)); - unsigned idx = *m_name2idx.find(n); + unsigned idx = get_vm_index(n); vm_decl const * d = m_decls.find(idx); + lean_assert(d); m_decls.insert(idx, vm_decl(n, idx, d->get_expr(), code_sz, code, args_info, pos, olean)); } }; @@ -1037,21 +1007,23 @@ environment add_native(environment const & env, name const & n, unsigned arity, bool is_vm_function(environment const & env, name const & fn) { auto const & ext = get_extension(env); - return ext.m_name2idx.contains(fn) || g_vm_builtins->contains(fn); + return ext.m_decls.contains(get_vm_index(fn)) || g_vm_builtins->contains(fn); } optional get_vm_constant_idx(environment const & env, name const & n) { auto const & ext = get_extension(env); - if (auto r = ext.m_name2idx.find(n)) - return optional(*r); + auto idx = get_vm_index(n); + if (ext.m_decls.contains(idx)) + return optional(idx); else return optional(); } optional get_vm_builtin_idx(name const & n) { lean_assert(g_ext); - if (auto r = g_ext->m_init_decls->m_name2idx.find(n)) - return optional(*r); + auto idx = get_vm_index(n); + if (g_ext->m_init_decls->m_decls.contains(idx)) + return optional(idx); else return optional(); } @@ -1094,7 +1066,7 @@ static void code_reader(deserializer & d, environment & env) { vm_decls ext = get_extension(env); buffer code; for (unsigned i = 0; i < code_sz; i++) { - code.push_back(read_vm_instr(d, ext.m_name2idx)); + code.push_back(read_vm_instr(d)); } ext.update(fn, code_sz, code.data(), args_info, pos, d.get_fname()); env = update(env, ext); @@ -1105,7 +1077,7 @@ environment update_vm_code(environment const & env, name const & fn, unsigned co vm_decls ext = get_extension(env); ext.update(fn, code_sz, code, args_info, pos); environment new_env = update(env, ext); - unsigned fidx = *ext.m_name2idx.find(fn); + unsigned fidx = get_vm_index(fn); return module::add(new_env, *g_vm_code_key, [=](environment const & env, serializer & s) { serialize_code(s, fidx, get_extension(env).m_decls); }); @@ -1119,16 +1091,17 @@ environment add_vm_code(environment const & env, name const & fn, expr const & e optional get_vm_decl(environment const & env, name const & n) { vm_decls const & ext = get_extension(env); - if (auto idx = ext.m_name2idx.find(n)) - return optional(*ext.m_decls.find(*idx)); + if (auto decl = ext.m_decls.find(get_vm_index(n))) + return optional(*decl); else return optional(); } optional get_vm_builtin_cases_idx(environment const & env, name const & n) { vm_decls const & ext = get_extension(env); - if (auto idx = ext.m_cases2idx.find(n)) - return optional(*idx); + auto idx = get_vm_index(n); + if (ext.m_cases.contains(idx)) + return optional(idx); else return optional(); } @@ -1149,11 +1122,9 @@ vm_state::vm_state(environment const & env, options const & opts): m_env(env), m_options(opts), m_decl_map(get_extension(m_env).m_decls), - m_decl_vector(get_extension(m_env).m_next_decl_idx), + m_decl_vector(get_vm_index_bound()), m_builtin_cases_map(get_extension(m_env).m_cases), - m_builtin_cases_vector(get_extension(m_env).m_next_cases_idx), - m_builtin_cases_names(get_extension(m_env).m_cases_names), - m_fn_name2idx(get_extension(m_env).m_name2idx), + m_builtin_cases_vector(get_vm_index_bound()), m_code(nullptr), m_fn_idx(g_null_fn_idx), m_bp(0) { @@ -1220,9 +1191,7 @@ void vm_state::update_env(environment const & env) { m_env = env; auto ext = get_extension(env); m_decl_map = ext.m_decls; - lean_assert(ext.m_next_decl_idx >= m_decl_vector.size()); - m_decl_vector.resize(ext.m_next_decl_idx); - m_fn_name2idx = ext.m_name2idx; + m_decl_vector.resize(get_vm_index_bound()); lean_assert(is_eqp(m_builtin_cases_map, ext.m_cases)); } @@ -1404,8 +1373,9 @@ vm_obj vm_state::invoke(unsigned fn_idx, unsigned nargs, vm_obj const * as) { } vm_obj vm_state::invoke(name const & fn, unsigned nargs, vm_obj const * as) { - if (auto r = m_fn_name2idx.find(fn)) { - return invoke(*r, nargs, as); + auto idx = get_vm_index(fn); + if (m_decl_map.contains(idx)) { + return invoke(idx, nargs, as); } else { throw exception(sstream() << "VM does not have code for '" << fn << "'"); } @@ -2524,13 +2494,7 @@ void vm_state::run() { /* We only trace VM in debug mode */ lean_trace(name({"vm", "run"}), tout() << m_pc << ": "; - instr.display(tout().get_stream(), - [&](unsigned idx) { - return optional(get_decl(idx).get_name()); - }, - [&](unsigned idx) { - return optional(*m_builtin_cases_names.find(idx)); - }); + instr.display(tout().get_stream()); tout() << "\n"; display_stack(tout().get_stream()); tout() << "\n";) @@ -2965,8 +2929,9 @@ void vm_state::run() { } void vm_state::invoke_fn(name const & fn) { - if (auto r = m_fn_name2idx.find(fn)) { - invoke_fn(*r); + auto idx = get_vm_index(fn); + if (m_decl_map.contains(idx)) { + invoke_fn(idx); } else { throw exception(sstream() << "VM does not have code for '" << fn << "'"); } @@ -2982,8 +2947,9 @@ void vm_state::invoke_fn(unsigned fn_idx) { } vm_obj vm_state::get_constant(name const & cname) { - if (auto fn_idx = m_fn_name2idx.find(cname)) { - vm_decl d = get_decl(*fn_idx); + auto fn_idx = get_vm_index(cname); + if (m_decl_map.contains(fn_idx)) { + vm_decl d = get_decl(fn_idx); if (d.get_arity() == 0) { DEBUG_CODE(unsigned stack_sz = m_stack.size();); unsigned saved_pc = m_pc; @@ -2995,7 +2961,7 @@ vm_obj vm_state::get_constant(name const & cname) { lean_assert(m_stack.size() == stack_sz); return r; } else { - return mk_vm_closure(*fn_idx, 0, nullptr); + return mk_vm_closure(fn_idx, 0, nullptr); } } else { throw exception(sstream() << "VM does not have code for '" << cname << "'"); @@ -3019,13 +2985,13 @@ void vm_state::apply(unsigned n) { } void vm_state::display(std::ostream & out, vm_obj const & o) const { - ::lean::display(out, o, - [&](unsigned idx) { return optional(get_decl(idx).get_name()); }); + ::lean::display(out, o); } optional vm_state::get_decl(name const & n) const { - if (auto idx = m_fn_name2idx.find(n)) - return optional(get_decl(*idx)); + auto idx = get_vm_index(n); + if (m_decl_map.contains(idx)) + return optional(get_decl(idx)); else return optional(); } @@ -3177,25 +3143,9 @@ void vm_state::profiler::snapshots::display(std::ostream & out) const { } void display_vm_code(std::ostream & out, environment const & env, unsigned code_sz, vm_instr const * code) { - vm_decls const & ext = get_extension(env); - auto idx2name = [&](unsigned idx) { - if (idx < ext.m_decls.size()) { - return optional(ext.m_decls.find(idx)->get_name()); - } else { - return optional(); - } - }; - auto cases2name = [&](unsigned idx) { - if (idx < ext.m_cases_names.size()) { - return optional(*ext.m_cases_names.find(idx)); - } else { - return optional(); - } - }; - for (unsigned i = 0; i < code_sz; i++) { out << i << ": "; - code[i].display(out, idx2name, cases2name); + code[i].display(out); out << "\n"; } } @@ -3295,7 +3245,70 @@ static void vm_monitor_reader(deserializer & d, environment & env) { env = update(env, ext); } +class vm_index_manager { + shared_mutex m_mutex; + std::unordered_map m_name2idx; + std::vector m_idx2name; + +public: + unsigned get_index(name const & n) { + { + shared_lock lock(m_mutex); + auto it = m_name2idx.find(n); + if (it != m_name2idx.end()) + return it->second; + } + { + exclusive_lock lock(m_mutex); + auto it = m_name2idx.find(n); + if (it != m_name2idx.end()) { + return it->second; + } else { + auto i = static_cast(m_idx2name.size()); + m_idx2name.push_back(n); + m_name2idx[n] = i; + return i; + } + } + } + + unsigned get_index_bound() { + shared_lock _(m_mutex); + return static_cast(m_idx2name.size()); + } + + name const & get_name(unsigned idx) { + shared_lock lock(m_mutex); + lean_assert(idx < m_idx2name.size()); + return m_idx2name.at(idx); + } + + optional find_name(unsigned idx) { + shared_lock lock(m_mutex); + if (idx < m_idx2name.size()) { + return optional(m_idx2name.at(idx)); + } else { + return optional(); + } + } +}; +static vm_index_manager * g_vm_index_manager = nullptr; + +unsigned get_vm_index(name const & n) { + return g_vm_index_manager->get_index(n); +} +unsigned get_vm_index_bound() { + return g_vm_index_manager->get_index_bound(); +} +name const & get_vm_name(unsigned idx) { + return g_vm_index_manager->get_name(idx); +} +optional find_vm_name(unsigned idx) { + return g_vm_index_manager->find_name(idx); +} + void initialize_vm_core() { + g_vm_index_manager = new vm_index_manager; g_vm_builtins = new name_map>(); g_vm_cbuiltins = new name_map>(); g_vm_cases_builtins = new name_map>(); @@ -3311,6 +3324,7 @@ void finalize_vm_core() { delete g_vm_builtins; delete g_vm_cbuiltins; delete g_vm_cases_builtins; + delete g_vm_index_manager; } void initialize_vm() { diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index 16102b04f7..77b9746e8c 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -49,7 +49,6 @@ public: #define LEAN_VM_BOX(num) (reinterpret_cast((num << 1) | 1)) #define LEAN_VM_UNBOX(obj) (reinterpret_cast(obj) >> 1) -void display(std::ostream & out, vm_obj const & o, std::function(unsigned)> const & idx2name); void display(std::ostream & out, vm_obj const & o); /** \brief VM object */ @@ -454,9 +453,6 @@ public: unsigned get_pc(unsigned i) const; void set_pc(unsigned i, unsigned pc); - void display(std::ostream & out, - std::function(unsigned)> const & idx2name, - std::function(unsigned)> const & cases_idx2name) const; void display(std::ostream & out) const; void serialize(serializer & s, std::function const & idx2name) const; @@ -573,8 +569,6 @@ class vm_state { cache_vector m_cache_vector; /* for 0-ary declarations */ builtin_cases_map m_builtin_cases_map; builtin_cases_vector m_builtin_cases_vector; - unsigned_map m_builtin_cases_names; - name_map m_fn_name2idx; vm_instr const * m_code; /* code of the current function being executed */ unsigned m_fn_idx; /* function idx being executed */ unsigned m_pc; /* program counter */ @@ -832,6 +826,11 @@ environment add_native(environment const & env, name const & n, vm_cfunction_7 f environment add_native(environment const & env, name const & n, vm_cfunction_8 fn); environment add_native(environment const & env, name const & n, unsigned arity, vm_cfunction_N fn); +unsigned get_vm_index(name const & n); +unsigned get_vm_index_bound(); +name const & get_vm_name(unsigned idx); +optional find_vm_name(unsigned idx); + /** \brief Reserve an index for the given function in the VM, the expression \c e is the value of \c fn after preprocessing. See library/compiler/pre_proprocess_rec.cpp for details. */