diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index a79f47dd43..38de99d476 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -61,6 +61,14 @@ class csimp_fn { return m_simplified.find(e) != m_simplified.end(); } + bool is_join_point_app(expr const & e) const { + if (!is_app(e)) return false; + expr const & fn = get_app_fn(e); + return + is_fvar(fn) && + is_join_point_name(m_lctx.get_local_decl(fn).get_user_name()); + } + /* Very simple predicate used to decide whether we should inline joint-points or not. TODO(Leo): improve */ bool is_small(expr const & e) const { @@ -191,14 +199,14 @@ class csimp_fn { return m_lctx.mk_lambda(xs, e); } - expr get_minor_body(expr minor, buffer & xs) { - while (is_lambda(minor)) { - expr d = instantiate_rev(binding_domain(minor), xs.size(), xs.data()); - expr x = m_lctx.mk_local_decl(ngen(), binding_name(minor), d, binding_info(minor)); + expr get_lambda_body(expr e, buffer & xs) { + while (is_lambda(e)) { + expr d = instantiate_rev(binding_domain(e), xs.size(), xs.data()); + expr x = m_lctx.mk_local_decl(ngen(), binding_name(e), d, binding_info(e)); xs.push_back(x); - minor = binding_body(minor); + e = binding_body(e); } - return instantiate_rev(minor, xs.size(), xs.data()); + return instantiate_rev(e, xs.size(), xs.data()); } /* Move let-decl `fvar` to the minor premise at position `minor_idx` of cases_on-application `c`. */ @@ -209,7 +217,7 @@ class csimp_fn { expr minor = args[minor_idx]; flet save_lctx(m_lctx, m_lctx); buffer xs; - minor = get_minor_body(minor, xs); + minor = get_lambda_body(minor, xs); if (minor == fvar) { /* `let x := v in x` ==> `v` */ minor = *m_lctx.get_local_decl(fvar).get_value(); @@ -233,7 +241,7 @@ class csimp_fn { The term `e` is "copied" in each branch of the the `cases_on` expression `c`. This method creates one (or more) join-point(s) for `e` (if needed). This method main fail because of dependent types. */ - optional mk_join_point_float_cases_on(expr const & fvar, expr const & c, expr const & e) { + optional mk_join_point_float_cases_on(expr const & fvar, expr const & e, expr const & c, buffer & new_jps) { lean_assert(is_cases_on_app(env(), c)); expr const & c_fn = get_app_fn(c); inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); @@ -254,17 +262,17 @@ class csimp_fn { for (unsigned minor_idx = begin_minors; minor_idx < end_minors; minor_idx++) { expr minor = args[minor_idx]; if (get_lcnf_size(env(), minor) > m_cfg.m_float_cases_jp_branch_threshold) { - buffer used_xs; /* used_xs[i] iff `minor` uses `xs[i]` */ + buffer used_zs; /* used_zs[i] iff `minor` uses `zs[i]` */ bool used_fvar = false; /* true iff `minor` uses `fvar` */ bool used_unit = false; /* true if we needed to add `unit ->` to joint point */ expr jp_val; /* Create join-point value: `jp-val` */ { flet save_lctx(m_lctx, m_lctx); - buffer xs; - minor = get_minor_body(minor, xs); - mark_used_fvars(minor, xs, used_xs); - lean_assert(xs.size() == used_xs.size()); + buffer zs; + minor = get_lambda_body(minor, zs); + mark_used_fvars(minor, zs, used_zs); + lean_assert(zs.size() == used_zs.size()); used_fvar = false; jp_val = minor; buffer jp_args; @@ -282,9 +290,9 @@ class csimp_fn { return none_expr(); } } - for (unsigned i = 0; i < used_xs.size(); i++) { - if (used_xs[i]) - jp_args.push_back(xs[i]); + for (unsigned i = 0; i < used_zs.size(); i++) { + if (used_zs[i]) + jp_args.push_back(zs[i]); } if (jp_args.empty()) { jp_args.push_back(m_lctx.mk_local_decl(ngen(), "_", mk_unit())); @@ -296,24 +304,24 @@ class csimp_fn { expr jp_type = cheap_beta_reduce(infer_type(jp_val)); mark_simplified(jp_val); expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); - m_fvars.push_back(jp_var); + new_jps.push_back(jp_var); /* Replace minor with new jp */ { flet save_lctx(m_lctx, m_lctx); - buffer xs; + buffer zs; minor = args[minor_idx]; - minor = get_minor_body(minor, xs); - lean_assert(xs.size() == used_xs.size()); + minor = get_lambda_body(minor, zs); + lean_assert(zs.size() == used_zs.size()); expr new_minor = jp_var; if (used_unit) new_minor = mk_app(new_minor, mk_unit_mk()); if (used_fvar) new_minor = mk_app(new_minor, fvar); - for (unsigned i = 0; i < used_xs.size(); i++) { - if (used_xs[i]) - new_minor = mk_app(new_minor, xs[i]); + for (unsigned i = 0; i < used_zs.size(); i++) { + if (used_zs[i]) + new_minor = mk_app(new_minor, zs[i]); } - new_minor = mk_minor_lambda(xs, new_minor); + new_minor = mk_minor_lambda(zs, new_minor); args[minor_idx] = new_minor; modified = true; } @@ -350,7 +358,7 @@ class csimp_fn { expr jp_type = cheap_beta_reduce(infer_type(jp_val)); mark_simplified(jp_val); expr jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_type, jp_val); - m_fvars.push_back(jp_var); + new_jps.push_back(jp_var); lean_trace(name({"compiler", "simp"}), tout() << "mk_join " << fvar << "\n" << c << "\n---\n" << e << "\n======>\n" << mk_app(jp_var, fvar) << "\n";); @@ -358,24 +366,104 @@ class csimp_fn { } } - /* Float cases transformation + /* Given `e[x]`, create a let-decl `y := v`, and return `e[y]` + Casts are introduced if necessary. The result is `none` if it fails to produce type correct `e[y]`. */ + optional apply_at(expr const & x, expr const & e, expr const & v) { + local_decl x_decl = m_lctx.get_local_decl(x); + expr v_type = infer_type(v); + expr new_v = mk_cast(v_type, x_decl.get_type(), v); + expr y = m_lctx.mk_local_decl(ngen(), x_decl.get_user_name(), x_decl.get_type(), new_v); + optional e_y_opt = replace_fvar_with(m_st, m_lctx, e, x, y); + if (!e_y_opt) return none_expr(); /* Failed to produce type correct `e[y]` */ + expr e_y = *e_y_opt; + m_fvars.push_back(y); + return some_expr(visit(e_y, false)); + } + + /* + Given `e[x]` ``` - let x := cases_on m + let jp := fun z, let .... in e' + ``` + ==> + ``` + let jp' := fun z, let ... y := e' in e[y] + ``` + If `e'` is a `cases_on` application, we use `float_cases_on_core`. That is, + ``` + let jp := fun z, let ... in + cases_on m (fun y_1, let ... in e_1) ... (fun y_n, let ... in e_n) - in e + ``` + ==> + ``` + let jp := fun z, let ... in + cases_on m + (fun y_1, let ... y := e_1 in e[y]) + ... + (fun y_n, let ... y := e_n in e[y]) + ``` + + Remark: this method return `none` if the new join point cannot be created + due to type errors. */ + optional mk_new_join_point(expr const & x, expr const & e, expr const & jp, buffer & new_jps, expr_map & new_jp_cache) { + auto it = new_jp_cache.find(jp); + if (it != new_jp_cache.end()) + return some_expr(it->second); + local_decl jp_decl = m_lctx.get_local_decl(jp); + lean_assert(is_join_point_name(jp_decl.get_user_name())); + expr jp_val = *jp_decl.get_value(); + buffer zs; + unsigned saved_fvars_size = m_fvars.size(); + jp_val = visit(get_lambda_body(jp_val, zs), false); + expr new_jp_val; + expr e_y; + if (is_join_point_app(jp_val)) { + buffer jp2_args; + expr const & jp2 = get_app_args(jp_val, jp2_args); + optional new_jp2_opt = mk_new_join_point(x, e, jp2, new_jps, new_jp_cache); + if (!new_jp2_opt) return none_expr(); + e_y = mk_app(*new_jp2_opt, jp2_args); + } else if (is_cases_on_app(env(), jp_val)) { + optional e_y_opt = float_cases_on_core(x, e, jp_val, new_jps, new_jp_cache); + if (!e_y_opt) return none_expr(); + e_y = *e_y_opt; + } else { + optional e_y_opt = apply_at(x, e, jp_val); + if (!e_y_opt) return none_expr(); + e_y = *e_y_opt; + } + expr e_y_type = infer_type(e_y); + expr jp_val_type = infer_type(jp_val); + new_jp_val = mk_cast(e_y_type, jp_val_type, e_y); + new_jp_val = mk_let(saved_fvars_size, new_jp_val); + new_jp_val = m_lctx.mk_lambda(zs, new_jp_val); + mark_simplified(new_jp_val); + expr new_jp_var = m_lctx.mk_local_decl(ngen(), next_jp_name(), jp_decl.get_type(), new_jp_val); + new_jps.push_back(new_jp_var); + new_jp_cache.insert(mk_pair(jp, new_jp_var)); + return some_expr(new_jp_var); + } + + /* Given `e[x]` + ``` + cases_on m + (fun zs, let ... in e_1) + ... + (fun zs, let ... in e_n) ``` ==> ``` cases_on m - (fun y_1, let ... x := e_1 in e) + (fun zs, let ... y := e_1 in e[y]) ... - (fun y_n, let ... x := e_n in e) + (fun y_n, let ... y := e_n in e[y]) ``` */ - optional float_cases_on_core(expr const & fvar, expr const & c, expr e) { + optional float_cases_on_core(expr const & x, expr const & e, expr const & c, buffer & new_jps, expr_map & new_jp_cache) { lean_assert(is_cases_on_app(env(), c)); - local_decl fvar_decl = m_lctx.get_local_decl(fvar); + local_decl x_decl = m_lctx.get_local_decl(x); expr result_type = whnf_infer_type(e); buffer c_args; expr c_fn = get_app_args(c, c_args); @@ -389,17 +477,17 @@ class csimp_fn { /* Update motive */ { flet save_lctx(m_lctx, m_lctx); - buffer fvars; + buffer zs; expr motive = c_args[motive_idx]; expr motive_type = whnf_infer_type(motive); for (unsigned i = 0; i < nindices + 1; i++) { lean_assert(is_pi(motive_type)); - expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type)); - fvars.push_back(fvar); - motive_type = whnf(instantiate(binding_body(motive_type), fvar)); + expr z = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type)); + zs.push_back(z); + motive_type = whnf(instantiate(binding_body(motive_type), z)); } level result_lvl = sort_level(tc().ensure_type(result_type)); - expr new_motive = m_lctx.mk_lambda(fvars, result_type); + expr new_motive = m_lctx.mk_lambda(zs, result_type); c_args[motive_idx] = new_motive; /* We need to update the resultant universe. */ levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn))); @@ -407,26 +495,29 @@ class csimp_fn { } /* Update minor premises */ for (unsigned i = 0; i < nminors; i++) { - unsigned minor_idx = first_minor_idx + i; - expr minor = c_args[minor_idx]; - buffer xs; - flet save_lctx(m_lctx, m_lctx); - unsigned old_fvars_size = m_fvars.size(); - expr minor_val = visit(get_minor_body(minor, xs), false); - /* TODO(Leo): We need to preserve join points. */ - expr minor_val_type = infer_type(minor_val); - minor_val = mk_cast(minor_val_type, fvar_decl.get_type(), minor_val); - expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type(), minor_val); - optional new_minor_opt = replace_fvar_with(m_st, m_lctx, e, fvar, new_fvar); - if (!new_minor_opt) return none_expr(); /* Failed to produce type correct `new_minor` */ - expr new_minor = *new_minor_opt; - m_fvars.push_back(new_fvar); - new_minor = visit(new_minor, false); - expr new_minor_type = infer_type(new_minor); - new_minor = mk_cast(new_minor_type, result_type, new_minor); - new_minor = mk_let(old_fvars_size, new_minor); - new_minor = mk_minor_lambda(xs, new_minor); - c_args[minor_idx] = new_minor; + unsigned minor_idx = first_minor_idx + i; + expr minor = c_args[minor_idx]; + buffer zs; + unsigned saved_fvars_size = m_fvars.size(); + expr minor_val = visit(get_lambda_body(minor, zs), false); + expr new_minor; + if (is_join_point_app(minor_val)) { + buffer jp_args; + expr const & jp = get_app_args(minor_val, jp_args); + optional new_jp_opt = mk_new_join_point(x, e, jp, new_jps, new_jp_cache); + if (!new_jp_opt) return none_expr(); + expr new_jp = *new_jp_opt; + new_minor = mk_app(new_jp, jp_args); + } else { + optional e_y_opt = apply_at(x, e, minor_val); + if (!e_y_opt) return none_expr(); + expr e_y = *e_y_opt; + expr e_y_type = infer_type(e_y); + new_minor = mk_cast(e_y_type, result_type, e_y); + } + new_minor = mk_let(saved_fvars_size, new_minor); + new_minor = mk_minor_lambda(zs, new_minor); + c_args[minor_idx] = new_minor; } lean_trace(name({"compiler", "simp"}), tout() << "float_cases_on [" << get_lcnf_size(env(), e) << "]\n" << c << "\n----\n" << e << "\n=====>\n" @@ -436,11 +527,12 @@ class csimp_fn { /* Float cases transformation (see: `float_cases_on_core`). This version may create join points if `e` is big, or "good" join-points could not be created. */ - optional float_cases_on(expr const & fvar, expr const & c, expr const & e) { + optional float_cases_on(expr const & x, expr const & e, expr const & c, buffer & new_jps) { + expr_map new_jp_cache; unsigned saved_fvars_size = m_fvars.size(); local_ctx saved_lctx = m_lctx; - if (optional new_e = mk_join_point_float_cases_on(fvar, c, e)) { - if (optional r = float_cases_on_core(fvar, c, *new_e)) { + if (optional new_e = mk_join_point_float_cases_on(x, e, c, new_jps)) { + if (optional r = float_cases_on_core(x, *new_e, c, new_jps, new_jp_cache)) { return r; } } @@ -591,23 +683,23 @@ class csimp_fn { if (is_cases_on_app(env(), val)) { /* Float cases transformation. */ if (m_cfg.m_float_cases) { - unsigned m_fvars_init_size = m_fvars.size(); /* We first create a let-declaration with all entries that depends on the current `fvar` which is a cases_on application. */ buffer> entries_dep_curr; buffer> entries_ndep_curr; split_entries(entries, fvar, entries_dep_curr, entries_ndep_curr); expr new_e = mk_let(entries_dep_curr, e); - if (optional new_e_opt = float_cases_on(fvar, val, new_e)) { + buffer new_jps; + if (optional new_e_opt = float_cases_on(fvar, new_e, val, new_jps)) { e = *new_e_opt; /* Reset `e_fvars` and `entries_fvars`, we need to reconstruct them. */ e_fvars.clear(); entries_fvars.clear(); collect_used(e, e_fvars); /* Join points may have been generated, we move them to entries. */ entries.clear(); - while (m_fvars.size() > m_fvars_init_size) { - expr jp_fvar = m_fvars.back(); - m_fvars.pop_back(); + while (!new_jps.empty()) { + expr jp_fvar = new_jps.back(); + new_jps.pop_back(); local_decl jp_decl = m_lctx.get_local_decl(jp_fvar); expr jp_type = jp_decl.get_type(); expr jp_val = *jp_decl.get_value(); @@ -662,8 +754,12 @@ class csimp_fn { let_fvars.push_back(new_val); } else { name n = let_name(e); - if (is_internal_name(n) && !is_join_point_name(n)) - n = next_name(); + if (is_internal_name(n)) { + if (is_join_point_name(n)) + n = next_jp_name(); + else + n = next_name(); + } expr new_fvar = m_lctx.mk_local_decl(ngen(), n, new_type, new_val); let_fvars.push_back(new_fvar); m_fvars.push_back(new_fvar); @@ -801,11 +897,11 @@ class csimp_fn { expr minor = args[minor_idx]; unsigned saved_fvars_size = m_fvars.size(); flet save_lctx(m_lctx, m_lctx); - buffer xs; - minor = get_minor_body(minor, xs); + buffer zs; + minor = get_lambda_body(minor, zs); expr new_minor = visit(minor, false); new_minor = mk_let(saved_fvars_size, new_minor); - new_minor = mk_minor_lambda(xs, new_minor); + new_minor = mk_minor_lambda(zs, new_minor); args[minor_idx] = new_minor; } expr r = mk_app(c, args); @@ -943,9 +1039,12 @@ class csimp_fn { return beta_reduce(fn, e, is_let_val); } else if (is_cases_on_app(env(), fn) && m_cfg.m_float_cases_app) { lean_assert(is_fvar(get_app_fn(e))); + buffer new_jps; /* float cases_on from application */ - if (optional new_e = float_cases_on_core(get_app_fn(e), fn, e)) { + expr_map new_jp_cache; + if (optional new_e = float_cases_on_core(get_app_fn(e), e, fn, new_jps, new_jp_cache)) { mark_simplified(*new_e); + m_fvars.append(new_jps); return *new_e; } else { mark_simplified(e);