diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index cec32fa2ec..bee563c993 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -407,6 +407,20 @@ public: return is_convertible(t1, t2, ctx, mk_justification); } + bool is_eq_convertible(expr const & t1, expr const & t2, context const & ctx) { + set_ctx(ctx); + update_menv(none_menv()); + if (t1 == t2) + return true; + expr new_t1 = normalize(t1, ctx, false); + expr new_t2 = normalize(t2, ctx, false); + if (new_t1 == new_t2) + return true; + new_t1 = normalize(new_t1, ctx, true); + new_t2 = normalize(new_t2, ctx, true); + return new_t1 == new_t2; + } + void check_type(expr const & e, context const & ctx) { set_ctx(ctx); update_menv(none_menv()); @@ -486,6 +500,9 @@ expr type_checker::check(expr const & e, context const & ctx) { bool type_checker::is_convertible(expr const & t1, expr const & t2, context const & ctx) { return m_ptr->is_convertible(t1, t2, ctx); } +bool type_checker::is_eq_convertible(expr const & t1, expr const & t2, context const & ctx) { + return m_ptr->is_eq_convertible(t1, t2, ctx); +} void type_checker::check_type(expr const & e, context const & ctx) { m_ptr->check_type(e, ctx); } diff --git a/src/kernel/type_checker.h b/src/kernel/type_checker.h index dec926739b..2edb24f746 100644 --- a/src/kernel/type_checker.h +++ b/src/kernel/type_checker.h @@ -86,6 +86,13 @@ public: /** \brief Return true iff \c t1 is convertible to \c t2 in the context \c ctx. */ bool is_convertible(expr const & t1, expr const & t2, context const & ctx = context()); + /** \brief Return true iff \c t1 is convertible to \c t2 in the context \c ctx, but does not consider + universe commutativity. + + \remark is_eq_convertible(t1, t2, ctx) implies is_convertible(t1, t2, ctx) + */ + bool is_eq_convertible(expr const & t1, expr const & t2, context const & ctx = context()); + /** \brief Return true iff \c e is a proposition (i.e., it has type Bool) */ bool is_proposition(expr const & e, context const & ctx, optional const & menv); bool is_proposition(expr const & e, context const & ctx, metavar_env const & menv); diff --git a/src/library/simplifier/simplifier.cpp b/src/library/simplifier/simplifier.cpp index 82a30d3254..8ed34e7e38 100644 --- a/src/library/simplifier/simplifier.cpp +++ b/src/library/simplifier/simplifier.cpp @@ -34,6 +34,10 @@ Author: Leonardo de Moura #define LEAN_SIMPLIFIER_BETA true #endif +#ifndef LEAN_SIMPLIFIER_ETA +#define LEAN_SIMPLIFIER_ETA true +#endif + #ifndef LEAN_SIMPLIFIER_UNFOLD #define LEAN_SIMPLIFIER_UNFOLD false #endif @@ -51,6 +55,7 @@ static name g_simplifier_proofs {"simplifier", "proofs"}; static name g_simplifier_contextual {"simplifier", "contextual"}; static name g_simplifier_single_pass {"simplifier", "single_pass"}; static name g_simplifier_beta {"simplifier", "beta"}; +static name g_simplifier_eta {"simplifier", "eta"}; static name g_simplifier_unfold {"simplifier", "unfold"}; static name g_simplifier_conditional {"simplifier", "conditional"}; static name g_simplifier_max_steps {"simplifier", "max_steps"}; @@ -58,7 +63,8 @@ static name g_simplifier_max_steps {"simplifier", "max_steps"}; RegisterBoolOption(g_simplifier_proofs, LEAN_SIMPLIFIER_PROOFS, "(simplifier) generate proofs"); RegisterBoolOption(g_simplifier_contextual, LEAN_SIMPLIFIER_CONTEXTUAL, "(simplifier) contextual simplification"); RegisterBoolOption(g_simplifier_single_pass, LEAN_SIMPLIFIER_SINGLE_PASS, "(simplifier) if false then the simplifier keeps applying simplifications as long as possible"); -RegisterBoolOption(g_simplifier_beta, LEAN_SIMPLIFIER_BETA, "(simplifier) use beta-reductions"); +RegisterBoolOption(g_simplifier_beta, LEAN_SIMPLIFIER_BETA, "(simplifier) use beta-reduction"); +RegisterBoolOption(g_simplifier_eta, LEAN_SIMPLIFIER_ETA, "(simplifier) use eta-reduction"); RegisterBoolOption(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD, "(simplifier) unfolds non-opaque definitions"); RegisterBoolOption(g_simplifier_conditional, LEAN_SIMPLIFIER_CONDITIONAL, "(simplifier) conditional rewriting"); RegisterUnsignedOption(g_simplifier_max_steps, LEAN_SIMPLIFIER_MAX_STEPS, "(simplifier) maximum number of steps"); @@ -75,6 +81,9 @@ bool get_simplifier_single_pass(options const & opts) { bool get_simplifier_beta(options const & opts) { return opts.get_bool(g_simplifier_beta, LEAN_SIMPLIFIER_BETA); } +bool get_simplifier_eta(options const & opts) { + return opts.get_bool(g_simplifier_eta, LEAN_SIMPLIFIER_ETA); +} bool get_simplifier_unfold(options const & opts) { return opts.get_bool(g_simplifier_unfold, LEAN_SIMPLIFIER_UNFOLD); } @@ -98,6 +107,7 @@ class simplifier_fn { bool m_contextual; bool m_single_pass; bool m_beta; + bool m_eta; bool m_unfold; bool m_conditional; unsigned m_max_steps; @@ -460,9 +470,41 @@ class simplifier_fn { return rewrite(e, result(e)); } - result rewrite_lambda(expr const & e, result const & r) { + bool is_eta_target(expr const & e) const { + if (is_lambda(e)) { + expr b = abst_body(e); + return + is_app(b) && is_var(arg(b, num_args(b) - 1), 0) && + std::all_of(begin_args(b), end_args(b) - 1, [](expr const & a) { return !has_free_var(a, 0); }); + } else { + return false; + } + } - rewrite(e, r); + result rewrite_lambda(expr const & e, result const & r) { + lean_assert(is_lambda(r.m_out)); + if (m_eta && is_eta_target(r.m_out)) { + expr b = abst_body(r.m_out); + expr new_rhs; + if (num_args(b) > 2) { + new_rhs = mk_app(num_args(b) - 1, &arg(b, 0)); + } else { + new_rhs = arg(b, 0); + } + new_rhs = lower_free_vars(new_rhs, 1, 1); + expr new_rhs_type = ensure_pi(infer_type(new_rhs)); + if (m_tc.is_eq_convertible(abst_domain(new_rhs_type), abst_domain(r.m_out), m_ctx)) { + if (m_proofs_enabled) { + expr new_proof = mk_eta_th(abst_domain(r.m_out), + mk_lambda(r.m_out, abst_body(new_rhs_type)), + new_rhs); + return rewrite(e, mk_trans_result(e, r, new_rhs, new_proof)); + } else { + return rewrite(e, result(new_rhs)); + } + } + } + return rewrite(e, r); } result simplify_lambda(expr const & e) { @@ -534,6 +576,7 @@ class simplifier_fn { m_contextual = get_simplifier_contextual(o); m_single_pass = get_simplifier_single_pass(o); m_beta = get_simplifier_beta(o); + m_eta = get_simplifier_eta(o); m_unfold = get_simplifier_unfold(o); m_conditional = get_simplifier_conditional(o); m_max_steps = get_simplifier_max_steps(o); diff --git a/tests/lua/simp1.lua b/tests/lua/simp1.lua index 57ce1293be..ec03a1d040 100644 --- a/tests/lua/simp1.lua +++ b/tests/lua/simp1.lua @@ -34,3 +34,9 @@ print(e) print(pr) local env = get_environment() print(env:type_check(pr)) + +e, pr = simplify(parse_lean('(fun x y, f x y) = f')) +print(e) +print(pr) +local env = get_environment() +print(env:type_check(pr))