From f222dc7cca36aeb4d5bfc3d879aa1bd893abe25c Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 22 Apr 2019 13:33:53 -0700 Subject: [PATCH] feat(library/compiler): destructive updates for `{x with ...}` expressions --- src/library/compiler/CMakeLists.txt | 3 +- src/library/compiler/compiler.cpp | 4 + src/library/compiler/llnf.cpp | 32 +--- src/library/compiler/struct_cases_on.cpp | 203 +++++++++++++++++++++++ src/library/compiler/struct_cases_on.h | 54 ++++++ src/library/compiler/util.cpp | 55 +++++- src/library/compiler/util.h | 10 ++ 7 files changed, 320 insertions(+), 41 deletions(-) create mode 100644 src/library/compiler/struct_cases_on.cpp create mode 100644 src/library/compiler/struct_cases_on.h diff --git a/src/library/compiler/CMakeLists.txt b/src/library/compiler/CMakeLists.txt index 1cced5bdf4..8bda0e0f39 100644 --- a/src/library/compiler/CMakeLists.txt +++ b/src/library/compiler/CMakeLists.txt @@ -4,4 +4,5 @@ add_library(compiler OBJECT emit_bytecode.cpp init_module.cpp ## New extract_closed.cpp simp_app_args.cpp llnf.cpp ll_infer_type.cpp reduce_arity.cpp closed_term_cache.cpp name_mangling.cpp emit_cpp.cpp export_attribute.cpp llnf_code.cpp extern_attribute.cpp - borrowed_annotation.cpp init_attribute.cpp eager_lambda_lifting.cpp) + borrowed_annotation.cpp init_attribute.cpp eager_lambda_lifting.cpp + struct_cases_on.cpp) diff --git a/src/library/compiler/compiler.cpp b/src/library/compiler/compiler.cpp index 046ab1bf96..b21ee90455 100644 --- a/src/library/compiler/compiler.cpp +++ b/src/library/compiler/compiler.cpp @@ -27,6 +27,7 @@ Author: Leonardo de Moura #include "library/compiler/llnf_code.h" #include "library/compiler/export_attribute.h" #include "library/compiler/extern_attribute.h" +#include "library/compiler/struct_cases_on.h" namespace lean { static name * g_codegen = nullptr; @@ -215,6 +216,8 @@ environment compile(environment const & env, options const & opts, names cs) { trace_compiler(name({"compiler", "elim_dead_let"}), ds); ds = apply(erase_irrelevant, new_env, ds); trace_compiler(name({"compiler", "erase_irrelevant"}), ds); + ds = apply(struct_cases_on, new_env, ds); + trace_compiler(name({"compiler", "struct_cases_on"}), ds); ds = apply(esimp, new_env, ds); trace_compiler(name({"compiler", "simp"}), ds); ds = reduce_arity(new_env, ds); @@ -269,6 +272,7 @@ void initialize_compiler() { register_trace_class({"compiler", "extract_closed"}); register_trace_class({"compiler", "reduce_arity"}); register_trace_class({"compiler", "simp_app_args"}); + register_trace_class({"compiler", "struct_cases_on"}); register_trace_class({"compiler", "llnf"}); register_trace_class({"compiler", "boxed"}); register_trace_class({"compiler", "optimize_bytecode"}); diff --git a/src/library/compiler/llnf.cpp b/src/library/compiler/llnf.cpp index 19f31f3cca..69f81c9c59 100644 --- a/src/library/compiler/llnf.cpp +++ b/src/library/compiler/llnf.cpp @@ -247,23 +247,6 @@ struct cnstr_info { } }; -static expr * g_usize = nullptr; -std::vector> * g_builtin_scalar_size = nullptr; - -static bool is_usize_type(expr const & e) { - return is_constant(e, get_usize_name()); -} - -static optional is_builtin_scalar(expr const & type) { - if (!is_constant(type)) return optional(); - for (pair const & p : *g_builtin_scalar_size) { - if (const_name(type) == p.first) { - return optional(p.second); - } - } - return optional(); -} - unsigned get_llnf_arity(environment const & env, name const & n) { /* First, try to infer arity from `_cstage2` auxiliary definition. */ name c = mk_cstage2_name(n); @@ -297,12 +280,6 @@ static bool uses_borrowed(environment const & env, name const & n) { return borrowed_res; } -static optional is_enum_type(environment const & env, expr const & type) { - expr const & I = get_app_fn(type); - if (!is_constant(I)) return optional(); - return is_enum_type(env, const_name(I)); -} - static void get_cnstr_info_core(type_checker::state & st, bool unboxed, name const & n, buffer & result) { environment const & env = st.env(); constant_info info = env.get(n); @@ -1654,7 +1631,7 @@ class explicit_boxing_fn { expr visit_uset(expr const & fn, buffer & args) { lean_assert(args.size() == 2); - args[1] = cast_if_needed(args[1], get_arg_type(args[1]), *g_usize); + args[1] = cast_if_needed(args[1], get_arg_type(args[1]), mk_usize_type()); return mk_app(fn, args); } @@ -2499,7 +2476,6 @@ pair to_llnf(environment const & env, comp_decls const } void initialize_llnf() { - g_usize = new expr(mk_constant(get_usize_name())); g_apply = new expr(mk_constant("_apply")); g_closure = new expr(mk_constant("_closure")); g_reuse = new name("_reuse"); @@ -2515,15 +2491,9 @@ void initialize_llnf() { g_inc = new expr(mk_constant("_inc")); g_dec = new expr(mk_constant("_dec")); register_trace_class({"compiler", "lambda_pure"}); - g_builtin_scalar_size = new std::vector>(); - g_builtin_scalar_size->emplace_back(get_uint8_name(), 1); - g_builtin_scalar_size->emplace_back(get_uint16_name(), 2); - g_builtin_scalar_size->emplace_back(get_uint32_name(), 4); - g_builtin_scalar_size->emplace_back(get_uint64_name(), 8); } void finalize_llnf() { - delete g_usize; delete g_closure; delete g_apply; delete g_reuse; diff --git a/src/library/compiler/struct_cases_on.cpp b/src/library/compiler/struct_cases_on.cpp new file mode 100644 index 0000000000..17c702d04f --- /dev/null +++ b/src/library/compiler/struct_cases_on.cpp @@ -0,0 +1,203 @@ +/* +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#include "runtime/flet.h" +#include "kernel/instantiate.h" +#include "kernel/type_checker.h" +#include "library/trace.h" +#include "library/suffixes.h" +#include "library/compiler/util.h" + +namespace lean { +class struct_cases_on_fn { + type_checker::state m_st; + local_ctx m_lctx; + name_set m_scrutinies; + name m_fld{"_d"}; + unsigned m_next_idx{1}; + + environment const & env() { return m_st.env(); } + + name_generator & ngen() { return m_st.ngen(); } + + name next_field_name() { + name r = m_fld.append_after(m_next_idx); + m_next_idx++; + return r; + } + + expr visit_cases(expr const & e) { + flet save(m_scrutinies, m_scrutinies); + buffer args; + expr const & c = get_app_args(e, args); + expr const & major = args[0]; + if (is_fvar(major)) + m_scrutinies.insert(fvar_name(major)); + for (unsigned i = 1; i < args.size(); i++) { + args[i] = visit(args[i]); + } + return mk_app(c, args); + } + + expr visit_app(expr const & e) { + if (is_cases_on_app(env(), e)) { + return visit_cases(e); + } else { + return e; + } + } + + expr visit_lambda(expr e) { + buffer fvars; + while (is_lambda(e)) { + lean_assert(!has_loose_bvars(binding_domain(e))); + expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e), binding_info(e)); + fvars.push_back(new_fvar); + e = binding_body(e); + } + e = instantiate_rev(e, fvars.size(), fvars.data()); + e = visit(e); + return m_lctx.mk_lambda(fvars, e); + } + + bool is_candidate(expr const & rhs) { + if (!is_proj(rhs)) return false; + expr const & s = proj_expr(rhs); + if (!is_fvar(s)) return false; + if (m_scrutinies.contains(fvar_name(s))) return false; + return true; + } + + /* Return true iff `e` is a constructor application of inductive type `S_name` and containg `x`. */ + bool is_ctor_of(expr const & e, name const & S_name, expr const & x) { + if (!is_constructor_app(env(), e)) return false; + buffer args; + expr const & k = get_app_args(e, args); + lean_assert(is_constant(k)); + constructor_val k_val = env().get(const_name(k)).to_constructor_val(); + if (k_val.get_induct() != S_name) return false; + for (unsigned i = k_val.get_nparams(); i < args.size(); i++) { + if (args[i] == x) + return true; + } + return false; + } + + /* Return true iff `e` contains a constructor application of inductive type `S_name` and containg `x`. */ + bool has_ctor_with(expr e, name const & S_name, expr const & x) { + while (is_let(e)) { + if (is_ctor_of(let_value(e), S_name, x)) + return true; + e = let_body(e); + } + if (is_cases_on_app(env(), e)) { + buffer args; + get_app_args(e, args); + for (unsigned i = 1; i < args.size(); i++) { + if (has_ctor_with(args[i], S_name, x)) + return true; + } + return false; + } else { + return is_ctor_of(e, S_name, x); + } + } + + static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer & result) { + environment const & env = st.env(); + constant_info info = env.get(S_name); + lean_assert(info.is_inductive()); + inductive_val I_val = info.to_inductive_val(); + lean_assert(length(I_val.get_cnstrs()) == 1); + constant_info ctor_info = env.get(head(I_val.get_cnstrs())); + expr type = ctor_info.get_type(); + unsigned nparams = I_val.get_nparams(); + local_ctx lctx; + buffer telescope; + to_telescope(env, lctx, st.ngen(), type, telescope); + lean_assert(telescope.size() >= nparams); + for (unsigned i = nparams; i < telescope.size(); i++) { + expr ftype = lctx.get_type(telescope[i]); + if (is_irrelevant_type(st, lctx, ftype)) { + result.push_back(mk_enf_neutral_type()); + } else { + type_checker tc(st, lctx); + ftype = tc.whnf(ftype); + if (is_usize_type(ftype)) { + result.push_back(ftype); + } else if (is_builtin_scalar(ftype)) { + result.push_back(ftype); + } else if (optional sz = is_enum_type(env, ftype)) { + optional uint = to_uint_type(*sz); + if (!uint) throw exception("code generation failed, enumeration type is too big"); + result.push_back(*uint); + } else { + result.push_back(mk_enf_object_type()); + } + } + } + } + + expr visit_let(expr e) { + buffer fvars; + while (is_let(e)) { + lean_assert(!has_loose_bvars(let_type(e))); + expr type = let_type(e); + expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); + name n = let_name(e); + e = let_body(e); + expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val); + fvars.push_back(new_fvar); + if (is_candidate(val)) { + lean_assert(is_proj(val)); + lean_assert(proj_idx(val).is_small()); + name const & S_name = proj_sname(val); + e = instantiate_rev(e, fvars.size(), fvars.data()); + if (has_ctor_with(e, S_name, new_fvar)) { + /* Introduce a casesOn application. */ + e = m_lctx.mk_lambda(new_fvar, e); + fvars.pop_back(); + expr major = proj_expr(val); + buffer field_types; + get_struct_field_types(m_st, S_name, field_types); + unsigned i = field_types.size(); + while (i > 0) { + --i; + e = mk_lambda(next_field_name(), field_types[i], e); + } + e = mk_app(mk_constant(name(S_name, g_cases_on)), major, e); + } + e = visit(e); + return m_lctx.mk_lambda(fvars, e); + } + } + e = visit(instantiate_rev(e, fvars.size(), fvars.data())); + return m_lctx.mk_lambda(fvars, e); + } + + expr visit(expr const & e) { + switch (e.kind()) { + case expr_kind::App: return visit_app(e); + case expr_kind::Lambda: return visit_lambda(e); + case expr_kind::Let: return visit_let(e); + default: return e; + } + } + +public: + struct_cases_on_fn(environment const & env): + m_st(env) { + } + + expr operator()(expr const & e) { + return visit(e); + } +}; + +expr struct_cases_on(environment const & env, expr const & e) { + return struct_cases_on_fn(env)(e); +} +} diff --git a/src/library/compiler/struct_cases_on.h b/src/library/compiler/struct_cases_on.h new file mode 100644 index 0000000000..5b534d95e3 --- /dev/null +++ b/src/library/compiler/struct_cases_on.h @@ -0,0 +1,54 @@ +/* +Copyright (c) 2019 Microsoft Corporation. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. + +Author: Leonardo de Moura +*/ +#pragma once +#include "kernel/environment.h" + +namespace lean { +/* Insert `S.casesOn` applications for a structure `S` when + 1- There is a constructor application `S.mk ... x ...`, and + 2- `x := y.i`, and + 3- There is no `S.casesOn y ...` + + This transformation is useful because the `reset/reuse` insertion + procedure uses `casesOn` applications as a guide. + Moreover, Lean structure update expressions are not compiled using + `casesOn` applicactions. + + Example: given + ``` + fun x, + let y_1 := x.1 in + let y_2 := 0 in + (y_1, y_2) + ``` + this function returns + ``` + fun x, + Prod.casesOn x + (fun fst snd, + let y_1 := x.1 in + let y_2 := 0 in + (y_1, y_2)) + ``` + Note that, we rely on the simplifier (csimp.cpp) to replace `x.1` with `fst`. + + Remark: this function assumes we have already erased irrelevant information. + + Remark: we have considered compiling the `{ x with ... }` expressions using `casesOn`, but + we loose useful definitional equalities. In the encoding we use, + `{x with field1 := v1, field2 := v2}.field1` is definitional equal to `v1`. + If we compile this expression using `casesOn`, we would have + ``` + (match x with + | {field1 := _, field2 := _, field3 := v3} := {field1 := v1, field2 := v2, field3 := v3}).field1 + ``` + as is only definitionally equal to `v1` IF `x` is definitionally equal to a constructor application. + The missing definitional equalities is problematic. For example, the whole algebraic hierarchy + in Lean relies on them. +*/ +expr struct_cases_on(environment const & env, expr const & e); +} diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 2700778f5e..1636d96091 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -641,16 +641,52 @@ bool lcnf_check_let_decls(environment const & env, comp_decls const & ds) { return true; } +// ======================================= +// UInt and USize helper functions + +std::vector> * g_builtin_scalar_size = nullptr; + +expr mk_usize_type() { + return *g_usize; +} + +bool is_usize_type(expr const & e) { + return is_constant(e, get_usize_name()); +} + +optional is_builtin_scalar(expr const & type) { + if (!is_constant(type)) return optional(); + for (pair const & p : *g_builtin_scalar_size) { + if (const_name(type) == p.first) { + return optional(p.second); + } + } + return optional(); +} + +optional is_enum_type(environment const & env, expr const & type) { + expr const & I = get_app_fn(type); + if (!is_constant(I)) return optional(); + return is_enum_type(env, const_name(I)); +} + +// ======================================= + void initialize_compiler_util() { - g_neutral_expr = new expr(mk_constant("_neutral")); - g_unreachable_expr = new expr(mk_constant("_unreachable")); - g_object_type = new expr(mk_constant("_obj")); - g_void_type = new expr(mk_constant("_void")); - g_usize = new expr(mk_constant(get_usize_name())); - g_uint8 = new expr(mk_constant(get_uint8_name())); - g_uint16 = new expr(mk_constant(get_uint16_name())); - g_uint32 = new expr(mk_constant(get_uint32_name())); - g_uint64 = new expr(mk_constant(get_uint64_name())); + g_neutral_expr = new expr(mk_constant("_neutral")); + g_unreachable_expr = new expr(mk_constant("_unreachable")); + g_object_type = new expr(mk_constant("_obj")); + g_void_type = new expr(mk_constant("_void")); + g_usize = new expr(mk_constant(get_usize_name())); + g_uint8 = new expr(mk_constant(get_uint8_name())); + g_uint16 = new expr(mk_constant(get_uint16_name())); + g_uint32 = new expr(mk_constant(get_uint32_name())); + g_uint64 = new expr(mk_constant(get_uint64_name())); + g_builtin_scalar_size = new std::vector>(); + g_builtin_scalar_size->emplace_back(get_uint8_name(), 1); + g_builtin_scalar_size->emplace_back(get_uint16_name(), 2); + g_builtin_scalar_size->emplace_back(get_uint32_name(), 4); + g_builtin_scalar_size->emplace_back(get_uint64_name(), 8); register_system_attribute(basic_attribute::with_check( "inline", "mark definition to always be inlined", @@ -695,5 +731,6 @@ void finalize_compiler_util() { delete g_uint16; delete g_uint32; delete g_uint64; + delete g_builtin_scalar_size; } } diff --git a/src/library/compiler/util.h b/src/library/compiler/util.h index 75c0ec6058..35197a32a5 100644 --- a/src/library/compiler/util.h +++ b/src/library/compiler/util.h @@ -173,6 +173,16 @@ optional mk_enf_fix_core(unsigned n); bool lcnf_check_let_decls(environment const & env, comp_decl const & d); bool lcnf_check_let_decls(environment const & env, comp_decls const & ds); +// ======================================= +// UInt and USize helper functions + +expr mk_usize_type(); +bool is_usize_type(expr const & e); +optional is_builtin_scalar(expr const & type); +optional is_enum_type(environment const & env, expr const & type); + +// ======================================= + void initialize_compiler_util(); void finalize_compiler_util(); }