From 96e02613fcdda8449c1ed8fcdf8e58e2c4140731 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 11 Feb 2018 11:35:31 -0800 Subject: [PATCH] fix(library/compiler/simp_inductive): erase trivial structure bug --- src/library/compiler/simp_inductive.cpp | 176 ++++++++++++++---- .../run/simp_inductive_compiler_issue.lean | 8 + 2 files changed, 149 insertions(+), 35 deletions(-) create mode 100644 tests/lean/run/simp_inductive_compiler_issue.lean diff --git a/src/library/compiler/simp_inductive.cpp b/src/library/compiler/simp_inductive.cpp index fcfcf44847..0b26420435 100644 --- a/src/library/compiler/simp_inductive.cpp +++ b/src/library/compiler/simp_inductive.cpp @@ -82,18 +82,9 @@ unsigned get_vm_supported_cases_num_minors(environment const & env, expr const & } } -class simp_inductive_fn : public compiler_step_visitor { +class simp_inductive_core_fn : public compiler_step_visitor { name_map> m_constructor_info; - - void get_constructor_info(name const & n, buffer & rel_fields) { - if (auto r = m_constructor_info.find(n)) { - to_buffer(*r, rel_fields); - } else { - get_constructor_relevant_fields(env(), n, rel_fields); - m_constructor_info.insert(n, to_list(rel_fields)); - } - } - +protected: /* Return new minor premise and a flag indicating whether the body is unreachable or not */ pair visit_minor_premise(expr e, buffer const & rel_fields) { type_context::tmp_locals locals(ctx()); @@ -111,6 +102,27 @@ class simp_inductive_fn : public compiler_step_visitor { return mk_pair(locals.mk_lambda(e), unreachable); } + void get_constructor_info(name const & n, buffer & rel_fields) { + if (auto r = m_constructor_info.find(n)) { + to_buffer(*r, rel_fields); + } else { + get_constructor_relevant_fields(env(), n, rel_fields); + m_constructor_info.insert(n, to_list(rel_fields)); + } + } +public: + simp_inductive_core_fn(environment const & env):compiler_step_visitor(env) {} +}; + +/* + Remove constructor/projection/cases_on applications of trivial structures. + + We say a structure is trivial if it has only constructor and + the constructor has only one relevant field. + In this case, we use a simple optimization where we represent elements of this inductive + datatype as the only relevant element. +*/ +class erase_trivial_structures_fn : public simp_inductive_core_fn { bool has_only_one_constructor(name const & I_name) const { if (auto r = inductive::get_num_intro_rules(env(), I_name)) return *r == 1; @@ -120,10 +132,7 @@ class simp_inductive_fn : public compiler_step_visitor { /* Return true iff inductive datatype I_name has only one constructor, and this constructor has only one relevant field. - The argument rel_fields is a bit-vector of relevant fields. - - In this case, we use a simple optimization where we represent elements of this inductive - datatype as the only relevant element. */ + The argument rel_fields is a bit-vector of relevant fields. */ bool has_trivial_structure(name const & I_name, buffer const & rel_fields) const { if (!has_only_one_constructor(I_name)) return false; @@ -137,6 +146,98 @@ class simp_inductive_fn : public compiler_step_visitor { return num_rel == 1; } + expr visit_default(name const & fn, buffer const & args) { + buffer new_args; + for (expr const & arg : args) + new_args.push_back(visit(arg)); + return mk_app(mk_constant(fn), new_args); + } + + expr visit_constructor(name const & fn, buffer const & args) { + if (is_vm_builtin_function(fn)) + return visit_default(fn, args); + + name I_name = *inductive::is_intro_rule(env(), fn); + buffer rel_fields; + get_constructor_info(fn, rel_fields); + if (has_trivial_structure(I_name, rel_fields)) { + unsigned nparams = *inductive::get_num_params(env(), I_name); + for (unsigned i = 0; i < rel_fields.size(); i++) { + if (rel_fields[i]) { + return visit(args[nparams + i]); + } + } + lean_unreachable(); + } else { + return visit_default(fn, args); + } + } + + expr visit_projection(name const & fn, buffer const & args) { + if (is_vm_builtin_function(fn)) + return visit_default(fn, args); + + projection_info const & info = *get_projection_info(env(), fn); + name I_name = *inductive::is_intro_rule(env(), info.m_constructor); + buffer rel_fields; + get_constructor_info(info.m_constructor, rel_fields); + if (has_trivial_structure(I_name, rel_fields)) { + expr major = args[info.m_nparams]; + expr r = visit(major); + /* Add additional arguments */ + for (unsigned i = info.m_nparams + 1; i < args.size(); i++) + r = mk_app(r, visit(args[i])); + return r; + } else { + return visit_default(fn, args); + } + } + + expr visit_cases_on(name const & fn, buffer & args) { + if (is_vm_builtin_function(fn)) + return visit_default(fn, args); + + name const & I_name = fn.get_prefix(); + buffer cnames; + get_intro_rule_names(env(), I_name, cnames); + if (cnames.size() != 1) + return visit_default(fn, args); + + buffer rel_fields; + get_constructor_info(cnames[0], rel_fields); + + if (has_trivial_structure(I_name, rel_fields)) { + lean_assert(args.size() >= 2); + expr major = visit(args[0]); + expr minor = visit_minor_premise(args[1], rel_fields).first; + for (unsigned i = 2; i < args.size(); i++) + args[i] = visit(args[i]); + return beta_reduce(mk_app(mk_app(minor, major), args.size() - 2, args.data() + 2)); + } else { + return visit_default(fn, args); + } + } + + virtual expr visit_app(expr const & e) override { + buffer args; + expr const & fn = get_app_args(e, args); + if (is_constant(fn)) { + name const & n = const_name(fn); + if (is_cases_on_recursor(env(), n)) { + return visit_cases_on(n, args); + } else if (inductive::is_intro_rule(env(), n)) { + return visit_constructor(n, args); + } else if (is_projection(env(), n)) { + return visit_projection(n, args); + } + } + return compiler_step_visitor::visit_app(e); + } +public: + erase_trivial_structures_fn(environment const & env):simp_inductive_core_fn(env) {} +}; + +class simp_inductive_fn : public simp_inductive_core_fn { /* Given a cases_on application, distribute extra arguments over minor premisses. cases_on major minor_1 ... minor_n a_1 ... a_n @@ -187,10 +288,6 @@ class simp_inductive_fn : public compiler_step_visitor { get_constructor_info(cnames[i], rel_fields); auto p = visit_minor_premise(args[i+1], rel_fields); expr new_minor = p.first; - if (i == 0 && !is_builtin && has_trivial_structure(I_name, rel_fields)) { - /* Optimization for an inductive datatype that has a single constructor with only one relevant field */ - return beta_reduce(mk_app(new_minor, args[0])); - } args[i+1] = new_minor; if (!p.second) { num_reachable++; @@ -245,12 +342,7 @@ class simp_inductive_fn : public compiler_step_visitor { new_args.push_back(visit(args[nparams + i])); } } - if (has_trivial_structure(I_name, rel_fields)) { - lean_assert(new_args.size() == 1); - return new_args[0]; - } else { - return mk_app(mk_cnstr(cidx), new_args); - } + return mk_app(mk_cnstr(cidx), new_args); } } @@ -272,12 +364,7 @@ class simp_inductive_fn : public compiler_step_visitor { j++; } expr r; - if (has_trivial_structure(I_name, rel_fields)) { - lean_assert(j == 0); - r = major; - } else { - r = mk_app(mk_proj(j), major); - } + r = mk_app(mk_proj(j), major); /* Add additional arguments */ for (unsigned i = info.m_nparams + 1; i < args.size(); i++) r = mk_app(r, visit(args[i])); @@ -313,17 +400,36 @@ class simp_inductive_fn : public compiler_step_visitor { } public: - simp_inductive_fn(environment const & env):compiler_step_visitor(env) {} + simp_inductive_fn(environment const & env):simp_inductive_core_fn(env) {} }; +/* + Remark: we used to combine erase_trivial_structures_fn and simp_inductive_fn in + a single pass. This is bad because the result may contain `cases` applications + where the number of arguments is not equal to the number of case + 1 (major). + The issue is that erase_trivial_structures_fn step may produce new opportunites + for the distribute-arguments-over-minor-premises transformation. + + Here is an small example that exposes the problem: + ``` + structure box (α : Type) := + (val : α) + + def f (g h : box (ℕ → ℕ)) (b : bool) : ℕ → ℕ := + box.val (bool.cases_on b g h) + ``` +*/ + expr simp_inductive(environment const & env, expr const & e) { - return simp_inductive_fn(env)(e); + expr e1 = erase_trivial_structures_fn(env)(e); + return simp_inductive_fn(env)(e1); } void simp_inductive(environment const & env, buffer & procs) { - simp_inductive_fn fn(env); + erase_trivial_structures_fn eraser(env); + simp_inductive_fn simplifier(env); for (auto & proc : procs) - proc.m_code = fn(proc.m_code); + proc.m_code = simplifier(eraser(proc.m_code)); } void initialize_simp_inductive() { diff --git a/tests/lean/run/simp_inductive_compiler_issue.lean b/tests/lean/run/simp_inductive_compiler_issue.lean new file mode 100644 index 0000000000..5e40fc5b26 --- /dev/null +++ b/tests/lean/run/simp_inductive_compiler_issue.lean @@ -0,0 +1,8 @@ +structure box (α : Type) := +(val : α) + +def f1 (g h : ℕ → ℕ) (b : bool) : ℕ → ℕ := +box.val (bool.cases_on b (box.mk g) (box.mk h)) + +def f2 (g h : box (ℕ → ℕ)) (b : bool) : ℕ → ℕ := +box.val (bool.cases_on b g h)