From 7d56382baa4f7be1ebfd9bd2759220e465dd219d Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Wed, 7 Sep 2016 17:57:10 -0700 Subject: [PATCH] feat(library/equations_compiler/util): generate equation lemmas for equations using invertible functions --- src/library/equations_compiler/util.cpp | 98 ++++++++++++++++++++++++- tests/lean/run/pack_unpack1.lean | 11 +++ 2 files changed, 107 insertions(+), 2 deletions(-) diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 00c1e5bdc2..5a3ce13831 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -11,9 +11,11 @@ Author: Leonardo de Moura #include "kernel/inductive/inductive.h" #include "library/util.h" #include "library/trace.h" +#include "library/app_builder.h" #include "library/private.h" #include "library/constants.h" #include "library/annotation.h" +#include "library/inverse.h" #include "library/replace_visitor.h" #include "library/aux_definition.h" #include "library/scope_pos_info_provider.h" @@ -299,6 +301,93 @@ static optional find_if_neg_hypothesis(type_context & ctx, expr const & c, return none_expr(); } + +/* + If `e` is of the form + + (@eq.rec B (f (g (f a))) C (h (g (f a))) (f a) (f_g_eq (f a)) + + such that + + f_g_eq : forall x, f (g x) = x + + and there is a lemma + + g_f_eq : forall x, g (f x) = x + + Return (h a) and a proof that (e = h a) + + The proof is of the form + + @eq.rec + A + a + (fun x : A, (forall H : f x = f a, @eq.rec B (f x) C (h x) (f a) H = h a)) + (fun H : f a = f a, eq.refl (h a)) + (g (f a)) + (eq.symm (g_f_eq a)) + (f_g_eq a) +*/ +static optional prove_eq_rec_invertible(type_context & ctx, expr const & e) { + buffer rec_args; + expr rec_fn = get_app_args(e, rec_args); + if (!is_constant(rec_fn, get_eq_rec_name()) || rec_args.size() != 6) return optional(); + expr B = rec_args[0]; + expr from = rec_args[1]; /* (f (g (f a))) */ + expr C = rec_args[2]; + expr minor = rec_args[3]; /* (h (g (f a))) */ + expr to = rec_args[4]; /* (f a) */ + expr major = rec_args[5]; /* (f_g_eq (f a)) */ + if (!is_app(from) || !is_app(minor)) return optional(); + if (!ctx.is_def_eq(app_arg(from), app_arg(minor))) return optional(); + expr h = app_fn(minor); + expr g_f_a = app_arg(from); + if (!is_app(g_f_a) || !ctx.is_def_eq(app_arg(g_f_a), to)) return optional(); + expr g = get_app_fn(g_f_a); + if (!is_constant(g)) return optional(); + expr f_a = to; + if (!is_app(f_a)) return optional(); + expr f = get_app_fn(f_a); + expr a = app_arg(f_a); + if (!is_constant(f)) return optional(); + optional info = has_inverse(ctx.env(), const_name(f)); + if (!info || info->m_inv != const_name(g)) return optional(); + name g_f_name = info->m_lemma; + optional info_inv = has_inverse(ctx.env(), const_name(g)); + if (!info_inv || info_inv->m_inv != const_name(f)) return optional(); + buffer major_args; + expr f_g_eq = get_app_args(major, major_args); + if (!is_constant(f_g_eq) || major_args.empty() || !ctx.is_def_eq(f_a, major_args.back())) return optional(); + if (const_name(f_g_eq) != info_inv->m_lemma) return optional(); + + expr A = ctx.infer(a); + level A_lvl = get_level(ctx, A); + expr h_a = mk_app(h, a); + expr refl_h_a = mk_eq_refl(ctx, h_a); + expr f_a_eq_f_a = mk_eq(ctx, f_a, f_a); + /* (fun H : f a = f a, eq.refl (h a)) */ + expr pr_minor = mk_lambda("_H", f_a_eq_f_a, refl_h_a); + type_context::tmp_locals aux_locals(ctx); + expr x = aux_locals.push_local("_x", A); + expr f_x = mk_app(app_fn(f_a), x); + expr f_x_eq_f_a = mk_eq(ctx, f_x, f_a); + expr H = aux_locals.push_local("_H", f_x_eq_f_a); + expr h_x = mk_app(h, x); + /* (@eq.rec B (f x) C (h x) (f a) H) */ + expr eq_rec2 = mk_app(rec_fn, {B, f_x, C, h_x, f_a, H}); + /* (@eq.rec B (f x) C (h x) (f a) H) = h a */ + expr eq_rec2_eq = mk_eq(ctx, eq_rec2, h_a); + /* (fun x : A, (forall H : f x = f a, @eq.rec B (f x) C (h x) (f a) H = h a)) */ + expr pr_motive = ctx.mk_lambda(x, ctx.mk_pi(H, eq_rec2_eq)); + expr g_f_eq_a = mk_app(ctx, g_f_name, a); + /* (eq.symm (g_f_eq a)) */ + expr pr_major = mk_eq_symm(ctx, g_f_eq_a); + expr pr = mk_app(mk_constant(get_eq_rec_name(), {mk_level_zero(), A_lvl}), + {A, a, pr_motive, pr_minor, g_f_a, pr_major, major}); + + return optional(mk_pair(h_a, pr)); +} + static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, expr const & lhs, expr const & rhs) { buffer ite_args; expr new_lhs = whnf_ite(ctx, lhs); @@ -325,12 +414,17 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex } } + if (optional p = prove_eq_rec_invertible(ctx, new_lhs)) { + expr new_new_lhs = p->first; + expr H1 = p->second; + expr H2 = prove_eqn_lemma_core(ctx, Hs, new_new_lhs, rhs); + return mk_eq_trans(ctx, H1, H2); + } + if (ctx.is_def_eq(lhs, rhs)) { return mk_eq_refl(ctx, rhs); } - /* TODO(Leo): add support for pack/unpack lemmas */ - throw exception("equation compiler failed to prove equation lemma (workaround: " "disable lemma generation using `set_option eqn_compiler.lemmas false`)"); } diff --git a/tests/lean/run/pack_unpack1.lean b/tests/lean/run/pack_unpack1.lean index 903bcd4360..45de0e21fa 100644 --- a/tests/lean/run/pack_unpack1.lean +++ b/tests/lean/run/pack_unpack1.lean @@ -64,3 +64,14 @@ constant mk2 {A : Type} (l : list (tree A)) : P (tree.node l) definition bla {A : Type} : ∀ n : tree A, P n | (tree.leaf a) := mk1 a | (tree.node l) := mk2 l + +definition foo {A : Type} : nat → tree A → nat +| 0 _ := 0 +| (n+1) (tree.leaf a) := 0 +| (n+1) (tree.node []) := foo n (tree.node []) +| (n+1) (tree.node (x::xs)) := foo n x + +check @foo._main.equations.eqn_1 +check @foo._main.equations.eqn_2 +check @foo._main.equations.eqn_3 +check @foo._main.equations.eqn_4