/* Copyright (c) 2018 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include #include "runtime/flet.h" #include "kernel/type_checker.h" #include "kernel/for_each_fn.h" #include "kernel/abstract.h" #include "kernel/instantiate.h" #include "library/util.h" #include "library/constants.h" #include "library/class.h" #include "library/compiler/util.h" #include "library/compiler/csimp.h" #include "library/trace.h" namespace lean { csimp_cfg::csimp_cfg() { m_inline = true; m_inline_threshold = 4; m_float_cases_app = true; m_float_cases = false; m_float_cases_jp_threshold = 3; m_float_cases_jp_branch_threshold = 3; m_inline_jp_threshold = 4; } class csimp_fn { type_checker::state m_st; local_ctx m_lctx; csimp_cfg const & m_cfg; buffer m_fvars; name m_x; name m_j; unsigned m_next_idx{1}; unsigned m_next_jp_idx{1}; expr_set m_simplified; typedef std::unordered_set name_set; environment const & env() const { return m_st.env(); } name_generator & ngen() { return m_st.ngen(); } void check(expr const & e) { try { type_checker(m_st, m_lctx).check(e); } catch (exception &) { lean_unreachable(); } } void mark_simplified(expr const & e) { m_simplified.insert(e); } bool already_simplified(expr const & e) const { return m_simplified.find(e) != m_simplified.end(); } bool is_join_point_app(expr const & e) const { if (!is_app(e)) return false; expr const & fn = get_app_fn(e); return is_fvar(fn) && is_join_point_name(m_lctx.get_local_decl(fn).get_user_name()); } /* Very simple predicate used to decide whether we should inline joint-points or not. TODO(Leo): improve */ bool is_small(expr const & e) const { if (is_app(e) && !is_cases_on_app(env(), e)) return true; if (is_lambda(e)) return is_small(binding_body(e)); return false; } expr find(expr const & e, bool skip_mdata = true) const { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { if (optional v = decl->get_value()) { if (!is_join_point_name(decl->get_user_name())) return find(*v, skip_mdata); else if (is_small(*v)) return find(*v, skip_mdata); } } } else if (is_mdata(e) && skip_mdata) { return find(mdata_expr(e), true); } return e; } type_checker tc() { return type_checker(m_st, m_lctx); } expr infer_type(expr const & e) { return type_checker(m_st, m_lctx).infer(e); } expr whnf(expr const & e) { return type_checker(m_st, m_lctx).whnf(e); } expr whnf_infer_type(expr const & e) { type_checker tc(m_st, m_lctx); return tc.whnf(tc.infer(e)); } name next_name() { /* Remark: we use `m_x.append_after(m_next_idx)` instead of `name(m_x, m_next_idx)` because the resulting name is confusing during debugging: it looks like a projection application. We should replace it with `name(m_x, m_next_idx)` when the compiler code gets more stable. */ name r = m_x.append_after(m_next_idx); m_next_idx++; return r; } name next_jp_name() { name r = m_j.append_after(m_next_jp_idx); m_next_jp_idx++; return mk_join_point_name(r); } /* Create a new let-declaration `x : t := e`, add `x` to `m_fvars` and return `x`. */ expr mk_let_decl(expr const & e) { lean_assert(!is_lcnf_atom(e)); expr type = cheap_beta_reduce(infer_type(e)); expr fvar = m_lctx.mk_local_decl(ngen(), next_name(), type, e); m_fvars.push_back(fvar); return fvar; } /* Given the `cases_on` application, return [first_minor_idx, first_minor_idx + nminors) */ pair get_cases_on_minors_range(name const & cases) { inductive_val I_val = env().get(cases.get_prefix()).to_inductive_val(); unsigned nparams = I_val.get_nparams(); unsigned nindices = I_val.get_nindices(); unsigned nminors = I_val.get_ncnstrs(); unsigned first_minor_idx = nparams + 1 /*motive*/ + nindices + 1 /* major */; return mk_pair(first_minor_idx, first_minor_idx + nminors); } /* Given a cases_on application `c`, return `some idx` iff `fvar` only occurs in the argument `idx`, this argument is a minor premise. */ optional used_in_one_minor(expr const & c, expr const & fvar) { lean_assert(is_cases_on_app(env(), c)); lean_assert(is_fvar(fvar)); buffer args; expr const & c_fn = get_app_args(c, args); unsigned minors_begin; unsigned minors_end; std::tie(minors_begin, minors_end) = get_cases_on_minors_range(const_name(c_fn)); unsigned i = 0; for (; i < minors_begin; i++) { if (has_fvar(args[i], fvar)) { /* Free variable occurs in a term that is a not a minor premise. */ return optional(); } } lean_assert(i == minors_begin); /* The following #pragma is to disable a bogus g++ 4.9 warning at `optional r` */ #if defined(__GNUC__) && !defined(__CLANG__) #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" #endif optional r; for (; i < minors_end; i++) { expr minor = args[i]; while (is_lambda(minor)) { if (has_fvar(binding_domain(minor), fvar)) { /* Free variable occurs in the type of a field */ return optional(); } minor = binding_body(minor); } if (has_fvar(minor, fvar)) { if (r) { /* Free variable occur in more than one minor premise. */ return optional(); } r = i; } } return r; } /* Return `let _x := e in _x` */ expr mk_trivial_let(expr const & e) { expr type = infer_type(e); return ::lean::mk_let("_x", type, e, mk_bvar(0)); } /* Create minor premise in LCNF. The minor premise is of the form `fun xs, e`. However, if `e` is a lambda, we create `fun xs, let _x := e in _x`. Thus, we don't "mix" `xs` variables with the variables of the `new_minor` lambda */ expr mk_minor_lambda(buffer const & xs, expr e) { if (is_lambda(e)) { /* We don't want to "mix" `xs` variables with the variables of the `new_minor` lambda */ e = mk_trivial_let(e); } return m_lctx.mk_lambda(xs, e); } expr get_lambda_body(expr e, buffer & xs) { while (is_lambda(e)) { expr d = instantiate_rev(binding_domain(e), xs.size(), xs.data()); expr x = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e)); xs.push_back(x); e = binding_body(e); } return instantiate_rev(e, xs.size(), xs.data()); } /* Move let-decl `fvar` to the minor premise at position `minor_idx` of cases_on-application `c`. */ expr move_let_to_minor(expr const & c, unsigned minor_idx, expr const & fvar) { lean_assert(is_cases_on_app(env(), c)); buffer args; expr const & c_fn = get_app_args(c, args); expr minor = args[minor_idx]; flet save_lctx(m_lctx, m_lctx); buffer xs; minor = get_lambda_body(minor, xs); if (minor == fvar) { /* `let x := v in x` ==> `v` */ minor = *m_lctx.get_local_decl(fvar).get_value(); } else { xs.push_back(fvar); } args[minor_idx] = mk_minor_lambda(xs, minor); return mk_app(c_fn, args); } static void collect_used(expr const & e, name_set & S) { if (!has_fvar(e)) return; for_each(e, [&](expr const & e, unsigned) { if (!has_fvar(e)) return false; if (is_fvar(e)) { S.insert(fvar_name(e)); return false; } return true; }); } /* The `float_cases_on` transformation may produce code duplication. The term `e` is "copied" in each branch of the the `cases_on` expression `c`. This method creates one (or more) join-point(s) for `e` (if needed). This method main fail because of dependent types. */ optional mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer & new_jps) { lean_assert(is_cases_on_app(env(), c)); expr const & c_fn = get_app_fn(c); inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); if (I_val.get_ncnstrs() == 1) { /* `c` has only one case. So, only one copy of `e` may be created. */ return some_expr(e); } else if (get_lcnf_size(env(), e) <= m_cfg.m_float_cases_jp_threshold) { /* `e` is "small", copying should be ok. */ return some_expr(e); } else if (is_cases_on_app(env(), e)) { local_decl fvar_decl = m_lctx.get_local_decl(fvar); buffer args; expr const & fn = get_app_args(e, args); bool modified = false; unsigned saved_fvars_size = m_fvars.size(); unsigned begin_minors; unsigned end_minors; std::tie(begin_minors, end_minors) = get_cases_on_minors_range(const_name(fn)); for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) { expr minor = args[minor_idx]; if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) { buffer used_zs; /* used_zs[i] iff `minor` uses `zs[i]` */ bool used_fvar = false; /* true iff `minor` uses `fvar` */ bool used_unit = false; /* true if we needed to add `unit ->` to joint point */ expr jp_val; /* Create join-point value: `jp-val` */ { flet save_lctx(m_lctx, m_lctx); buffer zs; minor = get_lambda_body(minor, zs); mark_used_fvars(minor, zs, used_zs); lean_assert(zs.size() == used_zs.size()); used_fvar = false; jp_val = minor; buffer jp_args; if (has_fvar(minor, fvar)) { /* `fvar` is a let-decl variable, we need to convert into a lambda variable. Remark: we need to use `replace_fvar_with` because replacing the let-decl variable `fvar` with the lambda variable `new_fvar` may produce a type incorrect term. */ used_fvar = true; expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type()); jp_args.push_back(new_fvar); if (optional jp_val_opt = replace_fvar_with(m_st, m_lctx, jp_val, fvar, new_fvar)) { jp_val = *jp_val_opt; } else { m_fvars.resize(saved_fvars_size); return none_expr(); } } for (unsigned i = 0; i < used_zs.size(); i++) { if (used_zs[i]) jp_args.push_back(zs[i]); } if (jp_args.empty()) { jp_args.push_back(m_lctx.mk_local_decl(ngen(), "_", mk_unit())); used_unit = true; } jp_val = m_lctx.mk_lambda(jp_args, jp_val); } /* Create new jp */ expr jp_type = cheap_beta_reduce(infer_type(jp_val)); mark_simplified(jp_val); expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); new_jps.push_back(jp_var); /* Replace minor with new jp */ { flet save_lctx(m_lctx, m_lctx); buffer zs; minor = args[minor_idx]; minor = get_lambda_body(minor, zs); lean_assert(zs.size() == used_zs.size()); expr new_minor = jp_var; if (used_unit) new_minor = mk_app(new_minor, mk_unit_mk()); if (used_fvar) new_minor = mk_app(new_minor, fvar); for (unsigned i = 0; i < used_zs.size(); i++) { if (used_zs[i]) new_minor = mk_app(new_minor, zs[i]); } new_minor = mk_minor_lambda(zs, new_minor); args[minor_idx] = new_minor; modified = true; } } } if (!modified) { return some_expr(e); } else { lean_trace(name({"compiler", "simp"}), tout() << "mk_join " << fvar << "\n" << c << "\n---\n" << e << "\n======>\n" << mk_app(fn, args) << "\n";); return some_expr(mk_app(fn, args)); } } else { /* Create jp value for `e` This kind of join-point is not very useful. It will only help if we decide to inline the join-point later in some of the branches. */ local_decl fvar_decl = m_lctx.get_local_decl(fvar); expr jp_val = e; { flet save_lctx(m_lctx, m_lctx); /* `fvar` is a let-decl variable, we need to convert into a lambda variable. Remark: we need to use `replace_fvar_with` because replacing the let-decl variable `fvar` with the lambda variable `new_fvar` may produce a type incorrect term. */ expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type()); if (optional jp_val_opt = replace_fvar_with(m_st, m_lctx, jp_val, fvar, new_fvar)) { jp_val = *jp_val_opt; } else { return none_expr(); } jp_val = m_lctx.mk_lambda(new_fvar, jp_val); } /* Create new jp */ expr jp_type = cheap_beta_reduce(infer_type(jp_val)); mark_simplified(jp_val); expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); new_jps.push_back(jp_var); lean_trace(name({"compiler", "simp"}), tout() << "mk_join " << fvar << "\n" << c << "\n---\n" << e << "\n======>\n" << mk_app(jp_var, fvar) << "\n";); return some_expr(mk_app(jp_var, fvar)); } } /* Given `e[x]`, create a let-decl `y := v`, and return `e[y]` Casts are introduced if necessary. The result is `none` if it fails to produce type correct `e[y]`. */ optional apply_at(expr const & x, expr const & e, expr const & v) { local_decl x_decl = m_lctx.get_local_decl(x); expr v_type = infer_type(v); expr new_v = mk_cast(v_type, x_decl.get_type(), v); expr y = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), x_decl.get_type(), new_v); optional e_y_opt = replace_fvar_with(m_st, m_lctx, e, x, y); if (!e_y_opt) return none_expr(); /* Failed to produce type correct `e[y]` */ expr e_y = *e_y_opt; m_fvars.push_back(y); return some_expr(visit(e_y, false)); } /* Given `e[x]` ``` let jp := fun z, let .... in e' ``` ==> ``` let jp' := fun z, let ... y := e' in e[y] ``` If `e'` is a `cases_on` application, we use `float_cases_on_core`. That is, ``` let jp := fun z, let ... in cases_on m (fun y_1, let ... in e_1) ... (fun y_n, let ... in e_n) ``` ==> ``` let jp := fun z, let ... in cases_on m (fun y_1, let ... y := e_1 in e[y]) ... (fun y_n, let ... y := e_n in e[y]) ``` Remark: this method return `none` if the new join point cannot be created due to type errors. */ optional mk_new_join_point(expr const & x, expr const & e, expr const & jp, buffer & new_jps, expr_map & new_jp_cache) { auto it = new_jp_cache.find(jp); if (it != new_jp_cache.end()) return some_expr(it->second); local_decl jp_decl = m_lctx.get_local_decl(jp); lean_assert(is_join_point_name(jp_decl.get_user_name())); expr jp_val = *jp_decl.get_value(); buffer zs; unsigned saved_fvars_size = m_fvars.size(); jp_val = visit(get_lambda_body(jp_val, zs), false); expr new_jp_val; expr e_y; if (is_join_point_app(jp_val)) { buffer jp2_args; expr const & jp2 = get_app_args(jp_val, jp2_args); optional new_jp2_opt = mk_new_join_point(x, e, jp2, new_jps, new_jp_cache); if (!new_jp2_opt) return none_expr(); e_y = mk_app(*new_jp2_opt, jp2_args); } else if (is_cases_on_app(env(), jp_val)) { optional e_y_opt = float_cases_on_core(x, e, jp_val, new_jps, new_jp_cache); if (!e_y_opt) return none_expr(); e_y = *e_y_opt; } else { optional e_y_opt = apply_at(x, e, jp_val); if (!e_y_opt) return none_expr(); e_y = *e_y_opt; } expr e_y_type = infer_type(e_y); expr jp_val_type = infer_type(jp_val); new_jp_val = mk_cast(e_y_type, jp_val_type, e_y); new_jp_val = mk_let(saved_fvars_size, new_jp_val); new_jp_val = m_lctx.mk_lambda(zs, new_jp_val); mark_simplified(new_jp_val); expr new_jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_decl.get_type(), new_jp_val); new_jps.push_back(new_jp_var); new_jp_cache.insert(mk_pair(jp, new_jp_var)); return some_expr(new_jp_var); } /* Given `e[x]` ``` cases_on m (fun zs, let ... in e_1) ... (fun zs, let ... in e_n) ``` ==> ``` cases_on m (fun zs, let ... y := e_1 in e[y]) ... (fun y_n, let ... y := e_n in e[y]) ``` */ optional float_cases_on_core(expr const & x, expr const & e, expr const & c, buffer & new_jps, expr_map & new_jp_cache) { lean_assert(is_cases_on_app(env(), c)); local_decl x_decl = m_lctx.get_local_decl(x); expr result_type = whnf_infer_type(e); buffer c_args; expr c_fn = get_app_args(c, c_args); inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); unsigned motive_idx = I_val.get_nparams(); unsigned first_index = motive_idx + 1; unsigned nindices = I_val.get_nindices(); unsigned major_idx = first_index + nindices; unsigned first_minor_idx = major_idx + 1; unsigned nminors = I_val.get_ncnstrs(); /* Update motive */ { flet save_lctx(m_lctx, m_lctx); buffer zs; expr motive = c_args[motive_idx]; expr motive_type = whnf_infer_type(motive); for (unsigned i = 0; i < nindices + 1; i++) { lean_assert(is_pi(motive_type)); expr z = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type)); zs.push_back(z); motive_type = whnf(instantiate(binding_body(motive_type), z)); } level result_lvl = sort_level(tc().ensure_type(result_type)); expr new_motive = m_lctx.mk_lambda(zs, result_type); c_args[motive_idx] = new_motive; /* We need to update the resultant universe. */ levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn))); c_fn = update_constant(c_fn, new_cases_lvls); } /* Update minor premises */ for (unsigned i = 0; i < nminors; i++) { unsigned minor_idx = first_minor_idx + i; expr minor = c_args[minor_idx]; buffer zs; unsigned saved_fvars_size = m_fvars.size(); expr minor_val = visit(get_lambda_body(minor, zs), false); expr new_minor; if (is_join_point_app(minor_val)) { buffer jp_args; expr const & jp = get_app_args(minor_val, jp_args); optional new_jp_opt = mk_new_join_point(x, e, jp, new_jps, new_jp_cache); if (!new_jp_opt) return none_expr(); expr new_jp = *new_jp_opt; new_minor = mk_app(new_jp, jp_args); } else { optional e_y_opt = apply_at(x, e, minor_val); if (!e_y_opt) return none_expr(); expr e_y = *e_y_opt; expr e_y_type = infer_type(e_y); new_minor = mk_cast(e_y_type, result_type, e_y); } new_minor = mk_let(saved_fvars_size, new_minor); new_minor = mk_minor_lambda(zs, new_minor); c_args[minor_idx] = new_minor; } lean_trace(name({"compiler", "simp"}), tout() << "float_cases_on [" << get_lcnf_size(env(), e) << "]\n" << c << "\n----\n" << e << "\n=====>\n" << mk_app(c_fn, c_args) << "\n";); return some_expr(mk_app(c_fn, c_args)); } /* Float cases transformation (see: `float_cases_on_core`). This version may create join points if `e` is big, or "good" join-points could not be created. */ optional float_cases_on(expr const & x, expr const & e, expr const & c, buffer & new_jps) { expr_map new_jp_cache; unsigned saved_fvars_size = m_fvars.size(); local_ctx saved_lctx = m_lctx; if (optional new_e = mk_join_point_float_cases_on(x, e, c, new_jps)) { if (optional r = float_cases_on_core(x, *new_e, c, new_jps, new_jp_cache)) { return r; } } m_fvars.shrink(saved_fvars_size); m_lctx = saved_lctx; return none_expr(); } /* Given the buffer `entries`: `[(x_1, w_1), ..., (x_n, w_n)]`, and `e`. Create the let-expression ``` let x_n := w_n ... x_1 := w_1 in e ``` The values `w_i` are the "simplified values" for the let-declaration `x_i`. */ expr mk_let(buffer> const & entries, expr e) { buffer fvars; buffer user_names; buffer types; buffer vals; unsigned i = entries.size(); while (i > 0) { --i; expr const & fvar = entries[i].first; fvars.push_back(fvar); vals.push_back(entries[i].second); local_decl fvar_decl = m_lctx.get_local_decl(fvar); user_names.push_back(fvar_decl.get_user_name()); types.push_back(fvar_decl.get_type()); } e = abstract(e, fvars.size(), fvars.data()); i = fvars.size(); while (i > 0) { --i; expr new_value = abstract(vals[i], i, fvars.data()); expr new_type = abstract(types[i], i, fvars.data()); e = ::lean::mk_let(user_names[i], new_type, new_value, e); } return e; } /* Return true iff `e` contains a free variable in `s` */ bool depends_on(expr const & e, name_set const & s) { if (!has_fvar(e)) return false; bool found = false; for_each(e, [&](expr const & e, unsigned) { if (!has_fvar(e)) return false; if (found) return false; if (is_fvar(e) && s.find(fvar_name(e)) != s.end()) { found = true; } return true; }); return found; } /* Split `entries` into two groups: `entries_dep_x` and `entries_ndep_x`. The first group contains the entries that depend on `x` and the second the ones that doesn't. This auxiliary method is used to float cases_on over expressions. `entries` is of the form `[(x_1, w_1), ..., (x_n, w_n)]`, where `x_i`s are let-decl free variables, and `w_i`s their new values. We use `entries` and an expression `e` to create a `let` expression: ``` let x_n := w_n ... x_1 := w_1 in e ``` */ void split_entries(buffer> const & entries, expr const & x, buffer> & entries_dep_x, buffer> & entries_ndep_x) { if (entries.empty()) return; name_set deps; deps.insert(fvar_name(x)); /* Recall that `entries` are in reverse order. That is, pos 0 is the inner most variable. */ unsigned i = entries.size(); while (i > 0) { --i; expr const & fvar = entries[i].first; expr fvar_type = m_lctx.get_type(fvar); expr fvar_new_val = entries[i].second; if (depends_on(fvar_type, deps) || depends_on(fvar_new_val, deps)) { deps.insert(fvar_name(fvar)); entries_dep_x.push_back(entries[i]); } else { entries_ndep_x.push_back(entries[i]); } } std::reverse(entries_dep_x.begin(), entries_dep_x.end()); std::reverse(entries_ndep_x.begin(), entries_ndep_x.end()); } /* Create a let-expression with body `e`, and all "used" let-declarations `m_fvars[i]` for `i in [saved_fvars_size, m_fvars.size)`. BTW, we also visit the lambda expressions in used let-declarations of the form `x : t := fun ...` Note that, we don't visit them when we have visit let-expressions. */ expr mk_let(unsigned saved_fvars_size, expr e) { if (saved_fvars_size == m_fvars.size()) return e; /* `entries` contains pairs (let-decl fvar, new value) for building the resultant let-declaration. We simplify the value of some let-declarations in this method, but we don't want to create a new temporary declaration just for this. */ buffer> entries; name_set e_fvars; /* Set of free variables names used in `e` */ name_set entries_fvars; /* Set of free variable names used in `entries` */ collect_used(e, e_fvars); bool e_is_cases = is_cases_on_app(env(), e); while (m_fvars.size() > saved_fvars_size) { expr fvar = m_fvars.back(); m_fvars.pop_back(); bool used_in_e = (e_fvars.find(fvar_name(fvar)) != e_fvars.end()); bool used_in_entries = (entries_fvars.find(fvar_name(fvar)) != entries_fvars.end()); if (!used_in_e && !used_in_entries) { /* Skip unused variables */ continue; } local_decl decl = m_lctx.get_local_decl(fvar); expr type = decl.get_type(); expr val = *decl.get_value(); bool modified_val = false; if (is_lambda(val)) { /* We don't simplify lambdas when we visit `let`-expressions. */ DEBUG_CODE(unsigned saved_fvars_size = m_fvars.size();); val = visit_lambda(val); modified_val = true; lean_assert(m_fvars.size() == saved_fvars_size); } if (is_fvar(e) && entries.empty() && fvar_name(e) == fvar_name(fvar)) { /* `let x := v in x` ==> `v` */ e = val; collect_used(val, e_fvars); e_is_cases = is_cases_on_app(env(), e); continue; } if (is_cases_on_app(env(), val)) { /* Float cases transformation. */ if (m_cfg.m_float_cases) { /* We first create a let-declaration with all entries that depends on the current `fvar` which is a cases_on application. */ buffer> entries_dep_curr; buffer> entries_ndep_curr; split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr); expr new_e = mk_let(entries_dep_curr, e); buffer new_jps; if (optional new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) { e = *new_e_opt; /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ e_fvars.clear(); entries_fvars.clear(); collect_used(e, e_fvars); /* Join points may have been generated, we move them to entries. */ entries.clear(); while (!new_jps.empty()) { expr jp_fvar = new_jps.back(); new_jps.pop_back(); local_decl jp_decl = m_lctx.get_local_decl(jp_fvar); expr jp_type = jp_decl.get_type(); expr jp_val = *jp_decl.get_value(); collect_used(jp_type, entries_fvars); collect_used(jp_val, entries_fvars); entries.emplace_back(jp_fvar, jp_val); } /* Copy `entries_ndep_curr` to `entries` */ for (unsigned i = 0; i < entries_ndep_curr.size(); i++) { pair const & ndep_entry = entries_ndep_curr[i]; entries.push_back(ndep_entry); collect_used(ndep_entry.second, entries_fvars); } continue; } } val = visit_cases_default(val); modified_val = true; } if (e_is_cases && used_in_e) { optional minor_idx = used_in_one_minor(e, fvar); if (minor_idx && !used_in_entries) { /* If fvar is only used in only one minor declaration, and is *not* used in any expression at entries */ if (modified_val) { /* We need to create a new free variable since the new simplified value `val` */ expr new_fvar = m_lctx.mk_local_decl(ngen(), decl.get_user_name(), type, val); e = replace_fvar(e, fvar, new_fvar); fvar = new_fvar; } collect_used(type, e_fvars); collect_used(val, e_fvars); e = move_let_to_minor(e, *minor_idx, fvar); continue; } } collect_used(type, entries_fvars); collect_used(val, entries_fvars); entries.emplace_back(fvar, val); } return mk_let(entries, e); } expr visit_let(expr e) { buffer let_fvars; while (is_let(e)) { expr new_type = instantiate_rev(let_type(e), let_fvars.size(), let_fvars.data()); expr new_val = visit(instantiate_rev(let_value(e), let_fvars.size(), let_fvars.data()), true); if (is_lcnf_atom(new_val)) { let_fvars.push_back(new_val); } else { name n = let_name(e); if (is_internal_name(n)) { if (is_join_point_name(n)) n = next_jp_name(); else n = next_name(); } expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); let_fvars.push_back(new_fvar); m_fvars.push_back(new_fvar); } e = let_body(e); } return visit(instantiate_rev(e, let_fvars.size(), let_fvars.data()), false); } expr visit_lambda(expr e) { lean_assert(is_lambda(e)); if (already_simplified(e)) return e; flet save_lctx(m_lctx, m_lctx); unsigned saved_fvars_size = m_fvars.size(); buffer binding_fvars; while (is_lambda(e)) { /* Types are ignored in compilation steps. So, we do not invoke visit for d. */ expr new_d = instantiate_rev(binding_domain(e), binding_fvars.size(), binding_fvars.data()); expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), new_d, binding_info(e)); binding_fvars.push_back(new_fvar); e = binding_body(e); } expr new_body = visit(instantiate_rev(e, binding_fvars.size(), binding_fvars.data()), false); new_body = mk_let(saved_fvars_size, new_body); expr r = m_lctx.mk_lambda(binding_fvars, new_body); mark_simplified(r); return r; } bool should_inline_instance(name const & n) const { if (is_instance(env(), n)) return !has_noinline_attribute(env(), n); else return false; } static unsigned get_num_nested_lambdas(expr e) { unsigned r = 0; while (is_lambda(e)) { r++; e = binding_body(e); } return r; } expr beta_reduce(expr fn, unsigned nargs, expr const * args, bool is_let_val) { unsigned i = 0; while (is_lambda(fn) && i < nargs) { i++; fn = binding_body(fn); } expr r = instantiate_rev(fn, i, args); if (is_lambda(r)) { lean_assert(i == nargs); return visit(r, is_let_val); } else { r = visit(r, false); if (!is_lcnf_atom(r)) r = mk_let_decl(r); return visit(mk_app(r, nargs - i, args + i), is_let_val); } } /* Remark: if `fn` is not a lambda expression, then this function will simply create the application `fn args_of(e)` */ expr beta_reduce(expr fn, expr const & e, bool is_let_val) { buffer args; get_app_args(e, args); return beta_reduce(fn, args.size(), args.data(), is_let_val); } optional try_inline_instance(expr const & fn, expr const & e) { lean_assert(is_constant(fn)); optional info = env().find(mk_cstage1_name(const_name(fn))); if (!info || !info->is_definition()) return none_expr(); if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return none_expr(); local_ctx saved_lctx = m_lctx; unsigned saved_fvars_size = m_fvars.size(); expr new_fn = instantiate_value_lparams(*info, const_levels(fn)); expr r = find(beta_reduce(new_fn, e, false)); if (!is_constructor_app(env(), r)) { m_lctx = saved_lctx; m_fvars.resize(saved_fvars_size); return none_expr(); } return some_expr(r); } expr proj_constructor(expr const & k_app, unsigned proj_idx) { lean_assert(is_constructor_app(env(), k_app)); buffer args; expr const & k = get_app_args(k_app, args); constructor_val k_val = env().get(const_name(k)).to_constructor_val(); lean_assert(k_val.get_nparams() + proj_idx < args.size()); return args[k_val.get_nparams() + proj_idx]; } expr visit_proj(expr const & e, bool is_let_val) { expr s = find(proj_expr(e)); if (is_constructor_app(env(), s)) return proj_constructor(s, proj_idx(e).get_small_value()); expr const & s_fn = get_app_fn(s); if (is_constant(s_fn) && should_inline_instance(const_name(s_fn))) { if (optional k_app = try_inline_instance(s_fn, s)) return visit(proj_constructor(*k_app, proj_idx(e).get_small_value()), is_let_val); } return e; } expr reduce_cases_cnstr(buffer const & args, inductive_val const & I_val, expr const & major, bool is_let_val) { lean_assert(is_constructor_app(env(), major)); unsigned nparams = I_val.get_nparams(); buffer k_args; expr const & k = get_app_args(major, k_args); lean_assert(is_constant(k)); lean_assert(nparams <= k_args.size()); unsigned first_minor_idx = nparams + 1 /* typeformer/motive */ + I_val.get_nindices() + 1 /* major */; constructor_val k_val = env().get(const_name(k)).to_constructor_val(); expr const & minor = args[first_minor_idx + k_val.get_cidx()]; return beta_reduce(minor, k_args.size() - nparams, k_args.data() + nparams, is_let_val); } /* Just simplify minor premises. */ expr visit_cases_default(expr const & e) { if (already_simplified(e)) return e; lean_assert(is_cases_on_app(env(), e)); buffer args; expr const & c = get_app_args(e, args); /* simplify minor premises */ unsigned minor_idx; unsigned minors_end; std::tie(minor_idx, minors_end) = get_cases_on_minors_range(const_name(c)); for (; minor_idx < minors_end; minor_idx++) { expr minor = args[minor_idx]; unsigned saved_fvars_size = m_fvars.size(); flet save_lctx(m_lctx, m_lctx); buffer zs; minor = get_lambda_body(minor, zs); expr new_minor = visit(minor, false); new_minor = mk_let(saved_fvars_size, new_minor); new_minor = mk_minor_lambda(zs, new_minor); args[minor_idx] = new_minor; } expr r = mk_app(c, args); mark_simplified(r); return r; } expr mk_cast(type_checker & tc, expr const & A, expr const & B, expr t) { if (tc.is_def_eq(A, B)) { return t; } else if (is_lc_proof_app(t)) { return mk_app(mk_constant(get_lc_proof_name()), B); } else { /* lc_cast.{u_1 u_2} : Π {α : Sort u_2} {β : Sort u_1}, α → β */ level u_2 = sort_level(tc.ensure_type(A)); level u_1 = sort_level(tc.ensure_type(B)); if (!is_lcnf_atom(t)) t = mk_let_decl(t); return mk_app(mk_constant(get_lc_cast_name(), {u_1, u_2}), A, B, t); } } expr mk_cast(expr const & A, expr const & B, expr const & t) { type_checker tc(m_st, m_lctx); return mk_cast(tc, A, B, t); } expr visit_cases(expr const & e, bool is_let_val) { buffer args; expr const & c = get_app_args(e, args); lean_assert(is_constant(c)); inductive_val I_val = env().get(const_name(c).get_prefix()).to_inductive_val(); unsigned major_idx = I_val.get_nparams() + 1 /* typeformer/motive */ + I_val.get_nindices(); lean_assert(major_idx < args.size()); expr const & major = find(args[major_idx]); if (is_constructor_app(env(), major)) { return reduce_cases_cnstr(args, I_val, major, is_let_val); } else if (!is_let_val) { return visit_cases_default(e); } else { return e; } } expr reduce_lc_cast(expr const & e) { buffer args; expr const & cast_fn1 = get_app_args(e, args); lean_assert(args.size() == 3); if (type_checker(m_st, m_lctx).is_def_eq(args[0], args[1])) { /* (lc_cast A A t) ==> t */ return args[2]; } expr major = find(args[2]); if (is_lc_cast_app(major)) { /* Cast transitivity: (lc_cast B C (lc_cast A B t)) ==> (lc_cast A C t) lc_cast.{u_1 u_2} : Π {α : Sort u_2} {β : Sort u_1}, α → β */ buffer nested_args; expr const & cast_fn2 = get_app_args(major, nested_args); expr const & C = args[1]; expr const & A = nested_args[0]; level u1 = head(const_levels(cast_fn1)); level u2 = head(tail(const_levels(cast_fn2))); return reduce_lc_cast(mk_app(mk_constant(get_lc_cast_name(), {u1, u2}), A, C, nested_args[2])); } return e; } expr merge_app_app(expr const & fn, expr const & e, bool is_let_val) { lean_assert(is_app(fn)); lean_assert(is_eqp(find(get_app_fn(e)), fn)); if (!is_lc_cast_app(fn) && !is_cases_on_app(env(), fn)) { buffer args; get_app_args(e, args); return visit_app(mk_app(fn, args), is_let_val); } else { return e; } } expr reduce_cast_app_app(expr const & fn, expr const & e, bool is_let_val) { lean_assert(is_lc_cast_app(fn)); lean_assert(is_eqp(find(get_app_fn(e)), fn)); /* f := lc_cast g e := f a_1 ... a_n ==> b_1 := lc_cast a_1 ... b_n := lc_cast a_n e := g b_1 ... b_n */ expr const & g = app_arg(fn); buffer args; get_app_args(e, args); expr g_type = whnf_infer_type(g); for (expr & arg : args) { lean_assert(is_pi(g_type)); expr expected_type = binding_domain(g_type); expr arg_type = infer_type(arg); expr new_arg = mk_cast(arg_type, expected_type, arg); arg = new_arg; g_type = whnf(instantiate(binding_body(g_type), new_arg)); } expr r = visit_app(mk_app(g, args), is_let_val); type_checker tc(m_st, m_lctx); expr r_type = tc.infer(r); expr e_type = tc.infer(e); return mk_cast(tc, r_type, e_type, r); } expr try_inline(expr const & fn, expr const & e, bool is_let_val) { lean_assert(is_constant(fn)); lean_assert(is_eqp(find(get_app_fn(e)), fn)); if (has_noinline_attribute(env(), const_name(fn))) return e; optional info = env().find(mk_cstage1_name(const_name(fn))); if (!info || !info->is_definition()) return e; if (get_app_num_args(e) < get_num_nested_lambdas(info->get_value())) return e; /* TODO(Leo): check size and whether function is boring or not. */ if (!has_inline_attribute(env(), const_name(fn))) return e; expr new_fn = instantiate_value_lparams(*info, const_levels(fn)); return beta_reduce(new_fn, e, is_let_val); } expr visit_app(expr const & e, bool is_let_val) { if (is_cases_on_app(env(), e)) { return visit_cases(e, is_let_val); } else if (is_lc_cast_app(e)) { return reduce_lc_cast(e); } expr fn = find(get_app_fn(e)); if (is_lambda(fn)) { return beta_reduce(fn, e, is_let_val); } else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) { lean_assert(is_fvar(get_app_fn(e))); buffer new_jps; /* float cases_on from application */ expr_map new_jp_cache; if (optional new_e = float_cases_on_core(get_app_fn(e), e, fn, new_jps, new_jp_cache)) { mark_simplified(*new_e); m_fvars.append(new_jps); return *new_e; } else { mark_simplified(e); return e; } } else if (is_lc_cast_app(fn)) { return reduce_cast_app_app(fn, e, is_let_val); } else if (is_lc_unreachable_app(fn)) { expr type = infer_type(e); return mk_lc_unreachable(m_st, m_lctx, type); } else if (is_app(fn)) { return merge_app_app(fn, e, is_let_val); } else if (is_constant(fn)) { return try_inline(fn, e, is_let_val); } return e; } expr visit(expr const & e, bool is_let_val) { switch (e.kind()) { case expr_kind::Lambda: return is_let_val ? e : visit_lambda(e); case expr_kind::Let: return visit_let(e); case expr_kind::Proj: return visit_proj(e, is_let_val); case expr_kind::App: return visit_app(e, is_let_val); default: return e; } } public: csimp_fn(environment const & env, local_ctx const & lctx, csimp_cfg const & cfg): m_st(env), m_lctx(lctx), m_cfg(cfg), m_x("_x"), m_j("j") {} expr operator()(expr const & e) { expr r = visit(e, false); return m_lctx.mk_lambda(m_fvars, r); } }; expr csimp(environment const & env, local_ctx const & lctx, expr const & e, csimp_cfg const & cfg) { return csimp_fn(env, lctx, cfg)(e); } }