diff --git a/src/library/compiler/emit_bytecode.cpp b/src/library/compiler/emit_bytecode.cpp index 3c244fa0d4..8e5c06062f 100644 --- a/src/library/compiler/emit_bytecode.cpp +++ b/src/library/compiler/emit_bytecode.cpp @@ -23,6 +23,18 @@ class emit_bytecode_fn { environment m_env; name_generator m_ngen; buffer & m_code; + unsigned m_arity; + + struct vdecl { + bool m_is_jp; + unsigned m_idx; + unsigned m_pc; // relevant only if m_is_jp == true + vdecl():m_is_jp(false), m_idx(0) {} + vdecl(unsigned idx):m_is_jp(false), m_idx(idx) {} + vdecl(unsigned bp, unsigned pc):m_is_jp(true), m_idx(bp), m_pc(pc) {} + }; + + typedef name_map vdecls; void emit(vm_instr const & i) { m_code.push_back(i); @@ -32,17 +44,13 @@ class emit_bytecode_fn { return m_code.size(); } - expr mk_local(name const & n) { - return ::lean::mk_local(n, mk_enf_neutral()); - } - - void compile_args(unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { for (unsigned i = 0; i < nargs; i++, bpz++) { compile(args[i], bpz, m); } } - void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_rev_args(unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { unsigned i = nargs; while (i > 0) { --i; @@ -51,7 +59,7 @@ class emit_bytecode_fn { } } - void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, name_map const & m) { + void compile_global(vm_decl const & decl, unsigned nargs, expr const * args, unsigned bpz, vdecls const & m) { compile_rev_args(nargs, args, bpz, m); if (decl.get_arity() <= nargs) { if (decl.is_builtin()) @@ -93,18 +101,19 @@ class emit_bytecode_fn { if (ssz != 0) throw_no_unboxed_support(); emit(mk_sconstructor_instr(cidx)); } else if (optional decl = get_vm_decl(m_env, n)) { - compile_global(*decl, 0, nullptr, 0, name_map()); + compile_global(*decl, 0, nullptr, 0, vdecls()); } else { throw_unknown_constant(n); } } - void compile_local(expr const & e, name_map const & m) { - unsigned idx = *m.find(local_name(e)); - emit(mk_push_instr(idx)); + void compile_fvar(expr const & e, vdecls const & m) { + vdecl d = *m.find(local_name(e)); + lean_assert(!d.m_is_jp); + emit(mk_push_instr(d.m_idx)); } - void compile_cases_on(expr const & e, unsigned bpz, name_map const & m) { + void compile_cases_on(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr fn = get_app_args(e, args); lean_assert(args.size() >= 3); /* major + at least 2 minor premises */ @@ -124,12 +133,12 @@ class emit_bytecode_fn { cases_args[i - 1] = next_pc(); expr b = args[i]; buffer locals; - name_map new_m = m; + vdecls new_m = m; unsigned new_bpz = bpz; while (is_lambda(b)) { name n = m_ngen.next(); - new_m.insert(n, new_bpz); - locals.push_back(mk_local(n)); + new_m.insert(n, vdecl(new_bpz)); + locals.push_back(mk_fvar(n)); new_bpz++; b = binding_body(b); } @@ -155,7 +164,7 @@ class emit_bytecode_fn { } } - void compile_cnstr(expr const & e, unsigned bpz, name_map const & m) { + void compile_cnstr(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr const & fn = get_app_args(e, args); lean_assert(is_llnf_cnstr(fn)); @@ -166,7 +175,7 @@ class emit_bytecode_fn { emit(mk_constructor_instr(cidx, get_app_num_args(e))); } - void compile_reuse(expr const & e, unsigned bpz, name_map const & m) { + void compile_reuse(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr const & fn = get_app_args(e, args); lean_assert(is_llnf_reuse(fn)); @@ -177,16 +186,15 @@ class emit_bytecode_fn { emit(mk_reuse_instr(cidx, get_app_num_args(e) - 1)); } - void compile_external(name const & n, buffer & args, unsigned bpz, name_map const & m) { + void compile_external(name const & n, buffer & args, unsigned bpz, vdecls const & m) { // Not sure if this is the best approach, trying to lazy load the required // dynamic libraries. - std::cout << "external compile" << n << std::endl; optional decl = get_vm_decl(m_env, n); lean_assert(decl); compile_global(*decl, args.size(), args.data(), bpz, m); } - void compile_fn_call(expr const & e, unsigned bpz, name_map const & m) { + void compile_fn_call(expr const & e, unsigned bpz, vdecls const & m) { buffer args; expr fn = get_app_args(e, args); if (!is_constant(fn)) { @@ -209,7 +217,7 @@ class emit_bytecode_fn { } } - void compile_proj(expr const & e, unsigned bpz, name_map const & m) { + void compile_proj(expr const & e, unsigned bpz, vdecls const & m) { expr const & p = app_fn(e); lean_assert(is_llnf_proj(p)); expr const & a = app_arg(e); @@ -219,7 +227,7 @@ class emit_bytecode_fn { emit(mk_proj_instr(idx)); } - void compile_reset(expr const & e, unsigned bpz, name_map const & m) { + void compile_reset(expr const & e, unsigned bpz, vdecls const & m) { expr fn = app_fn(e); expr s = app_arg(e); unsigned n; @@ -228,7 +236,16 @@ class emit_bytecode_fn { emit(mk_reset_instr(n)); } - void compile_app(expr const & e, unsigned bpz, name_map const & m) { + optional is_jp(vdecls const & m, expr const & fn) { + if (!is_fvar(fn)) return optional(); + vdecl const * d = m.find(fvar_name(fn)); + if (d && d->m_is_jp) + return optional(*d); + else + return optional(); + } + + void compile_app(expr const & e, unsigned bpz, vdecls const & m) { expr const & fn = get_app_fn(e); if (is_cases_on_app(m_env, fn)) { compile_cases_on(e, bpz, m); @@ -242,6 +259,11 @@ class emit_bytecode_fn { compile_reset(e, bpz, m); } else if (is_sorry(e)) { compile_global(*get_vm_decl(m_env, "sorry"), 0, nullptr, bpz, m); + } else if (optional d = is_jp(m, fn)) { + buffer args; + get_app_args(e, args); + compile_rev_args(args.size(), args.data(), bpz, m); + emit(mk_invoke_jp_instr(d->m_pc, d->m_idx, args.size())); } else { compile_fn_call(e, bpz, m); } @@ -261,22 +283,56 @@ class emit_bytecode_fn { } } - void compile_let(expr e, unsigned bpz, name_map const & m) { + void emit_jp(expr e, unsigned bpz, vdecls m) { + unsigned init_bpz = bpz; + unsigned arity = get_arity(e); + lean_assert(arity > 0); + buffer locals; + unsigned i = arity; + while (is_lambda(e)) { + name n = m_ngen.next(); + i--; + m.insert(n, vdecl(init_bpz+i)); + locals.push_back(mk_fvar(n)); + bpz++; + e = binding_body(e); + } + e = instantiate_rev(e, locals.size(), locals.data()); + compile(e, bpz, m); + lean_assert(bpz >= m_arity); + unsigned to_drop = bpz - m_arity; + if (to_drop > 0) + emit(mk_drop_instr(to_drop)); + emit(mk_ret_instr()); + } + + void compile_let(expr e, unsigned bpz, vdecls const & m) { unsigned counter = 0; buffer locals; - name_map new_m = m; + vdecls new_m = m; while (is_let(e)) { - counter++; - compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); - name n = m_ngen.next(); - new_m.insert(n, bpz); - locals.push_back(mk_local(n)); - bpz++; + if (is_join_point_name(let_name(e))) { + unsigned goto_pc = next_pc(); + emit(mk_goto_instr(0)); // fix later + name n = m_ngen.next(); + new_m.insert(n, vdecl(bpz, next_pc())); + emit_jp(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); + locals.push_back(mk_fvar(n)); + unsigned cont_pc = next_pc(); + m_code[goto_pc].set_goto_pc(cont_pc); + } else { + counter++; + compile(instantiate_rev(let_value(e), locals.size(), locals.data()), bpz, new_m); + name n = m_ngen.next(); + new_m.insert(n, vdecl(bpz)); + locals.push_back(mk_fvar(n)); + bpz++; + } e = let_body(e); } - lean_assert(counter > 0); compile(instantiate_rev(e, locals.size(), locals.data()), bpz, new_m); - emit(mk_drop_instr(counter)); + if (counter > 0) + emit(mk_drop_instr(counter)); } void compile_lit(expr const & e) { @@ -290,7 +346,7 @@ class emit_bytecode_fn { } } - void compile(expr const & e, unsigned bpz, name_map const & m) { + void compile(expr const & e, unsigned bpz, vdecls const & m) { switch (e.kind()) { case expr_kind::BVar: lean_unreachable(); case expr_kind::Sort: lean_unreachable(); @@ -300,7 +356,7 @@ class emit_bytecode_fn { case expr_kind::Proj: lean_unreachable(); case expr_kind::MData: compile(mdata_expr(e), bpz, m); break; case expr_kind::Const: compile_constant(e); break; - case expr_kind::FVar: compile_local(e, m); break; + case expr_kind::FVar: compile_fvar(e, m); break; case expr_kind::App: compile_app(e, bpz, m); break; case expr_kind::Let: compile_let(e, bpz, m); break; case expr_kind::Lit: compile_lit(e); break; @@ -323,15 +379,15 @@ public: pair> operator()(expr e) { buffer locals; unsigned bpz = 0; - unsigned arity = get_arity(e); - unsigned i = arity; - name_map m; + m_arity = get_arity(e); + unsigned i = m_arity; + vdecls m; list args_info; while (is_lambda(e)) { name n = m_ngen.next(); i--; - m.insert(n, i); - locals.push_back(mk_local(n)); + m.insert(n, vdecl(i)); + locals.push_back(mk_fvar(n)); bpz++; args_info = cons(vm_local_info(binding_name(e), to_type_info(binding_domain(e))), args_info); e = binding_body(e); @@ -339,7 +395,7 @@ public: e = instantiate_rev(e, locals.size(), locals.data()); compile(e, bpz, m); emit(mk_ret_instr()); - return mk_pair(arity, args_info); + return mk_pair(m_arity, args_info); } }; diff --git a/src/library/vm/optimize.cpp b/src/library/vm/optimize.cpp index e30151e1da..0d3acab490 100644 --- a/src/library/vm/optimize.cpp +++ b/src/library/vm/optimize.cpp @@ -105,6 +105,9 @@ class live_vars_fn { case opcode::Reset: case opcode::Reuse: s = collect(pc+1); break; + case opcode::InvokeJP: + s = collect(instr.get_jp_pc()); + break; case opcode::Push: case opcode::Move: s = collect(pc+1); s.insert(instr.get_idx()); diff --git a/src/library/vm/vm.cpp b/src/library/vm/vm.cpp index 417f3e8e67..adb272071d 100644 --- a/src/library/vm/vm.cpp +++ b/src/library/vm/vm.cpp @@ -555,6 +555,8 @@ void vm_instr::display(std::ostream & out) const { out << "cfun "; display_fn(out, m_fn_idx); break; + case opcode::InvokeJP: + out << "jp " << m_jp_pc << " " << m_jp_bp << " " << m_jp_arity; break; case opcode::Closure: out << "closure "; display_fn(out, m_fn_idx); @@ -736,6 +738,14 @@ vm_instr mk_invoke_cfun_instr(unsigned fn_idx) { return r; } +vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity) { + vm_instr r(opcode::InvokeJP); + r.m_jp_pc = pc; + r.m_jp_bp = bp; + r.m_jp_arity = arity; + return r; +} + vm_instr mk_closure_instr(unsigned fn_idx, unsigned n) { vm_instr r(opcode::Closure); r.m_fn_idx = fn_idx; @@ -770,6 +780,11 @@ void vm_instr::copy_args(vm_instr const & i) { case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun: m_fn_idx = i.m_fn_idx; break; + case opcode::InvokeJP: + m_jp_pc = i.m_jp_pc; + m_jp_bp = i.m_jp_bp; + m_jp_arity = i.m_jp_arity; + break; case opcode::Closure: m_fn_idx = i.m_fn_idx; m_nargs = i.m_nargs; @@ -884,6 +899,9 @@ void vm_instr::serialize(serializer & s, std::function const & i case opcode::InvokeGlobal: case opcode::InvokeBuiltin: case opcode::InvokeCFun: s << idx2name(m_fn_idx); break; + case opcode::InvokeJP: + s << m_jp_pc << m_jp_bp << m_jp_arity; + break; case opcode::Closure: s << idx2name(m_fn_idx) << m_nargs; break; @@ -954,6 +972,10 @@ static vm_instr read_vm_instr(deserializer & d) { case opcode::Closure: idx = read_fn_idx(d); return mk_closure_instr(idx, d.read_unsigned()); + case opcode::InvokeJP: + pc = d.read_unsigned(); + idx = d.read_unsigned(); + return mk_invoke_jp_instr(pc, idx, d.read_unsigned()); case opcode::Push: return mk_push_instr(d.read_unsigned()); case opcode::Move: @@ -2781,7 +2803,8 @@ void vm_state::run() { lean_trace(name({"vm", "run"}), tout() << m_decl_vector[m_fn_idx].get_name() << " @ " << m_pc << ": "; instr.display(tout().get_stream()); - tout() << "\n";); + tout() << "\n"; + display_stack(tout().get_stream());); }); switch (instr.op()) { case opcode::Push: @@ -3186,6 +3209,28 @@ void vm_state::run() { invoke_global(decl); goto main_loop; } + case opcode::InvokeJP: { + /* Join point call: jmp pc bp n + + stack before after + ... + bp=> ... a_1 + ... ... + a_1 a_n + ... + a_n + */ + unsigned jp_bp = instr.get_jp_bp(); + unsigned jp_arity = instr.get_jp_arity(); + unsigned i = m_bp + jp_bp; + unsigned j = m_stack.size() - jp_arity; + for (; j < m_stack.size(); j++, i++) { + std::swap(m_stack[i], m_stack[j]); + } + m_pc = instr.get_jp_pc(); + m_stack.resize(m_bp + jp_bp + jp_arity); + goto main_loop; + } case opcode::InvokeBuiltin: { check_interrupted(); check_heartbeat(); diff --git a/src/library/vm/vm.h b/src/library/vm/vm.h index e718a57196..d01a05965d 100644 --- a/src/library/vm/vm.h +++ b/src/library/vm/vm.h @@ -315,7 +315,7 @@ enum class opcode { Cases2, CasesN, Proj, Apply, InvokeGlobal, InvokeBuiltin, InvokeCFun, Closure, Unreachable, Expr, LocalInfo, - Reset, Reuse + Reset, Reuse, InvokeJP }; /** \brief VM instructions */ @@ -326,6 +326,11 @@ class vm_instr { unsigned m_fn_idx; /* InvokeGlobal, InvokeBuiltin, InvokeCFun and Closure. */ unsigned m_nargs; /* Closure */ }; + struct { /* InvokeJP */ + unsigned m_jp_pc; + unsigned m_jp_bp; + unsigned m_jp_arity; + }; /* Push, Move, Proj */ unsigned m_idx; /* Drop, Reset */ @@ -374,6 +379,7 @@ class vm_instr { friend vm_instr mk_invoke_global_instr(unsigned fn_idx); friend vm_instr mk_invoke_cfun_instr(unsigned fn_idx); friend vm_instr mk_invoke_builtin_instr(unsigned fn_idx); + friend vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity); friend vm_instr mk_closure_instr(unsigned fn_idx, unsigned n); friend vm_instr mk_expr_instr(expr const &e); friend vm_instr mk_local_info_instr(unsigned idx, name const & n, optional const & e); @@ -399,6 +405,10 @@ public: return m_fn_idx; } + unsigned get_jp_pc() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_pc; } + unsigned get_jp_bp() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_bp; } + unsigned get_jp_arity() const { lean_assert(m_op == opcode::InvokeJP); return m_jp_arity; } + unsigned get_nargs() const { lean_assert(m_op == opcode::Closure); return m_nargs; @@ -515,6 +525,7 @@ vm_instr mk_apply_instr(); vm_instr mk_invoke_global_instr(unsigned fn_idx); vm_instr mk_invoke_cfun_instr(unsigned fn_idx); vm_instr mk_invoke_builtin_instr(unsigned fn_idx); +vm_instr mk_invoke_jp_instr(unsigned pc, unsigned bp, unsigned arity); vm_instr mk_closure_instr(unsigned fn_idx, unsigned n); vm_instr mk_expr_instr(expr const &e); vm_instr mk_local_info_instr(unsigned idx, name const & n, optional const & e);