/* Copyright (c) 2016 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "kernel/type_checker.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/for_each_fn.h" #include "kernel/inductive/inductive.h" #include "library/locals.h" #include "library/util.h" #include "library/trace.h" #include "compiler/fresh_constant.h" #include "compiler/util.h" #include "compiler/comp_irrelevant.h" #include "compiler/compiler_step_visitor.h" namespace lean { class lambda_lifting_fn : public compiler_step_visitor { buffer & m_new_decls; bool m_trusted; protected: expr declare_aux_def(expr const & value) { pair ep = mk_fresh_constant(m_env); m_env = ep.first; name n = ep.second; m_new_decls.push_back(n); level_param_names ps = to_level_param_names(collect_univ_params(value)); /* We should use a new type checker because m_env is updated by this object. It is safe to use type_checker because value does not contain local_decl_ref objects. */ type_checker tc(m_env); expr type = tc.infer(value); bool use_conv_opt = false; declaration new_decl = mk_definition(m_env, n, ps, type, value, use_conv_opt, m_trusted); m_env = m_env.add(check(m_env, new_decl)); return mk_constant(n, param_names_to_levels(ps)); } typedef rb_map idx2decls; void collect_locals(expr const & e, idx2decls & r) { local_context const & lctx = m_ctx.lctx(); for_each(e, [&](expr const & e, unsigned) { if (is_local_decl_ref(e)) { local_decl d = *lctx.get_local_decl(e); r.insert(d.get_idx(), d); } return true; }); } expr visit_lambda_core(expr const & e) { type_context::tmp_locals locals(m_ctx); expr t = e; while (is_lambda(t)) { expr d = instantiate_rev(binding_domain(t), locals.size(), locals.data()); locals.push_local(binding_name(t), d, binding_info(t)); t = binding_body(t); } t = instantiate_rev(t, locals.size(), locals.data()); t = visit(t); return locals.mk_lambda(t); } virtual expr visit_lambda(expr const & e) override { expr new_e = visit_lambda_core(e); idx2decls map; collect_locals(new_e, map); if (map.empty()) { return declare_aux_def(new_e); } else { buffer args; while (!map.empty()) { /* remove local_decl with biggest idx */ local_decl d = map.erase_min(); expr l = d.mk_ref(); if (auto v = d.get_value()) { collect_locals(*v, map); new_e = instantiate(abstract_local(new_e, l), *v); } else { collect_locals(d.get_type(), map); if (is_comp_irrelevant(ctx(), l)) args.push_back(mark_comp_irrelevant(l)); else args.push_back(l); new_e = abstract_local(new_e, l); new_e = mk_lambda(d.get_name(), d.get_type(), new_e); } } expr c = declare_aux_def(new_e); return mk_rev_app(c, args); } } virtual expr visit_let(expr const & e) override { type_context::tmp_locals locals(m_ctx); expr t = e; while (is_let(t)) { expr d = instantiate_rev(let_type(t), locals.size(), locals.data()); expr v = visit(instantiate_rev(let_value(t), locals.size(), locals.data())); locals.push_let(let_name(t), d, v); t = let_body(t); } t = instantiate_rev(t, locals.size(), locals.data()); t = visit(t); return locals.mk_let(t); } expr visit_recursor_app(expr const & e) { // TODO(Leo): process recursor args return compiler_step_visitor::visit_app(e); } unsigned get_constructor_arity(name const & n) { declaration d = env().get(n); expr e = d.get_type(); unsigned r = 0; while (is_pi(e)) { r++; e = binding_body(e); } return r; } expr visit_cases_on_minor(unsigned data_sz, expr e) { type_context::tmp_locals locals(ctx()); for (unsigned i = 0; i < data_sz; i++) { e = ctx().whnf(e); lean_assert(is_lambda(e)); expr l = locals.push_local(binding_name(e), binding_domain(e), binding_info(e)); e = instantiate(binding_body(e), l); } e = visit(e); return locals.mk_lambda(e); } expr visit_cases_on_app(expr const & e) { buffer args; expr const & fn = get_app_args(e, args); lean_assert(is_constant(fn) && is_cases_on_recursor(env(), const_name(fn))); name const & cases_on_name = const_name(fn); name const & I_name = cases_on_name.get_prefix(); unsigned nparams = *inductive::get_num_params(env(), I_name); unsigned nminors = *inductive::get_num_minor_premises(env(), I_name); unsigned nindices = *inductive::get_num_indices(env(), I_name); unsigned cases_on_arity = nparams + 1 /* typeformer/motive */ + nindices + 1 /* major premise */ + nminors; unsigned major_idx = nparams + 1 + nindices; /* This transformation assumes eta-expansion have already been applied. So, we should have a sufficient number of arguments. */ lean_assert(args.size() >= cases_on_arity); buffer cnames; get_intro_rule_names(env(), I_name, cnames); /* Process major premise */ args[major_idx] = visit(args[major_idx]); /* Process extra arguments */ for (unsigned i = cases_on_arity; i < args.size(); i++) args[i] = visit(args[i]); /* Process minor premise */ for (unsigned i = 0, j = major_idx + 1; i < cnames.size(); i++, j++) { unsigned carity = get_constructor_arity(cnames[i]); lean_assert(carity >= nparams); unsigned cdata_sz = carity - nparams; args[j] = visit_cases_on_minor(cdata_sz, args[j]); } return mk_app(fn, args); } virtual expr visit_app(expr const & e) override { expr const & fn = get_app_fn(e); if (is_constant(fn)) { name const & n = const_name(fn); if (is_cases_on_recursor(env(), n)) { return visit_cases_on_app(e); } else if (inductive::is_elim_rule(env(), n)) { return visit_recursor_app(e); } } return compiler_step_visitor::visit_app(e); } public: lambda_lifting_fn(environment const & env, buffer & new_decls, bool trusted): compiler_step_visitor(env), m_new_decls(new_decls), m_trusted(trusted) {} environment const & env() const { return m_env; } expr operator()(expr const & e) { if (is_lambda(e)) return visit_lambda_core(e); else return visit(e); } }; expr lambda_lifting(environment & env, expr const & e, buffer & new_decls, bool trusted) { lambda_lifting_fn fn(env, new_decls, trusted); expr new_e = fn(e); env = fn.env(); return new_e; } }