diff --git a/src/library/equations_compiler/compiler.cpp b/src/library/equations_compiler/compiler.cpp index 63be719aa6..34ddd77086 100644 --- a/src/library/equations_compiler/compiler.cpp +++ b/src/library/equations_compiler/compiler.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "library/locals.h" #include "library/util.h" #include "library/replace_visitor.h" +#include "library/replace_visitor_with_tc.h" #include "library/equations_compiler/compiler.h" #include "library/equations_compiler/util.h" #include "library/equations_compiler/structural_rec.h" @@ -320,9 +321,35 @@ static expr remove_aux_main_name(expr const & e) { return e; } +struct eta_expand_rec_apps_fn : public replace_visitor_with_tc { + eta_expand_rec_apps_fn(type_context_old & ctx): replace_visitor_with_tc(ctx) {} + + virtual expr visit_local(expr const & e) { + if (is_rec(local_info(e))) { + expr e2 = m_ctx.eta_expand(e); + lean_assert(!is_local(e2)); + return visit(e2); + } + return e; + } + + virtual expr visit_app(expr const & e) { + expr const & fn = app_fn(e); + if (is_local(fn) && is_rec(local_info(fn))) { + // do not eta-expand `fn` + expr arg = visit(app_arg(e)); + return mk_app(fn, arg); + } else { + return replace_visitor::visit_app(e); + } + } +}; + static expr compile_equations_main(environment & env, elaborator & elab, metavar_context & mctx, local_context const & lctx, expr const & _eqns, bool report_cexs) { - expr eqns = _eqns; + // all following code assumes that all recursive occurrences are applications + type_context_old ctx(env, mctx, lctx, elab.get_cache(), transparency_mode::Semireducible); + expr eqns = eta_expand_rec_apps_fn(ctx)(_eqns); equations_header const & header = get_equations_header(eqns); eqn_compiler_result r; if (!header.m_is_meta && has_nested_rec(eqns)) { diff --git a/src/library/equations_compiler/wf_rec.cpp b/src/library/equations_compiler/wf_rec.cpp index 863c3ea298..af023d5d0c 100644 --- a/src/library/equations_compiler/wf_rec.cpp +++ b/src/library/equations_compiler/wf_rec.cpp @@ -167,9 +167,8 @@ struct wf_rec_fn { virtual expr visit_local(expr const & e) { if (local_name(e) == local_name(m_fn)) { - expr e2 = m_ctx.eta_expand(e); - lean_assert(!is_local(e2)); - return visit(e2); + /* unexpected occurrence of recursive function */ + throw generic_exception(e, "unexpected occurrence of recursive function\n"); } return e; }