diff --git a/src/library/compiler/emit_cpp.cpp b/src/library/compiler/emit_cpp.cpp index ccf69b18c8..d245a2f8dd 100644 --- a/src/library/compiler/emit_cpp.cpp +++ b/src/library/compiler/emit_cpp.cpp @@ -1083,7 +1083,10 @@ static void emit_initialize(std::ostream & out, environment const & env, module_ for (comp_decl const & d : ds) { name const & n = d.fst(); expr const & code = d.snd(); - if (!is_lambda(code)) { + if (is_io_unit_init_fn(env, n)) { + out << "w = " << to_cpp_name(env, n) << "(w);\n"; + out << "if (io_result_is_error(w)) return w;\n"; + } else if (!is_lambda(code)) { if (optional init_fn = get_init_fn_name_for(env, d.fst())) { out << "w = " << to_cpp_name(env, *init_fn) << "(w);\n"; out << "if (io_result_is_error(w)) return w;\n"; diff --git a/src/library/compiler/init_attribute.cpp b/src/library/compiler/init_attribute.cpp index 1f786d3c62..8ba975fb64 100644 --- a/src/library/compiler/init_attribute.cpp +++ b/src/library/compiler/init_attribute.cpp @@ -18,6 +18,8 @@ struct init_attr_data : public attr_data { virtual unsigned hash() const override { return m_init_fn.hash(); } virtual void parse(expr const & e) override { buffer args; get_app_args(e, args); + if (args.size() == 0) + return; if (args.size() != 1 || !is_const(extract_mdata(args[0]))) throw parser_error("constant expected", get_pos_info_provider()->get_pos_info_or_some(e)); m_init_fn = const_name(extract_mdata(args[0])); @@ -33,9 +35,19 @@ static init_attr const & get_init_attr() { return static_cast(get_system_attribute("init")); } +bool is_io_unit_init_fn(environment const & env, name const & n) { + if (auto const & data = get_init_attr().get(env, n)) + return data->m_init_fn.is_anonymous(); + else + return false; +} + optional get_init_fn_name_for(environment const & env, name const & n) { if (auto const & data = get_init_attr().get(env, n)) { - return optional(data->m_init_fn); + if (data->m_init_fn.is_anonymous()) + return optional(); + else + return optional(data->m_init_fn); } else { return optional(); } @@ -53,19 +65,28 @@ void initialize_init_attribute() { if (!persistent) throw exception("invalid [init] attribute, it must be persistent"); auto const & data = *get_init_attr().get(env, n); name init_fn = data.m_init_fn; - optional init_fn_info = env.find(init_fn); - if (!init_fn_info) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' not found"); - constant_info n_info = env.get(n); - if (!n_info.is_opaque()) throw exception(sstream() << "invalid [init] attribute, '" << n << "' must be a constant"); - expr type = n_info.get_type(); - expr init_fn_type = init_fn_info->get_type(); - optional io_arg_type = get_io_type_arg(init_fn_type); - if (!io_arg_type) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' must have type of the form 'io '"); - if (type != *io_arg_type) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' must have type of the form 'io ' " - << "where '' is the type of '" << n << "'"); - /* During code generation, we check whether constants tagged with the `[init]` attribute have arity 0. - We cannot perform this check here because attributes are registered before code generation. */ - return env; + if (init_fn.is_anonymous()) { + expr type = env.get(n).get_type(); + optional io_arg_type = get_io_type_arg(type); + if (!io_arg_type || !is_constant(*io_arg_type, get_unit_name())) + throw exception(sstream() << "invalid [init] attribute, initialization function '" << + n << "' must have type of the form 'IO Unit'"); + return env; + } else { + optional init_fn_info = env.find(init_fn); + if (!init_fn_info) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' not found"); + constant_info n_info = env.get(n); + if (!n_info.is_opaque()) throw exception(sstream() << "invalid [init] attribute, '" << n << "' must be a constant"); + expr type = n_info.get_type(); + expr init_fn_type = init_fn_info->get_type(); + optional io_arg_type = get_io_type_arg(init_fn_type); + if (!io_arg_type) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' must have type of the form 'IO '"); + if (type != *io_arg_type) throw exception(sstream() << "invalid [init] attribute, initialization function '" << init_fn << "' must have type of the form 'IO ' " + << "where '' is the type of '" << n << "'"); + /* During code generation, we check whether constants tagged with the `[init]` attribute have arity 0. + We cannot perform this check here because attributes are registered before code generation. */ + return env; + } })); } diff --git a/src/library/compiler/init_attribute.h b/src/library/compiler/init_attribute.h index 0941d9c7b3..2c4b4cab22 100644 --- a/src/library/compiler/init_attribute.h +++ b/src/library/compiler/init_attribute.h @@ -8,6 +8,7 @@ Authors: Leonardo de Moura #include "kernel/environment.h" namespace lean { +bool is_io_unit_init_fn(environment const & env, name const & n); optional get_init_fn_name_for(environment const & env, name const & n); inline bool has_init_attribute(environment const & env, name const & n) { return static_cast(get_init_fn_name_for(env, n)); diff --git a/tests/playground/opts.lean b/tests/playground/opts.lean index 74ef19b7a8..0c449a9f3d 100644 --- a/tests/playground/opts.lean +++ b/tests/playground/opts.lean @@ -2,25 +2,17 @@ import init.lean.options open Lean -def initRegopt1 : IO Unit := +@[init] def initRegopt1 : IO Unit := registerOption `myNatOpt {defValue := DataValue.ofNat 0, descr := "my Nat option" } -@[init initRegopt1] -constant regopt1 : Unit := default _ -def initRegopt2 : IO Unit := +@[init] def initRegopt2 : IO Unit := registerOption `myBoolOpt {defValue := DataValue.ofBool true, descr := "my Bool option" } -@[init initRegopt2] -constant regopt2 : Unit := default _ -def initRegopt3 : IO Unit := +@[init] def initRegopt3 : IO Unit := registerOption `myStringOpt {defValue := DataValue.ofString "", descr := "my String option" } -@[init initRegopt3] -constant regopt3 : Unit := default _ -def initRegopt4 : IO Unit := +@[init] def initRegopt4 : IO Unit := registerOption `myIntOpt {defValue := DataValue.ofInt 0, descr := "my Int option" } -@[init initRegopt4] -constant regopt4 : Unit := default _ def main : IO Unit :=