fix(library/compiler/struct_cases_on): performance problem exposed by badupdate1.lean

This commit is contained in:
Leonardo de Moura 2019-04-26 16:30:19 -07:00
parent 240ca3fc68
commit e1a84d2f2c
3 changed files with 90 additions and 87 deletions

View file

@ -6,6 +6,7 @@ Author: Leonardo de Moura
*/
#include "runtime/flet.h"
#include "kernel/instantiate.h"
#include "kernel/abstract.h"
#include "kernel/type_checker.h"
#include "library/trace.h"
#include "library/suffixes.h"
@ -15,7 +16,9 @@ namespace lean {
class struct_cases_on_fn {
type_checker::state m_st;
local_ctx m_lctx;
name_set m_scrutinies;
name_set m_scrutinies; /* Set of variables `x` such that there is `casesOn x ...` in the context */
name_map<name> m_first_proj; /* Map from variable `x` to the first projection `y := x.i` in the context */
name_set m_updated; /* Set of variables `x` such that there is a `S.mk ... x.i ... */
name m_fld{"_d"};
unsigned m_next_idx{1};
@ -29,6 +32,20 @@ class struct_cases_on_fn {
return r;
}
expr find(expr const & e) const {
if (is_fvar(e)) {
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
if (optional<expr> v = decl->get_value()) {
if (!is_join_point_name(decl->get_user_name()))
return find(*v);
}
}
} else if (is_mdata(e)) {
return find(mdata_expr(e));
}
return e;
}
expr visit_cases(expr const & e) {
flet<name_set> save(m_scrutinies, m_scrutinies);
buffer<expr> args;
@ -45,6 +62,18 @@ class struct_cases_on_fn {
expr visit_app(expr const & e) {
if (is_cases_on_app(env(), e)) {
return visit_cases(e);
} else if (is_constructor_app(env(), e)) {
buffer<expr> args;
expr const & k = get_app_args(e, args);
lean_assert(is_constant(k));
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
for (unsigned i = k_val.get_nparams(), idx = 0; i < args.size(); i++, idx++) {
expr arg = find(args[i]);
if (is_proj(arg) && proj_idx(arg) == idx && is_fvar(proj_expr(arg))) {
m_updated.insert(fvar_name(proj_expr(arg)));
}
}
return e;
} else {
return e;
}
@ -63,50 +92,16 @@ class struct_cases_on_fn {
return m_lctx.mk_lambda(fvars, e);
}
bool is_candidate(expr const & rhs) {
if (!is_proj(rhs)) return false;
/* Return `some s` if `rhs` is of the form `s.i`, and `s` is a free variables that has not been
scrutinized yet, and `s.i` is the first time it is being projected. */
optional<name> is_candidate(expr const & rhs) {
if (!is_proj(rhs)) return optional<name>();
expr const & s = proj_expr(rhs);
if (!is_fvar(s)) return false;
if (m_scrutinies.contains(fvar_name(s))) return false;
return true;
}
/* Return true iff `e` is a constructor application of inductive type `S_name` and containg `x`. */
bool is_ctor_of(expr const & e, name const & S_name, expr const & x) {
if (!is_constructor_app(env(), e)) return false;
buffer<expr> args;
expr const & k = get_app_args(e, args);
lean_assert(is_constant(k));
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
if (k_val.get_induct() != S_name) return false;
for (unsigned i = k_val.get_nparams(); i < args.size(); i++) {
if (args[i] == x)
return true;
}
return false;
}
/* Return true iff `e` contains a constructor application of inductive type `S_name` and containg `x`. */
bool has_ctor_with(expr e, name const & S_name, expr const & x) {
while (is_lambda(e)) {
e = binding_body(e);
}
while (is_let(e)) {
if (is_ctor_of(let_value(e), S_name, x))
return true;
e = let_body(e);
}
if (is_cases_on_app(env(), e)) {
buffer<expr> args;
get_app_args(e, args);
for (unsigned i = 1; i < args.size(); i++) {
if (has_ctor_with(args[i], S_name, x))
return true;
}
return false;
} else {
return is_ctor_of(e, S_name, x);
}
if (!is_fvar(s)) return optional<name>();
name const & s_name = fvar_name(s);
if (m_scrutinies.contains(s_name)) return optional<name>();
if (m_first_proj.contains(s_name)) return optional<name>();
return optional<name>(s_name);
}
static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer<expr> & result) {
@ -144,7 +139,17 @@ class struct_cases_on_fn {
}
}
bool should_add_cases_on(local_decl const & decl) {
expr val = *decl.get_value();
if (!is_proj(val)) return false;
expr const & s = proj_expr(val);
if (!is_fvar(s) || !m_updated.contains(fvar_name(s))) return false;
name const * x = m_first_proj.find(fvar_name(s));
return x && *x == decl.get_name();
}
expr visit_let(expr e) {
flet<name_map<name>> save(m_first_proj, m_first_proj);
buffer<expr> fvars;
while (is_let(e)) {
lean_assert(!has_loose_bvars(let_type(e)));
@ -154,31 +159,37 @@ class struct_cases_on_fn {
e = let_body(e);
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val);
fvars.push_back(new_fvar);
if (is_candidate(val)) {
lean_assert(is_proj(val));
lean_assert(proj_idx(val).is_small());
name const & S_name = proj_sname(val);
e = instantiate_rev(e, fvars.size(), fvars.data());
if (has_ctor_with(e, S_name, new_fvar)) {
/* Introduce a casesOn application. */
e = m_lctx.mk_lambda(new_fvar, e);
fvars.pop_back();
expr major = proj_expr(val);
buffer<expr> field_types;
get_struct_field_types(m_st, S_name, field_types);
unsigned i = field_types.size();
while (i > 0) {
--i;
e = mk_lambda(next_field_name(), field_types[i], e);
}
e = mk_app(mk_constant(name(S_name, g_cases_on)), major, e);
}
e = visit(e);
return m_lctx.mk_lambda(fvars, e);
if (optional<name> s = is_candidate(val)) {
m_first_proj.insert(*s, fvar_name(new_fvar));
}
}
e = visit(instantiate_rev(e, fvars.size(), fvars.data()));
return m_lctx.mk_lambda(fvars, e);
e = abstract(e, fvars.size(), fvars.data());
unsigned i = fvars.size();
while (i > 0) {
--i;
expr const & x = fvars[i];
lean_assert(is_fvar(x));
local_decl decl = m_lctx.get_local_decl(x);
expr type = decl.get_type();
expr val = *decl.get_value();
expr aval = abstract(val, i, fvars.data());
e = mk_let(decl.get_user_name(), type, aval, e);
if (should_add_cases_on(decl)) {
lean_assert(is_proj(val));
expr major = proj_expr(val);
buffer<expr> field_types;
get_struct_field_types(m_st, proj_sname(val), field_types);
e = lift_loose_bvars(e, field_types.size());
unsigned i = field_types.size();
while (i > 0) {
--i;
e = mk_lambda(next_field_name(), field_types[i], e);
}
e = mk_app(mk_constant(name(proj_sname(val), g_cases_on)), major, e);
}
}
return e;
}
expr visit(expr const & e) {

View file

@ -1,21 +1,18 @@
structure S :=
(vals : Array Nat) (sz : Nat)
structure S (α : Type) :=
(vals : Array α) (sz : Nat)
@[noinline] def inc0 (a : Array Nat) : Array Nat :=
a.modify 0 (+1)
set_option pp.implicit true
-- set_option trace.compiler.boxed true
def f1 (s : S) : S :=
def f1 (s : S Nat) : S Nat :=
{ vals := inc0 s.vals, .. s}
def f2 : S → S
def f2 : S Nat → S Nat
| ⟨vals, sz⟩ := ⟨inc0 vals, sz⟩
def test (f : S → S) (n : Nat): IO Unit :=
let s : S := { vals := mkArray (n*100) n, sz := n*100 } in
let s := n.repeat f s in
def test (f : S Nat → S Nat) (n : Nat): IO Unit :=
let s : S Nat := { vals := mkArray (n*100) n, sz := n*100 } in
let s := n.repeat f s in
IO.println (s.vals.get 0)
def main (xs : List String) : IO Unit :=

View file

@ -64,27 +64,22 @@ d.errorMsg != none
d.stxStack.size
def ParserData.restore (d : ParserData) (iniStackSz : Nat) (iniPos : Nat) : ParserData :=
match d with
| ⟨stack, _, cache, _⟩ := ⟨stack.shrink iniStackSz, iniPos, cache, none⟩
{ stxStack := d.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos, .. d}
def ParserData.setPos (d : ParserData) (pos : Nat) : ParserData :=
match d with
| ⟨stack, _, cache, msg⟩ := ⟨stack, pos, cache, msg⟩
{ pos := pos, .. d }
def ParserData.setCache (d : ParserData) (cache : ParserCache) : ParserData :=
match d with
| ⟨stack, pos, _, msg⟩ := ⟨stack, pos, cache, msg⟩
{ cache := cache, .. d }
def ParserData.pushSyntax (d : ParserData) (n : Syntax) : ParserData :=
match d with
| ⟨stack, pos, cache, msg⟩ := ⟨stack.push n, pos, cache, msg⟩
{ stxStack := d.stxStack.push n, .. d }
def ParserData.shrinkStack (d : ParserData) (iniStackSz : Nat) : ParserData :=
match d with
| ⟨stack, pos, cache, msg⟩ := ⟨stack.shrink iniStackSz, pos, cache, msg⟩
{ stxStack := d.stxStack.shrink iniStackSz, .. d }
def ParserData.next (d : ParserData) (s : String) (pos : Nat) : ParserData :=
d.setPos (s.next pos)
{ pos := s.next pos, .. d }
def ParserData.toErrorMsg (d : ParserData) (cfg : ParserConfig) : String :=
match d.errorMsg with