diff --git a/src/library/compiler/extract_closed.cpp b/src/library/compiler/extract_closed.cpp index ed3e10facc..9621438c7b 100644 --- a/src/library/compiler/extract_closed.cpp +++ b/src/library/compiler/extract_closed.cpp @@ -25,6 +25,7 @@ bool is_extract_closed_aux_fn(name const & n) { class extract_closed_fn { environment m_env; + comp_decls m_input_decls; name_generator m_ngen; local_ctx m_lctx; buffer m_new_decls; @@ -54,6 +55,13 @@ class extract_closed_fn { return e; } + bool in_current_mutual_block(name const & decl_name) { + for (auto d : m_input_decls) + if (d.fst() == decl_name) + return true; + return false; + } + bool is_closed(expr e) { switch (e.kind()) { case expr_kind::MVar: lean_unreachable(); @@ -61,7 +69,7 @@ class extract_closed_fn { case expr_kind::Sort: lean_unreachable(); case expr_kind::Lit: return true; case expr_kind::BVar: return true; - case expr_kind::Const: return true; + case expr_kind::Const: return !in_current_mutual_block(const_name(e)); case expr_kind::MData: return is_closed(mdata_expr(e)); case expr_kind::Proj: return is_closed(proj_expr(e)); default: @@ -277,8 +285,8 @@ class extract_closed_fn { } public: - extract_closed_fn(environment const & env): - m_env(env) { + extract_closed_fn(environment const & env, comp_decls const & ds): + m_env(env), m_input_decls(ds) { } pair operator()(comp_decl const & d) { @@ -299,15 +307,15 @@ public: } }; -pair extract_closed_core(environment const & env, comp_decl const & d) { - return extract_closed_fn(env)(d); +pair extract_closed_core(environment const & env, comp_decls const & input_ds, comp_decl const & d) { + return extract_closed_fn(env, input_ds)(d); } pair extract_closed(environment env, comp_decls const & ds) { comp_decls r; for (comp_decl const & d : ds) { comp_decls new_ds; - std::tie(env, new_ds) = extract_closed_core(env, d); + std::tie(env, new_ds) = extract_closed_core(env, ds, d); r = append(r, new_ds); } return mk_pair(env, r); diff --git a/tests/compiler/extractClosedMutualBlock.lean b/tests/compiler/extractClosedMutualBlock.lean new file mode 100644 index 0000000000..1af9b8b333 --- /dev/null +++ b/tests/compiler/extractClosedMutualBlock.lean @@ -0,0 +1,20 @@ +inductive Foo +| mk: (Int -> Foo) -> Foo +| terminal: Foo +deriving Inhabited + +mutual + partial def even (_: Unit) : Foo := + Foo.mk (fun i => odd () ) + partial def odd (_: Unit) : Foo := + Foo.mk (fun i => even ()) +end + +def hasLayer (f: Foo) : Bool := + match f with + | Foo.mk _ => true + | Foo.terminal => false + +def main : IO Unit := do + IO.println (if hasLayer (odd ()) then "LAYER" else "TERMINAL") + return () diff --git a/tests/compiler/extractClosedMutualBlock.lean.expected.out b/tests/compiler/extractClosedMutualBlock.lean.expected.out new file mode 100644 index 0000000000..c60a1a7572 --- /dev/null +++ b/tests/compiler/extractClosedMutualBlock.lean.expected.out @@ -0,0 +1 @@ +LAYER