/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include "kernel/find_fn.h" #include "kernel/instantiate.h" #include "library/trace.h" #include "library/locals.h" #include "library/replace_visitor.h" #include "library/equations_compiler/compiler.h" #include "library/equations_compiler/util.h" #include "library/equations_compiler/pack_domain.h" #include "library/equations_compiler/structural_rec.h" #include "library/equations_compiler/unbounded_rec.h" #include "library/equations_compiler/elim_match.h" namespace lean { #define trace_compiler(Code) lean_trace("eqn_compiler", scope_trace_env _scope1(ctx.env(), ctx); Code) static bool has_nested_rec(expr const & eqns) { return static_cast(find(eqns, [&](expr const & e, unsigned) { return is_local(e) && local_info(e).is_rec(); })); } static expr compile_equations_core(environment & env, options const & opts, metavar_context & mctx, local_context const & lctx, expr const & eqns) { type_context ctx(env, opts, mctx, lctx, transparency_mode::Semireducible); trace_compiler(tout() << "compiling\n" << eqns << "\n";); trace_compiler(tout() << "recursive: " << is_recursive_eqns(ctx, eqns) << "\n";); trace_compiler(tout() << "nested recursion: " << has_nested_rec(eqns) << "\n";); equations_header const & header = get_equations_header(eqns); lean_assert(header.m_is_meta || !has_nested_rec(eqns)); if (header.m_is_meta) { return mk_equations_result(unbounded_rec(env, opts, mctx, lctx, eqns)); } if (header.m_num_fns == 1) { if (!is_recursive_eqns(ctx, eqns)) { return mk_equations_result(mk_nonrec(env, opts, mctx, lctx, eqns)); } else if (optional r = try_structural_rec(env, opts, mctx, lctx, eqns)) { return mk_equations_result(*r); } } throw exception("support for well-founded recursion has not been implemented yet, " "use 'set_option trace.eqn_compiler true' for additional information"); // test pack_domain // pack_domain(ctx, eqns); } /** Auxiliary class for pulling nested recursive calls. Example: given definition f : nat → (nat × nat) → nat | 0 m := m.1 | (n+1) m := match m with | (a, b) := f n (b, a + 1) end when we compile match m with | (a, b) := f n (b, a + 1) end we consinder (f n (b, a + 1)) to be a nested recursive call. Then, we transform the expression to (fun g, match m with | (a, b) := g a b end) (fun a b, f n (b, a + 1)) Compile the match, and then beta-reduce. */ struct pull_nested_rec_fn : public replace_visitor { environment & m_env; options m_opts; metavar_context & m_mctx; buffer m_lctx_stack; buffer m_new_locals; buffer m_new_values; pull_nested_rec_fn(environment & env, options const & opts, metavar_context & mctx, local_context const & lctx): m_env(env), m_opts(opts), m_mctx(mctx) { m_lctx_stack.push_back(lctx); } local_context & base_lctx() { return m_lctx_stack[0]; } local_context & lctx() { return m_lctx_stack.back(); } type_context mk_type_context(local_context const & lctx) { return type_context(m_env, m_opts, m_mctx, lctx, transparency_mode::Semireducible); } expr visit_lambda_pi_let(bool is_lam, expr const & e) { buffer locals; m_lctx_stack.push_back(m_lctx_stack.back()); expr t = e; while (true) { if ((is_lam && is_lambda(t)) || (!is_lam && is_pi(t))) { expr d = instantiate_rev(binding_domain(t), locals.size(), locals.data()); d = visit(d); expr x = lctx().mk_local_decl(binding_name(t), d, binding_info(t)); locals.push_back(x); t = binding_body(t); } else if (is_let(t)) { expr d = instantiate_rev(let_type(t), locals.size(), locals.data()); expr v = instantiate_rev(let_value(t), locals.size(), locals.data()); d = visit(d); v = visit(v); expr x = lctx().mk_local_decl(let_name(t), d, v); locals.push_back(x); t = let_body(t); } else { break; } } t = instantiate_rev(t, locals.size(), locals.data()); t = visit(t); type_context ctx = mk_type_context(lctx()); t = is_lam ? ctx.mk_lambda(locals, t) : ctx.mk_pi(locals, t); m_mctx = ctx.mctx(); m_lctx_stack.pop_back(); return t; } virtual expr visit_lambda(expr const & e) override { return visit_lambda_pi_let(true, e); } virtual expr visit_let(expr const & e) override { return visit_lambda_pi_let(true, e); } virtual expr visit_pi(expr const & e) override { return visit_lambda_pi_let(false, e); } expr default_visit_app(expr const & e, expr const & fn, buffer & args) { expr new_fn = visit(fn); bool modified = !is_eqp(fn, new_fn); for (expr & arg : args) { expr new_arg = visit(arg); if (!is_eqp(new_arg, arg)) modified = true; arg = new_arg; } if (!modified) return e; else return mk_app(new_fn, args); } void collect_locals_core(expr const & e, name_set & found, buffer & R) { for_each(e, [&](expr const & e, unsigned) { if (is_local(e) && !base_lctx().get_local_decl(e) && !found.contains(mlocal_name(e))) { found.insert(mlocal_name(e)); R.push_back(e); } return true; }); } void collect_locals(expr const & e, buffer & R) { name_set found; collect_locals_core(e, found, R); for (unsigned i = 0; i < R.size(); i++) { expr const & x = R[i]; collect_locals_core(lctx().get_local_decl(x)->get_type(), found, R); } std::sort(R.begin(), R.end(), [&](expr const & x1, expr const & x2) { return lctx().get_local_decl(x1)->get_idx() < lctx().get_local_decl(x2)->get_idx(); }); } expr declare_new_local(name const & uid, name const & n, expr const & type) { lean_assert(m_lctx_stack.size() > 1); expr r; for (unsigned i = 0; i < m_lctx_stack.size(); i++) { local_context & lctx = m_lctx_stack[i]; if (i == 0) { r = lctx.mk_local_decl(uid, n, type); } else { DEBUG_CODE(expr r2 =) lctx.mk_local_decl(uid, n, type); lean_assert(r == r2); } } return r; } virtual expr visit_app(expr const & e) override { buffer args; expr const & fn = get_app_args(e, args); if (is_local(fn) && local_info(fn).is_rec() && base_lctx().get_local_decl(fn)) { buffer local_deps; collect_locals(e, local_deps); type_context ctx = mk_type_context(lctx()); expr val = ctx.mk_lambda(local_deps, e); expr val_type = ctx.infer(val); name fn_aux = name("_f").append_after(m_new_locals.size() + 1); name uid = mk_local_decl_name(); expr g = declare_new_local(uid, fn_aux, val_type); m_new_locals.push_back(g); m_new_values.push_back(val); return mk_app(g, local_deps); } else { return default_visit_app(e, fn, args); } } expr operator()(expr const & e) { expr new_e = visit(e); lean_assert(m_lctx_stack.size() == 1); local_context new_lctx = m_lctx_stack[0]; expr r = compile_equations_core(m_env, m_opts, m_mctx, new_lctx, new_e); type_context ctx = mk_type_context(new_lctx); expr new_r = replace_locals(r, m_new_locals, m_new_values); m_mctx = ctx.mctx(); return new_r; } }; expr compile_equations(environment & env, options const & opts, metavar_context & mctx, local_context const & lctx, expr const & eqns) { if (!get_equations_header(eqns).m_is_meta && has_nested_rec(eqns)) { return pull_nested_rec_fn(env, opts, mctx, lctx)(eqns); } else { return compile_equations_core(env, opts, mctx, lctx, eqns); } } void initialize_compiler() { } void finalize_compiler() { } }