feat(library/vm): add support for join-points in the old VM

Motivation: to make progress in the new compiler stack, we have
to preserve join points during lambda lifting. Right now, they are
lifted as regular lambdas. So, to keep them, we need some basic
support for them in the old VM. The implementation here is quick and
dirty. This is not an issue since this code will be deleted soon.
This commit is contained in:
Leonardo de Moura 2018-11-01 14:07:59 -07:00
parent ae30b16f0d
commit f22bdec775
4 changed files with 158 additions and 43 deletions

View file

@ -23,6 +23,18 @@ class emit_bytecode_fn {
environment m_env;
name_generator m_ngen;
buffer<vm_instr> & m_code;
unsigned m_arity;
struct vdecl {
bool m_is_jp;
unsigned m_idx;
unsigned m_pc; // relevant only if m_is_jp == true
vdecl():m_is_jp(false), m_idx(0) {}
vdecl(unsigned idx):m_is_jp(false), m_idx(idx) {}
vdecl(unsigned bp, unsigned pc):m_is_jp(true), m_idx(bp), m_pc(pc) {}
};
typedef name_map<vdecl> vdecls;
void emit(vm_instr const & i) {
m_code.push_back(i);
@ -32,17 +44,13 @@ class emit_bytecode_fn {
return m_code.size();
}
expr mk_local(name const & n) {
return ::lean::mk_local(n, mk_enf_neutral());
}
void compile_args(unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
void compile_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) {
for (unsigned i = 0; i < nargs; i++, bpz++) {
compile(args[i], bpz, m);
}
}
void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) {
unsigned i = nargs;
while (i > 0) {
--i;
@ -51,7 +59,7 @@ class emit_bytecode_fn {
}
}
void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map<unsigned> const & m) {
void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) {
compile_rev_args(nargs, args, bpz, m);
if (decl.get_arity() <= nargs) {
if (decl.is_builtin())
@ -93,18 +101,19 @@ class emit_bytecode_fn {
if (ssz != 0) throw_no_unboxed_support();
emit(mk_sconstructor_instr(cidx));
} else if (optional<vm_decl> decl = get_vm_decl(m_env, n)) {
compile_global(*decl, 0, nullptr, 0, name_map<unsigned>());
compile_global(*decl, 0, nullptr, 0, vdecls());
} else {
throw_unknown_constant(n);
}
}
void compile_local(expr const & e, name_map<unsigned> const & m) {
unsigned idx = *m.find(local_name(e));
emit(mk_push_instr(idx));
void compile_fvar(expr const & e, vdecls const & m) {
vdecl d = *m.find(local_name(e));
lean_assert(!d.m_is_jp);
emit(mk_push_instr(d.m_idx));
}
void compile_cases_on(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_cases_on(expr const & e, unsigned bpz, vdecls const & m) {
buffer<expr> args;
expr fn = get_app_args(e, args);
lean_assert(args.size() >= 3); /* major + at least 2 minor premises */
@ -124,12 +133,12 @@ class emit_bytecode_fn {
cases_args[i - 1] = next_pc();
expr b = args[i];
buffer<expr> locals;
name_map<unsigned> new_m = m;
vdecls new_m = m;
unsigned new_bpz = bpz;
while (is_lambda(b)) {
name n = m_ngen.next();
new_m.insert(n, new_bpz);
locals.push_back(mk_local(n));
new_m.insert(n, vdecl(new_bpz));
locals.push_back(mk_fvar(n));
new_bpz++;
b = binding_body(b);
}
@ -155,7 +164,7 @@ class emit_bytecode_fn {
}
}
void compile_cnstr(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_cnstr(expr const & e, unsigned bpz, vdecls const & m) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
lean_assert(is_llnf_cnstr(fn));
@ -166,7 +175,7 @@ class emit_bytecode_fn {
emit(mk_constructor_instr(cidx, get_app_num_args(e)));
}
void compile_reuse(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_reuse(expr const & e, unsigned bpz, vdecls const & m) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
lean_assert(is_llnf_reuse(fn));
@ -177,16 +186,15 @@ class emit_bytecode_fn {
emit(mk_reuse_instr(cidx, get_app_num_args(e) - 1));
}
void compile_external(name const & n, buffer<expr> & args, unsigned bpz, name_map<unsigned> const & m) {
void compile_external(name const & n, buffer<expr> & args, unsigned bpz, vdecls const & m) {
// Not sure if this is the best approach, trying to lazy load the required
// dynamic libraries.
std::cout << "external compile" << n << std::endl;
optional<vm_decl> decl = get_vm_decl(m_env, n);
lean_assert(decl);
compile_global(*decl, args.size(), args.data(), bpz, m);
}
void compile_fn_call(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_fn_call(expr const & e, unsigned bpz, vdecls const & m) {
buffer<expr> args;
expr fn = get_app_args(e, args);
if (!is_constant(fn)) {
@ -209,7 +217,7 @@ class emit_bytecode_fn {
}
}
void compile_proj(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_proj(expr const & e, unsigned bpz, vdecls const & m) {
expr const & p = app_fn(e);
lean_assert(is_llnf_proj(p));
expr const & a = app_arg(e);
@ -219,7 +227,7 @@ class emit_bytecode_fn {
emit(mk_proj_instr(idx));
}
void compile_reset(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile_reset(expr const & e, unsigned bpz, vdecls const & m) {
expr fn = app_fn(e);
expr s = app_arg(e);
unsigned n;
@ -228,7 +236,16 @@ class emit_bytecode_fn {
emit(mk_reset_instr(n));
}
void compile_app(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
optional<vdecl> is_jp(vdecls const & m, expr const & fn) {
if (!is_fvar(fn)) return optional<vdecl>();
vdecl const * d = m.find(fvar_name(fn));
if (d && d->m_is_jp)
return optional<vdecl>(*d);
else
return optional<vdecl>();
}
void compile_app(expr const & e, unsigned bpz, vdecls const & m) {
expr const & fn = get_app_fn(e);
if (is_cases_on_app(m_env, fn)) {
compile_cases_on(e, bpz, m);
@ -242,6 +259,11 @@ class emit_bytecode_fn {
compile_reset(e, bpz, m);
} else if (is_sorry(e)) {
compile_global(*get_vm_decl(m_env, "sorry"), 0, nullptr, bpz, m);
} else if (optional<vdecl> d = is_jp(m, fn)) {
buffer<expr> args;
get_app_args(e, args);
compile_rev_args(args.size(), args.data(), bpz, m);
emit(mk_invoke_jp_instr(d->m_pc, d->m_idx, args.size()));
} else {
compile_fn_call(e, bpz, m);
}
@ -261,22 +283,56 @@ class emit_bytecode_fn {
}
}
void compile_let(expr e, unsigned bpz, name_map<unsigned> const & m) {
void emit_jp(expr e, unsigned bpz, vdecls m) {
unsigned init_bpz = bpz;
unsigned arity = get_arity(e);
lean_assert(arity > 0);
buffer<expr> locals;
unsigned i = arity;
while (is_lambda(e)) {
name n = m_ngen.next();
i--;
m.insert(n, vdecl(init_bpz+i));
locals.push_back(mk_fvar(n));
bpz++;
e = binding_body(e);
}
e = instantiate_rev(e, locals.size(), locals.data());
compile(e, bpz, m);
lean_assert(bpz >= m_arity);
unsigned to_drop = bpz - m_arity;
if (to_drop > 0)
emit(mk_drop_instr(to_drop));
emit(mk_ret_instr());
}
void compile_let(expr e, unsigned bpz, vdecls const & m) {
unsigned counter = 0;
buffer<expr> locals;
name_map<unsigned> new_m = m;
vdecls new_m = m;
while (is_let(e)) {
counter++;
compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m);
name n = m_ngen.next();
new_m.insert(n, bpz);
locals.push_back(mk_local(n));
bpz++;
if (is_join_point_name(let_name(e))) {
unsigned goto_pc = next_pc();
emit(mk_goto_instr(0)); // fix later
name n = m_ngen.next();
new_m.insert(n, vdecl(bpz, next_pc()));
emit_jp(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m);
locals.push_back(mk_fvar(n));
unsigned cont_pc = next_pc();
m_code[goto_pc].set_goto_pc(cont_pc);
} else {
counter++;
compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m);
name n = m_ngen.next();
new_m.insert(n, vdecl(bpz));
locals.push_back(mk_fvar(n));
bpz++;
}
e = let_body(e);
}
lean_assert(counter > 0);
compile(instantiate_rev(e, locals.size(), locals.data()), bpz, new_m);
emit(mk_drop_instr(counter));
if (counter > 0)
emit(mk_drop_instr(counter));
}
void compile_lit(expr const & e) {
@ -290,7 +346,7 @@ class emit_bytecode_fn {
}
}
void compile(expr const & e, unsigned bpz, name_map<unsigned> const & m) {
void compile(expr const & e, unsigned bpz, vdecls const & m) {
switch (e.kind()) {
case expr_kind::BVar: lean_unreachable();
case expr_kind::Sort: lean_unreachable();
@ -300,7 +356,7 @@ class emit_bytecode_fn {
case expr_kind::Proj: lean_unreachable();
case expr_kind::MData: compile(mdata_expr(e), bpz, m); break;
case expr_kind::Const: compile_constant(e); break;
case expr_kind::FVar: compile_local(e, m); break;
case expr_kind::FVar: compile_fvar(e, m); break;
case expr_kind::App: compile_app(e, bpz, m); break;
case expr_kind::Let: compile_let(e, bpz, m); break;
case expr_kind::Lit: compile_lit(e); break;
@ -323,15 +379,15 @@ public:
pair<unsigned, list<vm_local_info>> operator()(expr e) {
buffer<expr> locals;
unsigned bpz = 0;
unsigned arity = get_arity(e);
unsigned i = arity;
name_map<unsigned> m;
m_arity = get_arity(e);
unsigned i = m_arity;
vdecls m;
list<vm_local_info> args_info;
while (is_lambda(e)) {
name n = m_ngen.next();
i--;
m.insert(n, i);
locals.push_back(mk_local(n));
m.insert(n, vdecl(i));
locals.push_back(mk_fvar(n));
bpz++;
args_info = cons(vm_local_info(binding_name(e), to_type_info(binding_domain(e))), args_info);
e = binding_body(e);
@ -339,7 +395,7 @@ public:
e = instantiate_rev(e, locals.size(), locals.data());
compile(e, bpz, m);
emit(mk_ret_instr());
return mk_pair(arity, args_info);
return mk_pair(m_arity, args_info);
}
};

View file

@ -105,6 +105,9 @@ class live_vars_fn {
case opcode::Reset: case opcode::Reuse:
s = collect(pc+1);
break;
case opcode::InvokeJP:
s = collect(instr.get_jp_pc());
break;
case opcode::Push: case opcode::Move:
s = collect(pc+1);
s.insert(instr.get_idx());

View file

@ -555,6 +555,8 @@ void vm_instr::display(std::ostream & out) const {
out << "cfun ";
display_fn(out, m_fn_idx);
break;
case opcode::InvokeJP:
out << "jp " << m_jp_pc << " " << m_jp_bp << " " << m_jp_arity; break;
case opcode::Closure:
out << "closure ";
display_fn(out, m_fn_idx);
@ -736,6 +738,14 @@ vm_instr mk_invoke_cfun_instr(unsigned fn_idx) {
return r;
}
vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity) {
vm_instr r(opcode::InvokeJP);
r.m_jp_pc = pc;
r.m_jp_bp = bp;
r.m_jp_arity = arity;
return r;
}
vm_instr mk_closure_instr(unsigned fn_idx, unsigned n) {
vm_instr r(opcode::Closure);
r.m_fn_idx = fn_idx;
@ -770,6 +780,11 @@ void vm_instr::copy_args(vm_instr const & i) {
case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun:
m_fn_idx = i.m_fn_idx;
break;
case opcode::InvokeJP:
m_jp_pc = i.m_jp_pc;
m_jp_bp = i.m_jp_bp;
m_jp_arity = i.m_jp_arity;
break;
case opcode::Closure:
m_fn_idx = i.m_fn_idx;
m_nargs = i.m_nargs;
@ -884,6 +899,9 @@ void vm_instr::serialize(serializer & s, std::function<name(unsigned)> const & i
case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun:
s << idx2name(m_fn_idx);
break;
case opcode::InvokeJP:
s << m_jp_pc << m_jp_bp << m_jp_arity;
break;
case opcode::Closure:
s << idx2name(m_fn_idx) << m_nargs;
break;
@ -954,6 +972,10 @@ static vm_instr read_vm_instr(deserializer & d) {
case opcode::Closure:
idx = read_fn_idx(d);
return mk_closure_instr(idx, d.read_unsigned());
case opcode::InvokeJP:
pc = d.read_unsigned();
idx = d.read_unsigned();
return mk_invoke_jp_instr(pc, idx, d.read_unsigned());
case opcode::Push:
return mk_push_instr(d.read_unsigned());
case opcode::Move:
@ -2781,7 +2803,8 @@ void vm_state::run() {
lean_trace(name({"vm", "run"}),
tout() << m_decl_vector[m_fn_idx].get_name() << " @ " << m_pc << ": ";
instr.display(tout().get_stream());
tout() << "\n";);
tout() << "\n";
display_stack(tout().get_stream()););
});
switch (instr.op()) {
case opcode::Push:
@ -3186,6 +3209,28 @@ void vm_state::run() {
invoke_global(decl);
goto main_loop;
}
case opcode::InvokeJP: {
/* Join point call: jmp pc bp n
stack before after
...
bp=> ... a_1
... ...
a_1 a_n
...
a_n
*/
unsigned jp_bp = instr.get_jp_bp();
unsigned jp_arity = instr.get_jp_arity();
unsigned i = m_bp + jp_bp;
unsigned j = m_stack.size() - jp_arity;
for (; j < m_stack.size(); j++, i++) {
std::swap(m_stack[i], m_stack[j]);
}
m_pc = instr.get_jp_pc();
m_stack.resize(m_bp + jp_bp + jp_arity);
goto main_loop;
}
case opcode::InvokeBuiltin: {
check_interrupted();
check_heartbeat();

View file

@ -315,7 +315,7 @@ enum class opcode {
Cases2, CasesN, Proj,
Apply, InvokeGlobal, InvokeBuiltin, InvokeCFun,
Closure, Unreachable, Expr, LocalInfo,
Reset, Reuse
Reset, Reuse, InvokeJP
};
/** \brief VM instructions */
@ -326,6 +326,11 @@ class vm_instr {
unsigned m_fn_idx; /* InvokeGlobal, InvokeBuiltin, InvokeCFun and Closure. */
unsigned m_nargs; /* Closure */
};
struct { /* InvokeJP */
unsigned m_jp_pc;
unsigned m_jp_bp;
unsigned m_jp_arity;
};
/* Push, Move, Proj */
unsigned m_idx;
/* Drop, Reset */
@ -374,6 +379,7 @@ class vm_instr {
friend vm_instr mk_invoke_global_instr(unsigned fn_idx);
friend vm_instr mk_invoke_cfun_instr(unsigned fn_idx);
friend vm_instr mk_invoke_builtin_instr(unsigned fn_idx);
friend vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity);
friend vm_instr mk_closure_instr(unsigned fn_idx, unsigned n);
friend vm_instr mk_expr_instr(expr const &e);
friend vm_instr mk_local_info_instr(unsigned idx, name const & n, optional<expr> const & e);
@ -399,6 +405,10 @@ public:
return m_fn_idx;
}
unsigned get_jp_pc() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_pc; }
unsigned get_jp_bp() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_bp; }
unsigned get_jp_arity() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_arity; }
unsigned get_nargs() const {
lean_assert(m_op == opcode::Closure);
return m_nargs;
@ -515,6 +525,7 @@ vm_instr mk_apply_instr();
vm_instr mk_invoke_global_instr(unsigned fn_idx);
vm_instr mk_invoke_cfun_instr(unsigned fn_idx);
vm_instr mk_invoke_builtin_instr(unsigned fn_idx);
vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity);
vm_instr mk_closure_instr(unsigned fn_idx, unsigned n);
vm_instr mk_expr_instr(expr const &e);
vm_instr mk_local_info_instr(unsigned idx, name const & n, optional<expr> const & e);