fix(library/compiler/struct_cases_on): performance problem exposed by badupdate1.lean
This commit is contained in:
parent
240ca3fc68
commit
e1a84d2f2c
3 changed files with 90 additions and 87 deletions
|
|
@ -6,6 +6,7 @@ Author: Leonardo de Moura
|
|||
*/
|
||||
#include "runtime/flet.h"
|
||||
#include "kernel/instantiate.h"
|
||||
#include "kernel/abstract.h"
|
||||
#include "kernel/type_checker.h"
|
||||
#include "library/trace.h"
|
||||
#include "library/suffixes.h"
|
||||
|
|
@ -15,7 +16,9 @@ namespace lean {
|
|||
class struct_cases_on_fn {
|
||||
type_checker::state m_st;
|
||||
local_ctx m_lctx;
|
||||
name_set m_scrutinies;
|
||||
name_set m_scrutinies; /* Set of variables `x` such that there is `casesOn x ...` in the context */
|
||||
name_map<name> m_first_proj; /* Map from variable `x` to the first projection `y := x.i` in the context */
|
||||
name_set m_updated; /* Set of variables `x` such that there is a `S.mk ... x.i ... */
|
||||
name m_fld{"_d"};
|
||||
unsigned m_next_idx{1};
|
||||
|
||||
|
|
@ -29,6 +32,20 @@ class struct_cases_on_fn {
|
|||
return r;
|
||||
}
|
||||
|
||||
expr find(expr const & e) const {
|
||||
if (is_fvar(e)) {
|
||||
if (optional<local_decl> decl = m_lctx.find_local_decl(e)) {
|
||||
if (optional<expr> v = decl->get_value()) {
|
||||
if (!is_join_point_name(decl->get_user_name()))
|
||||
return find(*v);
|
||||
}
|
||||
}
|
||||
} else if (is_mdata(e)) {
|
||||
return find(mdata_expr(e));
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
expr visit_cases(expr const & e) {
|
||||
flet<name_set> save(m_scrutinies, m_scrutinies);
|
||||
buffer<expr> args;
|
||||
|
|
@ -45,6 +62,18 @@ class struct_cases_on_fn {
|
|||
expr visit_app(expr const & e) {
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
return visit_cases(e);
|
||||
} else if (is_constructor_app(env(), e)) {
|
||||
buffer<expr> args;
|
||||
expr const & k = get_app_args(e, args);
|
||||
lean_assert(is_constant(k));
|
||||
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
|
||||
for (unsigned i = k_val.get_nparams(), idx = 0; i < args.size(); i++, idx++) {
|
||||
expr arg = find(args[i]);
|
||||
if (is_proj(arg) && proj_idx(arg) == idx && is_fvar(proj_expr(arg))) {
|
||||
m_updated.insert(fvar_name(proj_expr(arg)));
|
||||
}
|
||||
}
|
||||
return e;
|
||||
} else {
|
||||
return e;
|
||||
}
|
||||
|
|
@ -63,50 +92,16 @@ class struct_cases_on_fn {
|
|||
return m_lctx.mk_lambda(fvars, e);
|
||||
}
|
||||
|
||||
bool is_candidate(expr const & rhs) {
|
||||
if (!is_proj(rhs)) return false;
|
||||
/* Return `some s` if `rhs` is of the form `s.i`, and `s` is a free variables that has not been
|
||||
scrutinized yet, and `s.i` is the first time it is being projected. */
|
||||
optional<name> is_candidate(expr const & rhs) {
|
||||
if (!is_proj(rhs)) return optional<name>();
|
||||
expr const & s = proj_expr(rhs);
|
||||
if (!is_fvar(s)) return false;
|
||||
if (m_scrutinies.contains(fvar_name(s))) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
/* Return true iff `e` is a constructor application of inductive type `S_name` and containg `x`. */
|
||||
bool is_ctor_of(expr const & e, name const & S_name, expr const & x) {
|
||||
if (!is_constructor_app(env(), e)) return false;
|
||||
buffer<expr> args;
|
||||
expr const & k = get_app_args(e, args);
|
||||
lean_assert(is_constant(k));
|
||||
constructor_val k_val = env().get(const_name(k)).to_constructor_val();
|
||||
if (k_val.get_induct() != S_name) return false;
|
||||
for (unsigned i = k_val.get_nparams(); i < args.size(); i++) {
|
||||
if (args[i] == x)
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
/* Return true iff `e` contains a constructor application of inductive type `S_name` and containg `x`. */
|
||||
bool has_ctor_with(expr e, name const & S_name, expr const & x) {
|
||||
while (is_lambda(e)) {
|
||||
e = binding_body(e);
|
||||
}
|
||||
while (is_let(e)) {
|
||||
if (is_ctor_of(let_value(e), S_name, x))
|
||||
return true;
|
||||
e = let_body(e);
|
||||
}
|
||||
if (is_cases_on_app(env(), e)) {
|
||||
buffer<expr> args;
|
||||
get_app_args(e, args);
|
||||
for (unsigned i = 1; i < args.size(); i++) {
|
||||
if (has_ctor_with(args[i], S_name, x))
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
} else {
|
||||
return is_ctor_of(e, S_name, x);
|
||||
}
|
||||
if (!is_fvar(s)) return optional<name>();
|
||||
name const & s_name = fvar_name(s);
|
||||
if (m_scrutinies.contains(s_name)) return optional<name>();
|
||||
if (m_first_proj.contains(s_name)) return optional<name>();
|
||||
return optional<name>(s_name);
|
||||
}
|
||||
|
||||
static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer<expr> & result) {
|
||||
|
|
@ -144,7 +139,17 @@ class struct_cases_on_fn {
|
|||
}
|
||||
}
|
||||
|
||||
bool should_add_cases_on(local_decl const & decl) {
|
||||
expr val = *decl.get_value();
|
||||
if (!is_proj(val)) return false;
|
||||
expr const & s = proj_expr(val);
|
||||
if (!is_fvar(s) || !m_updated.contains(fvar_name(s))) return false;
|
||||
name const * x = m_first_proj.find(fvar_name(s));
|
||||
return x && *x == decl.get_name();
|
||||
}
|
||||
|
||||
expr visit_let(expr e) {
|
||||
flet<name_map<name>> save(m_first_proj, m_first_proj);
|
||||
buffer<expr> fvars;
|
||||
while (is_let(e)) {
|
||||
lean_assert(!has_loose_bvars(let_type(e)));
|
||||
|
|
@ -154,31 +159,37 @@ class struct_cases_on_fn {
|
|||
e = let_body(e);
|
||||
expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val);
|
||||
fvars.push_back(new_fvar);
|
||||
if (is_candidate(val)) {
|
||||
lean_assert(is_proj(val));
|
||||
lean_assert(proj_idx(val).is_small());
|
||||
name const & S_name = proj_sname(val);
|
||||
e = instantiate_rev(e, fvars.size(), fvars.data());
|
||||
if (has_ctor_with(e, S_name, new_fvar)) {
|
||||
/* Introduce a casesOn application. */
|
||||
e = m_lctx.mk_lambda(new_fvar, e);
|
||||
fvars.pop_back();
|
||||
expr major = proj_expr(val);
|
||||
buffer<expr> field_types;
|
||||
get_struct_field_types(m_st, S_name, field_types);
|
||||
unsigned i = field_types.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
e = mk_lambda(next_field_name(), field_types[i], e);
|
||||
}
|
||||
e = mk_app(mk_constant(name(S_name, g_cases_on)), major, e);
|
||||
}
|
||||
e = visit(e);
|
||||
return m_lctx.mk_lambda(fvars, e);
|
||||
if (optional<name> s = is_candidate(val)) {
|
||||
m_first_proj.insert(*s, fvar_name(new_fvar));
|
||||
}
|
||||
}
|
||||
e = visit(instantiate_rev(e, fvars.size(), fvars.data()));
|
||||
return m_lctx.mk_lambda(fvars, e);
|
||||
e = abstract(e, fvars.size(), fvars.data());
|
||||
unsigned i = fvars.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
expr const & x = fvars[i];
|
||||
lean_assert(is_fvar(x));
|
||||
local_decl decl = m_lctx.get_local_decl(x);
|
||||
expr type = decl.get_type();
|
||||
expr val = *decl.get_value();
|
||||
expr aval = abstract(val, i, fvars.data());
|
||||
e = mk_let(decl.get_user_name(), type, aval, e);
|
||||
if (should_add_cases_on(decl)) {
|
||||
lean_assert(is_proj(val));
|
||||
expr major = proj_expr(val);
|
||||
buffer<expr> field_types;
|
||||
get_struct_field_types(m_st, proj_sname(val), field_types);
|
||||
e = lift_loose_bvars(e, field_types.size());
|
||||
unsigned i = field_types.size();
|
||||
while (i > 0) {
|
||||
--i;
|
||||
e = mk_lambda(next_field_name(), field_types[i], e);
|
||||
}
|
||||
e = mk_app(mk_constant(name(proj_sname(val), g_cases_on)), major, e);
|
||||
}
|
||||
}
|
||||
return e;
|
||||
}
|
||||
|
||||
expr visit(expr const & e) {
|
||||
|
|
|
|||
|
|
@ -1,21 +1,18 @@
|
|||
structure S :=
|
||||
(vals : Array Nat) (sz : Nat)
|
||||
structure S (α : Type) :=
|
||||
(vals : Array α) (sz : Nat)
|
||||
|
||||
@[noinline] def inc0 (a : Array Nat) : Array Nat :=
|
||||
a.modify 0 (+1)
|
||||
|
||||
set_option pp.implicit true
|
||||
-- set_option trace.compiler.boxed true
|
||||
|
||||
def f1 (s : S) : S :=
|
||||
def f1 (s : S Nat) : S Nat :=
|
||||
{ vals := inc0 s.vals, .. s}
|
||||
|
||||
def f2 : S → S
|
||||
def f2 : S Nat → S Nat
|
||||
| ⟨vals, sz⟩ := ⟨inc0 vals, sz⟩
|
||||
|
||||
def test (f : S → S) (n : Nat): IO Unit :=
|
||||
let s : S := { vals := mkArray (n*100) n, sz := n*100 } in
|
||||
let s := n.repeat f s in
|
||||
def test (f : S Nat → S Nat) (n : Nat): IO Unit :=
|
||||
let s : S Nat := { vals := mkArray (n*100) n, sz := n*100 } in
|
||||
let s := n.repeat f s in
|
||||
IO.println (s.vals.get 0)
|
||||
|
||||
def main (xs : List String) : IO Unit :=
|
||||
|
|
|
|||
|
|
@ -64,27 +64,22 @@ d.errorMsg != none
|
|||
d.stxStack.size
|
||||
|
||||
def ParserData.restore (d : ParserData) (iniStackSz : Nat) (iniPos : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, _, cache, _⟩ := ⟨stack.shrink iniStackSz, iniPos, cache, none⟩
|
||||
{ stxStack := d.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos, .. d}
|
||||
|
||||
def ParserData.setPos (d : ParserData) (pos : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, _, cache, msg⟩ := ⟨stack, pos, cache, msg⟩
|
||||
{ pos := pos, .. d }
|
||||
|
||||
def ParserData.setCache (d : ParserData) (cache : ParserCache) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, _, msg⟩ := ⟨stack, pos, cache, msg⟩
|
||||
{ cache := cache, .. d }
|
||||
|
||||
def ParserData.pushSyntax (d : ParserData) (n : Syntax) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, msg⟩ := ⟨stack.push n, pos, cache, msg⟩
|
||||
{ stxStack := d.stxStack.push n, .. d }
|
||||
|
||||
def ParserData.shrinkStack (d : ParserData) (iniStackSz : Nat) : ParserData :=
|
||||
match d with
|
||||
| ⟨stack, pos, cache, msg⟩ := ⟨stack.shrink iniStackSz, pos, cache, msg⟩
|
||||
{ stxStack := d.stxStack.shrink iniStackSz, .. d }
|
||||
|
||||
def ParserData.next (d : ParserData) (s : String) (pos : Nat) : ParserData :=
|
||||
d.setPos (s.next pos)
|
||||
{ pos := s.next pos, .. d }
|
||||
|
||||
def ParserData.toErrorMsg (d : ParserData) (cfg : ParserConfig) : String :=
|
||||
match d.errorMsg with
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue