feat(library/vm/vm): store arguments in reverse order on the stack

It simplifies the code for handling closures.
This commit is contained in:
Leonardo de Moura 2016-05-13 10:54:29 -07:00
parent dbcd609aff
commit 2bd400964c
4 changed files with 62 additions and 52 deletions

View file

@ -40,8 +40,17 @@ class vm_compiler_fn {
}
}
void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
unsigned i = nargs;
while (i > 0) {
--i;
compile(args[i], bpz, m);
bpz++;
}
}
void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
compile_args(nargs, args, bpz, m);
compile_rev_args(nargs, args, bpz, m);
lean_assert(nargs <= decl.get_arity());
if (decl.get_arity() == nargs) {
if (decl.is_builtin())
@ -157,21 +166,20 @@ class vm_compiler_fn {
lean_assert(is_internal_proj(fn));
unsigned idx = *is_internal_proj(fn);
lean_assert(args.size() >= 1);
compile_rev_args(args.size() - 1, args.data() + 1, bpz, m);
bpz += args.size() - 1;
compile(args[0], bpz, m);
emit(mk_proj_instr(idx));
if (args.size() > 1) {
bpz++; /* function returned by proj is on the stack */
compile_args(args.size() - 1, args.data() + 1, bpz, m);
if (args.size() > 1)
emit(mk_invoke_instr(args.size() - 1));
}
}
void compile_fn_call(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
buffer<expr> args;
expr fn = get_app_args(e, args);
if (!is_constant(fn)) {
compile_rev_args(args.size(), args.data(), bpz+1, m);
compile(fn, bpz, m);
compile_args(args.size(), args.data(), bpz+1, m);
emit(mk_invoke_instr(args.size()));
return;
} else if (is_constant(fn)) {
@ -239,17 +247,29 @@ class vm_compiler_fn {
}
}
unsigned get_arity(expr e) {
unsigned r = 0;
while (is_lambda(e)) {
r++;
e = binding_body(e);
}
return r;
}
public:
vm_compiler_fn(environment const & env, buffer<vm_instr> & code):
m_env(env), m_code(code) {}
void operator()(expr e) {
unsigned operator()(expr e) {
buffer<expr> locals;
unsigned bpz = 0;
unsigned bpz = 0;
unsigned arity = get_arity(e);
unsigned i = arity;
name_map<unsigned> m;
while (is_lambda(e)) {
name n = mk_fresh_name();
m.insert(n, bpz);
i--;
m.insert(n, i);
locals.push_back(mk_local(n));
bpz++;
e = binding_body(e);
@ -257,6 +277,7 @@ public:
e = instantiate_rev(e, locals.size(), locals.data());
compile(e, bpz, m);
emit(mk_ret_instr());
return arity;
}
};
@ -268,9 +289,9 @@ environment vm_compile(environment const & env, buffer<pair<name, expr>> const &
for (auto const & p : procs) {
buffer<vm_instr> code;
vm_compiler_fn gen(new_env, code);
gen(p.second);
unsigned arity = gen(p.second);
optimize(new_env, code);
lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << "\n";
lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << " " << arity << "\n";
display_vm_code(tout().get_stream(), new_env, code.size(), code.data()););
new_env = update_vm_code(new_env, p.first, code.size(), code.data());
}

View file

@ -598,7 +598,7 @@ void vm_state::push_fields(vm_obj const & obj) {
void vm_state::invoke_builtin(vm_decl const & d) {
unsigned saved_bp = m_bp;
unsigned sz = m_stack.size();
m_bp = sz - d.get_arity();
m_bp = sz;
d.get_fn()(*this);
lean_assert(m_stack.size() == sz + 1);
m_bp = saved_bp;
@ -625,11 +625,6 @@ void vm_state::invoke_global_builtin(vm_decl const & d) {
}
void vm_state::run() {
/*
TODO(Leo): we can improve performance using the following tricks:
- Function arguments in reverse order.
- Function pointer after arguments.
*/
lean_assert(m_code);
unsigned init_call_stack_sz = m_call_stack.size();
m_pc = 0;
@ -767,27 +762,15 @@ void vm_state::run() {
case opcode::Invoke: {
unsigned nargs = instr.get_num();
unsigned sz = m_stack.size();
vm_obj closure = m_stack[sz - nargs - 1];
vm_obj closure = m_stack.back();
m_stack.pop_back();
unsigned fn_idx = cfn_idx(closure);
vm_decl const & d = m_decls[fn_idx];
unsigned csz = csize(closure);
unsigned arity = d.get_arity();
unsigned new_nargs = nargs + csz;
lean_assert(new_nargs <= arity);
if (csz == 0) {
/* Closure has size 0, then we just move arguments down 1 position */
m_stack.erase(m_stack.end() - nargs - 1); /* remove closure object */
} else if (csz == 1) {
/* Closure has size 1, then we replace the closure object with
the data stored in the closure */
*(m_stack.end() - nargs - 1) = cfield(closure, 0);
} else {
lean_assert(csz > 1);
/* Closure has size > 1, then we need to open space on the stack */
m_stack.resize(sz + csz - 1);
std::move_backward(m_stack.end() - nargs + 1 - csz, m_stack.end() + 1 - csz, m_stack.end());
std::copy(cfields(closure), cfields(closure) + csz, m_stack.end() - nargs - csz);
}
std::copy(cfields(closure), cfields(closure) + csz, m_stack.end());
/* Now, stack contains closure arguments + original stack arguments */
if (new_nargs == arity) {
invoke_global_builtin(d);

View file

@ -420,10 +420,16 @@ public:
vm_state(environment const & env);
environment const & env() const { return m_env; }
/** \brief Push object into the data stack */
void push(vm_obj const & o) { m_stack.push_back(o); }
/** \brief Retrieve object from the call stack */
vm_obj const & get(unsigned idx) const { lean_assert(m_bp + idx < m_stack.size()); return m_stack[m_bp + idx]; }
vm_obj const & get(int idx) const {
lean_assert(idx + static_cast<int>(m_bp) >= 0);
lean_assert(m_bp + idx < m_stack.size());
return m_stack[m_bp + idx];
}
void invoke_global(name const & fn);
void invoke_global(unsigned fn_idx);

View file

@ -49,7 +49,7 @@ static mpz const & to_mpz2(vm_obj const & o) {
}
static void nat_succ(vm_state & s) {
vm_obj const & a = s.get(0);
vm_obj const & a = s.get(-1);
if (is_simple(a)) {
s.push(mk_vm_nat(cidx(a) + 1));
} else {
@ -58,8 +58,8 @@ static void nat_succ(vm_state & s) {
}
static void nat_add(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
s.push(mk_vm_nat(cidx(a1) + cidx(a2)));
} else {
@ -68,8 +68,8 @@ static void nat_add(vm_state & s) {
}
static void nat_mul(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
unsigned long long r = static_cast<unsigned long long>(cidx(a1)) * static_cast<unsigned long long>(cidx(a2));
if (r < LEAN_MAX_SMALL_NAT) {
@ -81,8 +81,8 @@ static void nat_mul(vm_state & s) {
}
static void nat_sub(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
unsigned v1 = cidx(a1);
unsigned v2 = cidx(a2);
@ -101,8 +101,8 @@ static void nat_sub(vm_state & s) {
}
static void nat_div(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
unsigned v1 = cidx(a1);
unsigned v2 = cidx(a2);
@ -121,8 +121,8 @@ static void nat_div(vm_state & s) {
}
static void nat_mod(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
unsigned v1 = cidx(a1);
unsigned v2 = cidx(a2);
@ -141,16 +141,16 @@ static void nat_mod(vm_state & s) {
}
static void nat_gcd(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
mpz r;
gcd(r, to_mpz1(a1), to_mpz2(a2));
s.push(mk_vm_nat(r));
}
static void nat_has_decidable_eq(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
return s.push(mk_vm_bool(cidx(a1) == cidx(a2)));
} else {
@ -159,8 +159,8 @@ static void nat_has_decidable_eq(vm_state & s) {
}
static void nat_decidable_le(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
return s.push(mk_vm_bool(cidx(a1) <= cidx(a2)));
} else {
@ -169,8 +169,8 @@ static void nat_decidable_le(vm_state & s) {
}
static void nat_decidable_lt(vm_state & s) {
vm_obj const & a1 = s.get(0);
vm_obj const & a2 = s.get(1);
vm_obj const & a1 = s.get(-1);
vm_obj const & a2 = s.get(-2);
if (is_simple(a1) && is_simple(a2)) {
return s.push(mk_vm_bool(cidx(a1) < cidx(a2)));
} else {