diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index 9054d0cb4f..d31959e7df 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -543,13 +543,10 @@ static vm_instr read_vm_instr(deserializer & d, name_map const & name2 } vm_decl_cell::vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_function fn): - m_rc(0), m_kind(vm_decl_kind::Builtin), m_name(n), m_idx(idx), m_arity(arity), m_fn(fn), m_cfn(nullptr) {} + m_rc(0), m_kind(vm_decl_kind::Builtin), m_name(n), m_idx(idx), m_arity(arity), m_fn(fn) {} vm_decl_cell::vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_cfunction fn): - m_rc(0), m_kind(vm_decl_kind::CFun), m_name(n), m_idx(idx), m_arity(arity), m_fn(nullptr), m_cfn(fn) {} - -vm_decl_cell::vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_function fn1, vm_cfunction fn2): - m_rc(0), m_kind(vm_decl_kind::BuiltinCFun), m_name(n), m_idx(idx), m_arity(arity), m_fn(fn1), m_cfn(fn2) {} + m_rc(0), m_kind(vm_decl_kind::CFun), m_name(n), m_idx(idx), m_arity(arity), m_cfn(fn) {} vm_decl_cell::vm_decl_cell(name const & n, unsigned idx, expr const & e, unsigned code_sz, vm_instr const * code): m_rc(0), m_kind(vm_decl_kind::Bytecode), m_name(n), m_idx(idx), m_expr(e), m_arity(0), @@ -683,6 +680,54 @@ static environment update(environment const & env, vm_decls const & ext) { return env.update(g_ext->m_ext_id, std::make_shared(ext)); } +static environment declare_vm_builtin(environment const & env, name const & n, unsigned arity, vm_cfunction fn) { + auto ext = get_extension(env); + if (auto idx = ext.m_name2idx.find(n)) { + vm_decl d = ext.m_decls[*idx]; + lean_assert(d.get_arity() == arity); + ext.m_decls.set(*idx, vm_decl(n, *idx, arity, fn)); + } else { + ext.add(vm_decl(n, ext.m_decls.size(), arity, fn)); + } + return update(env, ext); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_1 fn) { + return declare_vm_builtin(env, n, 1, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_2 fn) { + return declare_vm_builtin(env, n, 2, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_3 fn) { + return declare_vm_builtin(env, n, 3, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_4 fn) { + return declare_vm_builtin(env, n, 4, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_5 fn) { + return declare_vm_builtin(env, n, 5, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_6 fn) { + return declare_vm_builtin(env, n, 6, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_7 fn) { + return declare_vm_builtin(env, n, 7, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_8 fn) { + return declare_vm_builtin(env, n, 8, reinterpret_cast(fn)); +} + +environment declare_vm_builtin(environment const & env, name const & n, unsigned arity, vm_cfunction_N fn) { + return declare_vm_builtin(env, n, arity, reinterpret_cast(fn)); +} + bool is_vm_function(environment const & env, name const & fn) { auto const & ext = get_extension(env); return ext.m_name2idx.contains(fn) || g_vm_builtins->contains(fn); @@ -1393,7 +1438,7 @@ void vm_state::invoke(vm_decl const & d) { switch (d.kind()) { case vm_decl_kind::Bytecode: invoke_global(d); break; - case vm_decl_kind::Builtin: case vm_decl_kind::BuiltinCFun: + case vm_decl_kind::Builtin: invoke_builtin(d); break; case vm_decl_kind::CFun: invoke_cfun(d); break; diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index 67bb37cac9..ecafc0e3b6 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -376,7 +376,7 @@ vm_instr mk_closure_instr(unsigned fn_idx, unsigned n); class vm_state; class vm_instr; -enum class vm_decl_kind { Bytecode, Builtin, CFun, BuiltinCFun }; +enum class vm_decl_kind { Bytecode, Builtin, CFun }; /** \brief VM function/constant declaration cell */ struct vm_decl_cell { @@ -391,14 +391,11 @@ struct vm_decl_cell { unsigned m_code_size; vm_instr * m_code; }; - struct { - vm_function m_fn; - vm_cfunction m_cfn; - }; + vm_function m_fn; + vm_cfunction m_cfn; }; vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_function fn); vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_cfunction fn); - vm_decl_cell(name const & n, unsigned idx, unsigned arity, vm_function fn1, vm_cfunction fn2); vm_decl_cell(name const & n, unsigned idx, expr const & e, unsigned code_sz, vm_instr const * code); ~vm_decl_cell(); void dealloc(); @@ -414,8 +411,6 @@ public: vm_decl(new vm_decl_cell(n, idx, arity, fn)) {} vm_decl(name const & n, unsigned idx, unsigned arity, vm_cfunction fn): vm_decl(new vm_decl_cell(n, idx, arity, fn)) {} - vm_decl(name const & n, unsigned idx, unsigned arity, vm_function fn1, vm_cfunction fn2): - vm_decl(new vm_decl_cell(n, idx, arity, fn1, fn2)) {} vm_decl(name const & n, unsigned idx, expr const & e, unsigned code_sz, vm_instr const * code): vm_decl(new vm_decl_cell(n, idx, e, code_sz, code)) {} vm_decl(vm_decl const & s):m_ptr(s.m_ptr) { if (m_ptr) m_ptr->inc_ref(); } @@ -429,14 +424,8 @@ public: vm_decl_kind kind() const { return m_ptr->m_kind; } bool is_bytecode() const { lean_assert(m_ptr); return m_ptr->m_kind == vm_decl_kind::Bytecode; } - bool is_builtin() const { - lean_assert(m_ptr); - return m_ptr->m_kind == vm_decl_kind::Builtin || m_ptr->m_kind == vm_decl_kind::BuiltinCFun; - } - bool is_cfun() const { - lean_assert(m_ptr); - return m_ptr->m_kind == vm_decl_kind::CFun || m_ptr->m_kind == vm_decl_kind::BuiltinCFun; - } + bool is_builtin() const { lean_assert(m_ptr); return m_ptr->m_kind == vm_decl_kind::Builtin; } + bool is_cfun() const { lean_assert(m_ptr); return m_ptr->m_kind == vm_decl_kind::CFun; } unsigned get_idx() const { lean_assert(m_ptr); return m_ptr->m_idx; } name get_name() const { lean_assert(m_ptr); return m_ptr->m_name; } unsigned get_arity() const { lean_assert(m_ptr); return m_ptr->m_arity; } @@ -531,7 +520,8 @@ public: }; /** \brief Add builtin implementation for the function named \c n. - All environment objects will contain this builtin. */ + All environment objects will contain this builtin. + \pre These procedures can only be invoked at initialization time. */ void declare_vm_builtin(name const & n, unsigned arity, vm_function fn); void declare_vm_builtin(name const & n, vm_cfunction_1 fn); void declare_vm_builtin(name const & n, vm_cfunction_2 fn); @@ -543,6 +533,17 @@ void declare_vm_builtin(name const & n, vm_cfunction_7 fn); void declare_vm_builtin(name const & n, vm_cfunction_8 fn); void declare_vm_builtin(name const & n, unsigned arity, vm_cfunction_N fn); +/** Register in the given environment \c fn as the implementation for function \c n. */ +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_1 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_2 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_3 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_4 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_5 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_6 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_7 fn); +environment declare_vm_builtin(environment const & env, name const & n, vm_cfunction_8 fn); +environment declare_vm_builtin(environment const & env, name const & n, unsigned arity, vm_cfunction_N fn); + /** \brief Reserve an index for the given function in the VM, the expression \c e is the value of \c fn after preprocessing. See library/compiler/pre_proprocess_rec.cpp for details. */