diff --git a/src/kernel/inductive.cpp b/src/kernel/inductive.cpp index a5acfc0219..f041479005 100644 --- a/src/kernel/inductive.cpp +++ b/src/kernel/inductive.cpp @@ -92,6 +92,8 @@ public: expr mk_pi(buffer const & fvars, expr const & e) const { return m_lctx.mk_pi(fvars, e); } expr mk_pi(expr const & fvar, expr const & e) const { return m_lctx.mk_pi(1, &fvar, e); } + expr mk_lambda(buffer const & fvars, expr const & e) const { return m_lctx.mk_lambda(fvars, e); } + expr mk_lambda(expr const & fvar, expr const & e) const { return m_lctx.mk_lambda(1, &fvar, e); } /** \brief Check whether the type of each datatype is well typed, and do not contain free variables or meta variables, @@ -546,13 +548,60 @@ public: ms.append(m_rec_infos[i].m_minors); } + recursor_rules mk_rec_rules(unsigned d_idx, buffer const & Cs, buffer const & minors, unsigned & minor_idx) { + inductive_type const & d = m_ind_types[d_idx]; + levels lvls = get_rec_level_params(); + buffer rules; + for (constructor const & cnstr : d.get_cnstrs()) { + buffer b_u; + buffer u; + expr t = constructor_type(cnstr); + unsigned i = 0; + while (is_pi(t)) { + if (i < m_nparams) { + t = instantiate(binding_body(t), m_params[i]); + } else { + expr l = mk_local_decl_for(t); + b_u.push_back(l); + if (is_rec_argument(binding_domain(t))) + u.push_back(l); + t = instantiate(binding_body(t), l); + } + i++; + } + buffer v; + for (unsigned i = 0; i < u.size(); i++) { + expr u_i = u[i]; + expr u_i_ty = whnf(infer_type(u_i)); + buffer xs; + while (is_pi(u_i_ty)) { + expr x = mk_local_decl_for(u_i_ty); + xs.push_back(x); + u_i_ty = whnf(instantiate(binding_body(u_i_ty), x)); + } + buffer it_indices; + unsigned it_idx = get_I_indices(u_i_ty, it_indices); + name rec_name = mk_rec_name(m_ind_types[it_idx].get_name()); + expr rec_app = mk_constant(rec_name, lvls); + rec_app = mk_app(mk_app(mk_app(mk_app(mk_app(rec_app, m_params), Cs), minors), it_indices), mk_app(u_i, xs)); + v.push_back(mk_lambda(xs, rec_app)); + } + expr e_app = mk_app(mk_app(minors[minor_idx], b_u), v); + expr comp_rhs = mk_lambda(m_params, mk_lambda(Cs, mk_lambda(minors, mk_lambda(b_u, e_app)))); + rules.push_back(recursor_rule(constructor_name(cnstr), b_u.size(), comp_rhs)); + minor_idx++; + } + return recursor_rules(rules); + } + /** \brief Declare recursors. */ void declare_recursors() { buffer Cs; collect_Cs(Cs); buffer minors; collect_minor_premises(minors); - unsigned nminors = minors.size(); - unsigned nmotives = Cs.size(); - names all = get_all_names(); + unsigned nminors = minors.size(); + unsigned nmotives = Cs.size(); + names all = get_all_names(); + unsigned minor_idx = 0; for (unsigned d_idx = 0; d_idx < m_ind_types.size(); d_idx++) { rec_info const & info = m_rec_infos[d_idx]; expr C_app = mk_app(mk_app(info.m_C, info.m_indices), info.m_major); @@ -562,12 +611,9 @@ public: rec_ty = mk_pi(Cs, rec_ty); rec_ty = mk_pi(m_params, rec_ty); rec_ty = infer_implicit(rec_ty, true /* strict */); - /* - TODO(Leo): gen reduction rule - */ - recursor_rules rules; - name rec_name = mk_rec_name(m_ind_types[d_idx].get_name()); - names rec_lparams = get_rec_level_param_names(); + recursor_rules rules = mk_rec_rules(d_idx, Cs, minors, minor_idx); + name rec_name = mk_rec_name(m_ind_types[d_idx].get_name()); + names rec_lparams = get_rec_level_param_names(); m_env.add_core(constant_info(recursor_val(rec_name, rec_lparams, rec_ty, all, m_nparams, m_nindices[d_idx], nmotives, nminors, rules, m_K_target, m_is_meta)));