/* Copyright (c) 2018 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "runtime/flet.h" #include "kernel/instantiate.h" #include "kernel/expr_maps.h" #include "kernel/for_each_fn.h" #include "library/module.h" #include "library/trace.h" #include "library/compiler/util.h" namespace lean { /* Cache closed term => global constant. TODO(Leo): use the to be implemented new module system. */ struct extract_closed_ext : public environment_extension { typedef rb_expr_map cache; cache m_cache; }; struct extract_closed_ext_reg { unsigned m_ext_id; extract_closed_ext_reg() { m_ext_id = environment::register_extension(std::make_shared()); } }; static extract_closed_ext_reg * g_ext = nullptr; static extract_closed_ext const & get_extension(environment const & env) { return static_cast(env.get_extension(g_ext->m_ext_id)); } static environment update(environment const & env, extract_closed_ext const & ext) { return env.update(g_ext->m_ext_id, std::make_shared(ext)); } /* Support for old module manager. Remark: this code will be deleted in the future */ struct ec_cache_modification : public modification { LEAN_MODIFICATION("ecc") expr m_expr; name m_name; ec_cache_modification(expr const & e, name const & n): m_expr(e), m_name(n) {} void perform(environment & env) const override { extract_closed_ext ext = get_extension(env); ext.m_cache.insert(m_expr, m_name); env = update(env, ext); } void serialize(serializer & s) const override { s << m_expr << m_name; } static std::shared_ptr deserialize(deserializer & d) { expr e; name n; d >> e >> n; return std::make_shared(e, n); } }; class extract_closed_fn { environment m_env; extract_closed_ext m_ext; name_generator m_ngen; local_ctx m_lctx; buffer m_new_decls; name m_base_name; unsigned m_next_idx{1}; expr_map m_closed; environment const & env() const { return m_env; } name_generator & ngen() { return m_ngen; } name next_name() { name r = name(m_base_name, "_closed").append_after(m_next_idx); m_next_idx++; return r; } expr find(expr const & e) { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { if (optional v = decl->get_value()) { return find(*v); } } } else if (is_mdata(e)) { return find(mdata_expr(e)); } return e; } bool is_closed(expr e) { switch (e.kind()) { case expr_kind::MVar: lean_unreachable(); case expr_kind::Pi: lean_unreachable(); case expr_kind::Sort: lean_unreachable(); case expr_kind::Lit: return true; case expr_kind::BVar: return true; case expr_kind::Const: return true; case expr_kind::MData: return is_closed(mdata_expr(e)); case expr_kind::Proj: return is_closed(proj_expr(e)); default: break; }; auto it = m_closed.find(e); if (it != m_closed.end()) return it->second; bool r; switch (e.kind()) { case expr_kind::FVar: if (auto v = m_lctx.get_local_decl(e).get_value()) { r = is_closed(*v); } else { r = false; } break; case expr_kind::App: { buffer args; expr const & fn = get_app_args(e, args); r = true; if (!is_closed(fn)) { r = false; } else { for (expr const & arg : args) { if (!is_closed(arg)) { r = false; break; } } } break; } case expr_kind::Lambda: while (is_lambda(e)) { e = binding_body(e); } r = is_closed(e); break; case expr_kind::Let: r = true; while (is_let(e)) { if (!is_closed(let_value(e))) { r = false; break; } e = let_body(e); } if (r && !is_closed(e)) { r = false; } break; default: lean_unreachable(); } m_closed.insert(mk_pair(e, r)); return r; } expr visit_lambda(expr e) { flet save_lctx(m_lctx, m_lctx); buffer fvars; while (is_lambda(e)) { lean_assert(!has_loose_bvars(binding_domain(e))); expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e)); fvars.push_back(new_fvar); e = binding_body(e); } expr r = visit(instantiate_rev(e, fvars.size(), fvars.data()), true); return m_lctx.mk_lambda(fvars, r); } expr visit_let(expr e) { flet save_lctx(m_lctx, m_lctx); buffer fvars; while (is_let(e)) { lean_assert(!has_loose_bvars(let_type(e))); expr new_val = visit(instantiate_rev(let_value(e), fvars.size(), fvars.data()), false); expr new_fvar = m_lctx.mk_local_decl(ngen(), let_name(e), let_type(e), new_val); fvars.push_back(new_fvar); e = let_body(e); } expr r = visit(instantiate_rev(e, fvars.size(), fvars.data()), true); return m_lctx.mk_lambda(fvars, r); } bool is_neutral_constructor_app(expr const & e) { if (!is_constructor_app(env(), e)) return false; buffer args; get_app_args(e, args); for (expr const & arg : args) { if (!is_enf_neutral(arg)) return false; } return true; } void collect_deps(expr e, name_set & collected, buffer & fvars) { buffer todo; while (true) { for_each(e, [&](expr const & x, unsigned) { if (!has_fvar(x)) return false; if (is_fvar(x) && !collected.contains(fvar_name(x))) { collected.insert(fvar_name(x)); optional v = m_lctx.get_local_decl(x).get_value(); lean_assert(v); fvars.push_back(x); todo.push_back(*v); } return true; }); if (todo.empty()) return; e = todo.back(); todo.pop_back(); } } void collect_deps(expr e, buffer & fvars) { name_set collected; collect_deps(e, collected, fvars); sort_fvars(m_lctx, fvars); } bool arity_eq_0(name c) { c = mk_cstage2_name(c); optional info = env().find(c); if (!info || !info->is_definition()) return false; return !is_lambda(info->get_value()); } expr mk_aux_constant(expr const & e0) { expr e = find(e0); if (is_enf_neutral(e) || is_enf_unreachable(e)) { return e0; } if (is_constant(e) && arity_eq_0(const_name(e))) { /* Remarr: if a constant `C` has arity > 0, then it is worth creating a new constant with arity 0 that just returns `C`. In this way, we cache the closure allocation. To implement this optimization we need to first store the definitions after erasure. */ return e0; } if (is_neutral_constructor_app(e)) { /* We don't create auxiliary constants for constructor applications such as: `none ◾` and `list.nil ◾` */ return e0; } if (is_lit(e) && lit_value(e).kind() == literal_kind::Nat && lit_value(e).get_nat().is_small()) { /* We don't create auxiliary constants for small nat literals. Reason: they are cheap. */ return e0; } buffer fvars; collect_deps(e, fvars); e = m_lctx.mk_lambda(fvars, e); lean_assert(!has_loose_bvars(e)); if (name const * c = m_ext.m_cache.find(e)) { return mk_constant(*c); } name c = next_name(); m_new_decls.push_back(comp_decl(c, e)); m_ext.m_cache.insert(e, c); m_env = module::add(env(), std::make_shared(e, c)); return mk_constant(c); } expr visit_app(expr const & e, bool root) { buffer args; expr const & fn = get_app_args(e, args); bool is_cases = is_cases_on_app(env(), fn); for (unsigned i = 0; i < args.size(); i++) { /* arg is a "root" iff `e` is a cases, and `arg` is a minor premise */ args[i] = visit(args[i], is_cases && i > 0); } expr r = mk_app(fn, args); bool c = is_closed(r); if (root && c) { return mk_aux_constant(r); } else if (!c && !is_cases) { /* Create auxiliary constants for closed arguments. */ bool modified = false; for (expr & arg : args) { if (is_closed(arg)) { expr new_arg = mk_aux_constant(arg); if (new_arg != arg) modified = true; arg = new_arg; } } return modified ? mk_app(fn, args) : r; } else { return r; } } expr visit_atom(expr const & e, bool root) { if (!root) return e; return mk_aux_constant(e); } expr visit(expr const & e, bool root) { switch (e.kind()) { case expr_kind::Lit: return visit_atom(e, root); case expr_kind::Const: return visit_atom(e, root); case expr_kind::App: return visit_app(e, root); case expr_kind::Lambda: return visit_lambda(e); case expr_kind::Let: return visit_let(e); default: return e; } } public: extract_closed_fn(environment const & env): m_env(env), m_ext(get_extension(env)) { } pair operator()(comp_decl const & d) { expr v = d.snd(); if (!is_lambda(v)) { /* `d` already has arity 0. */ return mk_pair(env(), comp_decls(d)); } m_base_name = d.fst(); expr new_v = visit(v, true); comp_decl new_d(d.fst(), new_v); environment new_env = update(env(), m_ext); return mk_pair(new_env, comp_decls(new_d, comp_decls(m_new_decls))); } }; pair extract_closed_core(environment const & env, comp_decl const & d) { return extract_closed_fn(env)(d); } pair extract_closed(environment env, comp_decls const & ds) { comp_decls r; for (comp_decl const & d : ds) { comp_decls new_ds; std::tie(env, new_ds) = extract_closed_core(env, d); r = append(r, new_ds); } return mk_pair(env, r); } void initialize_extract_closed() { g_ext = new extract_closed_ext_reg(); ec_cache_modification::init(); } void finalize_extract_closed() { ec_cache_modification::finalize(); delete g_ext; } }