feat(library/compiler/csimp): add fix_core_n => fix_core_m "eta-expansion-like" optimization
After this commit, `fix_1.lean` is not slower than `fix.lean` anymore.
This commit is contained in:
parent
d55a439542
commit
609b8e87e5
3 changed files with 64 additions and 0 deletions
|
|
@ -1366,6 +1366,45 @@ class csimp_fn {
|
|||
return mk_app(mk_constant(get_nat_add_name()), arg, mk_lit(literal(nat(1))));
|
||||
}
|
||||
|
||||
/*
|
||||
Replace `fix_core_n f a_1 ... a_m`
|
||||
with `fix_core_m f a_1 ... a_m` whenever `n < m`.
|
||||
|
||||
This optimization is for writing reusable/generic code. For
|
||||
example, we cannot write an efficient `rec_t` monad transformer
|
||||
without it because we don't know the arity of `m A` when we write `rec_t`.
|
||||
|
||||
Remark: the runtime provides a small set of `fix_core_i` implementations (`i in [1, 6]`).
|
||||
This methods does nothing if `m > 6`. */
|
||||
expr visit_fix_core(expr const & e, unsigned n) {
|
||||
if (m_before_erasure) return visit_app_default(e);
|
||||
buffer<expr> args;
|
||||
expr fn = get_app_args(e, args);
|
||||
lean_assert(is_constant(fn) && is_fix_core(const_name(fn)));
|
||||
unsigned arity =
|
||||
n + /* α_1 ... α_n Type arguments */
|
||||
1 + /* β : Type */
|
||||
1 + /* (base : α_1 → ... → α_n → β) */
|
||||
1 + /* (rec : (α_1 → ... → α_n → β) → α_1 → ... → α_n → β) */
|
||||
n; /* α_1 → ... → α_n */
|
||||
if (args.size() <= arity) return visit_app_default(e);
|
||||
/* This `fix_core_n` application is an overapplication.
|
||||
The `fix_core_n` is implemented by the runtime, and the result
|
||||
is a closure. This is bad for performance. We should
|
||||
replace it with `fix_core_m` (if the runtime contains one) */
|
||||
unsigned num_extra = args.size() - arity;
|
||||
unsigned m = n + num_extra;
|
||||
optional<expr> fix_core_m = mk_enf_fix_core(m);
|
||||
if (!fix_core_m) return visit_app_default(e);
|
||||
buffer<expr> new_args;
|
||||
/* Add α_1 ... α_n and β */
|
||||
for (unsigned i = 0; i < m+1; i++) {
|
||||
new_args.push_back(mk_enf_neutral());
|
||||
}
|
||||
new_args.append(args.size() - n - 1, args.data() + n + 1);
|
||||
return mk_app(*fix_core_m, new_args);
|
||||
}
|
||||
|
||||
expr visit_app(expr const & e, bool is_let_val) {
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
return visit_cases(e, is_let_val);
|
||||
|
|
@ -1406,6 +1445,8 @@ class csimp_fn {
|
|||
return mk_lit(literal(nat(0)));
|
||||
} else if (optional<expr> r = try_inline(fn, e, is_let_val)) {
|
||||
return *r;
|
||||
} else if (optional<unsigned> i = is_fix_core(n)) {
|
||||
return visit_fix_core(e, *i);
|
||||
} else {
|
||||
return visit_app_default(e);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ Author: Leonardo de Moura
|
|||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <limits>
|
||||
#include <cctype>
|
||||
#include "util/name_hash_set.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "kernel/for_each_fn.h"
|
||||
|
|
@ -512,6 +513,20 @@ optional<nat> get_num_lit_ext(expr const & e) {
|
|||
return to_optional_nat(get_num_lit_core(e.raw()));
|
||||
}
|
||||
|
||||
optional<unsigned> is_fix_core(name const & n) {
|
||||
if (!n.is_atomic() || !n.is_string()) return optional<unsigned>();
|
||||
string_ref const & r = n.get_string();
|
||||
if (r.length() != 10) return optional<unsigned>();
|
||||
char const * s = r.data();
|
||||
if (std::strncmp(s, "fix_core_", 9) != 0 || !std::isdigit(s[9])) return optional<unsigned>();
|
||||
return optional<unsigned>(s[9] - '0');
|
||||
}
|
||||
|
||||
optional<expr> mk_enf_fix_core(unsigned n) {
|
||||
if (n == 0 || n > 6) return none_expr();
|
||||
return some_expr(mk_constant(name("fix_core").append_after(n)));
|
||||
}
|
||||
|
||||
void initialize_compiler_util() {
|
||||
g_neutral_expr = new expr(mk_constant("_neutral"));
|
||||
g_unreachable_expr = new expr(mk_constant("_unreachable"));
|
||||
|
|
|
|||
|
|
@ -161,6 +161,14 @@ environment register_stage2_decl(environment const & env, name const & n, expr c
|
|||
optional<nat> get_num_lit_ext(expr const & e);
|
||||
inline bool is_morally_num_lit(expr const & e) { return static_cast<bool>(get_num_lit_ext(e)); }
|
||||
|
||||
/* Return `some n` if `c` is of the form `fix_core_n` where `n in [1, 6]`.
|
||||
Remark: this function is assuming the core library contains `fix_core_1` ... `fix_core_6` definitions. */
|
||||
optional<unsigned> is_fix_core(name const & c);
|
||||
/* Return the `fix_core_n` constant, and `none` if `n` is not in `[1, 6]`.
|
||||
Remark: this function is assuming the core library contains `fix_core_1` ... `fix_core_6` definitions.
|
||||
Remark: this function assumes universe levels have already been erased. */
|
||||
optional<expr> mk_enf_fix_core(unsigned n);
|
||||
|
||||
void initialize_compiler_util();
|
||||
void finalize_compiler_util();
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue