diff --git a/library/init/core.lean b/library/init/core.lean index 85ef0cc727..99e086a255 100644 --- a/library/init/core.lean +++ b/library/init/core.lean @@ -80,6 +80,23 @@ reserve infixl `; `:1 universes u v w +/- +The kernel definitional equality test (t =?= s) has special support for id_delta applications. +It implements the following rules + + 1) (id_delta t) =?= t + 2) t =?= (id_delta t) + 3) (id_delta t) =?= s IF (unfold_of t) =?= s + 4) t =?= id_delta s IF t =?= (unfold_of s) + +This is mechanism for controlling the delta reduction (aka unfolding) used in the kernel. + +We use id_delta applications to address performance problems when type checking +lemmas generated by the equation compiler. +-/ +@[inline] def id_delta {α : Sort u} (a : α) : α := +a + /-- Gadget for optional parameter support. -/ @[reducible] def opt_param (α : Sort u) (default : α) : Sort u := α diff --git a/src/kernel/type_checker.cpp b/src/kernel/type_checker.cpp index 4283ea8f1a..f6e517b13b 100644 --- a/src/kernel/type_checker.cpp +++ b/src/kernel/type_checker.cpp @@ -578,6 +578,8 @@ void type_checker::cache_failure(expr const & t, expr const & s) { m_failure_cache.insert(mk_pair(s, t)); } +static name * g_id_delta = nullptr; + /** \brief Perform one lazy delta-reduction step. Return - l_true if t_n and s_n are definitionally equal. @@ -590,6 +592,20 @@ auto type_checker::lazy_delta_reduction_step(expr & t_n, expr & s_n) -> reductio auto d_s = is_delta(s_n); if (!d_t && !d_s) { return reduction_status::DefUnknown; + } else if (d_t && d_t->get_name() == *g_id_delta) { + t_n = whnf_core(*unfold_definition(t_n)); + if (t_n == s_n) + return reduction_status::DefEqual; /* id_delta t =?= t */ + if (auto u = unfold_definition(t_n)) /* id_delta t =?= s ===> unfold(t) =?= s */ + t_n = whnf_core(*u); + return reduction_status::Continue; + } else if (d_s && d_s->get_name() == *g_id_delta) { + s_n = whnf_core(*unfold_definition(s_n)); + if (t_n == s_n) + return reduction_status::DefEqual; /* t =?= id_delta t */ + if (auto u = unfold_definition(s_n)) /* t =?= id_delta s ===> t =?= unfold(s) */ + s_n = whnf_core(*u); + return reduction_status::Continue; } else if (d_t && !d_s) { t_n = whnf_core(*unfold_definition(t_n)); } else if (!d_t && d_s) { @@ -790,10 +806,12 @@ certified_declaration certify_unchecked::certify_or_check(environment const & en } void initialize_type_checker() { + g_id_delta = new name("id_delta"); g_dont_care = new expr(Const("dontcare")); } void finalize_type_checker() { delete g_dont_care; + delete g_id_delta; } } diff --git a/src/library/app_builder.cpp b/src/library/app_builder.cpp index 5af42fb570..7983ec9eda 100644 --- a/src/library/app_builder.cpp +++ b/src/library/app_builder.cpp @@ -1136,6 +1136,12 @@ expr mk_id_rhs(type_context & ctx, expr const & h) { return mk_app(mk_constant(get_id_rhs_name(), {lvl}), type, h); } +expr mk_id_delta(type_context & ctx, expr const & h) { + expr type = ctx.infer(h); + level lvl = get_level(ctx, type); + return mk_app(mk_constant(get_id_delta_name(), {lvl}), type, h); +} + static bool is_eq_trans(expr const & h, expr & h1, expr & h2) { if (is_app_of(h, get_eq_trans_name(), 6)) { h1 = app_arg(app_fn(h)); diff --git a/src/library/app_builder.h b/src/library/app_builder.h index edee66385e..92d9831dd5 100644 --- a/src/library/app_builder.h +++ b/src/library/app_builder.h @@ -186,6 +186,9 @@ expr mk_id_locked(type_context & ctx, expr const & type, expr const & h); /* (id_rhs h) */ expr mk_id_rhs(type_context & ctx, expr const & h); +/* (id_delta h) */ +expr mk_id_delta(type_context & ctx, expr const & h); + expr mk_iff_mp(type_context & ctx, expr const & h1, expr const & h2); expr mk_iff_mpr(type_context & ctx, expr const & h1, expr const & h2); expr mk_eq_mp(type_context & ctx, expr const & h1, expr const & h2); diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 97dcb24124..1a8268b004 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -129,6 +129,7 @@ name const * g_heq_of_eq = nullptr; name const * g_hole_command = nullptr; name const * g_id_locked = nullptr; name const * g_id_rhs = nullptr; +name const * g_id_delta = nullptr; name const * g_if_neg = nullptr; name const * g_if_pos = nullptr; name const * g_iff = nullptr; @@ -510,6 +511,7 @@ void initialize_constants() { g_hole_command = new name{"hole_command"}; g_id_locked = new name{"id_locked"}; g_id_rhs = new name{"id_rhs"}; + g_id_delta = new name{"id_delta"}; g_if_neg = new name{"if_neg"}; g_if_pos = new name{"if_pos"}; g_iff = new name{"iff"}; @@ -892,6 +894,7 @@ void finalize_constants() { delete g_hole_command; delete g_id_locked; delete g_id_rhs; + delete g_id_delta; delete g_if_neg; delete g_if_pos; delete g_iff; @@ -1273,6 +1276,7 @@ name const & get_heq_of_eq_name() { return *g_heq_of_eq; } name const & get_hole_command_name() { return *g_hole_command; } name const & get_id_locked_name() { return *g_id_locked; } name const & get_id_rhs_name() { return *g_id_rhs; } +name const & get_id_delta_name() { return *g_id_delta; } name const & get_if_neg_name() { return *g_if_neg; } name const & get_if_pos_name() { return *g_if_pos; } name const & get_iff_name() { return *g_iff; } diff --git a/src/library/constants.h b/src/library/constants.h index 9a0331766a..77a117ec79 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -131,6 +131,7 @@ name const & get_heq_of_eq_name(); name const & get_hole_command_name(); name const & get_id_locked_name(); name const & get_id_rhs_name(); +name const & get_id_delta_name(); name const & get_if_neg_name(); name const & get_if_pos_name(); name const & get_iff_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 5f82540709..90e7505fad 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -124,6 +124,7 @@ heq_of_eq hole_command id_locked id_rhs +id_delta if_neg if_pos iff diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index ef5a64c046..549b3d7c16 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -619,15 +619,107 @@ static expr prove_eqn_lemma_core(type_context & ctx, buffer const & Hs, ex return mk_eq_trans(ctx, H1, H2); } - expr lhs_body = lhs; + /* Check if lhs =?= rhs, and create a reflexivity proof if this is the case. + + We have to be careful to avoid performace problems when checking this proof in the kernel. + We considered different options. + + Option 1) (refl rhs) or (refl lhs) + It will perform poorly in one of the following examples: + + + | f x 0 := 1 + | f x (y+1) := f complex_term y + + | g 0 y := 1 + | g (x+1) y := g x complex_term + + If we use (refl rhs), we will generate the proofs + + eq.refl (f complex_term y) + eq.refl (g x complex_term) + + These proofs trigger the following definitionally equality tests: + + f x (y+1) =?= f complex_term y + g (x+1) y =?= g x complex_term + + Since, we have f/g on both sides, the type checker will try + first to unify the arguments, and may timeout trying to solve + + x =?= complex_term + y =?= complex_term + + since it may take a long time to reduce `complex_term`. + + We have a similar problem if we use (refl lhs) + + Commit 7ebf16ca26da82b3d0e458dbcf32cda374ec785d tried to address this issue + by using Option 2). + + Option 2) (refl (unfold_of lhs)) + This option fixes the performance problem above, but it is still + inefficient for definitions that produce many equations. + For example, the following definition produces 121 equations. + + ``` + universes u + + inductive node (α : Type u) + | leaf : node + | red_node : node → α → node → node + | black_node : node → α → node → node + + namespace node + variable {α : Type u} + + def balance : node α → α → node α → node α + | (red_node (red_node a x b) y c) k d := red_node (black_node a x b) y (black_node c k d) + | (red_node a x (red_node b y c)) k d := red_node (black_node a x b) y (black_node c k d) + | l k r := black_node l k r + + end node + ``` + + In each equation we will have a big (unfold_of lhs) term. This increases the size of .olean + files, and introduces an overhead in the mk_aux_lemma procedure. + + Option 3) (refl (id_delta lhs)) + We are currently using this option. + This approach relies on the fact that the kernel type checker has special support for id_delta. + The kernel implements the following is_def_eq rules for id_delta. + + 1) (id_delta t) =?= t + 2) t =?= (id_delta t) + 3) (id_delta t) =?= s IF (unfold_of t) =?= s + 4) t =?= id_delta s IF t =?= (unfold_of s) + + We can view it as a "lazy" version of Option 2. The .olean file contains `id_delta t` + instead of the result of delta-reducing t. Similarly, no overhead is introduced to mk_aux_lemma + since the proof is quite small in this case. + + Finally, note that this optimization is only use when root = true. + That is, it is not use if the equation compiler used if-then-else compilation trick for + pattern matching scalar values, and/or the pack/unpack auxiliary definitions introduced + for nested inductive datatype declarations. + The problem described in Option 1 does not happen in this case since we unfold the left-hand-side + while building the proof. However, the performance problem described in Option 2 may happen. + */ if (root) { + /* Remark: type_context currently does not have support for id_delta. + So, we unfold lhs before invoking ctx.is_def_eq. */ + expr lhs_body = lhs; if (auto b = unfold_term(ctx.env(), lhs)) lhs_body = *b; - } - - if (ctx.is_def_eq(lhs_body, rhs)) { - // tout() << "DONE\n"; - return mk_eq_refl(ctx, lhs_body); + if (ctx.is_def_eq(lhs_body, rhs)) { + if (ctx.env().find(get_id_delta_name())) + return mk_eq_refl(ctx, mk_id_delta(ctx, lhs)); + else + return mk_eq_refl(ctx, lhs); + } + } else { + if (ctx.is_def_eq(lhs, rhs)) + return mk_eq_refl(ctx, lhs); } throw exception("equation compiler failed to prove equation lemma (workaround: " diff --git a/tests/lean/eqn_proof.lean b/tests/lean/eqn_proof.lean new file mode 100644 index 0000000000..4665ce83fe --- /dev/null +++ b/tests/lean/eqn_proof.lean @@ -0,0 +1,18 @@ +universes u + +inductive node (α : Type u) +| leaf : node +| red_node : node → α → node → node +| black_node : node → α → node → node + +namespace node +variable {α : Type u} + +def balance : node α → α → node α → node α +| (red_node (red_node a x b) y c) k d := red_node (black_node a x b) y (black_node c k d) +| (red_node a x (red_node b y c)) k d := red_node (black_node a x b) y (black_node c k d) +| l k r := black_node l k r + +#print balance._main.equations._eqn_1 + +end node diff --git a/tests/lean/eqn_proof.lean.expected.out b/tests/lean/eqn_proof.lean.expected.out new file mode 100644 index 0000000000..0630c2a3d0 --- /dev/null +++ b/tests/lean/eqn_proof.lean.expected.out @@ -0,0 +1,3 @@ +@[_refl_lemma] +theorem node.balance._main.equations._eqn_1 : ∀ {α : Type u} (k : α) (r : node α), balance._main (leaf α) k r = black_node (leaf α) k r := +λ {α : Type u} (k : α) (r : node α), eq.refl (id_delta (balance._main (leaf α) k r)) diff --git a/tests/lean/run/check_constants.lean b/tests/lean/run/check_constants.lean index 9dcb0547af..fa8ce105fd 100644 --- a/tests/lean/run/check_constants.lean +++ b/tests/lean/run/check_constants.lean @@ -129,6 +129,7 @@ run_cmd script_check_id `heq_of_eq run_cmd script_check_id `hole_command run_cmd script_check_id `id_locked run_cmd script_check_id `id_rhs +run_cmd script_check_id `id_delta run_cmd script_check_id `if_neg run_cmd script_check_id `if_pos run_cmd script_check_id `iff