feat(library/type_context): initialize type class resolution

This commit is contained in:
Leonardo de Moura 2016-03-27 13:41:07 -07:00
parent c5122223e1
commit ee27480210
5 changed files with 174 additions and 21 deletions

View file

@ -107,14 +107,18 @@ expr local_context::mk_local_decl(name const & ppn, expr const & type, expr cons
return mk_local_decl(mk_local_decl_name(), ppn, type, some_expr(value), binder_info());
}
optional<local_decl> local_context::get_local_decl(expr const & e) const {
lean_assert(is_local_decl_ref(e));
if (auto r = m_name2local_decl.find(mlocal_name(e)))
optional<local_decl> local_context::get_local_decl(name const & n) const {
if (auto r = m_name2local_decl.find(n))
return optional<local_decl>(*r);
else
return optional<local_decl>();
}
optional<local_decl> local_context::get_local_decl(expr const & e) const {
lean_assert(is_local_decl_ref(e));
return get_local_decl(mlocal_name(e));
}
void local_context::for_each(std::function<void(local_decl const &)> const & fn) const {
m_idx2local_decl.for_each([&](unsigned, local_decl const & d) { fn(d); });
}

View file

@ -87,6 +87,7 @@ public:
/** \brief Return the local declarations for the given reference.
\pre is_local_decl_ref(e) */
optional<local_decl> get_local_decl(expr const & e) const;
optional<local_decl> get_local_decl(name const & n) const;
/** \brief Traverse local declarations based on the order they were created */
void for_each(std::function<void(local_decl const &)> const & fn) const;
optional<local_decl> find_if(std::function<bool(local_decl const &)> const & pred) const; // NOLINT

View file

@ -40,13 +40,8 @@ static name * g_internal_prefix = nullptr;
static name * g_class_instance_max_depth = nullptr;
static name * g_class_trans_instances = nullptr;
unsigned get_class_instance_max_depth(options const & o) {
return o.get_unsigned(*g_class_instance_max_depth, LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH);
}
bool get_class_trans_instances(options const & o) {
return o.get_bool(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES);
}
unsigned get_class_instance_max_depth(options const & o);
bool get_class_trans_instances(options const & o);
old_type_context::old_type_context(environment const & env, options const & o, bool multiple_instances):
m_env(env),

View file

@ -34,6 +34,14 @@ namespace lean {
static name * g_class_instance_max_depth = nullptr;
static name * g_class_trans_instances = nullptr;
unsigned get_class_instance_max_depth(options const & o) {
return o.get_unsigned(*g_class_instance_max_depth, LEAN_DEFAULT_CLASS_INSTANCE_MAX_DEPTH);
}
bool get_class_trans_instances(options const & o) {
return o.get_bool(*g_class_trans_instances, LEAN_DEFAULT_CLASS_TRANS_INSTANCES);
}
/* =====================
type_context_cache
===================== */
@ -44,7 +52,8 @@ type_context_cache::type_context_cache(environment const & env, options const &
m_proj_info(get_projection_info_map(env)),
m_frozen_mode(false),
m_local_instances_initialized(false) {
m_ci_max_depth = 12; // TODO(Leo): fix
m_ci_max_depth = get_class_instance_max_depth(opts);
m_ci_trans_instances = get_class_trans_instances(opts);
}
bool type_context_cache::is_transparent(transparency_mode m, declaration const & d) {
@ -86,6 +95,47 @@ bool type_context_cache::should_unfold_macro(expr const &) {
return true;
}
static void collect_local_decls(expr const & e, buffer<name> & r, name_set & s) {
for_each(e, [&](expr const & e, unsigned) {
if (is_local_decl_ref(e)) {
name const & n = mlocal_name(e);
if (!s.contains(n)) {
r.push_back(n);
s.insert(n);
}
}
return true;
});
}
local_context type_context_cache::freeze_local_instances(metavar_context & mctx, local_context const & lctx) {
lean_assert(!m_frozen_mode);
type_context ctx(mctx, lctx, *this);
m_instance_cache.clear();
m_local_instances.clear();
buffer<name> to_freeze;
name_set to_freeze_set;
lctx.for_each([&](local_decl const & decl) {
if (auto cls_name = ctx.is_class(decl.get_type())) {
m_local_instances.emplace_back(*cls_name, decl.mk_ref());
to_freeze.push_back(decl.get_name());
to_freeze_set.insert(decl.get_name());
}
});
local_context new_lctx = lctx;
for (unsigned i = 0; i < to_freeze.size(); i++) {
new_lctx.freeze(to_freeze[i]);
/* freeze dependencies */
if (auto decl = lctx.get_local_decl(to_freeze[i])) {
collect_local_decls(decl->get_type(), to_freeze, to_freeze_set);
if (auto v = decl->get_value())
collect_local_decls(*v, to_freeze, to_freeze_set);
}
}
m_frozen_mode = true;
return new_lctx;
}
/* =====================
type_context::tmp_locals
===================== */
@ -99,11 +149,16 @@ type_context::tmp_locals::~tmp_locals() {
===================== */
void type_context::init_core(transparency_mode m) {
m_used_assignment = false;
m_transparency_mode = m;
m_in_is_def_eq = false;
m_is_def_eq_depth = 0;
m_tmp_mode = false;
m_used_assignment = false;
m_transparency_mode = m;
m_in_is_def_eq = false;
m_is_def_eq_depth = 0;
m_tmp_mode = false;
m_cache->m_init_local_context = m_lctx;
if (!m_cache->m_frozen_mode) {
/* default type class resolution mode */
m_cache->m_local_instances_initialized = false;
}
}
type_context::type_context(metavar_context & mctx, local_context const & lctx, type_context_cache & cache,
@ -1752,6 +1807,10 @@ bool type_context::try_unification_hints(expr const & e1, expr const & e2) {
return false;
}
/* -------------
Type classes
------------- */
/** \brief If the constant \c e is a class, return its name */
optional<name> type_context::constant_is_class(expr const & e) {
name const & cls_name = const_name(e);
@ -1832,6 +1891,89 @@ optional<name> type_context::is_class(expr const & type) {
return is_full_class(type);
}
bool type_context::compatible_local_instances(bool frozen_only) {
unsigned i = 0;
bool failed = false;
m_cache->m_init_local_context.for_each([&](local_decl const & decl) {
if (failed) return;
if (frozen_only && !m_cache->m_init_local_context.is_frozen(decl.get_name()))
return;
if (auto cname = is_class(decl.get_type())) {
if (i == m_cache->m_local_instances.size()) {
/* initial local context has more local instances than the ones cached at found m_local_instances */
failed = true;
return;
}
if (decl.get_name() != mlocal_name(m_cache->m_local_instances[i].second)) {
/* local instance in initial local constext is not compatible with the one cached at m_local_instances */
failed = true;
return;
}
i++;
}
});
return i == m_cache->m_local_instances.size();
}
void type_context::set_local_instances() {
m_cache->m_instance_cache.clear();
m_cache->m_local_instances.clear();
m_cache->m_init_local_context.for_each([&](local_decl const & decl) {
if (auto cls_name = is_class(decl.get_type())) {
m_cache->m_local_instances.emplace_back(*cls_name, decl.mk_ref());
}
});
}
void type_context::init_local_instances() {
if (m_cache->m_frozen_mode) {
lean_assert(m_cache->m_local_instances_initialized);
/* Check if the local instances are really compatible.
See comment at type_context_cache. */
lean_cond_assert("type_context", compatible_local_instances(true));
} else if (!m_cache->m_local_instances_initialized) {
/* default type class resolution mode */
bool frozen_only = false;
if (!compatible_local_instances(frozen_only)) {
set_local_instances();
}
m_cache->m_local_instances_initialized = true;
}
lean_assert(m_cache->m_local_instances_initialized);
}
struct instance_synthesizer {
struct stack_entry {
/* We only use transitive instances when we can solve the problem in a single step.
That is, the transitive instance does not have any instance argument, OR
it uses local instances to fill them.
We accomplish that by not considering global instances when solving
transitive instance subproblems. */
expr m_mvar;
unsigned m_depth;
bool m_trans_inst_subproblem;
stack_entry(expr const & m, unsigned d, bool s = false):
m_mvar(m), m_depth(d), m_trans_inst_subproblem(s) {}
};
struct state {
bool m_trans_inst_subproblem;
list<stack_entry> m_stack; // stack of meta-variables that need to be synthesized;
};
struct choice {
list<expr> m_local_instances;
list<name> m_trans_instances;
list<name> m_instances;
state m_state;
};
type_context & m_ctx;
expr m_main_mvar;
state m_state; // active state
std::vector<choice> m_choices;
};
void initialize_type_context() {
register_trace_class("class_instances");

View file

@ -65,13 +65,12 @@ class type_context_cache {
we do nothing. Otherwise, we reset m_local_instances with the new local_instances, and
reset the cache m_local_instances.
When frozen mode is set, we reset m_local_instances_initialized, the instance cache,
and the vector local_instances. Then, whenever a type_context object is created
(and debugging code is enabled) we store a copy of the initial local context.
When frozen mode is set, we reset m_local_instances_initialized.
Then, whenever a type_context object is created we store a copy of the initial local context.
Then, whenever type class resolution is invoked and m_local_instances_initialized is false,
we copy the set of frozen local_decls instances to m_local_instances.
If m_local_instances_initialized is true, and we are in debug mode, then
we check if the froze local_decls instances in the initial local context are indeed
we check if the frozen local_decls instances in the initial local context are indeed
equal to the ones in m_local_instances. If they are not, it is an assertion violation.
We use the same cache policy for m_subsingleton_cache. */
@ -86,7 +85,7 @@ class type_context_cache {
/* Maximum search depth when performing type class resolution. */
unsigned m_ci_max_depth;
bool m_ci_trans_instances;
friend class type_context;
void init(local_context const & lctx);
@ -95,6 +94,15 @@ class type_context_cache {
bool should_unfold_macro(expr const & e);
public:
type_context_cache(environment const & env, options const & opts);
/* Enable frozen mode for type class resolution, and free local instances.
Local declarations used by the local instances are also frozen.
This method returns a new local_context where the local decls have been marked as frozen.
\pre !frozen_mode() */
local_context freeze_local_instances(metavar_context & mctx, local_context const & lctx);
bool frozen_mode() const { return m_frozen_mode; }
};
class type_context : public abstract_type_context {
@ -324,6 +332,9 @@ private:
optional<name> constant_is_class(expr const & e);
optional<name> is_full_class(expr type);
lbool is_quick_class(expr const & type, name & result);
bool compatible_local_instances(bool frozen_only);
void set_local_instances();
void init_local_instances();
public:
/* Helper class for creating pushing local declarations into the local context m_lctx */