/* Copyright (c) 2018 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include "runtime/flet.h" #include "kernel/instantiate.h" #include "kernel/for_each_fn.h" #include "kernel/abstract.h" #include "library/class.h" #include "library/trace.h" #include "library/module.h" #include "library/compiler/util.h" #include "library/compiler/csimp.h" namespace lean { uint8 has_specialize_attribute_core(object* env, object* n); uint8 has_nospecialize_attribute_core(object* env, object* n); bool has_specialize_attribute(environment const & env, name const & n) { return has_specialize_attribute_core(env.to_obj_arg(), n.to_obj_arg()); } bool has_nospecialize_attribute(environment const & env, name const & n) { return has_nospecialize_attribute_core(env.to_obj_arg(), n.to_obj_arg()); } /* IMPORTANT: We currently do NOT specialize Fixed arguments. Only FixedNeutral, FixedHO and FixedInst. We do not have good heuristics to decide when it is a good idea to do it. TODO(Leo): allow users to specify that they want to consider some Fixed arguments for specialization. */ enum class spec_arg_kind { Fixed, FixedNeutral, /* computationally neutral */ FixedHO, /* higher order */ FixedInst, /* type class instance */ Other }; static spec_arg_kind to_spec_arg_kind(object_ref const & r) { lean_assert(is_scalar(r.raw())); return static_cast(unbox(r.raw())); } typedef objects spec_arg_kinds; static spec_arg_kinds to_spec_arg_kinds(buffer const & ks) { spec_arg_kinds r; unsigned i = ks.size(); while (i > 0) { --i; r = spec_arg_kinds(object_ref(box(static_cast(ks[i]))), r); } return r; } static void to_buffer(spec_arg_kinds const & ks, buffer & r) { for (object_ref const & k : ks) { r.push_back(to_spec_arg_kind(k)); } } static bool has_fixed_inst_arg(buffer const & ks) { for (spec_arg_kind k : ks) { if (k == spec_arg_kind::FixedInst) return true; } return false; } /* Return true if `ks` contains kind != Other */ static bool has_kind_ne_other(buffer const & ks) { for (spec_arg_kind k : ks) { if (k != spec_arg_kind::Other) return true; } return false; } char const * to_str(spec_arg_kind k) { switch (k) { case spec_arg_kind::Fixed: return "F"; case spec_arg_kind::FixedNeutral: return "N"; case spec_arg_kind::FixedHO: return "H"; case spec_arg_kind::FixedInst: return "I"; case spec_arg_kind::Other: return "X"; } lean_unreachable(); } class spec_info : public object_ref { explicit spec_info(b_obj_arg o, bool b):object_ref(o, b) {} public: spec_info(names const & ns, spec_arg_kinds ks): object_ref(mk_cnstr(0, ns, ks)) {} spec_info():spec_info(names(), spec_arg_kinds()) {} spec_info(spec_info const & other):object_ref(other) {} spec_info(spec_info && other):object_ref(other) {} spec_info & operator=(spec_info const & other) { object_ref::operator=(other); return *this; } spec_info & operator=(spec_info && other) { object_ref::operator=(other); return *this; } names const & get_mutual_decls() const { return static_cast(cnstr_get_ref(*this, 0)); } spec_arg_kinds const & get_arg_kinds() const { return static_cast(cnstr_get_ref(*this, 1)); } void serialize(serializer & s) const { s.write_object(raw()); } static spec_info deserialize(deserializer & d) { return spec_info(d.read_object(), true); } }; serializer & operator<<(serializer & s, spec_info const & si) { si.serialize(s); return s; } deserializer & operator>>(deserializer & d, spec_info & si) { si = spec_info::deserialize(d); return d; } /* Information for executing code specialization. TODO(Leo): use the to be implemented new module system. */ struct specialize_ext : public environment_extension { typedef rb_expr_map cache; name_map m_spec_info; cache m_cache; }; struct specialize_ext_reg { unsigned m_ext_id; specialize_ext_reg() { m_ext_id = environment::register_extension(new specialize_ext()); } }; static specialize_ext_reg * g_ext = nullptr; static specialize_ext const & get_extension(environment const & env) { return static_cast(env.get_extension(g_ext->m_ext_id)); } static environment update(environment const & env, specialize_ext const & ext) { return env.update(g_ext->m_ext_id, new specialize_ext(ext)); } /* Support for old module manager. Remark: this code will be deleted in the future */ struct spec_info_modification : public modification { LEAN_MODIFICATION("speci") name m_name; spec_info m_spec_info; spec_info_modification(name const & n, spec_info const & s) : m_name(n), m_spec_info(s) {} void perform(environment & env) const override { specialize_ext ext = get_extension(env); ext.m_spec_info.insert(m_name, m_spec_info); env = update(env, ext); } void serialize(serializer & s) const override { s << m_name << m_spec_info; } static modification* deserialize(deserializer & d) { name n; spec_info s; d >> n >> s; return new spec_info_modification(n, s); } }; typedef buffer>> spec_info_buffer; /* We only specialize arguments that are "fixed" in mutual recursive declarations. The buffer `info_buffer` stores which arguments are fixed for each declaration in a mutual recursive declaration. This procedure traverses `e` and updates `info_buffer`. Remark: we only create free variables for the header of each declaration. Then, we assume an argument of a recursive call is fixed iff it is a free variable (see `update_spec_info`). */ static void update_info_buffer(environment const & env, expr e, name_set const & S, spec_info_buffer & info_buffer) { while (true) { switch (e.kind()) { case expr_kind::Lambda: e = binding_body(e); break; case expr_kind::Let: update_info_buffer(env, let_value(e), S, info_buffer); e = let_body(e); break; case expr_kind::App: if (is_cases_on_app(env, e)) { buffer args; expr const & c_fn = get_app_args(e, args); unsigned minors_begin; unsigned minors_end; std::tie(minors_begin, minors_end) = get_cases_on_minors_range(env, const_name(c_fn)); for (unsigned i = minors_begin; i < minors_end; i++) { update_info_buffer(env, args[i], S, info_buffer); } } else { buffer args; expr const & fn = get_app_args(e, args); if (is_constant(fn) && S.contains(const_name(fn))) { for (auto & entry : info_buffer) { if (entry.first == const_name(fn)) { unsigned sz = entry.second.size(); for (unsigned i = 0; i < sz; i++) { if (i >= args.size() || !is_fvar(args[i])) { entry.second[i] = spec_arg_kind::Other; } } break; } } } } return; default: return; } } } environment update_spec_info(environment const & env, comp_decls const & ds) { name_set S; spec_info_buffer d_infos; name_generator ngen; /* Initialzie d_infos and S */ for (comp_decl const & d : ds) { S.insert(d.fst()); d_infos.push_back(pair>()); auto & info = d_infos.back(); info.first = d.fst(); expr code = d.snd(); buffer fvars; local_ctx lctx; while (is_lambda(code)) { expr type = instantiate_rev(binding_domain(code), fvars.size(), fvars.data()); expr fvar = lctx.mk_local_decl(ngen, binding_name(code), type); fvars.push_back(fvar); if (is_inst_implicit(binding_info(code))) { info.second.push_back(spec_arg_kind::FixedInst); } else { type_checker tc(env, lctx); type = tc.whnf(type); if (is_sort(type) || tc.is_prop(type)) { info.second.push_back(spec_arg_kind::FixedNeutral); } else if (is_pi(type)) { while (is_pi(type)) { expr fvar = lctx.mk_local_decl(ngen, binding_name(type), binding_domain(type)); type = type_checker(env, lctx).whnf(instantiate(binding_body(type), fvar)); } if (is_sort(type)) { /* Functions that return types are not relevant */ info.second.push_back(spec_arg_kind::FixedNeutral); } else { info.second.push_back(spec_arg_kind::FixedHO); } } else { info.second.push_back(spec_arg_kind::Fixed); } } code = binding_body(code); } } /* Update d_infos */ name x("_x"); for (comp_decl const & d : ds) { buffer fvars; expr code = d.snd(); unsigned i = 1; /* Create free variables for header variables. */ while (is_lambda(code)) { fvars.push_back(mk_fvar(name(x, i))); code = binding_body(code); } code = instantiate_rev(code, fvars.size(), fvars.data()); update_info_buffer(env, code, S, d_infos); } /* Update extension */ environment new_env = env; specialize_ext ext = get_extension(env); names mutual_decls = map2(ds, [&](comp_decl const & d) { return d.fst(); }); for (pair> const & info : d_infos) { name const & n = info.first; spec_info si(mutual_decls, to_spec_arg_kinds(info.second)); lean_trace(name({"compiler", "spec_info"}), tout() << n; for (spec_arg_kind k : info.second) { tout() << " " << to_str(k); } tout() << "\n";); new_env = module::add(new_env, new spec_info_modification(n, si)); ext.m_spec_info.insert(n, si); } return update(new_env, ext); } /* Support for old module manager. Remark: this code will be deleted in the future */ struct spec_cache_modification : public modification { LEAN_MODIFICATION("specc") expr m_key; name m_fn_name; spec_cache_modification(expr const & k, name const & fn) : m_key(k), m_fn_name(fn) {} void perform(environment & env) const override { specialize_ext ext = get_extension(env); ext.m_cache.insert(m_key, m_fn_name); env = update(env, ext); } void serialize(serializer & s) const override { s << m_key << m_fn_name; } static modification* deserialize(deserializer & d) { expr k; name f; d >> k >> f; return new spec_cache_modification(k, f); } }; class specialize_fn { type_checker::state m_st; csimp_cfg m_cfg; specialize_ext m_ext; local_ctx m_lctx; buffer m_new_decls; name m_base_name; name m_at; name m_spec; unsigned m_next_idx{1}; environment const & env() { return m_st.env(); } name_generator & ngen() { return m_st.ngen(); } expr visit_lambda(expr e) { flet save_lctx(m_lctx, m_lctx); buffer fvars; while (is_lambda(e)) { expr new_type = instantiate_rev(binding_domain(e), fvars.size(), fvars.data()); expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), new_type); fvars.push_back(new_fvar); e = binding_body(e); } expr r = visit(instantiate_rev(e, fvars.size(), fvars.data())); return m_lctx.mk_lambda(fvars, r); } expr visit_let(expr e) { flet save_lctx(m_lctx, m_lctx); buffer fvars; while (is_let(e)) { expr new_type = instantiate_rev(let_type(e), fvars.size(), fvars.data()); expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data())); expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), new_type, new_val); fvars.push_back(new_fvar); e = let_body(e); } expr r = visit(instantiate_rev(e, fvars.size(), fvars.data())); return m_lctx.mk_lambda(fvars, r); } expr visit_cases_on(expr const & e) { lean_assert(is_cases_on_app(env(), e)); buffer args; expr const & c = get_app_args(e, args); /* visit minor premises */ unsigned minor_idx; unsigned minors_end; std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c)); for (; minor_idx < minors_end; minor_idx++) { args[minor_idx] = visit(args[minor_idx]); } return mk_app(c, args); } expr find(expr const & e) { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { if (optional v = decl->get_value()) { return find(*v); } } } else if (is_mdata(e)) { return find(mdata_expr(e)); } return e; } struct spec_ctx { typedef rb_expr_map cache; names m_mutual; /* `m_params` contains all variables that must be lambda abstracted in the specialization. It may contain let-variables that occurs inside of binders. Reason: avoid work duplication. Example: suppose we are trying to specialize the following map-application. ``` def f2 (n : nat) (xs : list nat) : list (list nat) := let ys := list.repeat 0 n in xs.map (λ x, x :: ys) ``` We don't want to copy `list.repeat 0 n` inside of the specialized code. */ buffer m_params; /* `m_vars` contains `m_params` plus all let-declarations. Remark: we used to keep m_params and let-declarations in separate buffers. This produced incorrect results when the type of a variable in `m_params` depended on a let-declaration. */ buffer m_vars; cache m_cache; buffer m_pre_decls; bool in_mutual_decl(name const & n) const { return std::find(m_mutual.begin(), m_mutual.end(), n) != m_mutual.end(); } }; void get_arg_kinds(name const & fn, buffer & kinds) { spec_info const * info = m_ext.m_spec_info.find(fn); lean_assert(info); to_buffer(info->get_arg_kinds(), kinds); } static void to_bool_mask(buffer const & kinds, bool has_attr, buffer & mask) { unsigned sz = kinds.size(); mask.resize(sz, false); unsigned i = sz; bool found_inst = false; bool first = true; while (i > 0) { --i; switch (kinds[i]) { case spec_arg_kind::Other: break; case spec_arg_kind::FixedInst: mask[i] = true; if (first) mask.shrink(i+1); first = false; found_inst = true; break; case spec_arg_kind::Fixed: // REMARK: We have disabled specialization for this kind of argument. break; case spec_arg_kind::FixedHO: case spec_arg_kind::FixedNeutral: if (has_attr || found_inst) { mask[i] = true; if (first) mask.shrink(i+1); first = false; } break; } } } void get_bool_mask(name const & fn, unsigned args_size, buffer & mask) { buffer kinds; get_arg_kinds(fn, kinds); if (kinds.size() > args_size) kinds.shrink(args_size); to_bool_mask(kinds, has_specialize_attribute(env(), fn), mask); } name mk_spec_name(name const & fn) { name r = fn + m_at + m_base_name + (m_spec.append_after(m_next_idx)); m_next_idx++; return r; } static expr mk_cache_key(expr const & fn, buffer> const & mask) { expr r = fn; for (optional const & b : mask) { if (b) r = mk_app(r, *b); else r = mk_app(r, expr()); } return r; } bool is_specialize_candidate(expr const & fn, buffer const & args) { lean_assert(is_constant(fn)); buffer kinds; get_arg_kinds(const_name(fn), kinds); if (!has_specialize_attribute(env(), const_name(fn)) && !has_fixed_inst_arg(kinds)) return false; /* Nothing to specialize */ if (!has_kind_ne_other(kinds)) return false; /* Nothing to specialize */ type_checker tc(m_st, m_lctx); for (unsigned i = 0; i < args.size(); i++) { if (i >= kinds.size()) break; spec_arg_kind k = kinds[i]; expr w; switch (k) { case spec_arg_kind::FixedNeutral: break; case spec_arg_kind::FixedInst: /* We specialize this kind of argument if it reduces to a constructor application or lambda. Type class instances arguments are usually free variables bound to lambda declarations, or quickly reduce to constructor application or lambda. So, the following `whnf` is probably harmless. We need to consider the lambda case because of arguments such as `[decidable_rel lt]` */ w = tc.whnf(args[i]); if (is_constructor_app(env(), w) || is_lambda(w)) return true; break; case spec_arg_kind::FixedHO: /* We specialize higher-order arguments if they are lambda applications or a constant application. Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ w = find(args[i]); if (is_lambda(w) || is_constant(get_app_fn(w))) return true; break; case spec_arg_kind::Fixed: /* We specialize this kind of argument if they are constructor applications or literals. Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ break; // We have disabled this kind of argument w = find(args[i]); if (is_constructor_app(env(), w) || is_lit(w)) return true; break; case spec_arg_kind::Other: break; } } return false; } /* Auxiliary class for collecting specialization dependencies. */ class dep_collector { type_checker::state & m_st; local_ctx m_lctx; name_set m_visited_not_in_binder; name_set m_visited_in_binder; spec_ctx & m_ctx; void collect_fvar(expr const & x, bool in_binder) { name const & x_name = fvar_name(x); if (!in_binder) { if (m_visited_not_in_binder.contains(x_name)) return; m_visited_not_in_binder.insert(x_name); local_decl decl = m_lctx.get_local_decl(x); optional v = decl.get_value(); if (m_visited_in_binder.contains(x_name)) { /* If `x` was already visited in context inside of a binder, then it is already in `m_ctx.m_vars` and `m_ctx.m_params`. */ } else { /* Recall that `m_ctx.m_vars` contains all variables (lambda and let) the specialization depends on, and `m_ctx.m_params` contains the ones that should be lambda abstracted. */ m_ctx.m_vars.push_back(x); /* Thus, a variable occuring outside of a binder is only lambda abstracted if it is not a let-variable. */ if (!v) m_ctx.m_params.push_back(x); } collect(decl.get_type(), false); if (v) collect(*v, false); } else { if (m_visited_in_binder.contains(x_name)) return; m_visited_in_binder.insert(x_name); local_decl decl = m_lctx.get_local_decl(x); optional v = decl.get_value(); /* Remark: we must not lambda abstract join points. There is no risk of work duplication in this case, only code duplication. */ bool is_jp = is_join_point_name(decl.get_user_name()); lean_assert(!v || !is_irrelevant_type(m_st, m_lctx, decl.get_type())); if (m_visited_not_in_binder.contains(x_name)) { /* If `x` was already visited in a context outside of a binder, then it is already in `m_ctx.m_vars`. If `x` is not a let-variable, then it is also already in `m_ctx.m_params`. */ if (v && !is_jp) { m_ctx.m_params.push_back(x); v = none_expr(); /* make sure we don't collect v's dependencies */ } } else { /* Recall that if `x` occurs inside of a binder, then it will always be lambda abstracted. Reason: avoid work duplication. Example: suppose we are trying to specialize the following map-application. ``` def f2 (n : nat) (xs : list nat) : list (list nat) := let ys := list.repeat 0 n in xs.map (λ x, x :: ys) ``` We don't want to copy `list.repeat 0 n` inside of the specialized code. See comment above about join points. Remark: if `x` is not a let-var, then we must insert it into m_ctx.m_params. */ m_ctx.m_vars.push_back(x); if (!v || (v && !is_jp)) { m_ctx.m_params.push_back(x); v = none_expr(); /* make sure we don't collect v's dependencies */ } } collect(decl.get_type(), true); if (v) collect(*v, true); } } void collect(expr e, bool in_binder) { while (true) { if (!has_fvar(e)) return; switch (e.kind()) { case expr_kind::Lit: case expr_kind::BVar: case expr_kind::Sort: case expr_kind::Const: return; case expr_kind::MVar: lean_unreachable(); case expr_kind::FVar: collect_fvar(e, in_binder); return; case expr_kind::App: collect(app_arg(e), in_binder); e = app_fn(e); break; case expr_kind::Lambda: case expr_kind::Pi: collect(binding_domain(e), in_binder); if (!in_binder) { collect(binding_body(e), true); return; } else { e = binding_body(e); break; } case expr_kind::Let: collect(let_type(e), in_binder); collect(let_value(e), in_binder); e = let_body(e); break; case expr_kind::MData: e = mdata_expr(e); break; case expr_kind::Proj: e = proj_expr(e); break; } } } public: dep_collector(type_checker::state & st, local_ctx const & lctx, spec_ctx & ctx): m_st(st), m_lctx(lctx), m_ctx(ctx) {} void operator()(expr const & e) { return collect(e, false); } }; void sort_fvars(buffer & fvars) { ::lean::sort_fvars(m_lctx, fvars); } /* Initialize `spec_ctx` fields: `m_vars`. */ void specialize_init_deps(expr const & fn, buffer const & args, spec_ctx & ctx) { lean_assert(is_constant(fn)); buffer kinds; get_arg_kinds(const_name(fn), kinds); bool has_attr = has_specialize_attribute(env(), const_name(fn)); dep_collector collect(m_st, m_lctx, ctx); unsigned sz = std::min(kinds.size(), args.size()); unsigned i = sz; bool found_inst = false; while (i > 0) { --i; if (is_fvar(args[i])) { lean_trace(name({"compiler", "spec_candidate"}), local_decl d = m_lctx.get_local_decl(args[i]); tout() << "specialize_init_deps [" << i << "]: " << args[i] << " : " << d.get_type(); if (auto v = d.get_value()) tout() << " := " << *v; tout() << "\n";); } switch (kinds[i]) { case spec_arg_kind::Other: break; case spec_arg_kind::FixedInst: collect(args[i]); found_inst = true; break; case spec_arg_kind::Fixed: break; // We have disabled this kind of argument case spec_arg_kind::FixedHO: case spec_arg_kind::FixedNeutral: if (has_attr || found_inst) { collect(args[i]); } break; } } sort_fvars(ctx.m_vars); sort_fvars(ctx.m_params); lean_trace(name({"compiler", "spec_candidate"}), tout() << "candidate: " << mk_app(fn, args) << "\nclosure:"; for (expr const & p : ctx.m_vars) tout() << " " << p; tout() << "\nparams:"; for (expr const & p : ctx.m_params) tout() << " " << p; tout() << "\n";); } static bool contains(buffer> const & mask, expr const & e) { for (optional const & o : mask) { if (o && *o == e) return true; } return false; } optional adjust_rec_apps(expr e, buffer> const & mask, spec_ctx & ctx) { switch (e.kind()) { case expr_kind::App: if (is_cases_on_app(env(), e)) { buffer args; expr const & c = get_app_args(e, args); /* visit minor premises */ unsigned minor_idx; unsigned minors_end; std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c)); for (; minor_idx < minors_end; minor_idx++) { optional new_arg = adjust_rec_apps(args[minor_idx], mask, ctx); if (!new_arg) return none_expr(); args[minor_idx] = *new_arg; } return some_expr(mk_app(c, args)); } else { expr const & fn = get_app_fn(e); if (!is_constant(fn) || !ctx.in_mutual_decl(const_name(fn))) return some_expr(e); buffer args; get_app_args(e, args); buffer bmask; get_bool_mask(const_name(fn), args.size(), bmask); lean_assert(bmask.size() <= args.size()); buffer> new_mask; bool found = false; for (unsigned i = 0; i < bmask.size(); i++) { if (bmask[i] && contains(mask, args[i])) { found = true; new_mask.push_back(some_expr(args[i])); } else { new_mask.push_back(none_expr()); } } if (!found) return some_expr(e); optional new_fn_name = spec_preprocess(fn, new_mask, ctx); if (!new_fn_name) return none_expr(); expr r = mk_constant(*new_fn_name); r = mk_app(r, ctx.m_params); for (unsigned i = 0; i < bmask.size(); i++) { if (!bmask[i] || !contains(mask, args[i])) r = mk_app(r, args[i]); } for (unsigned i = bmask.size(); i < args.size(); i++) { r = mk_app(r, args[i]); } return some_expr(r); } case expr_kind::Lambda: { buffer entries; while (is_lambda(e)) { entries.push_back(e); e = binding_body(e); } optional new_e = adjust_rec_apps(e, mask, ctx); if (!new_e) return none_expr(); expr r = *new_e; unsigned i = entries.size(); while (i > 0) { --i; expr l = entries[i]; r = mk_lambda(binding_name(l), binding_domain(l), r); } return some_expr(r); } case expr_kind::Let: { buffer> entries; while (is_let(e)) { optional v = adjust_rec_apps(let_value(e), mask, ctx); if (!v) return none_expr(); expr new_val = *v; entries.emplace_back(e, new_val); e = let_body(e); } optional new_e = adjust_rec_apps(e, mask, ctx); if (!new_e) return none_expr(); expr r = *new_e; unsigned i = entries.size(); while (i > 0) { --i; expr l = entries[i].first; expr v = entries[i].second; r = mk_let(let_name(l), let_type(l), v, r); } return some_expr(r); } default: return some_expr(e); } } optional spec_preprocess(expr const & fn, buffer> const & mask, spec_ctx & ctx) { lean_assert(is_constant(fn)); lean_assert(ctx.in_mutual_decl(const_name(fn))); expr key = mk_cache_key(fn, mask); if (name const * r = ctx.m_cache.find(key)) { return optional(*r); } optional info = env().find(mk_cstage1_name(const_name(fn))); if (!info || !info->is_definition()) return optional(); // failed name new_name = mk_spec_name(const_name(fn)); ctx.m_cache.insert(key, new_name); expr new_code = instantiate_value_lparams(*info, const_levels(fn)); flet save_lctx(m_lctx, m_lctx); buffer fvars; buffer new_fvars; for (optional const & b : mask) { lean_assert(is_lambda(new_code)); if (b) { lean_assert(is_fvar(*b)); fvars.push_back(*b); } else { expr type = instantiate_rev(binding_domain(new_code), fvars.size(), fvars.data()); expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(new_code), type, binding_info(new_code)); new_fvars.push_back(new_fvar); fvars.push_back(new_fvar); } new_code = binding_body(new_code); } new_code = instantiate_rev(new_code, fvars.size(), fvars.data()); optional c = adjust_rec_apps(new_code, mask, ctx); if (!c) return optional(); new_code = *c; new_code = m_lctx.mk_lambda(new_fvars, new_code); ctx.m_pre_decls.push_back(comp_decl(new_name, new_code)); // lean_trace(name({"compiler", "spec_info"}), tout() << "new specialization " << new_name << " :=\n" << new_code << "\n";); return optional(new_name); } expr eta_expand_specialization(expr e) { /* Remark: we do not use `type_checker.eta_expand` because it does not preserve LCNF */ try { buffer args; type_checker tc(m_st); expr e_type = tc.whnf(tc.infer(e)); local_ctx lctx; while (is_pi(e_type)) { expr arg = lctx.mk_local_decl(ngen(), binding_name(e_type), binding_domain(e_type), binding_info(e_type)); args.push_back(arg); e_type = type_checker(m_st, lctx).whnf(instantiate(binding_body(e_type), arg)); } if (args.empty()) return e; buffer fvars; while (is_let(e)) { expr type = instantiate_rev(let_type(e), fvars.size(), fvars.data()); expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); expr fvar = lctx.mk_local_decl(ngen(), let_name(e), type, val); fvars.push_back(fvar); e = let_body(e); } e = instantiate_rev(e, fvars.size(), fvars.data()); if (!is_lcnf_atom(e)) { e = lctx.mk_local_decl(ngen(), "_e", type_checker(m_st, lctx).infer(e), e); fvars.push_back(e); } e = mk_app(e, args); return lctx.mk_lambda(args, lctx.mk_lambda(fvars, e)); } catch (exception &) { /* This can happen since previous compilation steps may have produced type incorrect terms. */ return e; } } expr abstract_spec_ctx(spec_ctx const & ctx, expr const & code) { /* Important: we cannot use ``` m_lctx.mk_lambda(ctx.m_vars, code) ``` because we may want to lambda abstract let-variables in `ctx.m_vars` to avoid code duplication. See comment at `spec_ctx` declaration. Remark: lambda-abstracting let-decls may introduce type errors when using dependent types. This is yet another place where typeability may be lost. */ name_set letvars_in_params; for (expr const & x : ctx.m_params) { if (m_lctx.get_local_decl(x).get_value()) letvars_in_params.insert(fvar_name(x)); } unsigned n = ctx.m_vars.size(); expr const * fvars = ctx.m_vars.data(); expr r = abstract(code, n, fvars); unsigned i = n; while (i > 0) { --i; local_decl const & decl = m_lctx.get_local_decl(fvar_name(fvars[i])); expr type = abstract(decl.get_type(), i, fvars); optional val = decl.get_value(); if (val && !letvars_in_params.contains(fvar_name(fvars[i]))) { r = ::lean::mk_let(decl.get_user_name(), type, abstract(*val, i, fvars), r); } else { r = ::lean::mk_lambda(decl.get_user_name(), type, r, decl.get_info()); } } return r; } void mk_new_decl(comp_decl const & pre_decl, buffer const & fvars, buffer const & fvar_vals, spec_ctx & ctx) { lean_assert(fvars.size() == fvar_vals.size()); name n = pre_decl.fst(); expr code = pre_decl.snd(); flet save_lctx(m_lctx, m_lctx); /* Add fvars decls */ type_checker tc(m_st, m_lctx); buffer new_let_decls; name y("_y"); for (unsigned i = 0; i < fvars.size(); i++) { expr type = tc.infer(fvar_vals[i]); if (is_irrelevant_type(m_st, m_lctx, type)) { /* In LCNF, the type `ty` at `let x : ty := v in t` must not be irrelevant. */ code = replace_fvar(code, fvars[i], fvar_vals[i]); } else { expr new_fvar = m_lctx.mk_local_decl(fvar_name(fvars[i]), y.append_after(i+1), type, fvar_vals[i]).mk_ref(); new_let_decls.push_back(new_fvar); } } code = m_lctx.mk_lambda(new_let_decls, code); // lean_trace(name("compiler", "spec_info"), tout() << "STEP 1 " << n << "\n" << code << "\n";); code = abstract_spec_ctx(ctx, code); lean_assert(!has_fvar(code)); /* We add the auxiliary declaration `n` as a "meta" axiom to the environment. This is a hack to make sure we can use `csimp` to simplify `code` and other definitions that use `n`. `csimp` uses the kernel type checker to infer types, and it will fail to infer the type of `n`-applications if we do not have an entry in the environment. Remark: we mark the axiom as `meta` to make sure it does not polute the environment regular definitions. We also considered the following cleaner solution: modify `csimp` to use a custom type checker that takes the types of auxiliary declarations such as `n` into account. A custom type checker would be extra work, but it has other benefits. For example, it could have better support for type errors introduced by `csimp`. */ { expr type = cheap_beta_reduce(type_checker(m_st).infer(code)); declaration aux_ax = mk_axiom(n, names(), type, true /* meta */); m_st.env() = env().add(aux_ax, false); } code = eta_expand_specialization(code); // lean_trace(name("compiler", "spec_info"), tout() << "STEP 2 " << n << "\n" << code << "\n";); code = csimp(env(), code, m_cfg); code = visit(code); // lean_trace(name("compiler", "spec_info"), tout() << "STEP 3 " << n << "\n" << code << "\n";); m_new_decls.push_back(comp_decl(n, code)); } optional get_closed(expr const & e) { if (has_univ_param(e)) return none_expr(); switch (e.kind()) { case expr_kind::MVar: lean_unreachable(); case expr_kind::Lit: return some_expr(e); case expr_kind::BVar: return some_expr(e); case expr_kind::Sort: return some_expr(e); case expr_kind::Const: return some_expr(e); case expr_kind::FVar: if (auto v = m_lctx.get_local_decl(e).get_value()) { return get_closed(*v); } else { return none_expr(); } case expr_kind::MData: return get_closed(mdata_expr(e)); case expr_kind::Proj: { optional new_s = get_closed(proj_expr(e)); if (!new_s) return none_expr(); return some_expr(update_proj(e, *new_s)); } case expr_kind::Pi: case expr_kind::Lambda: { optional dom = get_closed(binding_domain(e)); if (!dom) return none_expr(); optional body = get_closed(binding_body(e)); if (!body) return none_expr(); return some_expr(update_binding(e, *dom, *body)); } case expr_kind::App: { buffer args; expr const & fn = get_app_args(e, args); optional new_fn = get_closed(fn); if (!new_fn) return none_expr(); for (expr & arg : args) { optional new_arg = get_closed(arg); if (!new_arg) return none_expr(); arg = *new_arg; } return some_expr(mk_app(*new_fn, args)); } case expr_kind::Let: { optional type = get_closed(let_type(e)); if (!type) return none_expr(); optional val = get_closed(let_value(e)); if (!val) return none_expr(); optional body = get_closed(let_body(e)); if (!body) return none_expr(); return some_expr(update_let(e, *type, *val, *body)); } } lean_unreachable(); } optional specialize(expr const & fn, buffer const & args, spec_ctx & ctx) { if (!is_specialize_candidate(fn, args)) return none_expr(); // lean_trace(name("compiler", "specialize"), tout() << "specialize: " << fn << "\n";); specialize_init_deps(fn, args, ctx); buffer bmask; get_bool_mask(const_name(fn), args.size(), bmask); buffer> mask; buffer fvars; buffer fvar_vals; bool gcache_enabled = true; buffer gcache_key_args; for (unsigned i = 0; i < bmask.size(); i++) { if (bmask[i]) { if (gcache_enabled) { if (optional c = get_closed(args[i])) { gcache_key_args.push_back(*c); } else { /* We only cache specialization results if arguments (expanded by the specializer) are closed. */ gcache_enabled = false; } } name n = ngen().next(); expr fvar = mk_fvar(n); fvars.push_back(fvar); fvar_vals.push_back(args[i]); mask.push_back(some_expr(fvar)); } else { mask.push_back(none_expr()); if (gcache_enabled) gcache_key_args.push_back(expr()); } } optional new_fn_name; expr key; if (gcache_enabled) { key = mk_app(fn, gcache_key_args); if (name const * it = m_ext.m_cache.find(key)) new_fn_name = *it; } if (!new_fn_name) { /* Cache does not contain specialization result */ new_fn_name = spec_preprocess(fn, mask, ctx); if (!new_fn_name) return none_expr(); for (comp_decl const & pre_decl : ctx.m_pre_decls) { mk_new_decl(pre_decl, fvars, fvar_vals, ctx); } if (gcache_enabled) { m_ext.m_cache.insert(key, *new_fn_name); m_st.env() = module::add(env(), new spec_cache_modification(key, *new_fn_name)); } } expr r = mk_constant(*new_fn_name); r = mk_app(r, ctx.m_params); for (unsigned i = 0; i < bmask.size(); i++) { if (!bmask[i]) r = mk_app(r, args[i]); } for (unsigned i = bmask.size(); i < args.size(); i++) { r = mk_app(r, args[i]); } return some_expr(r); } expr visit_app(expr const & e) { if (is_cases_on_app(env(), e)) { return visit_cases_on(e); } else { buffer args; expr fn = get_app_args(e, args); if (!is_constant(fn) || has_nospecialize_attribute(env(), const_name(fn)) || (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) { return e; } spec_info const * info = m_ext.m_spec_info.find(const_name(fn)); if (!info) return e; spec_ctx ctx; ctx.m_mutual = info->get_mutual_decls(); if (optional r = specialize(fn, args, ctx)) return *r; else return e; } } expr visit(expr const & e) { switch (e.kind()) { case expr_kind::App: return visit_app(e); case expr_kind::Lambda: return visit_lambda(e); case expr_kind::Let: return visit_let(e); default: return e; } } public: specialize_fn(environment const & env, csimp_cfg const & cfg): m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {} pair operator()(comp_decl const & d) { m_base_name = d.fst(); lean_trace(name({"compiler", "specialize"}), tout() << "INPUT: " << d.fst() << "\n" << d.snd() << "\n";); expr new_v = visit(d.snd()); comp_decl new_d(d.fst(), new_v); environment new_env = update(env(), m_ext); return mk_pair(new_env, append(comp_decls(m_new_decls), comp_decls(new_d))); } }; pair specialize_core(environment const & env, comp_decl const & d, csimp_cfg const & cfg) { return specialize_fn(env, cfg)(d); } pair specialize(environment env, comp_decls const & ds, csimp_cfg const & cfg) { env = update_spec_info(env, ds); comp_decls r; for (comp_decl const & d : ds) { comp_decls new_ds; std::tie(env, new_ds) = specialize_core(env, d, cfg); r = append(r, new_ds); } return mk_pair(env, r); } void initialize_specialize() { g_ext = new specialize_ext_reg(); spec_info_modification::init(); spec_cache_modification::init(); register_trace_class({"compiler", "spec_info"}); register_trace_class({"compiler", "spec_candidate"}); } void finalize_specialize() { spec_info_modification::finalize(); spec_cache_modification::finalize(); delete g_ext; } }