diff --git a/src/library/compiler/vm_compiler.cpp b/src/library/compiler/vm_compiler.cpp index e6852873aa..60ce7977fb 100644 --- a/src/library/compiler/vm_compiler.cpp +++ b/src/library/compiler/vm_compiler.cpp @@ -40,8 +40,17 @@ class vm_compiler_fn { } } + void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + unsigned i = nargs; + while (i > 0) { + --i; + compile(args[i], bpz, m); + bpz++; + } + } + void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { - compile_args(nargs, args, bpz, m); + compile_rev_args(nargs, args, bpz, m); lean_assert(nargs <= decl.get_arity()); if (decl.get_arity() == nargs) { if (decl.is_builtin()) @@ -157,21 +166,20 @@ class vm_compiler_fn { lean_assert(is_internal_proj(fn)); unsigned idx = *is_internal_proj(fn); lean_assert(args.size() >= 1); + compile_rev_args(args.size() - 1, args.data() + 1, bpz, m); + bpz += args.size() - 1; compile(args[0], bpz, m); emit(mk_proj_instr(idx)); - if (args.size() > 1) { - bpz++; /* function returned by proj is on the stack */ - compile_args(args.size() - 1, args.data() + 1, bpz, m); + if (args.size() > 1) emit(mk_invoke_instr(args.size() - 1)); - } } void compile_fn_call(expr const & e, unsigned bpz, name_map const & m) { buffer args; expr fn = get_app_args(e, args); if (!is_constant(fn)) { + compile_rev_args(args.size(), args.data(), bpz+1, m); compile(fn, bpz, m); - compile_args(args.size(), args.data(), bpz+1, m); emit(mk_invoke_instr(args.size())); return; } else if (is_constant(fn)) { @@ -239,17 +247,29 @@ class vm_compiler_fn { } } + unsigned get_arity(expr e) { + unsigned r = 0; + while (is_lambda(e)) { + r++; + e = binding_body(e); + } + return r; + } + public: vm_compiler_fn(environment const & env, buffer & code): m_env(env), m_code(code) {} - void operator()(expr e) { + unsigned operator()(expr e) { buffer locals; - unsigned bpz = 0; + unsigned bpz = 0; + unsigned arity = get_arity(e); + unsigned i = arity; name_map m; while (is_lambda(e)) { name n = mk_fresh_name(); - m.insert(n, bpz); + i--; + m.insert(n, i); locals.push_back(mk_local(n)); bpz++; e = binding_body(e); @@ -257,6 +277,7 @@ public: e = instantiate_rev(e, locals.size(), locals.data()); compile(e, bpz, m); emit(mk_ret_instr()); + return arity; } }; @@ -268,9 +289,9 @@ environment vm_compile(environment const & env, buffer> const & for (auto const & p : procs) { buffer code; vm_compiler_fn gen(new_env, code); - gen(p.second); + unsigned arity = gen(p.second); optimize(new_env, code); - lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << "\n"; + lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << " " << arity << "\n"; display_vm_code(tout().get_stream(), new_env, code.size(), code.data());); new_env = update_vm_code(new_env, p.first, code.size(), code.data()); } diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index 3ea5d7fa44..75712ecb04 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -598,7 +598,7 @@ void vm_state::push_fields(vm_obj const & obj) { void vm_state::invoke_builtin(vm_decl const & d) { unsigned saved_bp = m_bp; unsigned sz = m_stack.size(); - m_bp = sz - d.get_arity(); + m_bp = sz; d.get_fn()(*this); lean_assert(m_stack.size() == sz + 1); m_bp = saved_bp; @@ -625,11 +625,6 @@ void vm_state::invoke_global_builtin(vm_decl const & d) { } void vm_state::run() { - /* - TODO(Leo): we can improve performance using the following tricks: - - Function arguments in reverse order. - - Function pointer after arguments. - */ lean_assert(m_code); unsigned init_call_stack_sz = m_call_stack.size(); m_pc = 0; @@ -767,27 +762,15 @@ void vm_state::run() { case opcode::Invoke: { unsigned nargs = instr.get_num(); unsigned sz = m_stack.size(); - vm_obj closure = m_stack[sz - nargs - 1]; + vm_obj closure = m_stack.back(); + m_stack.pop_back(); unsigned fn_idx = cfn_idx(closure); vm_decl const & d = m_decls[fn_idx]; unsigned csz = csize(closure); unsigned arity = d.get_arity(); unsigned new_nargs = nargs + csz; lean_assert(new_nargs <= arity); - if (csz == 0) { - /* Closure has size 0, then we just move arguments down 1 position */ - m_stack.erase(m_stack.end() - nargs - 1); /* remove closure object */ - } else if (csz == 1) { - /* Closure has size 1, then we replace the closure object with - the data stored in the closure */ - *(m_stack.end() - nargs - 1) = cfield(closure, 0); - } else { - lean_assert(csz > 1); - /* Closure has size > 1, then we need to open space on the stack */ - m_stack.resize(sz + csz - 1); - std::move_backward(m_stack.end() - nargs + 1 - csz, m_stack.end() + 1 - csz, m_stack.end()); - std::copy(cfields(closure), cfields(closure) + csz, m_stack.end() - nargs - csz); - } + std::copy(cfields(closure), cfields(closure) + csz, m_stack.end()); /* Now, stack contains closure arguments + original stack arguments */ if (new_nargs == arity) { invoke_global_builtin(d); diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index f6cc63ecf8..ea2877d060 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -420,10 +420,16 @@ public: vm_state(environment const & env); environment const & env() const { return m_env; } + /** \brief Push object into the data stack */ void push(vm_obj const & o) { m_stack.push_back(o); } + /** \brief Retrieve object from the call stack */ - vm_obj const & get(unsigned idx) const { lean_assert(m_bp + idx < m_stack.size()); return m_stack[m_bp + idx]; } + vm_obj const & get(int idx) const { + lean_assert(idx + static_cast(m_bp) >= 0); + lean_assert(m_bp + idx < m_stack.size()); + return m_stack[m_bp + idx]; + } void invoke_global(name const & fn); void invoke_global(unsigned fn_idx); diff --git a/src/library/vm/vm_nat.cpp b/src/library/vm/vm_nat.cpp index dd854954a1..ff92463952 100644 --- a/src/library/vm/vm_nat.cpp +++ b/src/library/vm/vm_nat.cpp @@ -49,7 +49,7 @@ static mpz const & to_mpz2(vm_obj const & o) { } static void nat_succ(vm_state & s) { - vm_obj const & a = s.get(0); + vm_obj const & a = s.get(-1); if (is_simple(a)) { s.push(mk_vm_nat(cidx(a) + 1)); } else { @@ -58,8 +58,8 @@ static void nat_succ(vm_state & s) { } static void nat_add(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { s.push(mk_vm_nat(cidx(a1) + cidx(a2))); } else { @@ -68,8 +68,8 @@ static void nat_add(vm_state & s) { } static void nat_mul(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { unsigned long long r = static_cast(cidx(a1)) * static_cast(cidx(a2)); if (r < LEAN_MAX_SMALL_NAT) { @@ -81,8 +81,8 @@ static void nat_mul(vm_state & s) { } static void nat_sub(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { unsigned v1 = cidx(a1); unsigned v2 = cidx(a2); @@ -101,8 +101,8 @@ static void nat_sub(vm_state & s) { } static void nat_div(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { unsigned v1 = cidx(a1); unsigned v2 = cidx(a2); @@ -121,8 +121,8 @@ static void nat_div(vm_state & s) { } static void nat_mod(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { unsigned v1 = cidx(a1); unsigned v2 = cidx(a2); @@ -141,16 +141,16 @@ static void nat_mod(vm_state & s) { } static void nat_gcd(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); mpz r; gcd(r, to_mpz1(a1), to_mpz2(a2)); s.push(mk_vm_nat(r)); } static void nat_has_decidable_eq(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { return s.push(mk_vm_bool(cidx(a1) == cidx(a2))); } else { @@ -159,8 +159,8 @@ static void nat_has_decidable_eq(vm_state & s) { } static void nat_decidable_le(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { return s.push(mk_vm_bool(cidx(a1) <= cidx(a2))); } else { @@ -169,8 +169,8 @@ static void nat_decidable_le(vm_state & s) { } static void nat_decidable_lt(vm_state & s) { - vm_obj const & a1 = s.get(0); - vm_obj const & a2 = s.get(1); + vm_obj const & a1 = s.get(-1); + vm_obj const & a2 = s.get(-2); if (is_simple(a1) && is_simple(a2)) { return s.push(mk_vm_bool(cidx(a1) < cidx(a2))); } else {