lean4-htt/src/library/compiler/csimp.cpp
2018-09-26 17:54:11 -07:00

1088 lines
46 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/*
Copyright (c) 2018 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <unordered_set>
#include "runtime/flet.h"
#include "kernel/type_checker.h"
#include "kernel/for_each_fn.h"
#include "kernel/abstract.h"
#include "kernel/instantiate.h"
#include "library/util.h"
#include "library/constants.h"
#include "library/class.h"
#include "library/compiler/util.h"
#include "library/compiler/csimp.h"
#include "library/trace.h"
namespace lean {
csimp_cfg::csimp_cfg() {
m_inline = true;
m_inline_threshold = 4;
m_float_cases_app = true;
m_float_cases = false;
m_float_cases_jp_threshold = 3;
m_float_cases_jp_branch_threshold = 3;
m_inline_jp_threshold = 4;
}
class csimp_fn {
type_checker::state m_st;
local_ctx m_lctx;
csimp_cfg const & m_cfg;
buffer<expr> m_fvars;
name m_x;
name m_j;
unsigned m_next_idx{1};
unsigned m_next_jp_idx{1};
expr_set m_simplified;
typedef std::unordered_set<name, name_hash> name_set;
environment const & env() const { return m_st.env(); }
name_generator & ngen() { return m_st.ngen(); }
void check(expr const & e) {
try {
type_checker(m_st, m_lctx).check(e);
} catch (exception &) {
lean_unreachable();
}
}
void mark_simplified(expr const & e) {
m_simplified.insert(e);
}
bool already_simplified(expr const & e) const {
return m_simplified.find(e) != m_simplified.end();
}
bool is_join_point_app(expr const & e) const {
if (!is_app(e)) return false;
expr const & fn = get_app_fn(e);
return
is_fvar(fn) &&
is_join_point_name(m_lctx.get_local_decl(fn).get_user_name());
}
/* Very simple predicate used to decide whether we should inline joint-points or not.
TODO(Leo): improve */
bool is_small(expr const & e) const {
if (is_app(e) && !is_cases_on_app(env(), e))
return true;
if (is_lambda(e))
return is_small(binding_body(e));
return false;
}
expr find(expr const & e, bool skip_mdata = true) const {
if (is_fvar(e)) {
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
if (optional<expr> v = decl->get_value()) {
if (!is_join_point_name(decl->get_user_name()))
return find(*v, skip_mdata);
else if (is_small(*v))
return find(*v, skip_mdata);
}
}
} else if (is_mdata(e) && skip_mdata) {
return find(mdata_expr(e), true);
}
return e;
}
type_checker tc() { return type_checker(m_st, m_lctx); }
expr infer_type(expr const & e) { return type_checker(m_st, m_lctx).infer(e); }
expr whnf(expr const & e) { return type_checker(m_st, m_lctx).whnf(e); }
expr whnf_infer_type(expr const & e) { type_checker tc(m_st, m_lctx); return tc.whnf(tc.infer(e)); }
name next_name() {
/* Remark: we use `m_x.append_after(m_next_idx)` instead of `name(m_x, m_next_idx)`
because the resulting name is confusing during debugging: it looks like a projection application.
We should replace it with `name(m_x, m_next_idx)` when the compiler code gets more stable. */
name r = m_x.append_after(m_next_idx);
m_next_idx++;
return r;
}
name next_jp_name() {
name r = m_j.append_after(m_next_jp_idx);
m_next_jp_idx++;
return mk_join_point_name(r);
}
/* Create a new let-declaration `x : t := e`, add `x` to `m_fvars` and return `x`. */
expr mk_let_decl(expr const & e) {
lean_assert(!is_lcnf_atom(e));
expr type = cheap_beta_reduce(infer_type(e));
expr fvar = m_lctx.mk_local_decl(ngen(), next_name(), type, e);
m_fvars.push_back(fvar);
return fvar;
}
/* Given the `cases_on` application, return [first_minor_idx, first_minor_idx + nminors) */
pair<unsigned, unsigned> get_cases_on_minors_range(name const & cases) {
inductive_val I_val = env().get(cases.get_prefix()).to_inductive_val();
unsigned nparams = I_val.get_nparams();
unsigned nindices = I_val.get_nindices();
unsigned nminors = I_val.get_ncnstrs();
unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */;
return mk_pair(first_minor_idx, first_minor_idx + nminors);
}
/* Given a cases_on application `c`, return `some idx` iff `fvar` only occurs
in the argument `idx`, this argument is a minor premise. */
optional<unsigned> used_in_one_minor(expr const & c, expr const & fvar) {
lean_assert(is_cases_on_app(env(), c));
lean_assert(is_fvar(fvar));
buffer<expr> args;
expr const & c_fn = get_app_args(c, args);
unsigned minors_begin; unsigned minors_end;
std::tie(minors_begin, minors_end) = get_cases_on_minors_range(const_name(c_fn));
unsigned i = 0;
for (; i < minors_begin; i++) {
if (has_fvar(args[i], fvar)) {
/* Free variable occurs in a term that is a not a minor premise. */
return optional<unsigned>();
}
}
lean_assert(i == minors_begin);
/* The following #pragma is to disable a bogus g++ 4.9 warning at `optional<unsigned> r` */
#if defined(__GNUC__) && !defined(__CLANG__)
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
#endif
optional<unsigned> r;
for (; i < minors_end; i++) {
expr minor = args[i];
while (is_lambda(minor)) {
if (has_fvar(binding_domain(minor), fvar)) {
/* Free variable occurs in the type of a field */
return optional<unsigned>();
}
minor = binding_body(minor);
}
if (has_fvar(minor, fvar)) {
if (r) {
/* Free variable occur in more than one minor premise. */
return optional<unsigned>();
}
r = i;
}
}
return r;
}
/* Return `let _x := e in _x` */
expr mk_trivial_let(expr const & e) {
expr type = infer_type(e);
return ::lean::mk_let("_x", type, e, mk_bvar(0));
}
/* Create minor premise in LCNF.
The minor premise is of the form `fun xs, e`.
However, if `e` is a lambda, we create `fun xs, let _x := e in _x`.
Thus, we don't "mix" `xs` variables with
the variables of the `new_minor` lambda */
expr mk_minor_lambda(buffer<expr> const & xs, expr e) {
if (is_lambda(e)) {
/* We don't want to "mix" `xs` variables with
the variables of the `new_minor` lambda */
e = mk_trivial_let(e);
}
return m_lctx.mk_lambda(xs, e);
}
expr get_lambda_body(expr e, buffer<expr> & xs) {
while (is_lambda(e)) {
expr d = instantiate_rev(binding_domain(e), xs.size(), xs.data());
expr x = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e));
xs.push_back(x);
e = binding_body(e);
}
return instantiate_rev(e, xs.size(), xs.data());
}
/* Move let-decl `fvar` to the minor premise at position `minor_idx` of cases_on-application `c`. */
expr move_let_to_minor(expr const & c, unsigned minor_idx, expr const & fvar) {
lean_assert(is_cases_on_app(env(), c));
buffer<expr> args;
expr const & c_fn = get_app_args(c, args);
expr minor = args[minor_idx];
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> xs;
minor = get_lambda_body(minor, xs);
if (minor == fvar) {
/* `let x := v in x` ==> `v` */
minor = *m_lctx.get_local_decl(fvar).get_value();
} else {
xs.push_back(fvar);
}
args[minor_idx] = mk_minor_lambda(xs, minor);
return mk_app(c_fn, args);
}
static void collect_used(expr const & e, name_set & S) {
if (!has_fvar(e)) return;
for_each(e, [&](expr const & e, unsigned) {
if (!has_fvar(e)) return false;
if (is_fvar(e)) { S.insert(fvar_name(e)); return false; }
return true;
});
}
/* The `float_cases_on` transformation may produce code duplication.
The term `e` is "copied" in each branch of the the `cases_on` expression `c`.
This method creates one (or more) join-point(s) for `e` (if needed).
This method main fail because of dependent types. */
optional<expr> mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer<expr> & new_jps) {
lean_assert(is_cases_on_app(env(), c));
expr const & c_fn = get_app_fn(c);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
if (I_val.get_ncnstrs() == 1) {
/* `c` has only one case. So, only one copy of `e` may be created. */
return some_expr(e);
} else if (get_lcnf_size(env(), e) <= m_cfg.m_float_cases_jp_threshold) {
/* `e` is "small", copying should be ok. */
return some_expr(e);
} else if (is_cases_on_app(env(), e)) {
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
buffer<expr> args;
expr const & fn = get_app_args(e, args);
bool modified = false;
unsigned saved_fvars_size = m_fvars.size();
unsigned begin_minors; unsigned end_minors;
std::tie(begin_minors, end_minors) = get_cases_on_minors_range(const_name(fn));
for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) {
expr minor = args[minor_idx];
if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) {
buffer<bool> used_zs; /* used_zs[i] iff `minor` uses `zs[i]` */
bool used_fvar = false; /* true iff `minor` uses `fvar` */
bool used_unit = false; /* true if we needed to add `unit ->` to joint point */
expr jp_val;
/* Create join-point value: `jp-val` */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> zs;
minor = get_lambda_body(minor, zs);
mark_used_fvars(minor, zs, used_zs);
lean_assert(zs.size() == used_zs.size());
used_fvar = false;
jp_val = minor;
buffer<expr> jp_args;
if (has_fvar(minor, fvar)) {
/* `fvar` is a let-decl variable, we need to convert into a lambda variable.
Remark: we need to use `replace_fvar_with` because replacing the let-decl variable `fvar` with
the lambda variable `new_fvar` may produce a type incorrect term. */
used_fvar = true;
expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type());
jp_args.push_back(new_fvar);
if (optional<expr> jp_val_opt = replace_fvar_with(m_st, m_lctx, jp_val, fvar, new_fvar)) {
jp_val = *jp_val_opt;
} else {
m_fvars.resize(saved_fvars_size);
return none_expr();
}
}
for (unsigned i = 0; i < used_zs.size(); i++) {
if (used_zs[i])
jp_args.push_back(zs[i]);
}
if (jp_args.empty()) {
jp_args.push_back(m_lctx.mk_local_decl(ngen(), "_", mk_unit()));
used_unit = true;
}
jp_val = m_lctx.mk_lambda(jp_args, jp_val);
}
/* Create new jp */
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
new_jps.push_back(jp_var);
/* Replace minor with new jp */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> zs;
minor = args[minor_idx];
minor = get_lambda_body(minor, zs);
lean_assert(zs.size() == used_zs.size());
expr new_minor = jp_var;
if (used_unit)
new_minor = mk_app(new_minor, mk_unit_mk());
if (used_fvar)
new_minor = mk_app(new_minor, fvar);
for (unsigned i = 0; i < used_zs.size(); i++) {
if (used_zs[i])
new_minor = mk_app(new_minor, zs[i]);
}
new_minor = mk_minor_lambda(zs, new_minor);
args[minor_idx] = new_minor;
modified = true;
}
}
}
if (!modified) {
return some_expr(e);
} else {
lean_trace(name({"compiler", "simp"}),
tout() << "mk_join " << fvar << "\n" << c << "\n---\n"
<< e << "\n======>\n" << mk_app(fn, args) << "\n";);
return some_expr(mk_app(fn, args));
}
} else {
/* Create jp value for `e`
This kind of join-point is not very useful. It will only help if we decide
to inline the join-point later in some of the branches. */
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
expr jp_val = e;
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
/* `fvar` is a let-decl variable, we need to convert into a lambda variable.
Remark: we need to use `replace_fvar_with` because replacing the let-decl variable `fvar` with
the lambda variable `new_fvar` may produce a type incorrect term. */
expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type());
if (optional<expr> jp_val_opt = replace_fvar_with(m_st, m_lctx, jp_val, fvar, new_fvar)) {
jp_val = *jp_val_opt;
} else {
return none_expr();
}
jp_val = m_lctx.mk_lambda(new_fvar, jp_val);
}
/* Create new jp */
expr jp_type = cheap_beta_reduce(infer_type(jp_val));
mark_simplified(jp_val);
expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val);
new_jps.push_back(jp_var);
lean_trace(name({"compiler", "simp"}),
tout() << "mk_join " << fvar << "\n" << c << "\n---\n"
<< e << "\n======>\n" << mk_app(jp_var, fvar) << "\n";);
return some_expr(mk_app(jp_var, fvar));
}
}
/* Given `e[x]`, create a let-decl `y := v`, and return `e[y]`
Casts are introduced if necessary. The result is `none` if it fails to produce type correct `e[y]`. */
optional<expr> apply_at(expr const & x, expr const & e, expr const & v) {
local_decl x_decl = m_lctx.get_local_decl(x);
expr v_type = infer_type(v);
expr new_v = mk_cast(v_type, x_decl.get_type(), v);
expr y = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), x_decl.get_type(), new_v);
optional<expr> e_y_opt = replace_fvar_with(m_st, m_lctx, e, x, y);
if (!e_y_opt) return none_expr(); /* Failed to produce type correct `e[y]` */
expr e_y = *e_y_opt;
m_fvars.push_back(y);
return some_expr(visit(e_y, false));
}
/*
Given `e[x]`
```
let jp := fun z, let .... in e'
```
==>
```
let jp' := fun z, let ... y := e' in e[y]
```
If `e'` is a `cases_on` application, we use `float_cases_on_core`. That is,
```
let jp := fun z, let ... in
cases_on m
(fun y_1, let ... in e_1)
...
(fun y_n, let ... in e_n)
```
==>
```
let jp := fun z, let ... in
cases_on m
(fun y_1, let ... y := e_1 in e[y])
...
(fun y_n, let ... y := e_n in e[y])
```
Remark: this method return `none` if the new join point cannot be created
due to type errors. */
optional<expr> mk_new_join_point(expr const & x, expr const & e, expr const & jp, buffer<expr> & new_jps, expr_map<expr> & new_jp_cache) {
auto it = new_jp_cache.find(jp);
if (it != new_jp_cache.end())
return some_expr(it->second);
local_decl jp_decl = m_lctx.get_local_decl(jp);
lean_assert(is_join_point_name(jp_decl.get_user_name()));
expr jp_val = *jp_decl.get_value();
buffer<expr> zs;
unsigned saved_fvars_size = m_fvars.size();
jp_val = visit(get_lambda_body(jp_val, zs), false);
expr new_jp_val;
expr e_y;
if (is_join_point_app(jp_val)) {
buffer<expr> jp2_args;
expr const & jp2 = get_app_args(jp_val, jp2_args);
optional<expr> new_jp2_opt = mk_new_join_point(x, e, jp2, new_jps, new_jp_cache);
if (!new_jp2_opt) return none_expr();
e_y = mk_app(*new_jp2_opt, jp2_args);
} else if (is_cases_on_app(env(), jp_val)) {
optional<expr> e_y_opt = float_cases_on_core(x, e, jp_val, new_jps, new_jp_cache);
if (!e_y_opt) return none_expr();
e_y = *e_y_opt;
} else {
optional<expr> e_y_opt = apply_at(x, e, jp_val);
if (!e_y_opt) return none_expr();
e_y = *e_y_opt;
}
expr e_y_type = infer_type(e_y);
expr jp_val_type = infer_type(jp_val);
new_jp_val = mk_cast(e_y_type, jp_val_type, e_y);
new_jp_val = mk_let(saved_fvars_size, new_jp_val);
new_jp_val = m_lctx.mk_lambda(zs, new_jp_val);
mark_simplified(new_jp_val);
expr new_jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_decl.get_type(), new_jp_val);
new_jps.push_back(new_jp_var);
new_jp_cache.insert(mk_pair(jp, new_jp_var));
return some_expr(new_jp_var);
}
/* Given `e[x]`
```
cases_on m
(fun zs, let ... in e_1)
...
(fun zs, let ... in e_n)
```
==>
```
cases_on m
(fun zs, let ... y := e_1 in e[y])
...
(fun y_n, let ... y := e_n in e[y])
``` */
optional<expr> float_cases_on_core(expr const & x, expr const & e, expr const & c, buffer<expr> & new_jps, expr_map<expr> & new_jp_cache) {
lean_assert(is_cases_on_app(env(), c));
local_decl x_decl = m_lctx.get_local_decl(x);
expr result_type = whnf_infer_type(e);
buffer<expr> c_args;
expr c_fn = get_app_args(c, c_args);
inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val();
unsigned motive_idx = I_val.get_nparams();
unsigned first_index = motive_idx + 1;
unsigned nindices = I_val.get_nindices();
unsigned major_idx = first_index + nindices;
unsigned first_minor_idx = major_idx + 1;
unsigned nminors = I_val.get_ncnstrs();
/* Update motive */
{
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> zs;
expr motive = c_args[motive_idx];
expr motive_type = whnf_infer_type(motive);
for (unsigned i = 0; i < nindices + 1; i++) {
lean_assert(is_pi(motive_type));
expr z = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type));
zs.push_back(z);
motive_type = whnf(instantiate(binding_body(motive_type), z));
}
level result_lvl = sort_level(tc().ensure_type(result_type));
expr new_motive = m_lctx.mk_lambda(zs, result_type);
c_args[motive_idx] = new_motive;
/* We need to update the resultant universe. */
levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn)));
c_fn = update_constant(c_fn, new_cases_lvls);
}
/* Update minor premises */
for (unsigned i = 0; i < nminors; i++) {
unsigned minor_idx = first_minor_idx + i;
expr minor = c_args[minor_idx];
buffer<expr> zs;
unsigned saved_fvars_size = m_fvars.size();
expr minor_val = visit(get_lambda_body(minor, zs), false);
expr new_minor;
if (is_join_point_app(minor_val)) {
buffer<expr> jp_args;
expr const & jp = get_app_args(minor_val, jp_args);
optional<expr> new_jp_opt = mk_new_join_point(x, e, jp, new_jps, new_jp_cache);
if (!new_jp_opt) return none_expr();
expr new_jp = *new_jp_opt;
new_minor = mk_app(new_jp, jp_args);
} else {
optional<expr> e_y_opt = apply_at(x, e, minor_val);
if (!e_y_opt) return none_expr();
expr e_y = *e_y_opt;
expr e_y_type = infer_type(e_y);
new_minor = mk_cast(e_y_type, result_type, e_y);
}
new_minor = mk_let(saved_fvars_size, new_minor);
new_minor = mk_minor_lambda(zs, new_minor);
c_args[minor_idx] = new_minor;
}
lean_trace(name({"compiler", "simp"}),
tout() << "float_cases_on [" << get_lcnf_size(env(), e) << "]\n" << c << "\n----\n" << e << "\n=====>\n"
<< mk_app(c_fn, c_args) << "\n";);
return some_expr(mk_app(c_fn, c_args));
}
/* Float cases transformation (see: `float_cases_on_core`).
This version may create join points if `e` is big, or "good" join-points could not be created. */
optional<expr> float_cases_on(expr const & x, expr const & e, expr const & c, buffer<expr> & new_jps) {
expr_map<expr> new_jp_cache;
unsigned saved_fvars_size = m_fvars.size();
local_ctx saved_lctx = m_lctx;
if (optional<expr> new_e = mk_join_point_float_cases_on(x, e, c, new_jps)) {
if (optional<expr> r = float_cases_on_core(x, *new_e, c, new_jps, new_jp_cache)) {
return r;
}
}
m_fvars.shrink(saved_fvars_size);
m_lctx = saved_lctx;
return none_expr();
}
/* Given the buffer `entries`: `[(x_1, w_1), ..., (x_n, w_n)]`, and `e`.
Create the let-expression
```
let x_n := w_n
...
x_1 := w_1
in e
```
The values `w_i` are the "simplified values" for the let-declaration `x_i`.
*/
expr mk_let(buffer<pair<expr, expr>> const & entries, expr e) {
buffer<expr> fvars;
buffer<name> user_names;
buffer<expr> types;
buffer<expr> vals;
unsigned i = entries.size();
while (i > 0) {
--i;
expr const & fvar = entries[i].first;
fvars.push_back(fvar);
vals.push_back(entries[i].second);
local_decl fvar_decl = m_lctx.get_local_decl(fvar);
user_names.push_back(fvar_decl.get_user_name());
types.push_back(fvar_decl.get_type());
}
e = abstract(e, fvars.size(), fvars.data());
i = fvars.size();
while (i > 0) {
--i;
expr new_value = abstract(vals[i], i, fvars.data());
expr new_type = abstract(types[i], i, fvars.data());
e = ::lean::mk_let(user_names[i], new_type, new_value, e);
}
return e;
}
/* Return true iff `e` contains a free variable in `s` */
bool depends_on(expr const & e, name_set const & s) {
if (!has_fvar(e)) return false;
bool found = false;
for_each(e, [&](expr const & e, unsigned) {
if (!has_fvar(e)) return false;
if (found) return false;
if (is_fvar(e) && s.find(fvar_name(e)) != s.end()) {
found = true;
}
return true;
});
return found;
}
/* Split `entries` into two groups: `entries_dep_x` and `entries_ndep_x`.
The first group contains the entries that depend on `x` and the second the ones that doesn't.
This auxiliary method is used to float cases_on over expressions.
`entries` is of the form `[(x_1, w_1), ..., (x_n, w_n)]`, where `x_i`s are
let-decl free variables, and `w_i`s their new values. We use `entries`
and an expression `e` to create a `let` expression:
```
let x_n := w_n
...
x_1 := w_1
in e
```
*/
void split_entries(buffer<pair<expr, expr>> const & entries,
expr const & x,
buffer<pair<expr, expr>> & entries_dep_x,
buffer<pair<expr, expr>> & entries_ndep_x) {
if (entries.empty())
return;
name_set deps;
deps.insert(fvar_name(x));
/* Recall that `entries` are in reverse order. That is, pos 0 is the inner most variable. */
unsigned i = entries.size();
while (i > 0) {
--i;
expr const & fvar = entries[i].first;
expr fvar_type = m_lctx.get_type(fvar);
expr fvar_new_val = entries[i].second;
if (depends_on(fvar_type, deps) ||
depends_on(fvar_new_val, deps)) {
deps.insert(fvar_name(fvar));
entries_dep_x.push_back(entries[i]);
} else {
entries_ndep_x.push_back(entries[i]);
}
}
std::reverse(entries_dep_x.begin(), entries_dep_x.end());
std::reverse(entries_ndep_x.begin(), entries_ndep_x.end());
}
/* Create a let-expression with body `e`, and
all "used" let-declarations `m_fvars[i]` for `i in [saved_fvars_size, m_fvars.size)`.
BTW, we also visit the lambda expressions in used let-declarations of the form
`x : t := fun ...`
Note that, we don't visit them when we have visit let-expressions. */
expr mk_let(unsigned saved_fvars_size, expr e) {
if (saved_fvars_size == m_fvars.size())
return e;
/* `entries` contains pairs (let-decl fvar, new value) for building the resultant let-declaration.
We simplify the value of some let-declarations in this method, but we don't want to create
a new temporary declaration just for this. */
buffer<pair<expr, expr>> entries;
name_set e_fvars; /* Set of free variables names used in `e` */
name_set entries_fvars; /* Set of free variable names used in `entries` */
collect_used(e, e_fvars);
bool e_is_cases = is_cases_on_app(env(), e);
while (m_fvars.size() > saved_fvars_size) {
expr fvar = m_fvars.back();
m_fvars.pop_back();
bool used_in_e = (e_fvars.find(fvar_name(fvar)) != e_fvars.end());
bool used_in_entries = (entries_fvars.find(fvar_name(fvar)) != entries_fvars.end());
if (!used_in_e && !used_in_entries) {
/* Skip unused variables */
continue;
}
local_decl decl = m_lctx.get_local_decl(fvar);
expr type = decl.get_type();
expr val = *decl.get_value();
bool modified_val = false;
if (is_lambda(val)) {
/* We don't simplify lambdas when we visit `let`-expressions. */
DEBUG_CODE(unsigned saved_fvars_size = m_fvars.size(););
val = visit_lambda(val);
modified_val = true;
lean_assert(m_fvars.size() == saved_fvars_size);
}
if (is_fvar(e) && entries.empty() && fvar_name(e) == fvar_name(fvar)) {
/* `let x := v in x` ==> `v` */
e = val;
collect_used(val, e_fvars);
e_is_cases = is_cases_on_app(env(), e);
continue;
}
if (is_cases_on_app(env(), val)) {
/* Float cases transformation. */
if (m_cfg.m_float_cases) {
/* We first create a let-declaration with all entries that depends on the current
`fvar` which is a cases_on application. */
buffer<pair<expr, expr>> entries_dep_curr;
buffer<pair<expr, expr>> entries_ndep_curr;
split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr);
expr new_e = mk_let(entries_dep_curr, e);
buffer<expr> new_jps;
if (optional<expr> new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) {
e = *new_e_opt;
/* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */
e_fvars.clear(); entries_fvars.clear();
collect_used(e, e_fvars);
/* Join points may have been generated, we move them to entries. */
entries.clear();
while (!new_jps.empty()) {
expr jp_fvar = new_jps.back();
new_jps.pop_back();
local_decl jp_decl = m_lctx.get_local_decl(jp_fvar);
expr jp_type = jp_decl.get_type();
expr jp_val = *jp_decl.get_value();
collect_used(jp_type, entries_fvars);
collect_used(jp_val, entries_fvars);
entries.emplace_back(jp_fvar, jp_val);
}
/* Copy `entries_ndep_curr` to `entries` */
for (unsigned i = 0; i < entries_ndep_curr.size(); i++) {
pair<expr, expr> const & ndep_entry = entries_ndep_curr[i];
entries.push_back(ndep_entry);
collect_used(ndep_entry.second, entries_fvars);
}
continue;
}
}
val = visit_cases_default(val);
modified_val = true;
}
if (e_is_cases && used_in_e) {
optional<unsigned> minor_idx = used_in_one_minor(e, fvar);
if (minor_idx && !used_in_entries) {
/* If fvar is only used in only one minor declaration,
and is *not* used in any expression at entries */
if (modified_val) {
/* We need to create a new free variable since the new
simplified value `val` */
expr new_fvar = m_lctx.mk_local_decl(ngen(), decl.get_user_name(), type, val);
e = replace_fvar(e, fvar, new_fvar);
fvar = new_fvar;
}
collect_used(type, e_fvars);
collect_used(val, e_fvars);
e = move_let_to_minor(e, *minor_idx, fvar);
continue;
}
}
collect_used(type, entries_fvars);
collect_used(val, entries_fvars);
entries.emplace_back(fvar, val);
}
return mk_let(entries, e);
}
expr visit_let(expr e) {
buffer<expr> let_fvars;
while (is_let(e)) {
expr new_type = instantiate_rev(let_type(e), let_fvars.size(), let_fvars.data());
expr new_val = visit(instantiate_rev(let_value(e), let_fvars.size(), let_fvars.data()), true);
if (is_lcnf_atom(new_val)) {
let_fvars.push_back(new_val);
} else {
name n = let_name(e);
if (is_internal_name(n)) {
if (is_join_point_name(n))
n = next_jp_name();
else
n = next_name();
}
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val);
let_fvars.push_back(new_fvar);
m_fvars.push_back(new_fvar);
}
e = let_body(e);
}
return visit(instantiate_rev(e, let_fvars.size(), let_fvars.data()), false);
}
expr visit_lambda(expr e) {
lean_assert(is_lambda(e));
if (already_simplified(e))
return e;
flet<local_ctx> save_lctx(m_lctx, m_lctx);
unsigned saved_fvars_size = m_fvars.size();
buffer<expr> binding_fvars;
while (is_lambda(e)) {
/* Types are ignored in compilation steps. So, we do not invoke visit for d. */
expr new_d = instantiate_rev(binding_domain(e), binding_fvars.size(), binding_fvars.data());
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), new_d, binding_info(e));
binding_fvars.push_back(new_fvar);
e = binding_body(e);
}
expr new_body = visit(instantiate_rev(e, binding_fvars.size(), binding_fvars.data()), false);
new_body = mk_let(saved_fvars_size, new_body);
expr r = m_lctx.mk_lambda(binding_fvars, new_body);
mark_simplified(r);
return r;
}
bool should_inline_instance(name const & n) const {
if (is_instance(env(), n))
return !has_noinline_attribute(env(), n);
else
return false;
}
static unsigned get_num_nested_lambdas(expr e) {
unsigned r = 0;
while (is_lambda(e)) {
r++;
e = binding_body(e);
}
return r;
}
expr beta_reduce(expr fn, unsigned nargs, expr const * args, bool is_let_val) {
unsigned i = 0;
while (is_lambda(fn) && i < nargs) {
i++;
fn = binding_body(fn);
}
expr r = instantiate_rev(fn, i, args);
if (is_lambda(r)) {
lean_assert(i == nargs);
return visit(r, is_let_val);
} else {
r = visit(r, false);
if (!is_lcnf_atom(r))
r = mk_let_decl(r);
return visit(mk_app(r, nargs - i, args + i), is_let_val);
}
}
/* Remark: if `fn` is not a lambda expression, then this function
will simply create the application `fn args_of(e)` */
expr beta_reduce(expr fn, expr const & e, bool is_let_val) {
buffer<expr> args;
get_app_args(e, args);
return beta_reduce(fn, args.size(), args.data(), is_let_val);
}
optional<expr> try_inline_instance(expr const & fn, expr const & e) {
lean_assert(is_constant(fn));
optional<constant_info> info = env().find(mk_cstage1_name(const_name(fn)));
if (!info || !info->is_definition()) return none_expr();
if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return none_expr();
local_ctx saved_lctx = m_lctx;
unsigned saved_fvars_size = m_fvars.size();
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
expr r = find(beta_reduce(new_fn, e, false));
if (!is_constructor_app(env(), r)) {
m_lctx = saved_lctx;
m_fvars.resize(saved_fvars_size);
return none_expr();
}
return some_expr(r);
}
expr proj_constructor(expr const & k_app, unsigned proj_idx) {
lean_assert(is_constructor_app(env(), k_app));
buffer<expr> args;
expr const & k = get_app_args(k_app, args);
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
lean_assert(k_val.get_nparams() + proj_idx < args.size());
return args[k_val.get_nparams() + proj_idx];
}
expr visit_proj(expr const & e, bool is_let_val) {
expr s = find(proj_expr(e));
if (is_constructor_app(env(), s))
return proj_constructor(s, proj_idx(e).get_small_value());
expr const & s_fn = get_app_fn(s);
if (is_constant(s_fn) && should_inline_instance(const_name(s_fn))) {
if (optional<expr> k_app = try_inline_instance(s_fn, s))
return visit(proj_constructor(*k_app, proj_idx(e).get_small_value()), is_let_val);
}
return e;
}
expr reduce_cases_cnstr(buffer<expr> const & args, inductive_val const & I_val, expr const & major, bool is_let_val) {
lean_assert(is_constructor_app(env(), major));
unsigned nparams = I_val.get_nparams();
buffer<expr> k_args;
expr const & k = get_app_args(major, k_args);
lean_assert(is_constant(k));
lean_assert(nparams <= k_args.size());
unsigned first_minor_idx = nparams + 1 /* typeformer/motive */ + I_val.get_nindices() + 1 /* major */;
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
expr const & minor = args[first_minor_idx + k_val.get_cidx()];
return beta_reduce(minor, k_args.size() - nparams, k_args.data() + nparams, is_let_val);
}
/* Just simplify minor premises. */
expr visit_cases_default(expr const & e) {
if (already_simplified(e))
return e;
lean_assert(is_cases_on_app(env(), e));
buffer<expr> args;
expr const & c = get_app_args(e, args);
/* simplify minor premises */
unsigned minor_idx; unsigned minors_end;
std::tie(minor_idx, minors_end) = get_cases_on_minors_range(const_name(c));
for (; minor_idx < minors_end; minor_idx++) {
expr minor = args[minor_idx];
unsigned saved_fvars_size = m_fvars.size();
flet<local_ctx> save_lctx(m_lctx, m_lctx);
buffer<expr> zs;
minor = get_lambda_body(minor, zs);
expr new_minor = visit(minor, false);
new_minor = mk_let(saved_fvars_size, new_minor);
new_minor = mk_minor_lambda(zs, new_minor);
args[minor_idx] = new_minor;
}
expr r = mk_app(c, args);
mark_simplified(r);
return r;
}
expr mk_cast(type_checker & tc, expr const & A, expr const & B, expr t) {
if (tc.is_def_eq(A, B)) {
return t;
} else if (is_lc_proof_app(t)) {
return mk_app(mk_constant(get_lc_proof_name()), B);
} else {
/* lc_cast.{u_1 u_2} : Π {α : Sort u_2} {β : Sort u_1}, α → β */
level u_2 = sort_level(tc.ensure_type(A));
level u_1 = sort_level(tc.ensure_type(B));
if (!is_lcnf_atom(t))
t = mk_let_decl(t);
return mk_app(mk_constant(get_lc_cast_name(), {u_1, u_2}), A, B, t);
}
}
expr mk_cast(expr const & A, expr const & B, expr const & t) {
type_checker tc(m_st, m_lctx);
return mk_cast(tc, A, B, t);
}
expr visit_cases(expr const & e, bool is_let_val) {
buffer<expr> args;
expr const & c = get_app_args(e, args);
lean_assert(is_constant(c));
inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val();
unsigned major_idx = I_val.get_nparams() + 1 /* typeformer/motive */ + I_val.get_nindices();
lean_assert(major_idx < args.size());
expr const & major = find(args[major_idx]);
if (is_constructor_app(env(), major)) {
return reduce_cases_cnstr(args, I_val, major, is_let_val);
} else if (!is_let_val) {
return visit_cases_default(e);
} else {
return e;
}
}
expr reduce_lc_cast(expr const & e) {
buffer<expr> args;
expr const & cast_fn1 = get_app_args(e, args);
lean_assert(args.size() == 3);
if (type_checker(m_st, m_lctx).is_def_eq(args[0], args[1])) {
/* (lc_cast A A t) ==> t */
return args[2];
}
expr major = find(args[2]);
if (is_lc_cast_app(major)) {
/* Cast transitivity:
(lc_cast B C (lc_cast A B t)) ==> (lc_cast A C t)
lc_cast.{u_1 u_2} : Π {α : Sort u_2} {β : Sort u_1}, α → β */
buffer<expr> nested_args;
expr const & cast_fn2 = get_app_args(major, nested_args);
expr const & C = args[1];
expr const & A = nested_args[0];
level u1 = head(const_levels(cast_fn1));
level u2 = head(tail(const_levels(cast_fn2)));
return reduce_lc_cast(mk_app(mk_constant(get_lc_cast_name(), {u1, u2}), A, C, nested_args[2]));
}
return e;
}
expr merge_app_app(expr const & fn, expr const & e, bool is_let_val) {
lean_assert(is_app(fn));
lean_assert(is_eqp(find(get_app_fn(e)), fn));
if (!is_lc_cast_app(fn) && !is_cases_on_app(env(), fn)) {
buffer<expr> args;
get_app_args(e, args);
return visit_app(mk_app(fn, args), is_let_val);
} else {
return e;
}
}
expr reduce_cast_app_app(expr const & fn, expr const & e, bool is_let_val) {
lean_assert(is_lc_cast_app(fn));
lean_assert(is_eqp(find(get_app_fn(e)), fn));
/*
f := lc_cast g
e := f a_1 ... a_n
==>
b_1 := lc_cast a_1
...
b_n := lc_cast a_n
e := g b_1 ... b_n
*/
expr const & g = app_arg(fn);
buffer<expr> args;
get_app_args(e, args);
expr g_type = whnf_infer_type(g);
for (expr & arg : args) {
lean_assert(is_pi(g_type));
expr expected_type = binding_domain(g_type);
expr arg_type = infer_type(arg);
expr new_arg = mk_cast(arg_type, expected_type, arg);
arg = new_arg;
g_type = whnf(instantiate(binding_body(g_type), new_arg));
}
expr r = visit_app(mk_app(g, args), is_let_val);
type_checker tc(m_st, m_lctx);
expr r_type = tc.infer(r);
expr e_type = tc.infer(e);
return mk_cast(tc, r_type, e_type, r);
}
expr try_inline(expr const & fn, expr const & e, bool is_let_val) {
lean_assert(is_constant(fn));
lean_assert(is_eqp(find(get_app_fn(e)), fn));
if (has_noinline_attribute(env(), const_name(fn))) return e;
optional<constant_info> info = env().find(mk_cstage1_name(const_name(fn)));
if (!info || !info->is_definition()) return e;
if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return e;
/* TODO(Leo): check size and whether function is boring or not. */
if (!has_inline_attribute(env(), const_name(fn))) return e;
expr new_fn = instantiate_value_lparams(*info, const_levels(fn));
return beta_reduce(new_fn, e, is_let_val);
}
expr visit_app(expr const & e, bool is_let_val) {
if (is_cases_on_app(env(), e)) {
return visit_cases(e, is_let_val);
} else if (is_lc_cast_app(e)) {
return reduce_lc_cast(e);
}
expr fn = find(get_app_fn(e));
if (is_lambda(fn)) {
return beta_reduce(fn, e, is_let_val);
} else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) {
lean_assert(is_fvar(get_app_fn(e)));
buffer<expr> new_jps;
/* float cases_on from application */
expr_map<expr> new_jp_cache;
if (optional<expr> new_e = float_cases_on_core(get_app_fn(e), e, fn, new_jps, new_jp_cache)) {
mark_simplified(*new_e);
m_fvars.append(new_jps);
return *new_e;
} else {
mark_simplified(e);
return e;
}
} else if (is_lc_cast_app(fn)) {
return reduce_cast_app_app(fn, e, is_let_val);
} else if (is_lc_unreachable_app(fn)) {
expr type = infer_type(e);
return mk_lc_unreachable(m_st, m_lctx, type);
} else if (is_app(fn)) {
return merge_app_app(fn, e, is_let_val);
} else if (is_constant(fn)) {
return try_inline(fn, e, is_let_val);
}
return e;
}
expr visit(expr const & e, bool is_let_val) {
switch (e.kind()) {
case expr_kind::Lambda: return is_let_val ? e : visit_lambda(e);
case expr_kind::Let: return visit_let(e);
case expr_kind::Proj: return visit_proj(e, is_let_val);
case expr_kind::App: return visit_app(e, is_let_val);
default: return e;
}
}
public:
csimp_fn(environment const & env, local_ctx const & lctx, csimp_cfg const & cfg):
m_st(env), m_lctx(lctx), m_cfg(cfg), m_x("_x"), m_j("j") {}
expr operator()(expr const & e) {
expr r = visit(e, false);
return m_lctx.mk_lambda(m_fvars, r);
}
};
expr csimp(environment const & env, local_ctx const & lctx, expr const & e, csimp_cfg const & cfg) {
return csimp_fn(env, lctx, cfg)(e);
}
}