feat(library/vm/vm): store arguments in reverse order on the stack
It simplifies the code for handling closures.
This commit is contained in:
parent
dbcd609aff
commit
2bd400964c
4 changed files with 62 additions and 52 deletions
|
|
@ -40,8 +40,17 @@ class vm_compiler_fn {
|
|||
}
|
||||
}
|
||||
|
||||
void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
|
||||
unsigned i = nargs;
|
||||
while (i > 0) {
|
||||
--i;
|
||||
compile(args[i], bpz, m);
|
||||
bpz++;
|
||||
}
|
||||
}
|
||||
|
||||
void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
|
||||
compile_args(nargs, args, bpz, m);
|
||||
compile_rev_args(nargs, args, bpz, m);
|
||||
lean_assert(nargs <= decl.get_arity());
|
||||
if (decl.get_arity() == nargs) {
|
||||
if (decl.is_builtin())
|
||||
|
|
@ -157,21 +166,20 @@ class vm_compiler_fn {
|
|||
lean_assert(is_internal_proj(fn));
|
||||
unsigned idx = *is_internal_proj(fn);
|
||||
lean_assert(args.size() >= 1);
|
||||
compile_rev_args(args.size() - 1, args.data() + 1, bpz, m);
|
||||
bpz += args.size() - 1;
|
||||
compile(args[0], bpz, m);
|
||||
emit(mk_proj_instr(idx));
|
||||
if (args.size() > 1) {
|
||||
bpz++; /* function returned by proj is on the stack */
|
||||
compile_args(args.size() - 1, args.data() + 1, bpz, m);
|
||||
if (args.size() > 1)
|
||||
emit(mk_invoke_instr(args.size() - 1));
|
||||
}
|
||||
}
|
||||
|
||||
void compile_fn_call(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(e, args);
|
||||
if (!is_constant(fn)) {
|
||||
compile_rev_args(args.size(), args.data(), bpz+1, m);
|
||||
compile(fn, bpz, m);
|
||||
compile_args(args.size(), args.data(), bpz+1, m);
|
||||
emit(mk_invoke_instr(args.size()));
|
||||
return;
|
||||
} else if (is_constant(fn)) {
|
||||
|
|
@ -239,17 +247,29 @@ class vm_compiler_fn {
|
|||
}
|
||||
}
|
||||
|
||||
unsigned get_arity(expr e) {
|
||||
unsigned r = 0;
|
||||
while (is_lambda(e)) {
|
||||
r++;
|
||||
e = binding_body(e);
|
||||
}
|
||||
return r;
|
||||
}
|
||||
|
||||
public:
|
||||
vm_compiler_fn(environment const & env, buffer<vm_instr> & code):
|
||||
m_env(env), m_code(code) {}
|
||||
|
||||
void operator()(expr e) {
|
||||
unsigned operator()(expr e) {
|
||||
buffer<expr> locals;
|
||||
unsigned bpz = 0;
|
||||
unsigned bpz = 0;
|
||||
unsigned arity = get_arity(e);
|
||||
unsigned i = arity;
|
||||
name_map<unsigned> m;
|
||||
while (is_lambda(e)) {
|
||||
name n = mk_fresh_name();
|
||||
m.insert(n, bpz);
|
||||
i--;
|
||||
m.insert(n, i);
|
||||
locals.push_back(mk_local(n));
|
||||
bpz++;
|
||||
e = binding_body(e);
|
||||
|
|
@ -257,6 +277,7 @@ public:
|
|||
e = instantiate_rev(e, locals.size(), locals.data());
|
||||
compile(e, bpz, m);
|
||||
emit(mk_ret_instr());
|
||||
return arity;
|
||||
}
|
||||
};
|
||||
|
||||
|
|
@ -268,9 +289,9 @@ environment vm_compile(environment const & env, buffer<pair<name, expr>> const &
|
|||
for (auto const & p : procs) {
|
||||
buffer<vm_instr> code;
|
||||
vm_compiler_fn gen(new_env, code);
|
||||
gen(p.second);
|
||||
unsigned arity = gen(p.second);
|
||||
optimize(new_env, code);
|
||||
lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << "\n";
|
||||
lean_trace(name({"compiler", "code_gen"}), tout() << " " << p.first << " " << arity << "\n";
|
||||
display_vm_code(tout().get_stream(), new_env, code.size(), code.data()););
|
||||
new_env = update_vm_code(new_env, p.first, code.size(), code.data());
|
||||
}
|
||||
|
|
|
|||
|
|
@ -598,7 +598,7 @@ void vm_state::push_fields(vm_obj const & obj) {
|
|||
void vm_state::invoke_builtin(vm_decl const & d) {
|
||||
unsigned saved_bp = m_bp;
|
||||
unsigned sz = m_stack.size();
|
||||
m_bp = sz - d.get_arity();
|
||||
m_bp = sz;
|
||||
d.get_fn()(*this);
|
||||
lean_assert(m_stack.size() == sz + 1);
|
||||
m_bp = saved_bp;
|
||||
|
|
@ -625,11 +625,6 @@ void vm_state::invoke_global_builtin(vm_decl const & d) {
|
|||
}
|
||||
|
||||
void vm_state::run() {
|
||||
/*
|
||||
TODO(Leo): we can improve performance using the following tricks:
|
||||
- Function arguments in reverse order.
|
||||
- Function pointer after arguments.
|
||||
*/
|
||||
lean_assert(m_code);
|
||||
unsigned init_call_stack_sz = m_call_stack.size();
|
||||
m_pc = 0;
|
||||
|
|
@ -767,27 +762,15 @@ void vm_state::run() {
|
|||
case opcode::Invoke: {
|
||||
unsigned nargs = instr.get_num();
|
||||
unsigned sz = m_stack.size();
|
||||
vm_obj closure = m_stack[sz - nargs - 1];
|
||||
vm_obj closure = m_stack.back();
|
||||
m_stack.pop_back();
|
||||
unsigned fn_idx = cfn_idx(closure);
|
||||
vm_decl const & d = m_decls[fn_idx];
|
||||
unsigned csz = csize(closure);
|
||||
unsigned arity = d.get_arity();
|
||||
unsigned new_nargs = nargs + csz;
|
||||
lean_assert(new_nargs <= arity);
|
||||
if (csz == 0) {
|
||||
/* Closure has size 0, then we just move arguments down 1 position */
|
||||
m_stack.erase(m_stack.end() - nargs - 1); /* remove closure object */
|
||||
} else if (csz == 1) {
|
||||
/* Closure has size 1, then we replace the closure object with
|
||||
the data stored in the closure */
|
||||
*(m_stack.end() - nargs - 1) = cfield(closure, 0);
|
||||
} else {
|
||||
lean_assert(csz > 1);
|
||||
/* Closure has size > 1, then we need to open space on the stack */
|
||||
m_stack.resize(sz + csz - 1);
|
||||
std::move_backward(m_stack.end() - nargs + 1 - csz, m_stack.end() + 1 - csz, m_stack.end());
|
||||
std::copy(cfields(closure), cfields(closure) + csz, m_stack.end() - nargs - csz);
|
||||
}
|
||||
std::copy(cfields(closure), cfields(closure) + csz, m_stack.end());
|
||||
/* Now, stack contains closure arguments + original stack arguments */
|
||||
if (new_nargs == arity) {
|
||||
invoke_global_builtin(d);
|
||||
|
|
|
|||
|
|
@ -420,10 +420,16 @@ public:
|
|||
vm_state(environment const & env);
|
||||
|
||||
environment const & env() const { return m_env; }
|
||||
|
||||
/** \brief Push object into the data stack */
|
||||
void push(vm_obj const & o) { m_stack.push_back(o); }
|
||||
|
||||
/** \brief Retrieve object from the call stack */
|
||||
vm_obj const & get(unsigned idx) const { lean_assert(m_bp + idx < m_stack.size()); return m_stack[m_bp + idx]; }
|
||||
vm_obj const & get(int idx) const {
|
||||
lean_assert(idx + static_cast<int>(m_bp) >= 0);
|
||||
lean_assert(m_bp + idx < m_stack.size());
|
||||
return m_stack[m_bp + idx];
|
||||
}
|
||||
|
||||
void invoke_global(name const & fn);
|
||||
void invoke_global(unsigned fn_idx);
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ static mpz const & to_mpz2(vm_obj const & o) {
|
|||
}
|
||||
|
||||
static void nat_succ(vm_state & s) {
|
||||
vm_obj const & a = s.get(0);
|
||||
vm_obj const & a = s.get(-1);
|
||||
if (is_simple(a)) {
|
||||
s.push(mk_vm_nat(cidx(a) + 1));
|
||||
} else {
|
||||
|
|
@ -58,8 +58,8 @@ static void nat_succ(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_add(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
s.push(mk_vm_nat(cidx(a1) + cidx(a2)));
|
||||
} else {
|
||||
|
|
@ -68,8 +68,8 @@ static void nat_add(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_mul(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
unsigned long long r = static_cast<unsigned long long>(cidx(a1)) * static_cast<unsigned long long>(cidx(a2));
|
||||
if (r < LEAN_MAX_SMALL_NAT) {
|
||||
|
|
@ -81,8 +81,8 @@ static void nat_mul(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_sub(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
unsigned v1 = cidx(a1);
|
||||
unsigned v2 = cidx(a2);
|
||||
|
|
@ -101,8 +101,8 @@ static void nat_sub(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_div(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
unsigned v1 = cidx(a1);
|
||||
unsigned v2 = cidx(a2);
|
||||
|
|
@ -121,8 +121,8 @@ static void nat_div(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_mod(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
unsigned v1 = cidx(a1);
|
||||
unsigned v2 = cidx(a2);
|
||||
|
|
@ -141,16 +141,16 @@ static void nat_mod(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_gcd(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
mpz r;
|
||||
gcd(r, to_mpz1(a1), to_mpz2(a2));
|
||||
s.push(mk_vm_nat(r));
|
||||
}
|
||||
|
||||
static void nat_has_decidable_eq(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
return s.push(mk_vm_bool(cidx(a1) == cidx(a2)));
|
||||
} else {
|
||||
|
|
@ -159,8 +159,8 @@ static void nat_has_decidable_eq(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_decidable_le(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
return s.push(mk_vm_bool(cidx(a1) <= cidx(a2)));
|
||||
} else {
|
||||
|
|
@ -169,8 +169,8 @@ static void nat_decidable_le(vm_state & s) {
|
|||
}
|
||||
|
||||
static void nat_decidable_lt(vm_state & s) {
|
||||
vm_obj const & a1 = s.get(0);
|
||||
vm_obj const & a2 = s.get(1);
|
||||
vm_obj const & a1 = s.get(-1);
|
||||
vm_obj const & a2 = s.get(-2);
|
||||
if (is_simple(a1) && is_simple(a2)) {
|
||||
return s.push(mk_vm_bool(cidx(a1) < cidx(a2)));
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue