diff --git a/src/library/compiler/emit_cpp.cpp b/src/library/compiler/emit_cpp.cpp index 8d311a496e..a3bf77fb8b 100644 --- a/src/library/compiler/emit_cpp.cpp +++ b/src/library/compiler/emit_cpp.cpp @@ -680,7 +680,7 @@ struct emit_fn_fn { m_out << "goto "; emit_lbl(jp); m_out << ";\n"; } - optional is_tail_call(expr const & val) { + optional is_self_call(expr const & val) { expr fn = get_app_fn(val); if (is_constant(fn) && const_name(fn) == m_fn_name) return some_expr(val); @@ -688,13 +688,6 @@ struct emit_fn_fn { return none_expr(); } - optional is_tail_call_terminal(expr const & e) { - if (!is_fvar(e)) return none_expr(); - optional val = m_lctx.get_local_decl(e).get_value(); - if (!val) return none_expr(); - return is_tail_call(*val); - } - void emit_tail_call(expr const & e) { buffer args; expr fn = get_app_args(e, args); @@ -708,13 +701,13 @@ struct emit_fn_fn { m_out << "goto _start;\n"; } - void emit_terminal(expr const & e) { + void emit_terminal(expr const & e, bool tail_call) { if (is_cases_on_app(m_env, e)) { emit_cases(e); } else if (is_jmp(e)) { emit_jmp(e); - } else if (optional c = is_tail_call_terminal(e)) { - emit_tail_call(*c); + } else if (tail_call) { + emit_tail_call(*m_lctx.get_local_decl(e).get_value()); } else if (is_fvar(e)) { m_out << "return "; emit_fvar(e); m_out << ";\n"; } else { @@ -728,6 +721,7 @@ struct emit_fn_fn { buffer locals; buffer instrs; bool declared_vars = false; + bool tail_call = false; while (is_let(e)) { expr v = instantiate_rev(let_value(e), locals.size(), locals.data()); if (is_join_point_name(let_name(e))) { @@ -748,8 +742,9 @@ struct emit_fn_fn { } else { expr x = m_lctx.mk_local_decl(m_ngen, let_name(e), let_type(e), v); locals.push_back(x); - if (is_bvar(let_body(e), 0) && is_tail_call(v)) { + if (is_bvar(let_body(e), 0) && is_self_call(v)) { /* Ignore tail call, we will emit it at emit_terminal as a `goto`. */ + tail_call = true; } else { if (!is_llnf_void_type(let_type(e))) { /* Declare local variable. @@ -775,7 +770,7 @@ struct emit_fn_fn { emit_instr(d); } } - emit_terminal(e); + emit_terminal(e, tail_call); for (expr const & jp : jps) { emit_lbl(jp); m_out << ":\n"; emit(*m_lctx.get_local_decl(jp).get_value());