diff --git a/src/library/compiler/struct_cases_on.cpp b/src/library/compiler/struct_cases_on.cpp index fa8d322660..790faf7a9c 100644 --- a/src/library/compiler/struct_cases_on.cpp +++ b/src/library/compiler/struct_cases_on.cpp @@ -6,6 +6,7 @@ Author: Leonardo de Moura */ #include "runtime/flet.h" #include "kernel/instantiate.h" +#include "kernel/abstract.h" #include "kernel/type_checker.h" #include "library/trace.h" #include "library/suffixes.h" @@ -15,7 +16,9 @@ namespace lean { class struct_cases_on_fn { type_checker::state m_st; local_ctx m_lctx; - name_set m_scrutinies; + name_set m_scrutinies; /* Set of variables `x` such that there is `casesOn x ...` in the context */ + name_map m_first_proj; /* Map from variable `x` to the first projection `y := x.i` in the context */ + name_set m_updated; /* Set of variables `x` such that there is a `S.mk ... x.i ... */ name m_fld{"_d"}; unsigned m_next_idx{1}; @@ -29,6 +32,20 @@ class struct_cases_on_fn { return r; } + expr find(expr const & e) const { + if (is_fvar(e)) { + if (optional decl = m_lctx.find_local_decl(e)) { + if (optional v = decl->get_value()) { + if (!is_join_point_name(decl->get_user_name())) + return find(*v); + } + } + } else if (is_mdata(e)) { + return find(mdata_expr(e)); + } + return e; + } + expr visit_cases(expr const & e) { flet save(m_scrutinies, m_scrutinies); buffer args; @@ -45,6 +62,18 @@ class struct_cases_on_fn { expr visit_app(expr const & e) { if (is_cases_on_app(env(), e)) { return visit_cases(e); + } else if (is_constructor_app(env(), e)) { + buffer args; + expr const & k = get_app_args(e, args); + lean_assert(is_constant(k)); + constructor_val k_val = env().get(const_name(k)).to_constructor_val(); + for (unsigned i = k_val.get_nparams(), idx = 0; i < args.size(); i++, idx++) { + expr arg = find(args[i]); + if (is_proj(arg) && proj_idx(arg) == idx && is_fvar(proj_expr(arg))) { + m_updated.insert(fvar_name(proj_expr(arg))); + } + } + return e; } else { return e; } @@ -63,50 +92,16 @@ class struct_cases_on_fn { return m_lctx.mk_lambda(fvars, e); } - bool is_candidate(expr const & rhs) { - if (!is_proj(rhs)) return false; + /* Return `some s` if `rhs` is of the form `s.i`, and `s` is a free variables that has not been + scrutinized yet, and `s.i` is the first time it is being projected. */ + optional is_candidate(expr const & rhs) { + if (!is_proj(rhs)) return optional(); expr const & s = proj_expr(rhs); - if (!is_fvar(s)) return false; - if (m_scrutinies.contains(fvar_name(s))) return false; - return true; - } - - /* Return true iff `e` is a constructor application of inductive type `S_name` and containg `x`. */ - bool is_ctor_of(expr const & e, name const & S_name, expr const & x) { - if (!is_constructor_app(env(), e)) return false; - buffer args; - expr const & k = get_app_args(e, args); - lean_assert(is_constant(k)); - constructor_val k_val = env().get(const_name(k)).to_constructor_val(); - if (k_val.get_induct() != S_name) return false; - for (unsigned i = k_val.get_nparams(); i < args.size(); i++) { - if (args[i] == x) - return true; - } - return false; - } - - /* Return true iff `e` contains a constructor application of inductive type `S_name` and containg `x`. */ - bool has_ctor_with(expr e, name const & S_name, expr const & x) { - while (is_lambda(e)) { - e = binding_body(e); - } - while (is_let(e)) { - if (is_ctor_of(let_value(e), S_name, x)) - return true; - e = let_body(e); - } - if (is_cases_on_app(env(), e)) { - buffer args; - get_app_args(e, args); - for (unsigned i = 1; i < args.size(); i++) { - if (has_ctor_with(args[i], S_name, x)) - return true; - } - return false; - } else { - return is_ctor_of(e, S_name, x); - } + if (!is_fvar(s)) return optional(); + name const & s_name = fvar_name(s); + if (m_scrutinies.contains(s_name)) return optional(); + if (m_first_proj.contains(s_name)) return optional(); + return optional(s_name); } static void get_struct_field_types(type_checker::state & st, name const & S_name, buffer & result) { @@ -144,7 +139,17 @@ class struct_cases_on_fn { } } + bool should_add_cases_on(local_decl const & decl) { + expr val = *decl.get_value(); + if (!is_proj(val)) return false; + expr const & s = proj_expr(val); + if (!is_fvar(s) || !m_updated.contains(fvar_name(s))) return false; + name const * x = m_first_proj.find(fvar_name(s)); + return x && *x == decl.get_name(); + } + expr visit_let(expr e) { + flet> save(m_first_proj, m_first_proj); buffer fvars; while (is_let(e)) { lean_assert(!has_loose_bvars(let_type(e))); @@ -154,31 +159,37 @@ class struct_cases_on_fn { e = let_body(e); expr new_fvar = m_lctx.mk_local_decl(ngen(), n, type, val); fvars.push_back(new_fvar); - if (is_candidate(val)) { - lean_assert(is_proj(val)); - lean_assert(proj_idx(val).is_small()); - name const & S_name = proj_sname(val); - e = instantiate_rev(e, fvars.size(), fvars.data()); - if (has_ctor_with(e, S_name, new_fvar)) { - /* Introduce a casesOn application. */ - e = m_lctx.mk_lambda(new_fvar, e); - fvars.pop_back(); - expr major = proj_expr(val); - buffer field_types; - get_struct_field_types(m_st, S_name, field_types); - unsigned i = field_types.size(); - while (i > 0) { - --i; - e = mk_lambda(next_field_name(), field_types[i], e); - } - e = mk_app(mk_constant(name(S_name, g_cases_on)), major, e); - } - e = visit(e); - return m_lctx.mk_lambda(fvars, e); + if (optional s = is_candidate(val)) { + m_first_proj.insert(*s, fvar_name(new_fvar)); } } e = visit(instantiate_rev(e, fvars.size(), fvars.data())); - return m_lctx.mk_lambda(fvars, e); + e = abstract(e, fvars.size(), fvars.data()); + unsigned i = fvars.size(); + while (i > 0) { + --i; + expr const & x = fvars[i]; + lean_assert(is_fvar(x)); + local_decl decl = m_lctx.get_local_decl(x); + expr type = decl.get_type(); + expr val = *decl.get_value(); + expr aval = abstract(val, i, fvars.data()); + e = mk_let(decl.get_user_name(), type, aval, e); + if (should_add_cases_on(decl)) { + lean_assert(is_proj(val)); + expr major = proj_expr(val); + buffer field_types; + get_struct_field_types(m_st, proj_sname(val), field_types); + e = lift_loose_bvars(e, field_types.size()); + unsigned i = field_types.size(); + while (i > 0) { + --i; + e = mk_lambda(next_field_name(), field_types[i], e); + } + e = mk_app(mk_constant(name(proj_sname(val), g_cases_on)), major, e); + } + } + return e; } expr visit(expr const & e) { diff --git a/tests/playground/badupdate1.lean b/tests/playground/badupdate1.lean index 794a59363c..29dddac232 100644 --- a/tests/playground/badupdate1.lean +++ b/tests/playground/badupdate1.lean @@ -1,21 +1,18 @@ -structure S := -(vals : Array Nat) (sz : Nat) +structure S (α : Type) := +(vals : Array α) (sz : Nat) @[noinline] def inc0 (a : Array Nat) : Array Nat := a.modify 0 (+1) -set_option pp.implicit true --- set_option trace.compiler.boxed true - -def f1 (s : S) : S := +def f1 (s : S Nat) : S Nat := { vals := inc0 s.vals, .. s} -def f2 : S → S +def f2 : S Nat → S Nat | ⟨vals, sz⟩ := ⟨inc0 vals, sz⟩ -def test (f : S → S) (n : Nat): IO Unit := -let s : S := { vals := mkArray (n*100) n, sz := n*100 } in -let s := n.repeat f s in +def test (f : S Nat → S Nat) (n : Nat): IO Unit := +let s : S Nat := { vals := mkArray (n*100) n, sz := n*100 } in +let s := n.repeat f s in IO.println (s.vals.get 0) def main (xs : List String) : IO Unit := diff --git a/tests/playground/parser/parser.lean b/tests/playground/parser/parser.lean index 0ffc373b60..213c5137a3 100644 --- a/tests/playground/parser/parser.lean +++ b/tests/playground/parser/parser.lean @@ -64,27 +64,22 @@ d.errorMsg != none d.stxStack.size def ParserData.restore (d : ParserData) (iniStackSz : Nat) (iniPos : Nat) : ParserData := -match d with -| ⟨stack, _, cache, _⟩ := ⟨stack.shrink iniStackSz, iniPos, cache, none⟩ +{ stxStack := d.stxStack.shrink iniStackSz, errorMsg := none, pos := iniPos, .. d} def ParserData.setPos (d : ParserData) (pos : Nat) : ParserData := -match d with -| ⟨stack, _, cache, msg⟩ := ⟨stack, pos, cache, msg⟩ +{ pos := pos, .. d } def ParserData.setCache (d : ParserData) (cache : ParserCache) : ParserData := -match d with -| ⟨stack, pos, _, msg⟩ := ⟨stack, pos, cache, msg⟩ +{ cache := cache, .. d } def ParserData.pushSyntax (d : ParserData) (n : Syntax) : ParserData := -match d with -| ⟨stack, pos, cache, msg⟩ := ⟨stack.push n, pos, cache, msg⟩ +{ stxStack := d.stxStack.push n, .. d } def ParserData.shrinkStack (d : ParserData) (iniStackSz : Nat) : ParserData := -match d with -| ⟨stack, pos, cache, msg⟩ := ⟨stack.shrink iniStackSz, pos, cache, msg⟩ +{ stxStack := d.stxStack.shrink iniStackSz, .. d } def ParserData.next (d : ParserData) (s : String) (pos : Nat) : ParserData := -d.setPos (s.next pos) +{ pos := s.next pos, .. d } def ParserData.toErrorMsg (d : ParserData) (cfg : ParserConfig) : String := match d.errorMsg with