diff --git a/src/library/compiler/compiler.cpp b/src/library/compiler/compiler.cpp index ef752fd233..47c4de62fe 100644 --- a/src/library/compiler/compiler.cpp +++ b/src/library/compiler/compiler.cpp @@ -152,13 +152,13 @@ environment compile(environment const & env, options const & opts, names const & trace_compiler(name({"compiler", "stage2"}), ds); ds = apply(esimp, new_env, ds); trace_compiler(name({"compiler", "simp"}), ds); - ds = apply(simp_inductive, new_env, ds); /* TODO(Leo): llnf is not integrated yet. We are only using it here for debugging. */ auto to_llnf_box = [&](environment const & env, expr const & e) { return to_llnf(env, e, true); }; comp_decls aux_ds = apply(to_llnf_box, new_env, ds); trace_compiler(name({"compiler", "llnf"}), aux_ds); + ds = apply(simp_inductive, new_env, ds); trace_compiler(name({"compiler", "simplify_inductive"}), ds); new_env = emit_bytecode(new_env, ds); return new_env; diff --git a/src/library/compiler/llnf.cpp b/src/library/compiler/llnf.cpp index b789496f13..bf915ea890 100644 --- a/src/library/compiler/llnf.cpp +++ b/src/library/compiler/llnf.cpp @@ -4,13 +4,14 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ +#include #include +#include "runtime/sstream.h" #include "kernel/instantiate.h" #include "library/util.h" #include "library/compiler/util.h" namespace lean { -static expr * g_cases = nullptr; static name * g_cnstr = nullptr; static name * g_updt = nullptr; static name * g_updt_cidx = nullptr; @@ -18,19 +19,12 @@ static name * g_updt_u8 = nullptr; static name * g_updt_u16 = nullptr; static name * g_updt_u32 = nullptr; static name * g_updt_u64 = nullptr; +static name * g_proj = nullptr; static name * g_proj_u8 = nullptr; static name * g_proj_u16 = nullptr; static name * g_proj_u32 = nullptr; static name * g_proj_u64 = nullptr; -expr mk_llnf_cases() { - return *g_cases; -} - -bool is_llnf_cases(expr const & e) { - return e == *g_cases; -} - expr mk_llnf_cnstr(unsigned cidx, unsigned scalar_sz) { return mk_constant(name(name(*g_cnstr, cidx), scalar_sz)); } @@ -72,6 +66,9 @@ bool is_llnf_updt_u32(expr const & e, unsigned & offset) { return is_llnf_primit expr mk_llnf_updt_u64(unsigned offset) { return mk_constant(name(*g_updt_u64, offset)); } bool is_llnf_updt_u64(expr const & e, unsigned & offset) { return is_llnf_primitive(e, *g_updt_u64, offset); } +expr mk_llnf_proj(unsigned idx) { return mk_constant(name(*g_proj, idx)); } +bool is_llnf_proj(expr const & e, unsigned & idx) { return is_llnf_primitive(e, *g_proj, idx); } + expr mk_llnf_proj_u8(unsigned offset) { return mk_constant(name(*g_proj_u8, offset)); } bool is_llnf_proj_u8(expr const & e, unsigned & offset) { return is_llnf_primitive(e, *g_proj_u8, offset); } @@ -84,6 +81,10 @@ bool is_llnf_proj_u32(expr const & e, unsigned & offset) { return is_llnf_primit expr mk_llnf_proj_u64(unsigned offset) { return mk_constant(name(*g_proj_u64, offset)); } bool is_llnf_proj_u64(expr const & e, unsigned & offset) { return is_llnf_primitive(e, *g_proj_u64, offset); } +[[ noreturn ]] static void throw_unsupported_field_size() { + throw exception("code generation failed, unsupported field size"); +} + struct field_info { enum kind { Irrelevant, Object, Scalar }; kind m_kind; @@ -97,14 +98,34 @@ struct field_info { field_info():m_kind(Irrelevant), m_idx(0) {} field_info(unsigned idx):m_kind(Object), m_idx(idx) {} field_info(unsigned offset, unsigned sz):m_kind(Scalar), m_offset(offset), m_size(sz) {} + expr get_type() const { + if (m_kind == Scalar) { + switch (m_size) { + case 1: return mk_constant(get_uint8_name()); + case 2: return mk_constant(get_uint16_name()); + case 4: return mk_constant(get_uint32_name()); + case 8: return mk_constant(get_uint64_name()); + default: throw_unsupported_field_size(); + } + } else { + return mk_enf_object_type(); + } + } }; struct cnstr_info { unsigned m_cidx; list m_field_info; + unsigned m_num_objs; unsigned m_scalar_sz; - cnstr_info(unsigned cidx, list const & finfo, unsigned scalar_sz): - m_cidx(cidx), m_field_info(finfo), m_scalar_sz(scalar_sz) { + cnstr_info(unsigned cidx, list const & finfo): + m_cidx(cidx), m_field_info(finfo), m_num_objs(0), m_scalar_sz(0) { + for (field_info const & info : finfo) { + if (info.m_kind == field_info::Object) + m_num_objs++; + else if (info.m_kind == field_info::Scalar) + m_scalar_sz += info.m_size; + } } }; @@ -141,10 +162,10 @@ class to_llnf_fn { optional is_enum_type_core(name const & I) { constant_info info = env().get(I); - if (info.is_inductive()) return optional(); + if (!info.is_inductive()) return optional(); unsigned n = 0; for (name const & c : info.to_inductive_val().get_cnstrs()) { - if (!empty(get_cnstr_info(c).m_field_info)) + if (is_pi(env().get(c).get_type())) return optional(); if (n == std::numeric_limits::max()) return optional(); @@ -170,7 +191,7 @@ class to_llnf_fn { return r; } - unsigned get_cnstr_info_core(name const & n, buffer & result) { + void get_cnstr_info_core(name const & n, buffer & result) { constant_info info = env().get(n); lean_assert(info.is_constructor()); constructor_val val = info.to_constructor_val(); @@ -206,7 +227,6 @@ class to_llnf_fn { } } unsigned nobjs = next_idx; - unsigned scalar_sz = next_offset; if (m_unboxed) { /* Remark: scalar data is stored after object pointers */ for (field_info & info : result) { @@ -215,7 +235,6 @@ class to_llnf_fn { } } } - return scalar_sz; } cnstr_info get_cnstr_info(name const & n) { @@ -223,9 +242,9 @@ class to_llnf_fn { if (it != m_cnstr_info_cache.end()) return it->second; buffer finfos; - unsigned scalar_sz = get_cnstr_info_core(n, finfos); + get_cnstr_info_core(n, finfos); unsigned cidx = get_constructor_idx(env(), n); - cnstr_info r(cidx, to_list(finfos), scalar_sz); + cnstr_info r(cidx, to_list(finfos)); m_cnstr_info_cache.insert(mk_pair(n, r)); return r; } @@ -260,6 +279,7 @@ class to_llnf_fn { collect_used(val, used_fvars); used.push_back(x); } + std::reverse(used.begin(), used.end()); return m_lctx.mk_lambda(used, r); } @@ -302,19 +322,170 @@ class to_llnf_fn { return m_lctx.mk_lambda(binding_fvars, r); } + expr mk_let_decl(expr const & type, expr const & e) { + expr fvar = m_lctx.mk_local_decl(ngen(), next_name(), type, e); + m_fvars.push_back(fvar); + return fvar; + } + + expr mk_scalar_proj(expr const & major, unsigned size, unsigned offset) { + switch (size) { + case 1: + return mk_app(mk_llnf_proj_u8(offset), major); + case 2: + return mk_app(mk_llnf_proj_u16(offset), major); + case 4: + return mk_app(mk_llnf_proj_u32(offset), major); + case 8: + return mk_app(mk_llnf_proj_u64(offset), major); + default: + throw_unsupported_field_size(); + } + } + + expr mk_scalar_updt(expr const & major, unsigned size, unsigned offset, expr const & v) { + switch (size) { + case 1: + return mk_app(mk_llnf_updt_u8(offset), major, v); + case 2: + return mk_app(mk_llnf_updt_u16(offset), major, v); + case 4: + return mk_app(mk_llnf_updt_u32(offset), major, v); + case 8: + return mk_app(mk_llnf_updt_u64(offset), major, v); + default: + throw_unsupported_field_size(); + } + } + expr visit_cases(expr const & e) { - // TODO(Leo): - return e; + buffer args; + expr const & fn = get_app_args(e, args); + lean_assert(is_constant(fn)); + name const & I_name = const_name(fn).get_prefix(); + if (is_inductive_predicate(env(), I_name)) + throw exception(sstream() << "code generation failed, inductive predicate '" << I_name << "' is not supported"); + buffer cnames; + get_constructor_names(env(), I_name, cnames); + lean_assert(args.size() == cnames.size() + 1); + /* Process major premise */ + expr major = visit(args[0]); + args[0] = major; + expr reachable_case; + unsigned num_reachable = 0; + expr some_reachable; + /* Process minor premises */ + for (unsigned i = 0; i < cnames.size(); i++) { + unsigned saved_fvars_size = m_fvars.size(); + expr minor = args[i+1]; + cnstr_info cinfo = get_cnstr_info(cnames[i]); + unsigned next_idx = 0; + unsigned next_offset = cinfo.m_num_objs * sizeof(void*); + buffer fields; + for (field_info const & info : cinfo.m_field_info) { + lean_assert(is_lambda(minor)); + switch (info.m_kind) { + case field_info::Irrelevant: + fields.push_back(mk_enf_neutral()); + break; + case field_info::Object: + fields.push_back(mk_let_decl(mk_enf_object_type(), mk_app(mk_llnf_proj(next_idx), major))); + next_idx++; + break; + case field_info::Scalar: + fields.push_back(mk_let_decl(binding_domain(minor), mk_scalar_proj(major, info.m_size, next_offset))); + next_offset += info.m_size; + break; + } + minor = binding_body(minor); + } + minor = instantiate_rev(minor, fields.size(), fields.data()); + minor = visit(minor); + if (!is_enf_unreachable(minor)) { + num_reachable++; + minor = mk_let(saved_fvars_size, minor); + some_reachable = minor; + args[i+1] = minor; + } else { + args[i+1] = minor; + } + } + /* TODO(Leo): check whether all reachable cases are equal or not. */ + if (num_reachable == 0) { + return mk_enf_unreachable(); + } else if (num_reachable == 1) { + return some_reachable; + } else { + return mk_app(fn, args); + } } expr visit_constructor(expr const & e) { - // TODO(Leo): - return e; + buffer args; + expr const & k = get_app_args(e, args); + lean_assert(is_constant(k)); + constructor_val k_val = env().get(const_name(k)).to_constructor_val(); + cnstr_info k_info = get_cnstr_info(const_name(k)); + unsigned nparams = k_val.get_nparams(); + unsigned cidx = k_info.m_cidx; + buffer obj_args; + unsigned j = nparams; + for (field_info const & info : k_info.m_field_info) { + if (info.m_kind != field_info::Irrelevant) + args[j] = visit(args[j]); + + if (info.m_kind == field_info::Object) { + obj_args.push_back(args[j]); + } + j++; + } + expr r = mk_app(mk_llnf_cnstr(cidx, k_info.m_scalar_sz), obj_args); + j = nparams; + unsigned offset = k_info.m_num_objs * sizeof(void*); + bool first = true; + for (field_info const & info : k_info.m_field_info) { + if (info.m_kind == field_info::Scalar) { + if (first && obj_args.size() > 0) { + r = mk_let_decl(mk_enf_object_type(), r); + } + r = mk_let_decl(info.get_type(), mk_scalar_updt(r, info.m_size, offset, args[j])); + offset += info.m_size; + first = false; + } + j++; + } + return r; } expr visit_proj(expr const & e) { - // TODO(Leo): - return e; + name S_name = proj_sname(e); + inductive_val S_val = env().get(S_name).to_inductive_val(); + lean_assert(S_val.get_ncnstrs() == 1); + name k_name = head(S_val.get_cnstrs()); + cnstr_info k_info = get_cnstr_info(k_name); + unsigned idx = 0; + unsigned offset = k_info.m_num_objs * sizeof(void*); + unsigned i = 0; + for (field_info const & info : k_info.m_field_info) { + switch (info.m_kind) { + case field_info::Irrelevant: + if (proj_idx(e) == i) + return mk_enf_neutral(); + break; + case field_info::Object: + if (proj_idx(e) == i) + return mk_app(mk_llnf_proj(idx), visit(proj_expr(e))); + idx++; + break; + case field_info::Scalar: + if (proj_idx(e) == i) + return mk_scalar_proj(visit(proj_expr(e)), info.m_size, offset); + offset += info.m_size; + break; + } + i++; + } + lean_unreachable(); } expr visit_constant(expr const & e) { @@ -361,7 +532,6 @@ expr to_llnf(environment const & env, expr const & e, bool unboxed) { } void initialize_llnf() { - g_cases = new expr(mk_constant("_cases")); g_cnstr = new name("_cnstr"); g_updt = new name("_updt"); g_updt_cidx = new name("_updt_cidx"); @@ -369,6 +539,7 @@ void initialize_llnf() { g_updt_u16 = new name("_updt_u16"); g_updt_u32 = new name("_updt_u32"); g_updt_u64 = new name("_updt_u64"); + g_proj = new name("_proj"); g_proj_u8 = new name("_proj_u8"); g_proj_u16 = new name("_proj_u16"); g_proj_u32 = new name("_proj_u32"); @@ -381,13 +552,13 @@ void initialize_llnf() { } void finalize_llnf() { - delete g_cases; delete g_cnstr; delete g_updt; delete g_updt_u8; delete g_updt_u16; delete g_updt_u32; delete g_updt_u64; + delete g_proj; delete g_proj_u8; delete g_proj_u16; delete g_proj_u32; diff --git a/src/library/compiler/llnf.h b/src/library/compiler/llnf.h index 0fbac66996..5822bfeef0 100644 --- a/src/library/compiler/llnf.h +++ b/src/library/compiler/llnf.h @@ -11,6 +11,10 @@ namespace lean { /* Convert expression to Low Level Normal Form (LLNF). This is the last normal form before converting to the IR. */ expr to_llnf(environment const & env, expr const & e, bool unboxed_data = false); + +bool is_llnf_cnstr(expr const & e, unsigned & cidx, unsigned & ssz); +bool is_llnf_proj(expr const & e, unsigned & idx); + void initialize_llnf(); void finalize_llnf(); }