feat(library/compiler): destructive updates for {x with ...} expressions

This commit is contained in:
Leonardo de Moura 2019-04-22 13:33:53 -07:00
parent ee0851921b
commit f222dc7cca
7 changed files with 320 additions and 41 deletions

View file

@ -4,4 +4,5 @@ add_library(compiler OBJECT emit_bytecode.cpp init_module.cpp ## New
extract_closed.cpp simp_app_args.cpp llnf.cpp ll_infer_type.cpp
reduce_arity.cpp closed_term_cache.cpp name_mangling.cpp
emit_cpp.cpp export_attribute.cpp llnf_code.cpp extern_attribute.cpp
borrowed_annotation.cpp init_attribute.cpp eager_lambda_lifting.cpp)
borrowed_annotation.cpp init_attribute.cpp eager_lambda_lifting.cpp
struct_cases_on.cpp)

View file

@ -27,6 +27,7 @@ Author: Leonardo de Moura
#include "library/compiler/llnf_code.h"
#include "library/compiler/export_attribute.h"
#include "library/compiler/extern_attribute.h"
#include "library/compiler/struct_cases_on.h"
namespace lean {
static name * g_codegen = nullptr;
@ -215,6 +216,8 @@ environment compile(environment const & env, options const & opts, names cs) {
trace_compiler(name({"compiler", "elim_dead_let"}), ds);
ds = apply(erase_irrelevant, new_env, ds);
trace_compiler(name({"compiler", "erase_irrelevant"}), ds);
ds = apply(struct_cases_on, new_env, ds);
trace_compiler(name({"compiler", "struct_cases_on"}), ds);
ds = apply(esimp, new_env, ds);
trace_compiler(name({"compiler", "simp"}), ds);
ds = reduce_arity(new_env, ds);
@ -269,6 +272,7 @@ void initialize_compiler() {
register_trace_class({"compiler", "extract_closed"});
register_trace_class({"compiler", "reduce_arity"});
register_trace_class({"compiler", "simp_app_args"});
register_trace_class({"compiler", "struct_cases_on"});
register_trace_class({"compiler", "llnf"});
register_trace_class({"compiler", "boxed"});
register_trace_class({"compiler", "optimize_bytecode"});

View file

@ -247,23 +247,6 @@ struct cnstr_info {
}
};
static expr * g_usize = nullptr;
std::vector<pair<name, unsigned>> * g_builtin_scalar_size = nullptr;
static bool is_usize_type(expr const & e) {
return is_constant(e, get_usize_name());
}
static optional<unsigned> is_builtin_scalar(expr const & type) {
if (!is_constant(type)) return optional<unsigned>();
for (pair<name, unsigned> const & p : *g_builtin_scalar_size) {
if (const_name(type) == p.first) {
return optional<unsigned>(p.second);
}
}
return optional<unsigned>();
}
unsigned get_llnf_arity(environment const & env, name const & n) {
/* First, try to infer arity from `_cstage2` auxiliary definition. */
name c = mk_cstage2_name(n);
@ -297,12 +280,6 @@ static bool uses_borrowed(environment const & env, name const & n) {
return borrowed_res;
}
static optional<unsigned> is_enum_type(environment const & env, expr const & type) {
expr const & I = get_app_fn(type);
if (!is_constant(I)) return optional<unsigned>();
return is_enum_type(env, const_name(I));
}
static void get_cnstr_info_core(type_checker::state & st, bool unboxed, name const & n, buffer<field_info> & result) {
environment const & env = st.env();
constant_info info = env.get(n);
@ -1654,7 +1631,7 @@ class explicit_boxing_fn {
expr visit_uset(expr const & fn, buffer<expr> & args) {
lean_assert(args.size() == 2);
args[1] = cast_if_needed(args[1], get_arg_type(args[1]), *g_usize);
args[1] = cast_if_needed(args[1], get_arg_type(args[1]), mk_usize_type());
return mk_app(fn, args);
}
@ -2499,7 +2476,6 @@ pair<environment, comp_decls> to_llnf(environment const & env, comp_decls const
}
void initialize_llnf() {
g_usize = new expr(mk_constant(get_usize_name()));
g_apply = new expr(mk_constant("_apply"));
g_closure = new expr(mk_constant("_closure"));
g_reuse = new name("_reuse");
@ -2515,15 +2491,9 @@ void initialize_llnf() {
g_inc = new expr(mk_constant("_inc"));
g_dec = new expr(mk_constant("_dec"));
register_trace_class({"compiler", "lambda_pure"});
g_builtin_scalar_size = new std::vector<pair<name, unsigned>>();
g_builtin_scalar_size->emplace_back(get_uint8_name(), 1);
g_builtin_scalar_size->emplace_back(get_uint16_name(), 2);
g_builtin_scalar_size->emplace_back(get_uint32_name(), 4);
g_builtin_scalar_size->emplace_back(get_uint64_name(), 8);
}
void finalize_llnf() {
delete g_usize;
delete g_closure;
delete g_apply;
delete g_reuse;

View file

@ -0,0 +1,203 @@
/*
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include "runtime/flet.h"
#include "kernel/instantiate.h"
#include "kernel/type_checker.h"
#include "library/trace.h"
#include "library/suffixes.h"
#include "library/compiler/util.h"
namespace lean {
class struct_cases_on_fn {
type_checker::state m_st;
local_ctx m_lctx;
name_set m_scrutinies;
name m_fld{"_d"};
unsigned m_next_idx{1};
environment const & env() { return m_st.env(); }
name_generator & ngen() { return m_st.ngen(); }
name next_field_name() {
name r = m_fld.append_after(m_next_idx);
m_next_idx++;
return r;
}
expr visit_cases(expr const & e) {
flet<name_set> save(m_scrutinies, m_scrutinies);
buffer<expr> args;
expr const & c = get_app_args(e, args);
expr const & major = args[0];
if (is_fvar(major))
m_scrutinies.insert(fvar_name(major));
for (unsigned i = 1; i < args.size(); i++) {
args[i] = visit(args[i]);
}
return mk_app(c, args);
}
expr visit_app(expr const & e) {
if (is_cases_on_app(env(), e)) {
return visit_cases(e);
} else {
return e;
}
}
expr visit_lambda(expr e) {
buffer<expr> fvars;
while (is_lambda(e)) {
lean_assert(!has_loose_bvars(binding_domain(e)));
expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e), binding_info(e));
fvars.push_back(new_fvar);
e = binding_body(e);
}
e = instantiate_rev(e, fvars.size(), fvars.data());
e = visit(e);
return m_lctx.mk_lambda(fvars, e);
}
bool is_candidate(expr const & rhs) {
if (!is_proj(rhs)) return false;
expr const & s = proj_expr(rhs);
if (!is_fvar(s)) return false;
if (m_scrutinies.contains(fvar_name(s))) return false;
return true;
}
/* Return true iff `e` is a constructor application of inductive type `S_name` and containg `x`. */
bool is_ctor_of(expr const & e, name const & S_name, expr const & x) {
if (!is_constructor_app(env(), e)) return false;
buffer<expr> args;
expr const & k = get_app_args(e, args);
lean_assert(is_constant(k));
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
if (k_val.get_induct() != S_name) return false;
for (unsigned i = k_val.get_nparams(); i < args.size(); i++) {
if (args[i] == x)
return true;
}
return false;
}
/* Return true iff `e` contains a constructor application of inductive type `S_name` and containg `x`. */
bool has_ctor_with(expr e, name const & S_name, expr const & x) {
while (is_let(e)) {
if (is_ctor_of(let_value(e), S_name, x))
return true;
e = let_body(e);
}
if (is_cases_on_app(env(), e)) {
buffer<expr> args;
get_app_args(e, args);
for (unsigned i = 1; i < args.size(); i++) {
if (has_ctor_with(args[i], S_name, x))
return true;
}
return false;
} else {
return is_ctor_of(e, S_name, x);
}
}
static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer<expr> & result) {
environment const & env = st.env();
constant_info info = env.get(S_name);
lean_assert(info.is_inductive());
inductive_val I_val = info.to_inductive_val();
lean_assert(length(I_val.get_cnstrs()) == 1);
constant_info ctor_info = env.get(head(I_val.get_cnstrs()));
expr type = ctor_info.get_type();
unsigned nparams = I_val.get_nparams();
local_ctx lctx;
buffer<expr> telescope;
to_telescope(env, lctx, st.ngen(), type, telescope);
lean_assert(telescope.size() >= nparams);
for (unsigned i = nparams; i < telescope.size(); i++) {
expr ftype = lctx.get_type(telescope[i]);
if (is_irrelevant_type(st, lctx, ftype)) {
result.push_back(mk_enf_neutral_type());
} else {
type_checker tc(st, lctx);
ftype = tc.whnf(ftype);
if (is_usize_type(ftype)) {
result.push_back(ftype);
} else if (is_builtin_scalar(ftype)) {
result.push_back(ftype);
} else if (optional<unsigned> sz = is_enum_type(env, ftype)) {
optional<expr> uint = to_uint_type(*sz);
if (!uint) throw exception("code generation failed, enumeration type is too big");
result.push_back(*uint);
} else {
result.push_back(mk_enf_object_type());
}
}
}
}
expr visit_let(expr e) {
buffer<expr> fvars;
while (is_let(e)) {
lean_assert(!has_loose_bvars(let_type(e)));
expr type = let_type(e);
expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data());
name n = let_name(e);
e = let_body(e);
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val);
fvars.push_back(new_fvar);
if (is_candidate(val)) {
lean_assert(is_proj(val));
lean_assert(proj_idx(val).is_small());
name const & S_name = proj_sname(val);
e = instantiate_rev(e, fvars.size(), fvars.data());
if (has_ctor_with(e, S_name, new_fvar)) {
/* Introduce a casesOn application. */
e = m_lctx.mk_lambda(new_fvar, e);
fvars.pop_back();
expr major = proj_expr(val);
buffer<expr> field_types;
get_struct_field_types(m_st, S_name, field_types);
unsigned i = field_types.size();
while (i > 0) {
--i;
e = mk_lambda(next_field_name(), field_types[i], e);
}
e = mk_app(mk_constant(name(S_name, g_cases_on)), major, e);
}
e = visit(e);
return m_lctx.mk_lambda(fvars, e);
}
}
e = visit(instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, e);
}
expr visit(expr const & e) {
switch (e.kind()) {
case expr_kind::App: return visit_app(e);
case expr_kind::Lambda: return visit_lambda(e);
case expr_kind::Let: return visit_let(e);
default: return e;
}
}
public:
struct_cases_on_fn(environment const & env):
m_st(env) {
}
expr operator()(expr const & e) {
return visit(e);
}
};
expr struct_cases_on(environment const & env, expr const & e) {
return struct_cases_on_fn(env)(e);
}
}

View file

@ -0,0 +1,54 @@
/*
Copyright (c) 2019 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#pragma once
#include "kernel/environment.h"
namespace lean {
/* Insert `S.casesOn` applications for a structure `S` when
1- There is a constructor application `S.mk ... x ...`, and
2- `x := y.i`, and
3- There is no `S.casesOn y ...`
This transformation is useful because the `reset/reuse` insertion
procedure uses `casesOn` applications as a guide.
Moreover, Lean structure update expressions are not compiled using
`casesOn` applicactions.
Example: given
```
fun x,
let y_1 := x.1 in
let y_2 := 0 in
(y_1, y_2)
```
this function returns
```
fun x,
Prod.casesOn x
(fun fst snd,
let y_1 := x.1 in
let y_2 := 0 in
(y_1, y_2))
```
Note that, we rely on the simplifier (csimp.cpp) to replace `x.1` with `fst`.
Remark: this function assumes we have already erased irrelevant information.
Remark: we have considered compiling the `{ x with ... }` expressions using `casesOn`, but
we loose useful definitional equalities. In the encoding we use,
`{x with field1 := v1, field2 := v2}.field1` is definitional equal to `v1`.
If we compile this expression using `casesOn`, we would have
```
(match x with
| {field1 := _, field2 := _, field3 := v3} := {field1 := v1, field2 := v2, field3 := v3}).field1
```
as is only definitionally equal to `v1` IF `x` is definitionally equal to a constructor application.
The missing definitional equalities is problematic. For example, the whole algebraic hierarchy
in Lean relies on them.
*/
expr struct_cases_on(environment const & env, expr const & e);
}

View file

@ -641,16 +641,52 @@ bool lcnf_check_let_decls(environment const & env, comp_decls const & ds) {
return true;
}
// =======================================
// UInt and USize helper functions
std::vector<pair<name, unsigned>> * g_builtin_scalar_size = nullptr;
expr mk_usize_type() {
return *g_usize;
}
bool is_usize_type(expr const & e) {
return is_constant(e, get_usize_name());
}
optional<unsigned> is_builtin_scalar(expr const & type) {
if (!is_constant(type)) return optional<unsigned>();
for (pair<name, unsigned> const & p : *g_builtin_scalar_size) {
if (const_name(type) == p.first) {
return optional<unsigned>(p.second);
}
}
return optional<unsigned>();
}
optional<unsigned> is_enum_type(environment const & env, expr const & type) {
expr const & I = get_app_fn(type);
if (!is_constant(I)) return optional<unsigned>();
return is_enum_type(env, const_name(I));
}
// =======================================
void initialize_compiler_util() {
g_neutral_expr = new expr(mk_constant("_neutral"));
g_unreachable_expr = new expr(mk_constant("_unreachable"));
g_object_type = new expr(mk_constant("_obj"));
g_void_type = new expr(mk_constant("_void"));
g_usize = new expr(mk_constant(get_usize_name()));
g_uint8 = new expr(mk_constant(get_uint8_name()));
g_uint16 = new expr(mk_constant(get_uint16_name()));
g_uint32 = new expr(mk_constant(get_uint32_name()));
g_uint64 = new expr(mk_constant(get_uint64_name()));
g_neutral_expr = new expr(mk_constant("_neutral"));
g_unreachable_expr = new expr(mk_constant("_unreachable"));
g_object_type = new expr(mk_constant("_obj"));
g_void_type = new expr(mk_constant("_void"));
g_usize = new expr(mk_constant(get_usize_name()));
g_uint8 = new expr(mk_constant(get_uint8_name()));
g_uint16 = new expr(mk_constant(get_uint16_name()));
g_uint32 = new expr(mk_constant(get_uint32_name()));
g_uint64 = new expr(mk_constant(get_uint64_name()));
g_builtin_scalar_size = new std::vector<pair<name, unsigned>>();
g_builtin_scalar_size->emplace_back(get_uint8_name(), 1);
g_builtin_scalar_size->emplace_back(get_uint16_name(), 2);
g_builtin_scalar_size->emplace_back(get_uint32_name(), 4);
g_builtin_scalar_size->emplace_back(get_uint64_name(), 8);
register_system_attribute(basic_attribute::with_check(
"inline", "mark definition to always be inlined",
@ -695,5 +731,6 @@ void finalize_compiler_util() {
delete g_uint16;
delete g_uint32;
delete g_uint64;
delete g_builtin_scalar_size;
}
}

View file

@ -173,6 +173,16 @@ optional<expr> mk_enf_fix_core(unsigned n);
bool lcnf_check_let_decls(environment const & env, comp_decl const & d);
bool lcnf_check_let_decls(environment const & env, comp_decls const & ds);
// =======================================
// UInt and USize helper functions
expr mk_usize_type();
bool is_usize_type(expr const & e);
optional<unsigned> is_builtin_scalar(expr const & type);
optional<unsigned> is_enum_type(environment const & env, expr const & type);
// =======================================
void initialize_compiler_util();
void finalize_compiler_util();
}