feat: allow attributes to be applied before elaboration

This is useful when the attribute may influence the elaboration of the declaration.
This commit is contained in:
Leonardo de Moura 2019-11-13 15:24:30 -08:00
parent 92316dff89
commit c4d974eb89
4 changed files with 41 additions and 139 deletions

View file

@ -10,7 +10,7 @@ import Init.Lean.Syntax
namespace Lean
inductive AttributeApplicationTime
| afterTypeChecking | afterCompilation
| afterTypeChecking | afterCompilation | beforeElaboration
structure AttributeImpl :=
(name : Name)

View file

@ -41,11 +41,15 @@ environment erase_attribute(environment const & env, name const & decl, name con
}
/*
inductive AttributeApplicationTime
| afterTypeChecking | afterCompilation
| afterTypeChecking | afterCompilation | beforeElaboration
*/
bool is_after_compilation_attribute(name const & n) {
return get_io_scalar_result<uint8>(lean_attribute_application_time(n.to_obj_arg(), io_mk_world())) == 1;
}
bool is_before_elaboration_attribute(name const & n) {
return get_io_scalar_result<uint8>(lean_attribute_application_time(n.to_obj_arg(), io_mk_world())) == 2;
}
// ==========================================
expr decl_attributes::parse_attr_arg(parser & p, name const & attr_id) {
@ -139,6 +143,8 @@ void decl_attributes::parse_core(parser & p, bool compact) {
}
if (is_after_compilation_attribute(id)) {
m_after_comp_entries = cons({id, deleted, scoped, args}, m_after_comp_entries);
} else if (is_before_elaboration_attribute(id)) {
m_before_elab_entries = cons({id, deleted, scoped, args}, m_before_elab_entries);
} else {
m_after_tc_entries = cons({id, deleted, scoped, args}, m_after_tc_entries);
}
@ -176,6 +182,8 @@ void decl_attributes::set_attribute(environment const & /* env */, name const &
syntax args(box(0));
if (is_after_compilation_attribute(attr_name)) {
m_after_comp_entries = cons({attr_name, false, false, args}, m_after_comp_entries);
} else if (is_before_elaboration_attribute(attr_name)) {
m_before_elab_entries = cons({attr_name, false, false, args}, m_before_elab_entries);
} else {
m_after_tc_entries = cons({attr_name, false, false, args}, m_after_tc_entries);
}
@ -194,7 +202,7 @@ bool decl_attributes::has_attribute(list<new_entry> const & entries, name const
bool decl_attributes::has_attribute(environment const & /* env */, name const & attr_name) const {
if (is_new_attribute(attr_name)) {
return has_attribute(m_after_tc_entries, attr_name) || has_attribute(m_after_comp_entries, attr_name);
return has_attribute(m_after_tc_entries, attr_name) || has_attribute(m_after_comp_entries, attr_name) || has_attribute(m_before_elab_entries, attr_name);
} else {
throw exception(sstream() << "unknown attribute [" << attr_name << "]");
}
@ -226,8 +234,13 @@ environment decl_attributes::apply_after_comp(environment env, name const & d) c
return apply_new_entries(env, m_after_comp_entries, d);
}
environment decl_attributes::apply_before_elab(environment env, io_state const & /* ios */, name const & d) const {
return apply_new_entries(env, m_before_elab_entries, d);
}
environment decl_attributes::apply_all(environment env, io_state const & ios, name const & d) const {
environment new_env = apply_after_tc(env, ios, d);
environment new_env = apply_before_elab(env, ios, d);
new_env = apply_after_tc(env, ios, d);
return apply_after_comp(new_env, d);
}

View file

@ -23,6 +23,7 @@ public:
};
private:
bool m_persistent;
list<new_entry> m_before_elab_entries;
list<new_entry> m_after_tc_entries;
list<new_entry> m_after_comp_entries;
void parse_core(parser & p, bool compact);
@ -38,6 +39,7 @@ public:
/* Parse attributes after `@[` ... ] */
void parse_compact(parser & p);
void set_attribute(environment const & env, name const & attr_name);
environment apply_before_elab(environment env, io_state const & ios, name const & d) const;
environment apply_after_tc(environment env, io_state const & ios, name const & d) const;
environment apply_after_comp(environment env, name const & d) const;
environment apply_all(environment env, io_state const & ios, name const & d) const;

View file

@ -115,46 +115,6 @@ void check_valid_end_of_equations(parser & p) {
}
}
static expr parse_mutual_definition(parser & p, buffer<name> & lp_names, buffer<expr> & fns, buffer<name> & prv_names,
buffer<expr> & params) {
parser::local_scope scope1(p);
auto header_pos = p.pos();
buffer<expr> pre_fns;
parse_mutual_header(p, lp_names, pre_fns, params);
buffer<expr> eqns;
buffer<name> full_names;
buffer<name> full_actual_names;
for (expr const & pre_fn : pre_fns) {
// TODO(leo, dhs): make use of attributes
expr fn_type = parse_inner_header(p, local_pp_name_p(pre_fn)).first;
declaration_name_scope scope2(local_pp_name_p(pre_fn));
declaration_name_scope scope3("_main");
full_names.push_back(scope3.get_name());
full_actual_names.push_back(scope3.get_actual_name());
prv_names.push_back(scope2.get_actual_name());
if (p.curr_is_token(get_period_tk())) {
auto period_pos = p.pos();
p.next();
eqns.push_back(p.save_pos(mk_no_equation(), period_pos));
} else {
while (p.curr_is_token(get_bar_tk())) {
eqns.push_back(parse_equation(p, pre_fn));
}
check_valid_end_of_equations(p);
}
expr fn = mk_local(local_name_p(pre_fn), local_pp_name_p(pre_fn), fn_type, mk_rec_info());
fns.push_back(fn);
}
if (p.curr_is_token(get_with_tk()))
p.maybe_throw_error({"unexpected 'with' clause", p.pos()});
optional<expr> wf_tacs = parse_using_well_founded(p);
for (expr & eq : eqns) {
eq = replace_locals_preserving_pos_info(eq, pre_fns, fns);
}
expr r = mk_equations(p, fns, full_names, full_actual_names, eqns, wf_tacs, header_pos);
return r;
}
static void finalize_definition(elaborator & elab, buffer<expr> const & params, expr & type,
expr & val, buffer<name> & lp_names) {
type = elab.mk_pi(params, type);
@ -184,16 +144,19 @@ static environment compile_decl(parser & p, environment const & env,
}
}
static pair<environment, name>
static name get_real_name(environment const & env, name const & c_name, name const & prv_name) {
if (is_private(prv_name)) {
return prv_name;
} else {
return get_namespace(env) + c_name;
}
}
environment
declare_definition(environment const & env, decl_cmd_kind kind, buffer<name> const & lp_names,
name const & c_name, name const & prv_name, expr type, optional<expr> val, cmd_meta const & meta) {
name c_real_name;
environment new_env = env;
if (is_private(prv_name)) {
c_real_name = prv_name;
} else {
c_real_name = get_namespace(env) + c_name;
}
name c_real_name = get_real_name(env, c_name, prv_name);
if (env.find(c_real_name)) {
throw exception(sstream() << "invalid definition, a declaration named '" << c_real_name << "' has already been declared");
}
@ -223,81 +186,7 @@ declare_definition(environment const & env, decl_cmd_kind kind, buffer<name> con
if (!meta.m_modifiers.m_is_private) {
new_env = ensure_decl_namespaces(new_env, c_real_name);
}
return mk_pair(new_env, c_real_name);
}
static environment elab_defs_core(parser & p, decl_cmd_kind kind, cmd_meta const & meta, buffer<name> lp_names, buffer<expr> const & fns,
buffer<name> const & prv_names, buffer<expr> params, expr val, pos_info const & header_pos) {
collect_implicit_locals(p, lp_names, params, val);
/* TODO(Leo): allow a different doc string for each function in a mutual definition. */
optional<std::string> doc_string = meta.m_doc_string;
environment env = p.env();
// skip elaboration of definitions during reparsing
if (p.get_break_at_pos())
return p.env();
bool recover_from_errors = true;
elaborator elab(env, p.get_options(), metavar_context(), local_context(), recover_from_errors);
buffer<expr> new_params;
elaborate_params(elab, params, new_params);
val = replace_locals_preserving_pos_info(val, params, new_params);
val = elab.elaborate(val);
if (!is_equations_result(val)) {
/* Failed to elaborate mutual recursion.
TODO(Leo): better error recovery. */
return p.env();
}
unsigned num_defs = get_equations_result_size(val);
lean_assert(num_defs == prv_names.size());
lean_assert(fns.size() == num_defs);
buffer<name> new_d_names;
/* Define functions */
for (unsigned i = 0; i < num_defs; i++) {
expr curr = get_equations_result(val, i);
expr curr_type = head_beta_reduce(elab.infer_type(curr));
finalize_definition(elab, new_params, curr_type, curr, lp_names);
environment env = elab.env();
name c_name = local_name_p(fns[i]);
name c_real_name;
std::tie(env, c_real_name) = declare_definition(env, kind, lp_names, c_name, prv_names[i],
curr_type, some_expr(curr), meta);
new_d_names.push_back(c_real_name);
elab.set_env(env);
}
/* Apply attributes last so that they may access any information on the new decl */
for (auto const & c_real_name : new_d_names) {
elab.set_env(meta.m_attrs.apply_after_tc(elab.env(), p.ios(), c_real_name));
}
/* Compile functions */
if (!meta.m_modifiers.m_is_noncomputable) {
lean_assert(new_d_names.size() == fns.size());
for (unsigned i = 0; i < fns.size(); i++) {
name c_name = local_name_p(fns[i]);
name c_real_name = new_d_names[i];
elab.set_env(compile_decl(p, elab.env(), c_name, c_real_name, header_pos));
}
}
/* Apply attributes last so that they may access any information on the new decl */
for (auto const & c_real_name : new_d_names) {
elab.set_env(meta.m_attrs.apply_after_comp(elab.env(), c_real_name));
}
return elab.env();
}
static environment mutual_definition_cmd_core(parser & p, decl_cmd_kind kind, cmd_meta const & meta) {
declaration_info_scope scope(p, kind, meta);
buffer<name> lp_names;
buffer<expr> fns, params;
/* TODO(Leo): allow a different doc string for each function in a mutual definition. */
optional<std::string> doc_string = meta.m_doc_string;
environment env = p.env();
private_name_scope prv_scope(meta.m_modifiers.m_is_private, env);
p.set_env(env);
buffer<name> prv_names;
auto header_pos = p.pos();
expr val = parse_mutual_definition(p, lp_names, fns, prv_names, params);
return elab_defs(p, kind, meta, lp_names, fns, prv_names, params, val, header_pos);
return new_env;
}
/* Return tuple (fn, val, actual_name) where
@ -524,8 +413,10 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
optional<expr> opt_val;
bool eqns = false;
name c_name = local_name_p(fn);
pair<environment, name> env_n;
environment new_env;
name c_real_name = get_real_name(elab.env(), c_name, prv_name);
if (kind == decl_cmd_kind::Theorem) {
elab.set_env(meta.m_attrs.apply_before_elab(elab.env(), p.ios(), c_real_name));
is_rfl = is_rfl_preexpr(val);
type = elab.elaborate_type(local_type_p(fn));
elab.ensure_no_unassigned_metavars(type);
@ -544,7 +435,7 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
opt_val = elaborate_proof(decl_env, opts, header_pos, new_params_list,
new_fn, val, thm_finfo, is_rfl, type,
mctx, lctx, pos_provider, use_info_manager, file_name);
env_n = declare_definition(elab.env(), kind, lp_names, c_name, prv_name, type, opt_val, meta);
new_env = declare_definition(elab.env(), kind, lp_names, c_name, prv_name, type, opt_val, meta);
} else if (kind == decl_cmd_kind::Example) {
auto env = p.env();
auto opts = p.get_options();
@ -561,6 +452,7 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
pos_provider, use_info_manager, file_name);
return p.env();
} else {
elab.set_env(meta.m_attrs.apply_before_elab(elab.env(), p.ios(), c_real_name));
std::tie(val, type) = elaborate_definition(p, elab, kind, fn, val, header_pos);
eqns = is_equations_result(val);
if (eqns) {
@ -569,11 +461,9 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
val = get_equations_result(val, 0);
}
finalize_definition(elab, new_params, type, val, lp_names);
env_n = declare_definition(elab.env(), kind, lp_names, c_name, prv_name, type, some_expr(val), meta);
new_env = declare_definition(elab.env(), kind, lp_names, c_name, prv_name, type, some_expr(val), meta);
}
time_task _("decl post-processing", p.mk_message(header_pos, INFORMATION), c_name);
environment new_env = env_n.first;
name c_real_name = env_n.second;
if (!meta.m_modifiers.m_is_unsafe && !meta.m_modifiers.m_is_partial &&
(kind == decl_cmd_kind::Definition || kind == decl_cmd_kind::Instance)) {
new_env = mk_smart_unfolding_definition(new_env, p.get_options(), c_real_name);
@ -619,13 +509,10 @@ static environment elab_single_def(parser & p, decl_cmd_kind const & kind, cmd_m
}
}
environment elab_defs(parser & p, decl_cmd_kind kind, cmd_meta const & meta, buffer <name> lp_names, buffer <expr> const & fns,
buffer <name> const & prv_names, buffer <expr> const & params, expr val, pos_info const & header_pos) {
if (meta.m_modifiers.m_is_mutual)
return elab_defs_core(p, kind, meta, lp_names, fns, prv_names,
params, val, header_pos);
else
return elab_single_def(p, kind, meta, lp_names, params, fns[0], val, header_pos, prv_names[0]);
environment elab_defs(parser & /* p */, decl_cmd_kind /* kind */, cmd_meta const & /* meta */, buffer <name> /* lp_names */,
buffer <expr> const & /* fns */, buffer <name> const & /* prv_names */,
buffer <expr> const & /* params */, expr /* val */, pos_info const & /* header_pos */) {
throw exception("lean_elaborator.cpp has been disabled");
}
environment single_definition_cmd_core(parser & p, decl_cmd_kind kind, cmd_meta meta) {
@ -652,7 +539,7 @@ environment single_definition_cmd_core(parser & p, decl_cmd_kind kind, cmd_meta
environment definition_cmd_core(parser & p, decl_cmd_kind kind, cmd_meta const & meta) {
if (meta.m_modifiers.m_is_mutual)
return mutual_definition_cmd_core(p, kind, meta);
throw exception("mutual definitions have been disabled");
else
return single_definition_cmd_core(p, kind, meta);
}