diff --git a/src/kernel/local_ctx.h b/src/kernel/local_ctx.h index 631d42d1e1..3a02129360 100644 --- a/src/kernel/local_ctx.h +++ b/src/kernel/local_ctx.h @@ -49,14 +49,16 @@ protected: template expr mk_binding(unsigned num, expr const * fvars, expr const & b) const; - local_decl mk_local_decl(name const & n, name const & un, expr const & type, binder_info bi); - local_decl mk_local_decl(name const & n, name const & un, expr const & type, expr const & value); - public: local_ctx():m_next_idx(0) {} bool empty() const { return m_idx2local_decl.empty(); } + /* Low level `mk_local_decl` */ + local_decl mk_local_decl(name const & n, name const & un, expr const & type, binder_info bi); + /* Low level `mk_local_decl` */ + local_decl mk_local_decl(name const & n, name const & un, expr const & type, expr const & value); + expr mk_local_decl(name_generator & g, name const & un, expr const & type, binder_info bi = mk_binder_info()) { return mk_local_decl(g.next(), un, type, bi).mk_ref(); } diff --git a/src/library/compiler/compiler.cpp b/src/library/compiler/compiler.cpp index 55e55d5867..b55e0241ac 100644 --- a/src/library/compiler/compiler.cpp +++ b/src/library/compiler/compiler.cpp @@ -96,7 +96,8 @@ environment compile(environment const & env, options const & opts, names const & } comp_decls ds = to_comp_decls(env, cs); - auto simp = [&](environment const & env, expr const & e) { return csimp(env, e, csimp_cfg(opts)); }; + csimp_cfg cfg(opts); + auto simp = [&](environment const & env, expr const & e) { return csimp(env, e, cfg); }; trace_compiler(name({"compiler", "input"}), ds); ds = apply(eta_expand, env, ds); trace_compiler(name({"compiler", "eta_expand"}), ds); @@ -113,7 +114,7 @@ environment compile(environment const & env, options const & opts, names const & ds = apply(max_sharing, ds); trace_compiler(name({"compiler", "stage1"}), ds); environment new_env = cache_stage1(env, ds); - std::tie(new_env, ds) = specialize(new_env, ds); + std::tie(new_env, ds) = specialize(new_env, ds, cfg); trace_compiler(name({"compiler", "specialize"}), ds); ds = apply(elim_dead_let, ds); trace_compiler(name({"compiler", "elim_dead_let"}), ds); diff --git a/src/library/compiler/specialize.cpp b/src/library/compiler/specialize.cpp index 2ff9422972..beddd303bf 100644 --- a/src/library/compiler/specialize.cpp +++ b/src/library/compiler/specialize.cpp @@ -11,6 +11,7 @@ Author: Leonardo de Moura #include "library/module.h" #include "library/attribute_manager.h" #include "library/compiler/util.h" +#include "library/compiler/csimp.h" #include "library/trace.h" @@ -287,9 +288,14 @@ environment update_spec_info(environment const & env, comp_decls const & ds) { class specialize_fn { type_checker::state m_st; + csimp_cfg m_cfg; + specialize_ext m_ext; local_ctx m_lctx; buffer m_new_decls; name m_base_name; + name m_at; + name m_spec; + unsigned m_next_idx{1}; environment const & env() { return m_st.env(); } @@ -348,7 +354,135 @@ class specialize_fn { return e; } - void collect_deps(expr e, name_set & collected, buffer & new_params, buffer & let_vars) { + struct spec_ctx { + typedef rb_expr_map cache; + names m_mutual; + buffer m_params; + buffer m_let_vars; + cache m_cache; + buffer m_pre_decls; + + bool in_mutual_decl(name const & n) const { + return std::find(m_mutual.begin(), m_mutual.end(), n) != m_mutual.end(); + } + }; + + void get_arg_kinds(name const & fn, buffer & kinds) { + spec_info const * info = m_ext.m_spec_info.find(fn); + lean_assert(info); + to_buffer(info->get_arg_kinds(), kinds); + } + + static void to_bool_mask(buffer const & kinds, bool has_attr, buffer & mask) { + unsigned sz = kinds.size(); + mask.resize(sz, false); + unsigned i = sz; + bool found_inst = false; + bool first = true; + while (i > 0) { + --i; + switch (kinds[i]) { + case spec_arg_kind::Other: + break; + case spec_arg_kind::FixedInst: + mask[i] = true; + if (first) mask.shrink(i+1); + first = false; + found_inst = true; + break; + case spec_arg_kind::FixedHO: + case spec_arg_kind::FixedNeutral: + case spec_arg_kind::Fixed: + if (has_attr || found_inst) { + mask[i] = true; + if (first) + mask.shrink(i+1); + first = false; + } + break; + } + } + } + + void get_bool_mask(name const & fn, buffer & mask) { + buffer kinds; + get_arg_kinds(fn, kinds); + to_bool_mask(kinds, has_specialize_attribute(env(), fn), mask); + } + + name mk_spec_name(name const & fn) { + name r = fn + m_at + m_base_name + (m_spec.append_after(m_next_idx)); + m_next_idx++; + return r; + } + + static expr mk_cache_key(expr const & fn, buffer> const & mask) { + expr r = fn; + for (optional const & b : mask) { + if (b) + r = mk_app(r, *b); + else + r = mk_app(r, expr()); + } + return r; + } + + bool is_specialize_candidate(expr const & fn, buffer const & args) { + lean_assert(is_constant(fn)); + buffer kinds; + get_arg_kinds(const_name(fn), kinds); + if (!has_specialize_attribute(env(), const_name(fn)) && !has_fixed_inst_arg(kinds)) + return false; /* Nothing to specialize */ + unsigned spec_arity = get_specialization_arity(kinds); + if (spec_arity == 0) + return false; /* Nothing to specialize */ + if (spec_arity > args.size()) { + /* We do not perform partial specialization. + We only specialize if all fixed arguments have been provided. */ + return false; + } + type_checker tc(m_st, m_lctx); + for (unsigned i = 0; i < args.size(); i++) { + if (i >= kinds.size()) + break; + spec_arg_kind k = kinds[i]; + expr w; + switch (k) { + case spec_arg_kind::FixedNeutral: + break; + case spec_arg_kind::FixedInst: + /* We specialize this kind of argument if it reduces to a constructor application. + Type class instances arguments are usually free variables bound to lambda declarations, + or quickly reduce to constructor applications. So, the following `whnf` is probably + harmless. */ + w = tc.whnf(args[i]); + if (is_constructor_app(env(), w)) + return true; + break; + case spec_arg_kind::FixedHO: + /* We specialize higher-order arguments if they are lambda applications or + a constant application. + + Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ + w = find(args[i]); + if (is_lambda(w) || is_constant(get_app_fn(w))) + return true; + break; + case spec_arg_kind::Fixed: + /* We specialize this kind of argument if they are constructor applications or literals. + Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ + w = find(args[i]); + if (is_constructor_app(env(), w) || is_lit(w)) + return true; + break; + case spec_arg_kind::Other: + break; + } + } + return false; + } + + void collect_deps(expr e, name_set & collected, spec_ctx & ctx) { buffer todo; while (true) { for_each(e, [&](expr const & x, unsigned) { @@ -356,10 +490,10 @@ class specialize_fn { if (is_fvar(x) && !collected.contains(fvar_name(x))) { collected.insert(fvar_name(x)); if (optional v = m_lctx.get_local_decl(x).get_value()) { - let_vars.push_back(x); + ctx.m_let_vars.push_back(x); todo.push_back(*v); } else { - new_params.push_back(x); + ctx.m_params.push_back(x); } } return true; @@ -371,15 +505,22 @@ class specialize_fn { } } - optional specialize(expr const & fn, buffer const & args, names const & mutual, buffer const & kinds, bool has_attr) { + void sort_fvars(buffer & fvars) { + std::sort(fvars.begin(), fvars.end(), + [&](expr const & x, expr const & y) { + return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); + }); + } + + /* Initialize `spec_ctx` fields: `m_params`, `m_let_vars`. */ + void specialize_init_deps(expr const & fn, buffer const & args, spec_ctx & ctx) { + lean_assert(is_constant(fn)); + buffer kinds; + get_arg_kinds(const_name(fn), kinds); + bool has_attr = has_specialize_attribute(env(), const_name(fn)); name_set collected; - buffer new_params; - buffer let_vars; - unsigned sz = get_specialization_arity(kinds); - lean_assert(sz <= args.size()); + unsigned sz = kinds.size(); unsigned i = sz; - buffer mask; - mask.resize(args.size(), false); bool found_inst = false; while (i > 0) { --i; @@ -387,33 +528,234 @@ class specialize_fn { case spec_arg_kind::Other: break; case spec_arg_kind::FixedInst: - mask[i] = true; - collect_deps(args[i], collected, new_params, let_vars); + collect_deps(args[i], collected, ctx); found_inst = true; break; case spec_arg_kind::FixedHO: case spec_arg_kind::FixedNeutral: case spec_arg_kind::Fixed: if (has_attr || found_inst) { - mask[i] = true; - collect_deps(args[i], collected, new_params, let_vars); + collect_deps(args[i], collected, ctx); } break; } } - std::sort(new_params.begin(), new_params.end(), - [&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); }); - std::sort(let_vars.begin(), let_vars.end(), - [&](expr const & x, expr const & y) { return m_lctx.get_local_decl(x).get_idx() < m_lctx.get_local_decl(y).get_idx(); }); + sort_fvars(ctx.m_params); + sort_fvars(ctx.m_let_vars); lean_trace(name({"compiler", "spec_candidate"}), tout() << "candidate: " << mk_app(fn, args) << "\nclosure:"; - for (expr const & p : new_params) tout() << " " << p; - for (expr const & x : let_vars) tout() << " " << x; - tout() << "\nmask:"; - for (bool m : mask) tout() << " " << m; + for (expr const & p : ctx.m_params) tout() << " " << p; + for (expr const & x : ctx.m_let_vars) tout() << " " << x; tout() << "\n";); - // TODO(Leo): - return none_expr(); + } + + static bool contains(buffer> const & mask, expr const & e) { + for (optional const & o : mask) { + if (o && *o == e) + return true; + } + return false; + } + + optional adjust_rec_apps(expr e, buffer> const & mask, spec_ctx & ctx) { + switch (e.kind()) { + case expr_kind::App: + if (is_cases_on_app(env(), e)) { + buffer args; + expr const & c = get_app_args(e, args); + /* visit minor premises */ + unsigned minor_idx; unsigned minors_end; + std::tie(minor_idx, minors_end) = get_cases_on_minors_range(env(), const_name(c)); + for (; minor_idx < minors_end; minor_idx++) { + optional new_arg = adjust_rec_apps(args[minor_idx], mask, ctx); + if (!new_arg) return none_expr(); + args[minor_idx] = *new_arg; + } + return some_expr(mk_app(c, args)); + } else { + expr const & fn = get_app_fn(e); + if (!is_constant(fn) || !ctx.in_mutual_decl(const_name(fn))) + return some_expr(e); + buffer args; + get_app_args(e, args); + buffer bmask; + get_bool_mask(const_name(fn), bmask); + lean_assert(bmask.size() <= args.size()); + buffer> new_mask; + bool found = false; + for (unsigned i = 0; i < bmask.size(); i++) { + if (bmask[i] && contains(mask, args[i])) { + found = true; + new_mask.push_back(some_expr(args[i])); + } else { + new_mask.push_back(none_expr()); + } + } + if (!found) + return some_expr(e); + optional new_fn_name = spec_preprocess(fn, new_mask, ctx); + if (!new_fn_name) return none_expr(); + expr r = mk_constant(*new_fn_name); + r = mk_app(r, ctx.m_params); + for (unsigned i = 0; i < bmask.size(); i++) { + if (!bmask[i] || !contains(mask, args[i])) + r = mk_app(r, args[i]); + } + for (unsigned i = bmask.size(); i < args.size(); i++) { + r = mk_app(r, args[i]); + } + return some_expr(r); + } + case expr_kind::Lambda: { + buffer entries; + while (is_lambda(e)) { + entries.push_back(e); + e = binding_body(e); + } + optional new_e = adjust_rec_apps(e, mask, ctx); + if (!new_e) return none_expr(); + expr r = *new_e; + unsigned i = entries.size(); + while (i > 0) { + --i; + expr l = entries[i]; + r = mk_lambda(binding_name(l), binding_domain(l), r); + } + return some_expr(r); + } + case expr_kind::Let: { + buffer> entries; + while (is_let(e)) { + optional v = adjust_rec_apps(let_value(e), mask, ctx); + if (!v) return none_expr(); + expr new_val = *v; + entries.emplace_back(e, new_val); + e = let_body(e); + } + optional new_e = adjust_rec_apps(e, mask, ctx); + if (!new_e) return none_expr(); + expr r = *new_e; + unsigned i = entries.size(); + while (i > 0) { + --i; + expr l = entries[i].first; + expr v = entries[i].second; + r = mk_let(let_name(l), let_type(l), v, r); + } + return some_expr(r); + } + default: + return some_expr(e); + } + } + + optional spec_preprocess(expr const & fn, buffer> const & mask, spec_ctx & ctx) { + lean_assert(is_constant(fn)); + lean_assert(ctx.in_mutual_decl(const_name(fn))); + expr key = mk_cache_key(fn, mask); + if (name const * r = ctx.m_cache.find(key)) { + return optional(*r); + } + optional info = env().find(mk_cstage1_name(const_name(fn))); + if (!info || !info->is_definition()) return optional(); // failed + name new_name = mk_spec_name(const_name(fn)); + ctx.m_cache.insert(key, new_name); + expr new_code = instantiate_value_lparams(*info, const_levels(fn)); + flet save_lctx(m_lctx, m_lctx); + buffer fvars; + buffer new_fvars; + for (optional const & b : mask) { + lean_assert(is_lambda(new_code)); + if (b) { + lean_assert(is_fvar(*b)); + fvars.push_back(*b); + } else { + expr type = instantiate_rev(binding_domain(new_code), fvars.size(), fvars.data()); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(new_code), type, binding_info(new_code)); + new_fvars.push_back(new_fvar); + fvars.push_back(new_fvar); + } + new_code = binding_body(new_code); + } + new_code = instantiate_rev(new_code, fvars.size(), fvars.data()); + optional c = adjust_rec_apps(new_code, mask, ctx); + if (!c) return optional(); + new_code = *c; + new_code = m_lctx.mk_lambda(new_fvars, new_code); + ctx.m_pre_decls.push_back(comp_decl(new_name, new_code)); + lean_trace(name({"compiler", "specialize"}), tout() << "new specialization " << new_name << " :=\n" << new_code << "\n";); + return optional(new_name); + } + + void mk_new_decl(comp_decl const & pre_decl, buffer const & fvars, buffer const & fvar_vals, spec_ctx & ctx) { + lean_assert(fvars.size() == fvar_vals.size()); + name n = pre_decl.fst(); + expr code = pre_decl.snd(); + flet save_lctx(m_lctx, m_lctx); + buffer new_fvars; + while (is_lambda(code)) { + expr type = instantiate_rev(binding_domain(code), new_fvars.size(), new_fvars.data()); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(code), type, binding_info(code)); + new_fvars.push_back(new_fvar); + code = binding_body(code); + } + code = instantiate_rev(code, new_fvars.size(), new_fvars.data()); + /* Add fvars decls */ + type_checker tc(m_st, m_lctx); + buffer new_let_decls; + name y("_y"); + for (unsigned i = 0; i < fvars.size(); i++) { + expr type = tc.infer(fvar_vals[i]); + expr new_fvar = m_lctx.mk_local_decl(fvar_name(fvars[i]), y.append_after(i+1), type, fvar_vals[i]).mk_ref(); + new_let_decls.push_back(new_fvar); + } + code = m_lctx.mk_lambda(new_let_decls, code); + code = m_lctx.mk_lambda(ctx.m_let_vars, code); + code = m_lctx.mk_lambda(new_fvars, code); + code = m_lctx.mk_lambda(ctx.m_params, code); + lean_assert(!has_fvar(code)); + code = csimp(env(), code, m_cfg); + code = visit(code); + tout() << n << " :=\n" << code << "\n"; + m_new_decls.push_back(comp_decl(n, code)); + } + + optional specialize(expr const & fn, buffer const & args, spec_ctx & ctx) { + if (!is_specialize_candidate(fn, args)) + return none_expr(); + specialize_init_deps(fn, args, ctx); + buffer bmask; + get_bool_mask(const_name(fn), bmask); + buffer> mask; + buffer fvars; + buffer fvar_vals; + for (unsigned i = 0; i < bmask.size(); i++) { + if (bmask[i]) { + name n = ngen().next(); + expr fvar = mk_fvar(n); + fvars.push_back(fvar); + fvar_vals.push_back(args[i]); + mask.push_back(some_expr(fvar)); + } else { + mask.push_back(none_expr()); + } + } + optional new_fn_name = spec_preprocess(fn, mask, ctx); + if (!new_fn_name) + return none_expr(); + for (comp_decl const & pre_decl : ctx.m_pre_decls) { + mk_new_decl(pre_decl, fvars, fvar_vals, ctx); + } + expr r = mk_constant(*new_fn_name); + r = mk_app(r, ctx.m_params); + for (unsigned i = 0; i < bmask.size(); i++) { + if (!bmask[i]) + r = mk_app(r, args[i]); + } + for (unsigned i = bmask.size(); i < args.size(); i++) { + r = mk_app(r, args[i]); + } + return some_expr(r); } expr visit_app(expr const & e) { @@ -427,66 +769,11 @@ class specialize_fn { || (is_instance(env(), const_name(fn)) && !has_specialize_attribute(env(), const_name(fn)))) { return e; } - specialize_ext ext = get_extension(env()); - spec_info const * info = ext.m_spec_info.find(const_name(fn)); + spec_info const * info = m_ext.m_spec_info.find(const_name(fn)); if (!info) return e; - bool has_attr = has_specialize_attribute(env(), const_name(fn)); - buffer kinds; - to_buffer(info->get_arg_kinds(), kinds); - if (!has_attr && !has_fixed_inst_arg(kinds)) - return e; /* Nothing to specialize */ - unsigned spec_arity = get_specialization_arity(kinds); - if (spec_arity == 0) - return e; /* Nothing to specialize */ - if (spec_arity > args.size()) { - /* We do not perform partial specialization. - We only specialize if all fixed arguments have been provided. */ - return e; - } - type_checker tc(m_st, m_lctx); - bool is_candidate = false; - for (unsigned i = 0; i < args.size(); i++) { - if (i >= kinds.size()) - break; - spec_arg_kind k = kinds[i]; - expr w; - switch (k) { - case spec_arg_kind::FixedNeutral: - break; - case spec_arg_kind::FixedInst: - /* We specialize this kind of argument if it reduces to a constructor application. - Type class instances arguments are usually free variables bound to lambda declarations, - or quickly reduce to constructor applications. So, the following `whnf` is probably - harmless. */ - w = tc.whnf(args[i]); - if (is_constructor_app(env(), w)) - is_candidate = true; - break; - case spec_arg_kind::FixedHO: - /* We specialize higher-order arguments if they are lambda applications or - a constant application. - - Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ - w = find(args[i]); - if (is_lambda(w) || is_constant(get_app_fn(w))) - is_candidate = true; - break; - case spec_arg_kind::Fixed: - /* We specialize this kind of argument if they are constructor applications or literals. - Remark: it is not feasible to invoke whnf since it may consume a lot of time. */ - w = find(args[i]); - if (is_constructor_app(env(), w) || is_lit(w)) - is_candidate = true; - break; - case spec_arg_kind::Other: - break; - } - if (is_candidate) - break; - } - if (!is_candidate) - return e; - if (optional r = specialize(fn, args, info->get_mutual_decls(), kinds, has_attr)) + spec_ctx ctx; + ctx.m_mutual = info->get_mutual_decls(); + if (optional r = specialize(fn, args, ctx)) return *r; else return e; @@ -503,8 +790,8 @@ class specialize_fn { } public: - specialize_fn(environment const & env): - m_st(env) {} + specialize_fn(environment const & env, csimp_cfg const & cfg): + m_st(env), m_cfg(cfg), m_ext(get_extension(env)), m_at("_at"), m_spec("_spec") {} pair operator()(comp_decl const & d) { m_base_name = d.fst(); @@ -514,16 +801,18 @@ public: } }; -pair specialize_core(environment const & env, comp_decl const & d) { - return specialize_fn(env)(d); +pair specialize_core(environment const & env, comp_decl const & d, csimp_cfg const & cfg) { + // TODO(Leo): we still need to implement main cache. + // return specialize_fn(env, cfg)(d); + return mk_pair(env, comp_decls(d)); } -pair specialize(environment env, comp_decls const & ds) { +pair specialize(environment env, comp_decls const & ds, csimp_cfg const & cfg) { env = update_spec_info(env, ds); comp_decls r; for (comp_decl const & d : ds) { comp_decls new_ds; - std::tie(env, new_ds) = specialize_core(env, d); + std::tie(env, new_ds) = specialize_core(env, d, cfg); r = append(r, new_ds); } return mk_pair(env, r); diff --git a/src/library/compiler/specialize.h b/src/library/compiler/specialize.h index 1473f81f7c..a1651e39c9 100644 --- a/src/library/compiler/specialize.h +++ b/src/library/compiler/specialize.h @@ -7,8 +7,9 @@ Author: Leonardo de Moura #pragma once #include "kernel/environment.h" #include "library/compiler/util.h" +#include "library/compiler/csimp.h" namespace lean { -pair specialize(environment env, comp_decls const & ds); +pair specialize(environment env, comp_decls const & ds, csimp_cfg const & cfg); void initialize_specialize(); void finalize_specialize(); }