/* Copyright (c) 2019 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "runtime/flet.h" #include "kernel/instantiate.h" #include "kernel/abstract.h" #include "kernel/type_checker.h" #include "library/trace.h" #include "library/suffixes.h" #include "library/compiler/util.h" namespace lean { class struct_cases_on_fn { type_checker::state m_st; local_ctx m_lctx; name_set m_scrutinies; /* Set of variables `x` such that there is `casesOn x ...` in the context */ name_map m_first_proj; /* Map from variable `x` to the first projection `y := x.i` in the context */ name_set m_updated; /* Set of variables `x` such that there is a `S.mk ... x.i ... */ name m_fld{"_d"}; unsigned m_next_idx{1}; environment const & env() { return m_st.env(); } name_generator & ngen() { return m_st.ngen(); } name next_field_name() { name r = m_fld.append_after(m_next_idx); m_next_idx++; return r; } expr find(expr const & e) const { if (is_fvar(e)) { if (optional decl = m_lctx.find_local_decl(e)) { if (optional v = decl->get_value()) { if (!is_join_point_name(decl->get_user_name())) return find(*v); } } } else if (is_mdata(e)) { return find(mdata_expr(e)); } return e; } expr visit_cases(expr const & e) { flet save(m_scrutinies, m_scrutinies); buffer args; expr const & c = get_app_args(e, args); expr const & major = args[0]; if (is_fvar(major)) m_scrutinies.insert(fvar_name(major)); for (unsigned i = 1; i < args.size(); i++) { args[i] = visit(args[i]); } return mk_app(c, args); } expr visit_app(expr const & e) { if (is_cases_on_app(env(), e)) { return visit_cases(e); } else if (is_constructor_app(env(), e)) { buffer args; expr const & k = get_app_args(e, args); lean_assert(is_constant(k)); constructor_val k_val = env().get(const_name(k)).to_constructor_val(); for (unsigned i = k_val.get_nparams(), idx = 0; i < args.size(); i++, idx++) { expr arg = find(args[i]); if (is_proj(arg) && proj_idx(arg) == idx && is_fvar(proj_expr(arg))) { m_updated.insert(fvar_name(proj_expr(arg))); } } return e; } else { return e; } } expr visit_lambda(expr e) { buffer fvars; while (is_lambda(e)) { lean_assert(!has_loose_bvars(binding_domain(e))); expr new_fvar = m_lctx.mk_local_decl(ngen(), binding_name(e), binding_domain(e), binding_info(e)); fvars.push_back(new_fvar); e = binding_body(e); } e = instantiate_rev(e, fvars.size(), fvars.data()); e = visit(e); return m_lctx.mk_lambda(fvars, e); } /* Return `some s` if `rhs` is of the form `s.i`, and `s` is a free variables that has not been scrutinized yet, and `s.i` is the first time it is being projected. */ optional is_candidate(expr const & rhs) { if (!is_proj(rhs)) return optional(); expr const & s = proj_expr(rhs); if (!is_fvar(s)) return optional(); name const & s_name = fvar_name(s); if (m_scrutinies.contains(s_name)) return optional(); if (m_first_proj.contains(s_name)) return optional(); return optional(s_name); } static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer & result) { environment const & env = st.env(); constant_info info = env.get(S_name); lean_assert(info.is_inductive()); inductive_val I_val = info.to_inductive_val(); lean_assert(length(I_val.get_cnstrs()) == 1); constant_info ctor_info = env.get(head(I_val.get_cnstrs())); expr type = ctor_info.get_type(); unsigned nparams = I_val.get_nparams(); local_ctx lctx; buffer telescope; to_telescope(env, lctx, st.ngen(), type, telescope); lean_assert(telescope.size() >= nparams); for (unsigned i = nparams; i < telescope.size(); i++) { expr ftype = lctx.get_type(telescope[i]); if (is_irrelevant_type(st, lctx, ftype)) { result.push_back(mk_enf_neutral_type()); } else { type_checker tc(st, lctx); ftype = tc.whnf(ftype); if (is_usize_type(ftype)) { result.push_back(ftype); } else if (is_builtin_scalar(ftype)) { result.push_back(ftype); } else if (optional sz = is_enum_type(env, ftype)) { optional uint = to_uint_type(*sz); if (!uint) throw exception("code generation failed, enumeration type is too big"); result.push_back(*uint); } else { result.push_back(mk_enf_object_type()); } } } } bool should_add_cases_on(local_decl const & decl) { expr val = *decl.get_value(); if (!is_proj(val)) return false; expr const & s = proj_expr(val); if (!is_fvar(s) || !m_updated.contains(fvar_name(s))) return false; name const * x = m_first_proj.find(fvar_name(s)); return x && *x == decl.get_name(); } expr visit_let(expr e) { flet> save(m_first_proj, m_first_proj); buffer fvars; while (is_let(e)) { lean_assert(!has_loose_bvars(let_type(e))); expr type = let_type(e); expr val = instantiate_rev(let_value(e), fvars.size(), fvars.data()); name n = let_name(e); e = let_body(e); expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val); fvars.push_back(new_fvar); if (optional s = is_candidate(val)) { m_first_proj.insert(*s, fvar_name(new_fvar)); } } e = visit(instantiate_rev(e, fvars.size(), fvars.data())); e = abstract(e, fvars.size(), fvars.data()); unsigned i = fvars.size(); while (i > 0) { --i; expr const & x = fvars[i]; lean_assert(is_fvar(x)); local_decl decl = m_lctx.get_local_decl(x); expr type = decl.get_type(); expr val = *decl.get_value(); expr aval = abstract(val, i, fvars.data()); e = mk_let(decl.get_user_name(), type, aval, e); if (should_add_cases_on(decl)) { lean_assert(is_proj(val)); expr major = proj_expr(val); buffer field_types; get_struct_field_types(m_st, proj_sname(val), field_types); e = lift_loose_bvars(e, field_types.size()); unsigned i = field_types.size(); while (i > 0) { --i; e = mk_lambda(next_field_name(), field_types[i], e); } e = mk_app(mk_constant(name(proj_sname(val), g_cases_on)), major, e); } } return e; } expr visit(expr const & e) { switch (e.kind()) { case expr_kind::App: return visit_app(e); case expr_kind::Lambda: return visit_lambda(e); case expr_kind::Let: return visit_let(e); default: return e; } } public: struct_cases_on_fn(environment const & env): m_st(env) { } expr operator()(expr const & e) { return visit(e); } }; expr struct_cases_on(environment const & env, expr const & e) { return struct_cases_on_fn(env)(e); } }