diff --git a/src/frontends/lean/match_expr.cpp b/src/frontends/lean/match_expr.cpp index cfa1996a17..e0cc227d62 100644 --- a/src/frontends/lean/match_expr.cpp +++ b/src/frontends/lean/match_expr.cpp @@ -36,9 +36,9 @@ expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos) { if (p.curr_is_token(get_colon_tk())) { p.next(); expr type = p.parse_expr(); - fn = mk_local(mk_fresh_name(), *g_match_name, type, binder_info()); + fn = mk_local(mk_fresh_name(), *g_match_name, type, mk_rec_info(true)); } else { - fn = mk_local(mk_fresh_name(), *g_match_name, mk_expr_placeholder(), binder_info()); + fn = mk_local(mk_fresh_name(), *g_match_name, mk_expr_placeholder(), mk_rec_info(true)); } p.check_token_next(get_with_tk(), "invalid 'match' expression, 'with' expected"); diff --git a/src/kernel/expr.h b/src/kernel/expr.h index 31c239b69e..3ae8717baf 100644 --- a/src/kernel/expr.h +++ b/src/kernel/expr.h @@ -216,7 +216,8 @@ class binder_info { inferred by class-instance resolution. */ unsigned m_inst_implicit:1; /** \brief Auxiliary internal attribute used to mark local constants representing recursive functions - in recursive equations */ + in recursive equations. TODO(Leo): rename to eqn_decl since we also mark non recursive equations + (e.g., `match ... with ... end`) with this flag. */ unsigned m_rec:1; public: binder_info(bool implicit = false, bool strict_implicit = false, bool inst_implicit = false, bool rec = false): diff --git a/src/library/equations_compiler/compiler.cpp b/src/library/equations_compiler/compiler.cpp index 1fad0d97d1..abf8132928 100644 --- a/src/library/equations_compiler/compiler.cpp +++ b/src/library/equations_compiler/compiler.cpp @@ -135,6 +135,32 @@ struct pull_nested_rec_fn : public replace_visitor { t = is_lam ? ctx.mk_lambda(locals, t) : ctx.mk_pi(locals, t); m_mctx = ctx.mctx(); m_lctx_stack.pop_back(); + /* We clear the cache whenever we visit a binder because of + collect_local_props. + + When pulling a recursive call (f t), the resulting term + may contain local variables that do not occur in (f t). + Thus, the cached value for (f t) may not be valid + in other contexts. + + By clearing the cache we conservatively fix this issue. + + Here is an example: + + def filter : list A → list A + | nil := nil + | (a :: l) := + match (H a) with + | (is_true h_1) := a :: filter l + | (is_false h_2) := filter l + end + + The first (filter l) is replaced with a term + (_f_1 l h_1) where _f_1 is a new fresh local. + We cannot replace the second (filter l) + with (_f_1 l h_1), since h_1 is not in the scope. + */ + m_cache.clear(); return t; } @@ -175,9 +201,31 @@ struct pull_nested_rec_fn : public replace_visitor { }); } + /* Collect local propositions. This is needed when the nested recursive call will + be defined by well-founded recursion, and we don't know whether local propositions + are hints for helping the "decreasing tactic". + In the future, we should add a mechanism that will only include these propositions + if the recursive call will be defined using well founded recursion. + */ + void collect_local_props(name_set & found, buffer & R) { + type_context ctx = mk_type_context(lctx()); + lctx().for_each([&](local_decl const & d) { + if (!base_lctx().find_local_decl(d.get_name()) && + !found.contains(d.get_name()) && + !d.get_info().is_rec() && + ctx.is_prop(d.get_type())) { + found.insert(d.get_name()); + R.push_back(d.mk_ref()); + } + }); + } + void collect_locals(expr const & e, buffer & R) { name_set found; + /* Collect used local declarations. */ collect_locals_core(e, found, R); + /* Collect local propositions. */ + collect_local_props(found, R); for (unsigned i = 0; i < R.size(); i++) { expr const & x = R[i]; collect_locals_core(lctx().get_local_decl(x).get_type(), found, R); diff --git a/tests/lean/run/bin_tree.lean b/tests/lean/run/bin_tree.lean new file mode 100644 index 0000000000..a3b75d664b --- /dev/null +++ b/tests/lean/run/bin_tree.lean @@ -0,0 +1,25 @@ +def pairs_with_sum' : Π (m n) {d}, m + n = d → list {p : ℕ × ℕ // p.1 + p.2 = d} +| 0 n d h := [⟨(0, n), h⟩] +| (m+1) n d h := ⟨(m+1, n), h⟩ :: pairs_with_sum' m (n+1) (by simp at h; simp [h]) + +def pairs_with_sum (n) : list {p : ℕ × ℕ // p.1 + p.2 = n} := +pairs_with_sum' n 0 rfl + +inductive bin_tree +| leaf : bin_tree +| branch : bin_tree → bin_tree → bin_tree +open bin_tree + +def size : bin_tree → ℕ +| leaf := 0 +| (branch l r) := size l + size r + 1 + +def trees_of_size : Π s, list {bt : bin_tree // size bt = s} +| 0 := [⟨leaf, rfl⟩] +| (n+1) := + do ⟨(s1, s2), (h : s1 + s2 = n)⟩ ← pairs_with_sum n, + ⟨t1, sz1⟩ ← have s1 < n+1, by apply nat.lt_succ_of_le; rw -h; apply nat.le_add_right, + trees_of_size s1, + ⟨t2, sz2⟩ ← have s2 < n+1, by apply nat.lt_succ_of_le; rw -h; apply nat.le_add_left, + trees_of_size s2, + return ⟨branch t1 t2, by rw [-h, -sz1, -sz2]; refl⟩ diff --git a/tests/lean/run/term_app2.lean b/tests/lean/run/term_app2.lean index 593813c84c..f7117159f6 100644 --- a/tests/lean/run/term_app2.lean +++ b/tests/lean/run/term_app2.lean @@ -54,3 +54,15 @@ def num_consts : term → nat #eval num_consts (term.app "f" [term.const "x", term.app "g" [term.const "x", term.const "y"]]) #check num_consts.equations._eqn_2 + +def num_consts' : term → nat +| (term.const n) := 1 +| (term.app n ts) := + ts.attach.foldl + (λ r ⟨t, h⟩, + have sizeof t < 1 + (sizeof n + sizeof ts), from + nat.lt_one_add_of_lt (nat.lt_add_of_lt (list.sizeof_lt_sizeof_of_mem h)), + r + num_consts' t) + 0 + +#check num_consts'.equations._eqn_2