diff --git a/src/library/equations_compiler/elim_match.cpp b/src/library/equations_compiler/elim_match.cpp index 8b33b30f85..0ede5260b4 100644 --- a/src/library/equations_compiler/elim_match.cpp +++ b/src/library/equations_compiler/elim_match.cpp @@ -873,27 +873,11 @@ struct elim_match_fn { [&](expr const & c, buffer const & new_c_vars) { expr var = pattern; /* We are replacing `var` with `c` */ + buffer vars; to_buffer(eqn.m_vars, vars); + buffer new_vars; buffer from; buffer to; - buffer new_vars; - for (expr const & curr : eqn.m_vars) { - if (curr == var) { - from.push_back(var); - to.push_back(c); - new_vars.append(new_c_vars); - } else { - expr curr_type = ctx.infer(curr); - expr new_curr_type = replace_locals(curr_type, from, to); - if (curr_type == new_curr_type) { - new_vars.push_back(curr); - } else { - expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type); - from.push_back(curr); - to.push_back(new_curr); - new_vars.push_back(new_curr); - } - } - } + update_telescope(ctx, vars, var, c, new_c_vars, new_vars, from, to); equation new_eqn = eqn; new_eqn.m_lctx = ctx.lctx(); new_eqn.m_vars = to_list(new_vars); diff --git a/src/library/equations_compiler/structural_rec.cpp b/src/library/equations_compiler/structural_rec.cpp index bcf8dbf76e..a4f2175547 100644 --- a/src/library/equations_compiler/structural_rec.cpp +++ b/src/library/equations_compiler/structural_rec.cpp @@ -503,26 +503,100 @@ struct structural_rec_fn { } }; + /* Return true if we need to complete equations by expanding the recursive argument. + + For example, suppose we have, where the recursive argument is the second + + def f : nat → nat → nat + | (x+1) (y+1) := f (x+10) y + | _ _ := 1 + + this function returns true because + 1) We need to perform case analysis in the first argument (first equation), + (flag has_case_analysis_before in the followin procedure); and + 2) W have an equation (second) where the recursive argument is a variable + (flag incomplete). + */ + bool must_complete_rec_arg(type_context & ctx, unpack_eqns const & ues) { + if (m_arg_pos == 0) return false; + buffer const & eqns = ues.get_eqns_of(0); + bool has_case_analysis_before = false; + bool incomplete = false; + for (expr const & eqn : eqns) { + unpack_eqn ue(ctx, eqn); + buffer lhs_args; + get_app_args(ue.lhs(), lhs_args); + + if (!has_case_analysis_before) { + for (unsigned i = 0; i < m_arg_pos; i++) { + if (!is_local(lhs_args[i]) && !is_inaccessible(lhs_args[i])) { + has_case_analysis_before = true; + break; + } + } + } + + if (is_local(lhs_args[m_arg_pos])) + incomplete = true; + + if (has_case_analysis_before && incomplete) + return true; + } + return false; + } + void update_eqs(type_context & ctx, unpack_eqns & ues, expr const & fn, expr const & new_fn) { /* C is a temporary "abstract" motive, we use it to access the "brec_on dictionary". The "brec_on dictionary is an element of type below, and it is the last argument of the new function. */ expr C = mk_local(mk_fresh_name(), "_C", m_motive_type, binder_info()); buffer & eqns = ues.get_eqns_of(0); - for (expr & eqn : eqns) { + buffer new_eqns; + bool complete = must_complete_rec_arg(ctx, ues); + for (expr const & eqn : eqns) { unpack_eqn ue(ctx, eqn); expr lhs = ue.lhs(); expr rhs = ue.rhs(); buffer lhs_args; get_app_args(lhs, lhs_args); - expr new_lhs = mk_app(new_fn, lhs_args); - expr type = ctx.whnf(ctx.infer(new_lhs)); - lean_assert(is_pi(type)); - expr F = ue.add_var(binding_name(type), binding_domain(type)); - new_lhs = mk_app(new_lhs, F); - ue.lhs() = new_lhs; - ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs); - eqn = ue.repack(); + if (complete && is_local(lhs_args[m_arg_pos])) { + expr var = lhs_args[m_arg_pos]; + for_each_compatible_constructor(ctx, var, + [&](expr const & c, buffer const & new_c_vars) { + buffer new_vars; + buffer from; + buffer to; + update_telescope(ctx, ue.get_vars(), var, c, new_c_vars, + new_vars, from, to); + buffer new_lhs_args(lhs_args); + new_lhs_args[m_arg_pos] = c; + for (unsigned i = m_arg_pos + 1; i < new_lhs_args.size(); i++) + new_lhs_args[i] = replace_locals(new_lhs_args[i], from, to); + expr new_lhs = mk_app(new_fn, new_lhs_args); + expr type = ctx.whnf(ctx.infer(new_lhs)); + lean_assert(is_pi(type)); + type_context::tmp_locals extra(ctx); + expr F = extra.push_local(binding_name(type), binding_domain(type)); + new_vars.push_back(F); + new_lhs = mk_app(new_lhs, F); + /* The lhs was a variable, so we don't need to update the rhs using elim_rec_apps_fn. + Reason: the rhs should not contain recursive equations. + But, we need to update the locals. */ + expr new_rhs = replace_locals(ue.rhs(), from, to); + expr new_eqn = copy_tag(ue.get_nested_src(), mk_equation(new_lhs, new_rhs)); + new_eqns.push_back(copy_tag(eqn, ctx.mk_lambda(new_vars, new_eqn))); + }); + } else { + expr new_lhs = mk_app(new_fn, lhs_args); + expr type = ctx.whnf(ctx.infer(new_lhs)); + lean_assert(is_pi(type)); + expr F = ue.add_var(binding_name(type), binding_domain(type)); + new_lhs = mk_app(new_lhs, F); + ue.lhs() = new_lhs; + ue.rhs() = elim_rec_apps_fn(ctx, fn, m_arg_pos, m_indices_pos, F, C)(rhs); + new_eqns.push_back(ue.repack()); + } } + eqns = new_eqns; } optional elim_recursion(expr const & e) { diff --git a/src/library/equations_compiler/util.cpp b/src/library/equations_compiler/util.cpp index 51ec498e33..6dbadded8f 100644 --- a/src/library/equations_compiler/util.cpp +++ b/src/library/equations_compiler/util.cpp @@ -13,6 +13,7 @@ Author: Leonardo de Moura #include "library/trace.h" #include "library/app_builder.h" #include "library/private.h" +#include "library/locals.h" #include "library/idx_metavar.h" #include "library/constants.h" #include "library/annotation.h" @@ -676,6 +677,37 @@ void for_each_compatible_constructor(type_context & ctx, expr const & var, } } +/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i, + and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n}, + Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)}, + where T(x_j) updates the type of x_j (j > i) by replacing x_i with t. + + \remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k} +*/ +void update_telescope(type_context & ctx, buffer const & vars, expr const & var, + expr const & t, buffer const & t_vars, buffer & new_vars, + buffer & from, buffer & to) { + /* We are replacing `var` with `c` */ + for (expr const & curr : vars) { + if (curr == var) { + from.push_back(var); + to.push_back(t); + new_vars.append(t_vars); + } else { + expr curr_type = ctx.infer(curr); + expr new_curr_type = replace_locals(curr_type, from, to); + if (curr_type == new_curr_type) { + new_vars.push_back(curr); + } else { + expr new_curr = ctx.push_local(local_pp_name(curr), new_curr_type); + from.push_back(curr); + to.push_back(new_curr); + new_vars.push_back(new_curr); + } + } + } +} + void initialize_eqn_compiler_util() { register_trace_class("eqn_compiler"); register_trace_class(name{"debug", "eqn_compiler"}); diff --git a/src/library/equations_compiler/util.h b/src/library/equations_compiler/util.h index a6c48db3a3..fd46cafded 100644 --- a/src/library/equations_compiler/util.h +++ b/src/library/equations_compiler/util.h @@ -62,9 +62,10 @@ class unpack_eqn { public: unpack_eqn(type_context & ctx, expr const & eqn); expr add_var(name const & n, expr const & type); - buffer const & get_vars() { return m_vars; } + buffer & get_vars() { return m_vars; } expr & lhs() { return m_lhs; } expr & rhs() { return m_rhs; } + expr const & get_nested_src() const { return m_nested_src; } expr repack(); }; @@ -110,6 +111,22 @@ bool is_nat_int_char_string_value(type_context & ctx, expr const & e); void for_each_compatible_constructor(type_context & ctx, expr const & var, std::function &)> const & fn); +/* Given the telescope vars [x_1, ..., x_i, ..., x_n] and var := x_i, + and t is a term containing variables t_vars := {y_1, ..., y_k} disjoint from {x_1, ..., x_n}, + Return [x_1, ..., x_{i-1}, y_1, ..., y_k, T(x_{i+1}), ..., T(x_n)}, + where T(x_j) updates the type of x_j (j > i) by replacing x_i with t. + + \remark The set of variables in t is a subset of {x_1, ..., x_{i-1}} union {y_1, ..., y_k} + + The output parameters from/to contain the replacement + [x_i, ... x_n] => [t, T(x_{i+1}), ..., T(x_n)] + + The replacement will suppress entries x_j => T(x_j) if T(x_j) is equal to x_j. +*/ +void update_telescope(type_context & ctx, buffer const & vars, expr const & var, + expr const & t, buffer const & t_vars, buffer & new_vars, + buffer & from, buffer & to); + void initialize_eqn_compiler_util(); void finalize_eqn_compiler_util(); } diff --git a/tests/lean/run/complete_rec_var.lean b/tests/lean/run/complete_rec_var.lean new file mode 100644 index 0000000000..9235f83a40 --- /dev/null +++ b/tests/lean/run/complete_rec_var.lean @@ -0,0 +1,17 @@ +def f : nat → nat → nat +| (x+1) (y+1) := f (x+10) y +| _ _ := 1 + +vm_eval f 1 1000 + +example (x y) : f (x+1) (y+1) = f (x+10) y := +rfl + +example (y) : f 0 (y+1) = 1 := +rfl + +example (x) : f (x+1) 0 = 1 := +rfl + +example : f 0 0 = 1 := +rfl