perf(library/compiler/cse): make sure we eliminate common sub expressions in match-cases associated with 0-ary constructors

The new test exposes the problem. Before this commit, the common
subexpressions at

```
def tst : tree → nat
| (tree.leaf v) := v
| (tree.node v l r) :=
  match f v with
  | tt := tst l + tst l - tst l   -- <<< HERE
  | ff := tst r
  end
```

were not converted into a let-exprs.
This commit is contained in:
Leonardo de Moura 2017-11-09 13:32:15 -08:00
parent 0bbe51615e
commit f2ef24696d
2 changed files with 64 additions and 0 deletions

View file

@ -149,6 +149,11 @@ class cse_fn : public compiler_step_visitor {
}
}
void collect_common_subexprs(expr const & e, expr_struct_set & r) {
buffer<expr> tmp;
collect_common_subexprs(tmp, e, r);
}
/* Helper functor for converting common subexpressions into fresh let-decls */
struct cse_processor {
unsigned & m_counter;
@ -264,6 +269,38 @@ class cse_fn : public compiler_step_visitor {
return visit_lambda_let(e);
}
expr visit_cases_on(expr const & e) {
buffer<expr> args;
expr const & fn = get_app_args(e, args);
args[0] = visit(args[0]); // major premise
for (unsigned i = 1; i < args.size(); i++) {
expr m = args[i];
if (is_lambda(m)) {
args[i] = visit(m);
} else {
m = visit(m);
expr_struct_set common_subexprs;
collect_common_subexprs(m, common_subexprs);
if (!common_subexprs.empty()) {
cse_processor proc(m_counter, m_ctx, common_subexprs);
m = proc.process(m);
m = copy_tag(args[i], proc.m_all_locals.mk_lambda(m));
}
args[i] = m;
}
}
return mk_app(fn, args);
}
virtual expr visit_app(expr const & e) override {
expr const & fn = get_app_fn(e);
if (is_vm_supported_cases(m_env, fn)) {
return visit_cases_on(e);
} else {
return compiler_step_visitor::visit_app(e);
}
}
public:
cse_fn(environment const & env):compiler_step_visitor(env) {}
};

View file

@ -0,0 +1,27 @@
def f : nat → bool
| 0 := ff
| _ := tt
inductive tree
| leaf : nat → tree
| node : nat → tree → tree → tree
def mk_tree : nat → nat → tree
| 0 v := tree.leaf v
| (n+1) v :=
let t := mk_tree n v in
tree.node v t t
def tst : tree → nat
| (tree.leaf v) := v
| (tree.node v l r) :=
match f v with
| tt := tst l + tst l - tst l
| ff := tst r
end
def tree.is_node : tree → bool
| (tree.leaf v) := ff
| _ := tt
#eval timeit "tst" $ tst (mk_tree 100 10)