fix(library/compiler/simp_inductive): erase trivial structure bug
This commit is contained in:
parent
30cfcc0fa6
commit
96e02613fc
2 changed files with 149 additions and 35 deletions
|
|
@ -82,18 +82,9 @@ unsigned get_vm_supported_cases_num_minors(environment const & env, expr const &
|
|||
}
|
||||
}
|
||||
|
||||
class simp_inductive_fn : public compiler_step_visitor {
|
||||
class simp_inductive_core_fn : public compiler_step_visitor {
|
||||
name_map<list<bool>> m_constructor_info;
|
||||
|
||||
void get_constructor_info(name const & n, buffer<bool> & rel_fields) {
|
||||
if (auto r = m_constructor_info.find(n)) {
|
||||
to_buffer(*r, rel_fields);
|
||||
} else {
|
||||
get_constructor_relevant_fields(env(), n, rel_fields);
|
||||
m_constructor_info.insert(n, to_list(rel_fields));
|
||||
}
|
||||
}
|
||||
|
||||
protected:
|
||||
/* Return new minor premise and a flag indicating whether the body is unreachable or not */
|
||||
pair<expr, bool> visit_minor_premise(expr e, buffer<bool> const & rel_fields) {
|
||||
type_context::tmp_locals locals(ctx());
|
||||
|
|
@ -111,6 +102,27 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
return mk_pair(locals.mk_lambda(e), unreachable);
|
||||
}
|
||||
|
||||
void get_constructor_info(name const & n, buffer<bool> & rel_fields) {
|
||||
if (auto r = m_constructor_info.find(n)) {
|
||||
to_buffer(*r, rel_fields);
|
||||
} else {
|
||||
get_constructor_relevant_fields(env(), n, rel_fields);
|
||||
m_constructor_info.insert(n, to_list(rel_fields));
|
||||
}
|
||||
}
|
||||
public:
|
||||
simp_inductive_core_fn(environment const & env):compiler_step_visitor(env) {}
|
||||
};
|
||||
|
||||
/*
|
||||
Remove constructor/projection/cases_on applications of trivial structures.
|
||||
|
||||
We say a structure is trivial if it has only constructor and
|
||||
the constructor has only one relevant field.
|
||||
In this case, we use a simple optimization where we represent elements of this inductive
|
||||
datatype as the only relevant element.
|
||||
*/
|
||||
class erase_trivial_structures_fn : public simp_inductive_core_fn {
|
||||
bool has_only_one_constructor(name const & I_name) const {
|
||||
if (auto r = inductive::get_num_intro_rules(env(), I_name))
|
||||
return *r == 1;
|
||||
|
|
@ -120,10 +132,7 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
|
||||
/* Return true iff inductive datatype I_name has only one constructor,
|
||||
and this constructor has only one relevant field.
|
||||
The argument rel_fields is a bit-vector of relevant fields.
|
||||
|
||||
In this case, we use a simple optimization where we represent elements of this inductive
|
||||
datatype as the only relevant element. */
|
||||
The argument rel_fields is a bit-vector of relevant fields. */
|
||||
bool has_trivial_structure(name const & I_name, buffer<bool> const & rel_fields) const {
|
||||
if (!has_only_one_constructor(I_name))
|
||||
return false;
|
||||
|
|
@ -137,6 +146,98 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
return num_rel == 1;
|
||||
}
|
||||
|
||||
expr visit_default(name const & fn, buffer<expr> const & args) {
|
||||
buffer<expr> new_args;
|
||||
for (expr const & arg : args)
|
||||
new_args.push_back(visit(arg));
|
||||
return mk_app(mk_constant(fn), new_args);
|
||||
}
|
||||
|
||||
expr visit_constructor(name const & fn, buffer<expr> const & args) {
|
||||
if (is_vm_builtin_function(fn))
|
||||
return visit_default(fn, args);
|
||||
|
||||
name I_name = *inductive::is_intro_rule(env(), fn);
|
||||
buffer<bool> rel_fields;
|
||||
get_constructor_info(fn, rel_fields);
|
||||
if (has_trivial_structure(I_name, rel_fields)) {
|
||||
unsigned nparams = *inductive::get_num_params(env(), I_name);
|
||||
for (unsigned i = 0; i < rel_fields.size(); i++) {
|
||||
if (rel_fields[i]) {
|
||||
return visit(args[nparams + i]);
|
||||
}
|
||||
}
|
||||
lean_unreachable();
|
||||
} else {
|
||||
return visit_default(fn, args);
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_projection(name const & fn, buffer<expr> const & args) {
|
||||
if (is_vm_builtin_function(fn))
|
||||
return visit_default(fn, args);
|
||||
|
||||
projection_info const & info = *get_projection_info(env(), fn);
|
||||
name I_name = *inductive::is_intro_rule(env(), info.m_constructor);
|
||||
buffer<bool> rel_fields;
|
||||
get_constructor_info(info.m_constructor, rel_fields);
|
||||
if (has_trivial_structure(I_name, rel_fields)) {
|
||||
expr major = args[info.m_nparams];
|
||||
expr r = visit(major);
|
||||
/* Add additional arguments */
|
||||
for (unsigned i = info.m_nparams + 1; i < args.size(); i++)
|
||||
r = mk_app(r, visit(args[i]));
|
||||
return r;
|
||||
} else {
|
||||
return visit_default(fn, args);
|
||||
}
|
||||
}
|
||||
|
||||
expr visit_cases_on(name const & fn, buffer<expr> & args) {
|
||||
if (is_vm_builtin_function(fn))
|
||||
return visit_default(fn, args);
|
||||
|
||||
name const & I_name = fn.get_prefix();
|
||||
buffer<name> cnames;
|
||||
get_intro_rule_names(env(), I_name, cnames);
|
||||
if (cnames.size() != 1)
|
||||
return visit_default(fn, args);
|
||||
|
||||
buffer<bool> rel_fields;
|
||||
get_constructor_info(cnames[0], rel_fields);
|
||||
|
||||
if (has_trivial_structure(I_name, rel_fields)) {
|
||||
lean_assert(args.size() >= 2);
|
||||
expr major = visit(args[0]);
|
||||
expr minor = visit_minor_premise(args[1], rel_fields).first;
|
||||
for (unsigned i = 2; i < args.size(); i++)
|
||||
args[i] = visit(args[i]);
|
||||
return beta_reduce(mk_app(mk_app(minor, major), args.size() - 2, args.data() + 2));
|
||||
} else {
|
||||
return visit_default(fn, args);
|
||||
}
|
||||
}
|
||||
|
||||
virtual expr visit_app(expr const & e) override {
|
||||
buffer<expr> args;
|
||||
expr const & fn = get_app_args(e, args);
|
||||
if (is_constant(fn)) {
|
||||
name const & n = const_name(fn);
|
||||
if (is_cases_on_recursor(env(), n)) {
|
||||
return visit_cases_on(n, args);
|
||||
} else if (inductive::is_intro_rule(env(), n)) {
|
||||
return visit_constructor(n, args);
|
||||
} else if (is_projection(env(), n)) {
|
||||
return visit_projection(n, args);
|
||||
}
|
||||
}
|
||||
return compiler_step_visitor::visit_app(e);
|
||||
}
|
||||
public:
|
||||
erase_trivial_structures_fn(environment const & env):simp_inductive_core_fn(env) {}
|
||||
};
|
||||
|
||||
class simp_inductive_fn : public simp_inductive_core_fn {
|
||||
/* Given a cases_on application, distribute extra arguments over minor premisses.
|
||||
|
||||
cases_on major minor_1 ... minor_n a_1 ... a_n
|
||||
|
|
@ -187,10 +288,6 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
get_constructor_info(cnames[i], rel_fields);
|
||||
auto p = visit_minor_premise(args[i+1], rel_fields);
|
||||
expr new_minor = p.first;
|
||||
if (i == 0 && !is_builtin && has_trivial_structure(I_name, rel_fields)) {
|
||||
/* Optimization for an inductive datatype that has a single constructor with only one relevant field */
|
||||
return beta_reduce(mk_app(new_minor, args[0]));
|
||||
}
|
||||
args[i+1] = new_minor;
|
||||
if (!p.second) {
|
||||
num_reachable++;
|
||||
|
|
@ -245,12 +342,7 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
new_args.push_back(visit(args[nparams + i]));
|
||||
}
|
||||
}
|
||||
if (has_trivial_structure(I_name, rel_fields)) {
|
||||
lean_assert(new_args.size() == 1);
|
||||
return new_args[0];
|
||||
} else {
|
||||
return mk_app(mk_cnstr(cidx), new_args);
|
||||
}
|
||||
return mk_app(mk_cnstr(cidx), new_args);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -272,12 +364,7 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
j++;
|
||||
}
|
||||
expr r;
|
||||
if (has_trivial_structure(I_name, rel_fields)) {
|
||||
lean_assert(j == 0);
|
||||
r = major;
|
||||
} else {
|
||||
r = mk_app(mk_proj(j), major);
|
||||
}
|
||||
r = mk_app(mk_proj(j), major);
|
||||
/* Add additional arguments */
|
||||
for (unsigned i = info.m_nparams + 1; i < args.size(); i++)
|
||||
r = mk_app(r, visit(args[i]));
|
||||
|
|
@ -313,17 +400,36 @@ class simp_inductive_fn : public compiler_step_visitor {
|
|||
}
|
||||
|
||||
public:
|
||||
simp_inductive_fn(environment const & env):compiler_step_visitor(env) {}
|
||||
simp_inductive_fn(environment const & env):simp_inductive_core_fn(env) {}
|
||||
};
|
||||
|
||||
/*
|
||||
Remark: we used to combine erase_trivial_structures_fn and simp_inductive_fn in
|
||||
a single pass. This is bad because the result may contain `cases` applications
|
||||
where the number of arguments is not equal to the number of case + 1 (major).
|
||||
The issue is that erase_trivial_structures_fn step may produce new opportunites
|
||||
for the distribute-arguments-over-minor-premises transformation.
|
||||
|
||||
Here is an small example that exposes the problem:
|
||||
```
|
||||
structure box (α : Type) :=
|
||||
(val : α)
|
||||
|
||||
def f (g h : box (ℕ → ℕ)) (b : bool) : ℕ → ℕ :=
|
||||
box.val (bool.cases_on b g h)
|
||||
```
|
||||
*/
|
||||
|
||||
expr simp_inductive(environment const & env, expr const & e) {
|
||||
return simp_inductive_fn(env)(e);
|
||||
expr e1 = erase_trivial_structures_fn(env)(e);
|
||||
return simp_inductive_fn(env)(e1);
|
||||
}
|
||||
|
||||
void simp_inductive(environment const & env, buffer<procedure> & procs) {
|
||||
simp_inductive_fn fn(env);
|
||||
erase_trivial_structures_fn eraser(env);
|
||||
simp_inductive_fn simplifier(env);
|
||||
for (auto & proc : procs)
|
||||
proc.m_code = fn(proc.m_code);
|
||||
proc.m_code = simplifier(eraser(proc.m_code));
|
||||
}
|
||||
|
||||
void initialize_simp_inductive() {
|
||||
|
|
|
|||
8
tests/lean/run/simp_inductive_compiler_issue.lean
Normal file
8
tests/lean/run/simp_inductive_compiler_issue.lean
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
structure box (α : Type) :=
|
||||
(val : α)
|
||||
|
||||
def f1 (g h : ℕ → ℕ) (b : bool) : ℕ → ℕ :=
|
||||
box.val (bool.cases_on b (box.mk g) (box.mk h))
|
||||
|
||||
def f2 (g h : box (ℕ → ℕ)) (b : bool) : ℕ → ℕ :=
|
||||
box.val (bool.cases_on b g h)
|
||||
Loading…
Add table
Reference in a new issue