fix(library/equations_compiler): improve pull_nested_rec_fn, and make sure it communicates local propositions to the well founded recursion module

The bin_tree and num_consts examples can now be encoded more naturally.
This commit is contained in:
Leonardo de Moura 2017-05-26 10:45:39 -07:00
parent 0cd5feed6e
commit 4bdb2da1b6
5 changed files with 89 additions and 3 deletions

View file

@ -36,9 +36,9 @@ expr parse_match(parser & p, unsigned, expr const *, pos_info const & pos) {
if (p.curr_is_token(get_colon_tk())) {
p.next();
expr type = p.parse_expr();
fn = mk_local(mk_fresh_name(), *g_match_name, type, binder_info());
fn = mk_local(mk_fresh_name(), *g_match_name, type, mk_rec_info(true));
} else {
fn = mk_local(mk_fresh_name(), *g_match_name, mk_expr_placeholder(), binder_info());
fn = mk_local(mk_fresh_name(), *g_match_name, mk_expr_placeholder(), mk_rec_info(true));
}
p.check_token_next(get_with_tk(), "invalid 'match' expression, 'with' expected");

View file

@ -216,7 +216,8 @@ class binder_info {
inferred by class-instance resolution. */
unsigned m_inst_implicit:1;
/** \brief Auxiliary internal attribute used to mark local constants representing recursive functions
in recursive equations */
in recursive equations. TODO(Leo): rename to eqn_decl since we also mark non recursive equations
(e.g., `match ... with ... end`) with this flag. */
unsigned m_rec:1;
public:
binder_info(bool implicit = false, bool strict_implicit = false, bool inst_implicit = false, bool rec = false):

View file

@ -135,6 +135,32 @@ struct pull_nested_rec_fn : public replace_visitor {
t = is_lam ? ctx.mk_lambda(locals, t) : ctx.mk_pi(locals, t);
m_mctx = ctx.mctx();
m_lctx_stack.pop_back();
/* We clear the cache whenever we visit a binder because of
collect_local_props.
When pulling a recursive call (f t), the resulting term
may contain local variables that do not occur in (f t).
Thus, the cached value for (f t) may not be valid
in other contexts.
By clearing the cache we conservatively fix this issue.
Here is an example:
def filter : list A list A
| nil := nil
| (a :: l) :=
match (H a) with
| (is_true h_1) := a :: filter l
| (is_false h_2) := filter l
end
The first (filter l) is replaced with a term
(_f_1 l h_1) where _f_1 is a new fresh local.
We cannot replace the second (filter l)
with (_f_1 l h_1), since h_1 is not in the scope.
*/
m_cache.clear();
return t;
}
@ -175,9 +201,31 @@ struct pull_nested_rec_fn : public replace_visitor {
});
}
/* Collect local propositions. This is needed when the nested recursive call will
be defined by well-founded recursion, and we don't know whether local propositions
are hints for helping the "decreasing tactic".
In the future, we should add a mechanism that will only include these propositions
if the recursive call will be defined using well founded recursion.
*/
void collect_local_props(name_set & found, buffer<expr> & R) {
type_context ctx = mk_type_context(lctx());
lctx().for_each([&](local_decl const & d) {
if (!base_lctx().find_local_decl(d.get_name()) &&
!found.contains(d.get_name()) &&
!d.get_info().is_rec() &&
ctx.is_prop(d.get_type())) {
found.insert(d.get_name());
R.push_back(d.mk_ref());
}
});
}
void collect_locals(expr const & e, buffer<expr> & R) {
name_set found;
/* Collect used local declarations. */
collect_locals_core(e, found, R);
/* Collect local propositions. */
collect_local_props(found, R);
for (unsigned i = 0; i < R.size(); i++) {
expr const & x = R[i];
collect_locals_core(lctx().get_local_decl(x).get_type(), found, R);

View file

@ -0,0 +1,25 @@
def pairs_with_sum' : Π (m n) {d}, m + n = d → list {p : × // p.1 + p.2 = d}
| 0 n d h := [⟨(0, n), h⟩]
| (m+1) n d h := ⟨(m+1, n), h⟩ :: pairs_with_sum' m (n+1) (by simp at h; simp [h])
def pairs_with_sum (n) : list {p : × // p.1 + p.2 = n} :=
pairs_with_sum' n 0 rfl
inductive bin_tree
| leaf : bin_tree
| branch : bin_tree → bin_tree → bin_tree
open bin_tree
def size : bin_tree →
| leaf := 0
| (branch l r) := size l + size r + 1
def trees_of_size : Π s, list {bt : bin_tree // size bt = s}
| 0 := [⟨leaf, rfl⟩]
| (n+1) :=
do ⟨(s1, s2), (h : s1 + s2 = n)⟩ ← pairs_with_sum n,
⟨t1, sz1⟩ ← have s1 < n+1, by apply nat.lt_succ_of_le; rw -h; apply nat.le_add_right,
trees_of_size s1,
⟨t2, sz2⟩ ← have s2 < n+1, by apply nat.lt_succ_of_le; rw -h; apply nat.le_add_left,
trees_of_size s2,
return ⟨branch t1 t2, by rw [-h, -sz1, -sz2]; refl⟩

View file

@ -54,3 +54,15 @@ def num_consts : term → nat
#eval num_consts (term.app "f" [term.const "x", term.app "g" [term.const "x", term.const "y"]])
#check num_consts.equations._eqn_2
def num_consts' : term → nat
| (term.const n) := 1
| (term.app n ts) :=
ts.attach.foldl
(λ r ⟨t, h⟩,
have sizeof t < 1 + (sizeof n + sizeof ts), from
nat.lt_one_add_of_lt (nat.lt_add_of_lt (list.sizeof_lt_sizeof_of_mem h)),
r + num_consts' t)
0
#check num_consts'.equations._eqn_2