diff --git a/library/init/meta/congr_tactic.lean b/library/init/meta/congr_tactic.lean index 2d850fb749..49dfe671bf 100644 --- a/library/init/meta/congr_tactic.lean +++ b/library/init/meta/congr_tactic.lean @@ -59,11 +59,4 @@ do focus1 (try assumption >> congr_core >> all_goals (try reflexivity >> try con meta def rel_congr : tactic unit := do focus1 (try assumption >> rel_congr_core >> all_goals (try reflexivity)) -namespace interactive - -meta def congr := tactic.congr -meta def rel_congr := tactic.rel_congr - -end interactive - end tactic diff --git a/library/init/meta/interactive.lean b/library/init/meta/interactive.lean index 31fed53dbe..b1d2b828d9 100644 --- a/library/init/meta/interactive.lean +++ b/library/init/meta/interactive.lean @@ -7,6 +7,7 @@ prelude import init.meta.tactic init.meta.rewrite_tactic init.meta.simp_tactic import init.meta.smt.congruence_closure init.category.combinators import init.meta.interactive_base init.meta.derive init.meta.match_tactic +import init.meta.congr_tactic open lean open lean.parser @@ -1530,6 +1531,9 @@ do e ← i_to_expr p, then tactic.note h.local_pp_name none e >> try (tactic.clear h) else tactic.fail "specialize requires a term of the form `h x_1 .. x_n` where `h` appears in the local context" +meta def congr := tactic.congr +meta def rel_congr := tactic.rel_congr + end interactive end tactic @@ -1576,3 +1580,64 @@ mwhen has_dup $ do intron n end add_interactive + +namespace tactic +/- Helper tactic for `mk_inj_eq -/ +protected meta def apply_inj_lemma : tactic unit := +do h ← intro `h, + some (lhs, rhs) ← expr.is_eq <$> infer_type h, + (expr.const C _) ← return lhs.get_app_fn, + applyc (name.mk_string "inj" C), + assumption + +/- Auxiliary tactic for proving `I.C.inj_eq` lemmas. + These lemmas are automatically generated by the equation compiler. + Example: + ``` + list.cons.inj_eq : forall h1 h2 t1 t2, (h1::t1 = h2::t2) = (h1 = h2 ∧ t1 = t2) := + by mk_inj_eq + ``` +-/ +meta def mk_inj_eq : tactic unit := +`[ + intros, + apply propext, + apply iff.intro, + { tactic.apply_inj_lemma }, + { intro _, try { cases_matching* _ ∧ _ }, refl <|> { congr; { assumption <|> subst_vars } } } +] +end tactic + +/- Define inj_eq lemmas for inductive datatypes that were declared before `mk_inj_eq` -/ + +universes u v + +lemma sum.inl.inj_eq {α : Type u} (β : Type v) (a₁ a₂ : α) : (@sum.inl α β a₁ = sum.inl a₂) = (a₁ = a₂) := +by tactic.mk_inj_eq + +lemma sum.inr.inj_eq (α : Type u) {β : Type v} (b₁ b₂ : β) : (@sum.inr α β b₁ = sum.inr b₂) = (b₁ = b₂) := +by tactic.mk_inj_eq + +lemma psum.inl.inj_eq {α : Sort u} (β : Sort v) (a₁ a₂ : α) : (@psum.inl α β a₁ = psum.inl a₂) = (a₁ = a₂) := +by tactic.mk_inj_eq + +lemma psum.inr.inj_eq (α : Sort u) {β : Sort v} (b₁ b₂ : β) : (@psum.inr α β b₁ = psum.inr b₂) = (b₁ = b₂) := +by tactic.mk_inj_eq + +lemma sigma.mk.inj_eq {α : Type u} {β : α → Type v} (a₁ : α) (b₁ : β a₁) (a₂ : α) (b₂ : β a₂) : (sigma.mk a₁ b₁ = sigma.mk a₂ b₂) = (a₁ = a₂ ∧ b₁ == b₂) := +by tactic.mk_inj_eq + +lemma psigma.mk.inj_eq {α : Sort u} {β : α → Sort v} (a₁ : α) (b₁ : β a₁) (a₂ : α) (b₂ : β a₂) : (psigma.mk a₁ b₁ = psigma.mk a₂ b₂) = (a₁ = a₂ ∧ b₁ == b₂) := +by tactic.mk_inj_eq + +lemma subtype.mk.inj_eq {α : Type u} {p : α → Prop} (a₁ : α) (h₁ : p a₁) (a₂ : α) (h₂ : p a₂) : (subtype.mk a₁ h₁ = subtype.mk a₂ h₂) = (a₁ = a₂) := +by tactic.mk_inj_eq + +lemma option.some.inj_eq {α : Type u} (a₁ a₂ : α) : (some a₁ = some a₂) = (a₁ = a₂) := +by tactic.mk_inj_eq + +lemma list.cons.inj_eq {α : Type u} (h₁ : α) (t₁ : list α) (h₂ : α) (t₂ : list α) : (list.cons h₁ t₁ = list.cons h₂ t₂) = (h₁ = h₂ ∧ t₁ = t₂) := +by tactic.mk_inj_eq + +lemma nat.succ.inj_eq (n₁ n₂ : nat) : (nat.succ n₁ = nat.succ n₂) = (n₁ = n₂) := +by tactic.mk_inj_eq diff --git a/src/frontends/lean/structure_cmd.cpp b/src/frontends/lean/structure_cmd.cpp index 2f9ea6c986..d5aec307d6 100644 --- a/src/frontends/lean/structure_cmd.cpp +++ b/src/frontends/lean/structure_cmd.cpp @@ -1258,7 +1258,10 @@ struct structure_cmd_fn { return; if (!has_and_decls(m_env)) return; - m_env = mk_injective_lemmas(m_env, m_name); + /* We do not generate `*.inj_eq` lemmas for classes since they can be quite expensive to + generate for big classes, and they don't seem to be useful in this case. */ + bool gen_inj_eq = !m_meta_info.m_attrs.has_class(); + m_env = mk_injective_lemmas(m_env, m_name, gen_inj_eq); add_alias(mk_injective_name(m_name)); add_alias(mk_injective_arrow_name(m_name)); } diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 97e84f1939..6f98110dc9 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -357,6 +357,7 @@ name const * g_psum_inr = nullptr; name const * g_tactic = nullptr; name const * g_tactic_try = nullptr; name const * g_tactic_triv = nullptr; +name const * g_tactic_mk_inj_eq = nullptr; name const * g_thunk = nullptr; name const * g_to_fmt = nullptr; name const * g_trans_rel_left = nullptr; @@ -738,6 +739,7 @@ void initialize_constants() { g_tactic = new name{"tactic"}; g_tactic_try = new name{"tactic", "try"}; g_tactic_triv = new name{"tactic", "triv"}; + g_tactic_mk_inj_eq = new name{"tactic", "mk_inj_eq"}; g_thunk = new name{"thunk"}; g_to_fmt = new name{"to_fmt"}; g_trans_rel_left = new name{"trans_rel_left"}; @@ -1120,6 +1122,7 @@ void finalize_constants() { delete g_tactic; delete g_tactic_try; delete g_tactic_triv; + delete g_tactic_mk_inj_eq; delete g_thunk; delete g_to_fmt; delete g_trans_rel_left; @@ -1501,6 +1504,7 @@ name const & get_psum_inr_name() { return *g_psum_inr; } name const & get_tactic_name() { return *g_tactic; } name const & get_tactic_try_name() { return *g_tactic_try; } name const & get_tactic_triv_name() { return *g_tactic_triv; } +name const & get_tactic_mk_inj_eq_name() { return *g_tactic_mk_inj_eq; } name const & get_thunk_name() { return *g_thunk; } name const & get_to_fmt_name() { return *g_to_fmt; } name const & get_trans_rel_left_name() { return *g_trans_rel_left; } diff --git a/src/library/constants.h b/src/library/constants.h index 503972faa0..a2d7bfc999 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -359,6 +359,7 @@ name const & get_psum_inr_name(); name const & get_tactic_name(); name const & get_tactic_try_name(); name const & get_tactic_triv_name(); +name const & get_tactic_mk_inj_eq_name(); name const & get_thunk_name(); name const & get_to_fmt_name(); name const & get_trans_rel_left_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index ff03f86278..6d4e8f774e 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -352,6 +352,7 @@ psum.inr tactic tactic.try tactic.triv +tactic.mk_inj_eq thunk to_fmt trans_rel_left diff --git a/src/library/constructions/injective.cpp b/src/library/constructions/injective.cpp index 917af6c34b..eacf4f6ed3 100644 --- a/src/library/constructions/injective.cpp +++ b/src/library/constructions/injective.cpp @@ -19,6 +19,7 @@ Author: Daniel Selsam, Leonardo de Moura #include "library/tactic/tactic_state.h" #include "library/tactic/intro_tactic.h" #include "library/tactic/subst_tactic.h" +#include "library/tactic/tactic_evaluator.h" namespace lean { @@ -57,23 +58,23 @@ static void collect_args(type_context & tctx, expr const & type, unsigned num_pa lean_assert(!is_pi(ty)); } -expr mk_injective_type(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names) { +expr mk_injective_type_core(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names, bool use_eq) { // The transparency needs to match the kernel since we need to be consistent with the no_confusion construction. - type_context tctx(env, transparency_mode::All); + type_context ctx(env, transparency_mode::All); buffer params, args1, args2, new_args; - collect_args(tctx, ir_type, num_params, params, args1, args2, new_args); + collect_args(ctx, ir_type, num_params, params, args1, args2, new_args); expr c_ir_params = mk_app(mk_constant(ir_name, param_names_to_levels(lp_names)), params); expr lhs = mk_app(c_ir_params, args1); expr rhs = mk_app(c_ir_params, args2); - expr eq_type = mk_eq(tctx, lhs, rhs); + expr eq_type = mk_eq(ctx, lhs, rhs); buffer eqs; for (unsigned arg_idx = 0; arg_idx < args1.size(); ++arg_idx) { - if (!tctx.is_prop(tctx.infer(args1[arg_idx])) && args1[arg_idx] != args2[arg_idx]) { - if (tctx.is_def_eq(tctx.infer(args1[arg_idx]), tctx.infer(args2[arg_idx]))) { - eqs.push_back(mk_eq(tctx, args1[arg_idx], args2[arg_idx])); + if (!ctx.is_prop(ctx.infer(args1[arg_idx])) && args1[arg_idx] != args2[arg_idx]) { + if (ctx.is_def_eq(ctx.infer(args1[arg_idx]), ctx.infer(args2[arg_idx]))) { + eqs.push_back(mk_eq(ctx, args1[arg_idx], args2[arg_idx])); } else { - eqs.push_back(mk_heq(tctx, args1[arg_idx], args2[arg_idx])); + eqs.push_back(mk_heq(ctx, args1[arg_idx], args2[arg_idx])); } } } @@ -90,7 +91,16 @@ expr mk_injective_type(environment const & env, name const & ir_name, expr const } } - return tctx.mk_pi(params, tctx.mk_pi(args1, tctx.mk_pi(new_args, mk_arrow(eq_type, and_type)))); + expr result = use_eq ? mk_eq(ctx, eq_type, and_type) : mk_arrow(eq_type, and_type); + return ctx.mk_pi(params, ctx.mk_pi(args1, ctx.mk_pi(new_args, result))); +} + +expr mk_injective_type(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names) { + return mk_injective_type_core(env, ir_name, ir_type, num_params, lp_names, false); +} + +expr mk_injective_eq_type(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names) { + return mk_injective_type_core(env, ir_name, ir_type, num_params, lp_names, true); } expr prove_by_assumption(type_context & tctx, expr const & ty, expr const & eq) { @@ -237,7 +247,23 @@ environment mk_injective_arrow(environment const & env, name const & ir_name) { return new_env; } -environment mk_injective_lemmas(environment const & _env, name const & ind_name) { +expr prove_injective_eq(environment const & env, expr const & inj_eq_type, name const & inj_eq_name) { + try { + type_context ctx(env, transparency_mode::Semireducible); + expr dummy_ref; + tactic_state s = mk_tactic_state_for(env, options(), inj_eq_name, metavar_context(), local_context(), inj_eq_type); + vm_obj r = tactic_evaluator(ctx, options(), dummy_ref)(mk_constant(get_tactic_mk_inj_eq_name()), s); + if (auto new_s = tactic::is_success(r)) { + metavar_context mctx = new_s->mctx(); + return mctx.instantiate_mvars(new_s->main()); + } + } catch (exception & ex) { + throw nested_exception(sstream() << "failed to generate auxiliary lemma '" << inj_eq_name << "'", ex); + } + throw exception(sstream() << "failed to generate auxiliary lemma '" << inj_eq_name << "'"); +} + +environment mk_injective_lemmas(environment const & _env, name const & ind_name, bool gen_inj_eq) { environment env = _env; auto idecls = inductive::is_inductive_decl(env, ind_name); @@ -262,6 +288,12 @@ environment mk_injective_lemmas(environment const & _env, name const & ind_name) lean_trace(name({"constructions", "injective"}), tout() << ir_name << " : " << inj_type << " :=\n " << inj_val << "\n";); env = module::add(env, check(env, mk_definition_inferring_trusted(env, mk_injective_name(ir_name), lp_names, inj_type, inj_val, true))); env = mk_injective_arrow(env, ir_name); + if (gen_inj_eq && env.find(get_tactic_mk_inj_eq_name())) { + name inj_eq_name = mk_injective_eq_name(ir_name); + expr inj_eq_type = mk_injective_eq_type(env, ir_name, ir_type, num_params, lp_names); + expr inj_eq_value = prove_injective_eq(env, inj_eq_type, inj_eq_name); + env = module::add(env, check(env, mk_definition_inferring_trusted(env, inj_eq_name, lp_names, inj_eq_type, inj_eq_value, true))); + } } return env; } @@ -270,6 +302,10 @@ name mk_injective_name(name const & ir_name) { return name(ir_name, "inj"); } +name mk_injective_eq_name(name const & ir_name) { + return name(ir_name, "inj_eq"); +} + name mk_injective_arrow_name(name const & ir_name) { return name(ir_name, "inj_arrow"); } diff --git a/src/library/constructions/injective.h b/src/library/constructions/injective.h index 5b4b98c1fd..267fb64c91 100644 --- a/src/library/constructions/injective.h +++ b/src/library/constructions/injective.h @@ -8,13 +8,21 @@ Author: Daniel Selsam, Leonardo de Moura namespace lean { -environment mk_injective_lemmas(environment const & env, name const & ind_name); +/* Generate injectivity lemmas `*.inj`, `*.inj_arrow` and `*.inj_eq`. + If `gen_inj_eq` is false, then `*.inj_eq` lemma is not generated. + The `*.inj_eq` lemma is used by the simplifier. + We don't generate it for classes because they can be expensive to generate and are rarely used in this case. +*/ +environment mk_injective_lemmas(environment const & env, name const & ind_name, bool gen_inj_eq = true); environment mk_injective_arrow(environment const & env, name const & ir_name); expr mk_injective_type(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names); +expr mk_injective_eq_type(environment const & env, name const & ir_name, expr const & ir_type, unsigned num_params, level_param_names const & lp_names); +expr prove_injective_eq(environment const & env, expr const & inj_eq_type, name const & inj_eq_name); name mk_injective_name(name const & ir_name); +name mk_injective_eq_name(name const & ir_name); name mk_injective_arrow_name(name const & ir_name); void initialize_injective(); diff --git a/src/library/inductive_compiler/mutual.cpp b/src/library/inductive_compiler/mutual.cpp index 23ba35f66b..2b1261f65e 100644 --- a/src/library/inductive_compiler/mutual.cpp +++ b/src/library/inductive_compiler/mutual.cpp @@ -461,11 +461,23 @@ class add_mutual_inductive_decl_fn { if (!static_cast(m_env.find(mk_injective_name(mlocal_name(m_basic_decl.get_intro_rule(0, basic_ir_idx)))))) { return; } - expr inj_and_type = mk_injective_type(m_env, mlocal_name(ir), Pi(m_mut_decl.get_params(), mlocal_type(ir)), m_mut_decl.get_num_params(), to_list(m_mut_decl.get_lp_names())); + level_param_names lp_names = to_list(m_mut_decl.get_lp_names()); + unsigned num_params = m_mut_decl.get_num_params(); + name ir_name = mlocal_name(ir); + expr ir_type = Pi(m_mut_decl.get_params(), mlocal_type(ir)); + expr inj_and_type = mk_injective_type(m_env, ir_name, ir_type, num_params, lp_names); expr inj_and_val = mk_constant(mk_injective_name(mlocal_name(m_basic_decl.get_intro_rule(0, basic_ir_idx))), m_mut_decl.get_levels()); - lean_trace(name({"inductive_compiler", "mutual", "injective"}), tout() << mk_injective_name(mlocal_name(ir)) << " : " << inj_and_type << " :=\n " << inj_and_val << "\n";); - m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, mk_injective_name(mlocal_name(ir)), to_list(m_mut_decl.get_lp_names()), inj_and_type, inj_and_val, true))); - m_env = mk_injective_arrow(m_env, mlocal_name(ir)); + lean_trace(name({"inductive_compiler", "mutual", "injective"}), tout() << mk_injective_name(ir_name) << " : " << inj_and_type << " :=\n " << inj_and_val << "\n";); + m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, mk_injective_name(ir_name), lp_names, inj_and_type, inj_and_val, true))); + m_env = mk_injective_arrow(m_env, ir_name); + + if (m_env.find(get_tactic_mk_inj_eq_name())) { + name inj_eq_name = mk_injective_eq_name(ir_name); + expr inj_eq_type = mk_injective_eq_type(m_env, ir_name, ir_type, num_params, lp_names); + expr inj_eq_value = prove_injective_eq(m_env, inj_eq_type, inj_eq_name); + m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, inj_eq_name, lp_names, inj_eq_type, inj_eq_value, true))); + } + m_tctx.set_env(m_env); basic_ir_idx++; } diff --git a/src/library/inductive_compiler/nested.cpp b/src/library/inductive_compiler/nested.cpp index 2dc3861d8f..45c7e70848 100644 --- a/src/library/inductive_compiler/nested.cpp +++ b/src/library/inductive_compiler/nested.cpp @@ -2272,17 +2272,25 @@ class add_nested_inductive_decl_fn { for (unsigned ind_idx = 0; ind_idx < m_nested_decl.get_num_inds(); ++ind_idx) { for (unsigned ir_idx = 0; ir_idx < m_nested_decl.get_num_intro_rules(ind_idx); ++ir_idx) { expr const & ir = m_nested_decl.get_intro_rule(ind_idx, ir_idx); + level_param_names lp_names = to_list(m_nested_decl.get_lp_names()); + name ir_name = mlocal_name(ir); + expr ir_type = Pi(m_nested_decl.get_params(), mlocal_type(ir)); + unsigned num_params = m_nested_decl.get_num_params(); name inj_name = mk_injective_name(mlocal_name(ir)); - expr inj_type = mk_injective_type(m_env, mlocal_name(ir), Pi(m_nested_decl.get_params(), mlocal_type(ir)), - m_nested_decl.get_num_params(), to_list(m_nested_decl.get_lp_names())); - + expr inj_type = mk_injective_type(m_env, ir_name, ir_type, num_params, lp_names); name inj_arrow_name = mk_injective_arrow_name(mlocal_name(m_inner_decl.get_intro_rule(ind_idx, ir_idx))); expr inj_val = prove_nested_injective(inj_type, m_inj_lemmas, inj_arrow_name); m_env = module::add(m_env, check(m_env, - mk_definition_inferring_trusted(m_env, inj_name, to_list(m_nested_decl.get_lp_names()), inj_type, inj_val, true))); - m_env = mk_injective_arrow(m_env, mlocal_name(ir)); + mk_definition_inferring_trusted(m_env, inj_name, lp_names, inj_type, inj_val, true))); + m_env = mk_injective_arrow(m_env, ir_name); + if (m_env.find(get_tactic_mk_inj_eq_name())) { + name inj_eq_name = mk_injective_eq_name(ir_name); + expr inj_eq_type = mk_injective_eq_type(m_env, ir_name, ir_type, num_params, lp_names); + expr inj_eq_value = prove_injective_eq(m_env, inj_eq_type, inj_eq_name); + m_env = module::add(m_env, check(m_env, mk_definition_inferring_trusted(m_env, inj_eq_name, lp_names, inj_eq_type, inj_eq_value, true))); + } } } m_tctx.set_env(m_env); diff --git a/tests/lean/run/check_constants.lean b/tests/lean/run/check_constants.lean index 80dcfbacb3..5cb6567e05 100644 --- a/tests/lean/run/check_constants.lean +++ b/tests/lean/run/check_constants.lean @@ -357,6 +357,7 @@ run_cmd script_check_id `psum.inr run_cmd script_check_id `tactic run_cmd script_check_id `tactic.try run_cmd script_check_id `tactic.triv +run_cmd script_check_id `tactic.mk_inj_eq run_cmd script_check_id `thunk run_cmd script_check_id `to_fmt run_cmd script_check_id `trans_rel_left diff --git a/tests/lean/run/simp_constructor.lean b/tests/lean/run/simp_constructor.lean index 224fa2b3c0..748a89d7dd 100644 --- a/tests/lean/run/simp_constructor.lean +++ b/tests/lean/run/simp_constructor.lean @@ -10,3 +10,28 @@ end example : ¬ term.var "a" = term.app "f" [] := by simp + +#check @term.app.inj_eq + +universes u + +inductive vec (α : Type u) : nat → Type u +| nil : vec 0 +| cons : Π {n}, α → vec n → vec (nat.succ n) + +#check @vec.cons.inj_eq + +example (a b : nat) (h : a == b) : a + 1 = b + 1 := +begin + subst h +end + +mutual inductive Expr, Expr_list +with Expr : Type +| var : string → Expr +| app : string → Expr_list → Expr +with Expr_list : Type +| nil : Expr_list +| cons : Expr → Expr_list → Expr_list + +#check @Expr.app.inj_eq diff --git a/tests/lean/run/simp_univ_metavars.lean b/tests/lean/run/simp_univ_metavars.lean index 74aa45569e..278f993344 100644 --- a/tests/lean/run/simp_univ_metavars.lean +++ b/tests/lean/run/simp_univ_metavars.lean @@ -7,6 +7,8 @@ structure { u v } Category := (compose : Π ⦃X Y Z : Obj⦄, Hom X Y → Hom Y Z → Hom X Z) (left_identity : ∀ ⦃X Y : Obj⦄ (f : Hom X Y), compose (identity _) f = f) +#check @Category.mk.inj_eq + structure Functor (C : Category) (D : Category) := (onObjects : C^.Obj → D^.Obj) (onMorphisms : Π ⦃X Y : C^.Obj⦄,