diff --git a/src/library/compiler/csimp.cpp b/src/library/compiler/csimp.cpp index 88cc315e26..12015afa68 100644 --- a/src/library/compiler/csimp.cpp +++ b/src/library/compiler/csimp.cpp @@ -30,12 +30,25 @@ class csimp_fn { name_generator & ngen() { return m_st.ngen(); } + /* Very simple predicate used to decide whether we should inline joint-points or not. + TODO(Leo): improve */ + bool is_small(expr const & e) const { + if (is_app(e) && !is_cases_on_app(env(), e)) + return true; + if (is_lambda(e)) + return is_small(binding_body(e)); + return false; + } + expr find(expr const & e, bool skip_mdata = true) const { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { - if (optional v = decl->get_value()) + if (optional v = decl->get_value()) { if (!is_join_point_name(decl->get_user_name())) return find(*v, skip_mdata); + else if (is_small(*v)) + return find(*v, skip_mdata); + } } } else if (is_mdata(e) && skip_mdata) { return find(mdata_expr(e), true); @@ -43,6 +56,8 @@ class csimp_fn { return e; } + type_checker tc() { return type_checker(m_st, m_lctx); } + expr infer_type(expr const & e) { return type_checker(m_st, m_lctx).infer(e); } expr whnf(expr const & e) { return type_checker(m_st, m_lctx).whnf(e); } @@ -166,7 +181,7 @@ class csimp_fn { The main `mk_let` does not modify the user-name/type, but we store them at `entries` just for convenience and to avoid additional accesses to `m_lctx`. - Remark: the buffers `fvars` and `entries` are reverted by this method. */ + Remark: `fvars` and `entries` are not modified. */ expr mk_let(buffer & fvars, buffer> & entries, expr e) { lean_assert(fvars.size() == entries.size()); std::reverse(fvars.begin(), fvars.end()); @@ -179,9 +194,94 @@ class csimp_fn { expr new_type = abstract(std::get<1>(entries[i]), i, fvars.data()); e = ::lean::mk_let(std::get<0>(entries[i]), new_type, new_value, e); } + /* Restore `fvars` and `entries` */ + std::reverse(fvars.begin(), fvars.end()); + std::reverse(entries.begin(), entries.end()); return e; } + /* Float cases transformation + ``` + let x := cases_on m (fun y_1, let ... in e_1) + ... + (fun y_n, let ... in e_n)) + in e + ``` + ==> + ``` + cases_on m (fun y_1, let ... x := e_1 in e) + ... + (fun y_n, let ... x := e_n in e) + ``` */ + optional float_cases_on(expr const & fvar, expr const & c, expr const & e) { + lean_assert(is_cases_on_app(env(), c)); + local_decl fvar_decl = m_lctx.get_local_decl(fvar); + expr result_type = whnf_infer_type(e); + buffer c_args; + expr c_fn = get_app_args(c, c_args); + inductive_val I_val = env().get(const_name(c_fn).get_prefix()).to_inductive_val(); + unsigned motive_idx = I_val.get_nparams(); + unsigned first_index = motive_idx + 1; + unsigned nindices = I_val.get_nindices(); + unsigned major_idx = first_index + nindices; + unsigned first_minor_idx = major_idx + 1; + unsigned nminors = I_val.get_ncnstrs(); + /* Update motive */ + { + flet save_lctx(m_lctx, m_lctx); + buffer fvars; + expr motive = c_args[motive_idx]; + expr motive_type = whnf_infer_type(motive); + for (unsigned i = 0; i < nindices + 1; i++) { + lean_assert(is_pi(motive_type)); + expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type)); + fvars.push_back(fvar); + motive_type = whnf(instantiate(binding_body(motive_type), fvar)); + } + level result_lvl = sort_level(tc().ensure_type(result_type)); + expr new_motive = m_lctx.mk_lambda(fvars, result_type); + c_args[motive_idx] = new_motive; + /* We need to update the resultant universe. */ + levels new_cases_lvls = levels(result_lvl, tail(const_levels(c_fn))); + c_fn = update_constant(c_fn, new_cases_lvls); + } + flet save_lctx(m_lctx, m_lctx); + unsigned saved_fvars_size = m_fvars.size(); + /* Update minor premises */ + for (unsigned i = 0; i < nminors; i++) { + unsigned minor_idx = first_minor_idx + i; + expr minor = c_args[minor_idx]; + buffer minor_fvars; + unsigned old_fvars_size = m_fvars.size(); + while (is_lambda(minor)) { + expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data()); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor)); + minor_fvars.push_back(new_fvar); + minor = binding_body(minor); + } + expr minor_val = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data()), true); + /* TODO(Leo): We need to preserve join points. */ + expr minor_val_type = infer_type(minor_val); + minor_val = mk_cast(minor_val_type, fvar_decl.get_type(), minor_val); + expr new_fvar = m_lctx.mk_local_decl(ngen(), fvar_decl.get_user_name(), fvar_decl.get_type(), minor_val); + m_fvars.push_back(new_fvar); + expr new_minor; + if (optional new_minor_opt = replace_fvar_with(m_st, m_lctx, e, fvar, new_fvar)) { + new_minor = *new_minor_opt; + } else { + m_fvars.resize(saved_fvars_size); + return none_expr(); /* Failed to produce type correct `new_minor` */ + } + new_minor = visit(new_minor, false); + expr new_minor_type = infer_type(new_minor); + new_minor = mk_cast(new_minor_type, result_type, new_minor); + new_minor = mk_let(old_fvars_size, new_minor); + new_minor = m_lctx.mk_lambda(minor_fvars, new_minor); + c_args[minor_idx] = new_minor; + } + return some_expr(mk_app(c_fn, c_args)); + } + /* Create a let-expression with body `e`, and all "used" let-declarations `m_fvars[i]` for `i in [old_fvars_size, m_fvars.size)`. @@ -229,9 +329,25 @@ class csimp_fn { continue; } - if (is_cases_on_app(env(), val)) { - // TODO(Leo); - } else if (e_is_cases && used_in_e) { + /* TODO(Leo): enable "float cases" as soon as it has descent performance */ + if (false && is_cases_on_app(env(), val)) { + /* Float cases transformation. */ + DEBUG_CODE(unsigned saved_fvars_size = m_fvars.size();); + expr new_e = mk_let(fvars, entries, e); + /* TODO(Leo): create new joint point if `e` is too big */ + if (optional new_e_opt = float_cases_on(fvar, val, new_e)) { + e = *new_e_opt; + fvars.clear(); entries.clear(); + e_fvars.clear(); entries_fvars.clear(); + collect_used(e, e_fvars); + lean_assert(m_fvars.size() == saved_fvars_size); + continue; + } else { + lean_assert(m_fvars.size() == saved_fvars_size); + } + } + + if (e_is_cases && used_in_e) { optional minor_idx = used_in_one_minor(e, fvar); if (minor_idx && !used_in_entries) { /* If fvar is only used in only one minor declaration, @@ -249,6 +365,7 @@ class csimp_fn { continue; } } + fvars.push_back(fvar); collect_used(type, entries_fvars); collect_used(val, entries_fvars); @@ -501,78 +618,6 @@ class csimp_fn { } } - // expr distrib_app_cases(expr const & fn, expr const & e) { - // lean_assert(is_cases_on_app(env(), fn)); - // lean_assert(is_eqp(find(get_app_fn(e)), fn)); - // expr result_type = infer_type(e); - // buffer args; - // get_app_args(e, args); - // buffer cases_args; - // expr cases = get_app_args(fn, cases_args); - // lean_assert(is_constant(cases)); - // inductive_val I_val = env().get(const_name(cases).get_prefix()).to_inductive_val(); - // unsigned motive_idx = I_val.get_nparams(); - // unsigned first_index = motive_idx + 1; - // unsigned nindices = I_val.get_nindices(); - // unsigned major_idx = first_index + nindices; - // unsigned first_minor_idx = major_idx + 1; - // unsigned nminors = length(I_val.get_cnstrs()); - // /* Infer argument types */ - // buffer arg_types; - // { - // type_checker tc(m_st, m_lctx); - // for (expr const & arg : args) { - // arg_types.push_back(tc.infer(arg)); - // } - // } - // /* Update motive */ - // { - // flet save_lctx(m_lctx, m_lctx); - // buffer fvars; - // expr motive = cases_args[motive_idx]; - // expr motive_type = whnf_infer_type(motive); - // for (unsigned i = 0; i < nindices + 1; i++) { - // lean_assert(is_pi(motive_type)); - // expr fvar = m_lctx.mk_local_decl(ngen(), binding_name(motive_type), binding_domain(motive_type), binding_info(motive_type)); - // fvars.push_back(fvar); - // motive_type = whnf(instantiate(binding_body(motive_type), fvar)); - // } - // level result_lvl = sort_level(type_checker(env(), m_lctx).ensure_type(result_type)); - // expr new_motive = m_lctx.mk_lambda(fvars, result_type); - // cases_args[motive_idx] = new_motive; - // /* We need to update the resultant universe. */ - // levels new_cases_lvls = levels(result_lvl, tail(const_levels(cases))); - // cases = update_constant(cases, new_cases_lvls); - // } - // /* Update minor premises */ - // for (unsigned i = 0; i < nminors; i++) { - // unsigned minor_idx = first_minor_idx + i; - // expr minor = cases_args[minor_idx]; - // flet save_lctx(m_lctx, m_lctx); - // buffer minor_fvars; - // unsigned old_fvars_size = m_fvars.size(); - // while (is_lambda(minor)) { - // expr new_d = instantiate_rev(binding_domain(minor), minor_fvars.size(), minor_fvars.data()); - // expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(minor), new_d, binding_info(minor)); - // minor_fvars.push_back(new_fvar); - // minor = binding_body(minor); - // } - // expr new_minor = visit(instantiate_rev(minor, minor_fvars.size(), minor_fvars.data())); - // for (unsigned i = 0; i < args.size(); i++) { - // expr new_minor_type = whnf_infer_type(new_minor); - // lean_assert(is_pi(new_minor_type)); - // new_minor = mk_app(new_minor, mk_cast(arg_types[i], binding_domain(new_minor_type), args[i])); - // } - // new_minor = visit_let_value(new_minor); - // type_checker tc(m_st, m_lctx); - // new_minor = mk_cast(tc, tc.infer(new_minor), result_type, new_minor); - // new_minor = mk_let(old_fvars_size, new_minor); - // new_minor = m_lctx.mk_lambda(minor_fvars, new_minor); - // cases_args[minor_idx] = new_minor; - // } - // return mk_let_decl(mk_app(cases, cases_args)); - // } - expr reduce_lc_cast(expr const & e) { buffer args; expr const & cast_fn1 = get_app_args(e, args); diff --git a/tests/lean/run/new_compiler.lean b/tests/lean/run/new_compiler.lean index 54af257f63..a0e9b70eb9 100644 --- a/tests/lean/run/new_compiler.lean +++ b/tests/lean/run/new_compiler.lean @@ -2,7 +2,7 @@ import init.lean.parser.parsec import init.control.coroutine universes u v w r s -set_option trace.compiler.lcnf true +set_option trace.compiler.stage1 true -- set_option pp.implicit true set_option pp.binder_types false set_option pp.proofs true