diff --git a/src/library/compiler/cse.cpp b/src/library/compiler/cse.cpp index d603cd9a8e..729f98febc 100644 --- a/src/library/compiler/cse.cpp +++ b/src/library/compiler/cse.cpp @@ -149,6 +149,11 @@ class cse_fn : public compiler_step_visitor { } } + void collect_common_subexprs(expr const & e, expr_struct_set & r) { + buffer tmp; + collect_common_subexprs(tmp, e, r); + } + /* Helper functor for converting common subexpressions into fresh let-decls */ struct cse_processor { unsigned & m_counter; @@ -264,6 +269,38 @@ class cse_fn : public compiler_step_visitor { return visit_lambda_let(e); } + expr visit_cases_on(expr const & e) { + buffer args; + expr const & fn = get_app_args(e, args); + args[0] = visit(args[0]); // major premise + for (unsigned i = 1; i < args.size(); i++) { + expr m = args[i]; + if (is_lambda(m)) { + args[i] = visit(m); + } else { + m = visit(m); + expr_struct_set common_subexprs; + collect_common_subexprs(m, common_subexprs); + if (!common_subexprs.empty()) { + cse_processor proc(m_counter, m_ctx, common_subexprs); + m = proc.process(m); + m = copy_tag(args[i], proc.m_all_locals.mk_lambda(m)); + } + args[i] = m; + } + } + return mk_app(fn, args); + } + + virtual expr visit_app(expr const & e) override { + expr const & fn = get_app_fn(e); + if (is_vm_supported_cases(m_env, fn)) { + return visit_cases_on(e); + } else { + return compiler_step_visitor::visit_app(e); + } + } + public: cse_fn(environment const & env):compiler_step_visitor(env) {} }; diff --git a/tests/lean/run/cse_perf_issue.lean b/tests/lean/run/cse_perf_issue.lean new file mode 100644 index 0000000000..b2001985dd --- /dev/null +++ b/tests/lean/run/cse_perf_issue.lean @@ -0,0 +1,27 @@ +def f : nat → bool +| 0 := ff +| _ := tt + +inductive tree +| leaf : nat → tree +| node : nat → tree → tree → tree + +def mk_tree : nat → nat → tree +| 0 v := tree.leaf v +| (n+1) v := + let t := mk_tree n v in + tree.node v t t + +def tst : tree → nat +| (tree.leaf v) := v +| (tree.node v l r) := + match f v with + | tt := tst l + tst l - tst l + | ff := tst r + end + +def tree.is_node : tree → bool +| (tree.leaf v) := ff +| _ := tt + +#eval timeit "tst" $ tst (mk_tree 100 10)