diff --git a/src/compiler/simp_pr1_rec.cpp b/src/compiler/simp_pr1_rec.cpp index 670b21cd06..46b0983df3 100644 --- a/src/compiler/simp_pr1_rec.cpp +++ b/src/compiler/simp_pr1_rec.cpp @@ -30,11 +30,17 @@ class simp_pr1_rec_fn : public compiler_step_visitor { } bool is_rec_arg(expr const & e) { - if (!is_local(e)) + buffer e_args; + expr const & fn = get_app_args(e, e_args); + if (!is_local(fn)) return false; for (unsigned i = 0; i < minor_ctx.size(); i++) { - if (minor_is_rec_arg[i] && mlocal_name(minor_ctx[i]) == mlocal_name(e)) + if (minor_is_rec_arg[i] && mlocal_name(minor_ctx[i]) == mlocal_name(fn)) { + /* make sure arguments contain only valid occurrences */ + for (expr const & e_arg : e_args) + visit(e_arg); return true; + } } return false; } @@ -54,8 +60,9 @@ class simp_pr1_rec_fn : public compiler_step_visitor { } virtual expr visit_local(expr const & e) { - if (is_rec_arg(e)) + if (is_rec_arg(e)) { throw failed(); + } return replace_visitor::visit_local(e); } }; @@ -125,19 +132,27 @@ class simp_pr1_rec_fn : public compiler_step_visitor { for (unsigned k = 0; k < minor_ctx.size(); k++) { if (minor_is_rec_arg[k]) { expr type = ctx().whnf(ctx().infer(minor_ctx[k])); + type_context::tmp_locals locals(ctx()); + while (is_pi(type)) { + expr l = locals.push_local(binding_name(type), binding_domain(type), binding_info(type)); + type = instantiate(binding_body(type), l); + type = ctx().whnf(type); + } buffer type_args; expr type_fn = get_app_args(type, type_args); if (!is_constant(type_fn) || const_name(type_fn) != get_prod_name() || type_args.size() != 2) { return none_expr(); } - minor_ctx[k] = update_mlocal(minor_ctx[k], type_args[0]); + minor_ctx[k] = update_mlocal(minor_ctx[k], locals.mk_pi(type_args[0])); } } // Step 2 buffer minor_body_args; expr minor_body_fn = get_app_args(minor_body, minor_body_args); - if (!is_constant(minor_body_fn) || const_name(minor_body_fn) != get_prod_mk_name() || minor_body_args.size() != 4) { + if (!is_constant(minor_body_fn) || + const_name(minor_body_fn) != get_prod_mk_name() || + minor_body_args.size() != 4) { return none_expr(); } minor_body = minor_body_args[2];