From f22bdec775b9369df30d3728908caebfa5b9c8b9 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Thu, 1 Nov 2018 14:07:59 -0700 Subject: [PATCH] feat(library/vm): add support for join-points in the old VM Motivation: to make progress in the new compiler stack, we have to preserve join points during lambda lifting. Right now, they are lifted as regular lambdas. So, to keep them, we need some basic support for them in the old VM. The implementation here is quick and dirty. This is not an issue since this code will be deleted soon. --- src/library/compiler/emit_bytecode.cpp | 138 +++++++++++++++++-------- src/library/vm/optimize.cpp | 3 + src/library/vm/vm.cpp | 47 ++++++++- src/library/vm/vm.h | 13 ++- 4 files changed, 158 insertions(+), 43 deletions(-) diff --git a/src/library/compiler/emit_bytecode.cpp b/src/library/compiler/emit_bytecode.cpp index 3c244fa0d4..8e5c06062f 100644 --- a/src/library/compiler/emit_bytecode.cpp +++ b/src/library/compiler/emit_bytecode.cpp @@ -23,6 +23,18 @@ class emit_bytecode_fn { environment m_env; name_generator m_ngen; buffer & m_code; + unsigned m_arity; + + struct vdecl { + bool m_is_jp; + unsigned m_idx; + unsigned m_pc; // relevant only if m_is_jp == true + vdecl():m_is_jp(false), m_idx(0) {} + vdecl(unsigned idx):m_is_jp(false), m_idx(idx) {} + vdecl(unsigned bp, unsigned pc):m_is_jp(true), m_idx(bp), m_pc(pc) {} + }; + + typedef name_map vdecls; void emit(vm_instr const & i) { m_code.push_back(i); @@ -32,17 +44,13 @@ class emit_bytecode_fn { return m_code.size(); } - expr mk_local(name const & n) { - return ::lean::mk_local(n, mk_enf_neutral()); - } - - void compile_args(unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { for (unsigned i = 0; i < nargs; i++, bpz++) { compile(args[i], bpz, m); } } - void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { unsigned i = nargs; while (i > 0) { --i; @@ -51,7 +59,7 @@ class emit_bytecode_fn { } } - void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { compile_rev_args(nargs, args, bpz, m); if (decl.get_arity() <= nargs) { if (decl.is_builtin()) @@ -93,18 +101,19 @@ class emit_bytecode_fn { if (ssz != 0) throw_no_unboxed_support(); emit(mk_sconstructor_instr(cidx)); } else if (optional decl = get_vm_decl(m_env, n)) { - compile_global(*decl, 0, nullptr, 0, name_map()); + compile_global(*decl, 0, nullptr, 0, vdecls()); } else { throw_unknown_constant(n); } } - void compile_local(expr const & e, name_map const & m) { - unsigned idx = *m.find(local_name(e)); - emit(mk_push_instr(idx)); + void compile_fvar(expr const & e, vdecls const & m) { + vdecl d = *m.find(local_name(e)); + lean_assert(!d.m_is_jp); + emit(mk_push_instr(d.m_idx)); } - void compile_cases_on(expr const & e, unsigned bpz, name_map const & m) { + void compile_cases_on(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr fn = get_app_args(e, args); lean_assert(args.size() >= 3); /* major + at least 2 minor premises */ @@ -124,12 +133,12 @@ class emit_bytecode_fn { cases_args[i - 1] = next_pc(); expr b = args[i]; buffer locals; - name_map new_m = m; + vdecls new_m = m; unsigned new_bpz = bpz; while (is_lambda(b)) { name n = m_ngen.next(); - new_m.insert(n, new_bpz); - locals.push_back(mk_local(n)); + new_m.insert(n, vdecl(new_bpz)); + locals.push_back(mk_fvar(n)); new_bpz++; b = binding_body(b); } @@ -155,7 +164,7 @@ class emit_bytecode_fn { } } - void compile_cnstr(expr const & e, unsigned bpz, name_map const & m) { + void compile_cnstr(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr const & fn = get_app_args(e, args); lean_assert(is_llnf_cnstr(fn)); @@ -166,7 +175,7 @@ class emit_bytecode_fn { emit(mk_constructor_instr(cidx, get_app_num_args(e))); } - void compile_reuse(expr const & e, unsigned bpz, name_map const & m) { + void compile_reuse(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr const & fn = get_app_args(e, args); lean_assert(is_llnf_reuse(fn)); @@ -177,16 +186,15 @@ class emit_bytecode_fn { emit(mk_reuse_instr(cidx, get_app_num_args(e) - 1)); } - void compile_external(name const & n, buffer & args, unsigned bpz, name_map const & m) { + void compile_external(name const & n, buffer & args, unsigned bpz, vdecls const & m) { // Not sure if this is the best approach, trying to lazy load the required // dynamic libraries. - std::cout << "external compile" << n << std::endl; optional decl = get_vm_decl(m_env, n); lean_assert(decl); compile_global(*decl, args.size(), args.data(), bpz, m); } - void compile_fn_call(expr const & e, unsigned bpz, name_map const & m) { + void compile_fn_call(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr fn = get_app_args(e, args); if (!is_constant(fn)) { @@ -209,7 +217,7 @@ class emit_bytecode_fn { } } - void compile_proj(expr const & e, unsigned bpz, name_map const & m) { + void compile_proj(expr const & e, unsigned bpz, vdecls const & m) { expr const & p = app_fn(e); lean_assert(is_llnf_proj(p)); expr const & a = app_arg(e); @@ -219,7 +227,7 @@ class emit_bytecode_fn { emit(mk_proj_instr(idx)); } - void compile_reset(expr const & e, unsigned bpz, name_map const & m) { + void compile_reset(expr const & e, unsigned bpz, vdecls const & m) { expr fn = app_fn(e); expr s = app_arg(e); unsigned n; @@ -228,7 +236,16 @@ class emit_bytecode_fn { emit(mk_reset_instr(n)); } - void compile_app(expr const & e, unsigned bpz, name_map const & m) { + optional is_jp(vdecls const & m, expr const & fn) { + if (!is_fvar(fn)) return optional(); + vdecl const * d = m.find(fvar_name(fn)); + if (d && d->m_is_jp) + return optional(*d); + else + return optional(); + } + + void compile_app(expr const & e, unsigned bpz, vdecls const & m) { expr const & fn = get_app_fn(e); if (is_cases_on_app(m_env, fn)) { compile_cases_on(e, bpz, m); @@ -242,6 +259,11 @@ class emit_bytecode_fn { compile_reset(e, bpz, m); } else if (is_sorry(e)) { compile_global(*get_vm_decl(m_env, "sorry"), 0, nullptr, bpz, m); + } else if (optional d = is_jp(m, fn)) { + buffer args; + get_app_args(e, args); + compile_rev_args(args.size(), args.data(), bpz, m); + emit(mk_invoke_jp_instr(d->m_pc, d->m_idx, args.size())); } else { compile_fn_call(e, bpz, m); } @@ -261,22 +283,56 @@ class emit_bytecode_fn { } } - void compile_let(expr e, unsigned bpz, name_map const & m) { + void emit_jp(expr e, unsigned bpz, vdecls m) { + unsigned init_bpz = bpz; + unsigned arity = get_arity(e); + lean_assert(arity > 0); + buffer locals; + unsigned i = arity; + while (is_lambda(e)) { + name n = m_ngen.next(); + i--; + m.insert(n, vdecl(init_bpz+i)); + locals.push_back(mk_fvar(n)); + bpz++; + e = binding_body(e); + } + e = instantiate_rev(e, locals.size(), locals.data()); + compile(e, bpz, m); + lean_assert(bpz >= m_arity); + unsigned to_drop = bpz - m_arity; + if (to_drop > 0) + emit(mk_drop_instr(to_drop)); + emit(mk_ret_instr()); + } + + void compile_let(expr e, unsigned bpz, vdecls const & m) { unsigned counter = 0; buffer locals; - name_map new_m = m; + vdecls new_m = m; while (is_let(e)) { - counter++; - compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); - name n = m_ngen.next(); - new_m.insert(n, bpz); - locals.push_back(mk_local(n)); - bpz++; + if (is_join_point_name(let_name(e))) { + unsigned goto_pc = next_pc(); + emit(mk_goto_instr(0)); // fix later + name n = m_ngen.next(); + new_m.insert(n, vdecl(bpz, next_pc())); + emit_jp(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); + locals.push_back(mk_fvar(n)); + unsigned cont_pc = next_pc(); + m_code[goto_pc].set_goto_pc(cont_pc); + } else { + counter++; + compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); + name n = m_ngen.next(); + new_m.insert(n, vdecl(bpz)); + locals.push_back(mk_fvar(n)); + bpz++; + } e = let_body(e); } - lean_assert(counter > 0); compile(instantiate_rev(e, locals.size(), locals.data()), bpz, new_m); - emit(mk_drop_instr(counter)); + if (counter > 0) + emit(mk_drop_instr(counter)); } void compile_lit(expr const & e) { @@ -290,7 +346,7 @@ class emit_bytecode_fn { } } - void compile(expr const & e, unsigned bpz, name_map const & m) { + void compile(expr const & e, unsigned bpz, vdecls const & m) { switch (e.kind()) { case expr_kind::BVar: lean_unreachable(); case expr_kind::Sort: lean_unreachable(); @@ -300,7 +356,7 @@ class emit_bytecode_fn { case expr_kind::Proj: lean_unreachable(); case expr_kind::MData: compile(mdata_expr(e), bpz, m); break; case expr_kind::Const: compile_constant(e); break; - case expr_kind::FVar: compile_local(e, m); break; + case expr_kind::FVar: compile_fvar(e, m); break; case expr_kind::App: compile_app(e, bpz, m); break; case expr_kind::Let: compile_let(e, bpz, m); break; case expr_kind::Lit: compile_lit(e); break; @@ -323,15 +379,15 @@ public: pair> operator()(expr e) { buffer locals; unsigned bpz = 0; - unsigned arity = get_arity(e); - unsigned i = arity; - name_map m; + m_arity = get_arity(e); + unsigned i = m_arity; + vdecls m; list args_info; while (is_lambda(e)) { name n = m_ngen.next(); i--; - m.insert(n, i); - locals.push_back(mk_local(n)); + m.insert(n, vdecl(i)); + locals.push_back(mk_fvar(n)); bpz++; args_info = cons(vm_local_info(binding_name(e), to_type_info(binding_domain(e))), args_info); e = binding_body(e); @@ -339,7 +395,7 @@ public: e = instantiate_rev(e, locals.size(), locals.data()); compile(e, bpz, m); emit(mk_ret_instr()); - return mk_pair(arity, args_info); + return mk_pair(m_arity, args_info); } }; diff --git a/src/library/vm/optimize.cpp b/src/library/vm/optimize.cpp index e30151e1da..0d3acab490 100644 --- a/src/library/vm/optimize.cpp +++ b/src/library/vm/optimize.cpp @@ -105,6 +105,9 @@ class live_vars_fn { case opcode::Reset: case opcode::Reuse: s = collect(pc+1); break; + case opcode::InvokeJP: + s = collect(instr.get_jp_pc()); + break; case opcode::Push: case opcode::Move: s = collect(pc+1); s.insert(instr.get_idx()); diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index 417f3e8e67..adb272071d 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -555,6 +555,8 @@ void vm_instr::display(std::ostream & out) const { out << "cfun "; display_fn(out, m_fn_idx); break; + case opcode::InvokeJP: + out << "jp " << m_jp_pc << " " << m_jp_bp << " " << m_jp_arity; break; case opcode::Closure: out << "closure "; display_fn(out, m_fn_idx); @@ -736,6 +738,14 @@ vm_instr mk_invoke_cfun_instr(unsigned fn_idx) { return r; } +vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity) { + vm_instr r(opcode::InvokeJP); + r.m_jp_pc = pc; + r.m_jp_bp = bp; + r.m_jp_arity = arity; + return r; +} + vm_instr mk_closure_instr(unsigned fn_idx, unsigned n) { vm_instr r(opcode::Closure); r.m_fn_idx = fn_idx; @@ -770,6 +780,11 @@ void vm_instr::copy_args(vm_instr const & i) { case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun: m_fn_idx = i.m_fn_idx; break; + case opcode::InvokeJP: + m_jp_pc = i.m_jp_pc; + m_jp_bp = i.m_jp_bp; + m_jp_arity = i.m_jp_arity; + break; case opcode::Closure: m_fn_idx = i.m_fn_idx; m_nargs = i.m_nargs; @@ -884,6 +899,9 @@ void vm_instr::serialize(serializer & s, std::function const & i case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun: s << idx2name(m_fn_idx); break; + case opcode::InvokeJP: + s << m_jp_pc << m_jp_bp << m_jp_arity; + break; case opcode::Closure: s << idx2name(m_fn_idx) << m_nargs; break; @@ -954,6 +972,10 @@ static vm_instr read_vm_instr(deserializer & d) { case opcode::Closure: idx = read_fn_idx(d); return mk_closure_instr(idx, d.read_unsigned()); + case opcode::InvokeJP: + pc = d.read_unsigned(); + idx = d.read_unsigned(); + return mk_invoke_jp_instr(pc, idx, d.read_unsigned()); case opcode::Push: return mk_push_instr(d.read_unsigned()); case opcode::Move: @@ -2781,7 +2803,8 @@ void vm_state::run() { lean_trace(name({"vm", "run"}), tout() << m_decl_vector[m_fn_idx].get_name() << " @ " << m_pc << ": "; instr.display(tout().get_stream()); - tout() << "\n";); + tout() << "\n"; + display_stack(tout().get_stream());); }); switch (instr.op()) { case opcode::Push: @@ -3186,6 +3209,28 @@ void vm_state::run() { invoke_global(decl); goto main_loop; } + case opcode::InvokeJP: { + /* Join point call: jmp pc bp n + + stack before after + ... + bp=> ... a_1 + ... ... + a_1 a_n + ... + a_n + */ + unsigned jp_bp = instr.get_jp_bp(); + unsigned jp_arity = instr.get_jp_arity(); + unsigned i = m_bp + jp_bp; + unsigned j = m_stack.size() - jp_arity; + for (; j < m_stack.size(); j++, i++) { + std::swap(m_stack[i], m_stack[j]); + } + m_pc = instr.get_jp_pc(); + m_stack.resize(m_bp + jp_bp + jp_arity); + goto main_loop; + } case opcode::InvokeBuiltin: { check_interrupted(); check_heartbeat(); diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index e718a57196..d01a05965d 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -315,7 +315,7 @@ enum class opcode { Cases2, CasesN, Proj, Apply, InvokeGlobal, InvokeBuiltin, InvokeCFun, Closure, Unreachable, Expr, LocalInfo, - Reset, Reuse + Reset, Reuse, InvokeJP }; /** \brief VM instructions */ @@ -326,6 +326,11 @@ class vm_instr { unsigned m_fn_idx; /* InvokeGlobal, InvokeBuiltin, InvokeCFun and Closure. */ unsigned m_nargs; /* Closure */ }; + struct { /* InvokeJP */ + unsigned m_jp_pc; + unsigned m_jp_bp; + unsigned m_jp_arity; + }; /* Push, Move, Proj */ unsigned m_idx; /* Drop, Reset */ @@ -374,6 +379,7 @@ class vm_instr { friend vm_instr mk_invoke_global_instr(unsigned fn_idx); friend vm_instr mk_invoke_cfun_instr(unsigned fn_idx); friend vm_instr mk_invoke_builtin_instr(unsigned fn_idx); + friend vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity); friend vm_instr mk_closure_instr(unsigned fn_idx, unsigned n); friend vm_instr mk_expr_instr(expr const &e); friend vm_instr mk_local_info_instr(unsigned idx, name const & n, optional const & e); @@ -399,6 +405,10 @@ public: return m_fn_idx; } + unsigned get_jp_pc() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_pc; } + unsigned get_jp_bp() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_bp; } + unsigned get_jp_arity() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_arity; } + unsigned get_nargs() const { lean_assert(m_op == opcode::Closure); return m_nargs; @@ -515,6 +525,7 @@ vm_instr mk_apply_instr(); vm_instr mk_invoke_global_instr(unsigned fn_idx); vm_instr mk_invoke_cfun_instr(unsigned fn_idx); vm_instr mk_invoke_builtin_instr(unsigned fn_idx); +vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity); vm_instr mk_closure_instr(unsigned fn_idx, unsigned n); vm_instr mk_expr_instr(expr const &e); vm_instr mk_local_info_instr(unsigned idx, name const & n, optional const & e);