lean4-htt/src/library/unification_hint.cpp

214 lines
7.9 KiB
C++

/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#include <string>
#include "util/sexpr/format.h"
#include "kernel/expr.h"
#include "kernel/error_msgs.h"
#include "library/attribute_manager.h"
#include "library/constants.h"
#include "library/unification_hint.h"
#include "library/util.h"
#include "library/expr_lt.h"
#include "library/scoped_ext.h"
namespace lean {
/* Unification hints */
unification_hint::unification_hint(expr const & lhs, expr const & rhs, list<expr_pair> const & constraints, unsigned num_vars):
m_lhs(lhs), m_rhs(rhs), m_constraints(constraints), m_num_vars(num_vars) {}
int unification_hint_cmp::operator()(unification_hint const & uh1, unification_hint const & uh2) const {
if (uh1.get_lhs() != uh2.get_lhs()) {
return expr_quick_cmp()(uh1.get_lhs(), uh2.get_lhs());
} else if (uh1.get_rhs() != uh2.get_rhs()) {
return expr_quick_cmp()(uh1.get_rhs(), uh2.get_rhs());
} else {
auto it1 = uh1.get_constraints().begin();
auto it2 = uh2.get_constraints().begin();
auto end1 = uh1.get_constraints().end();
auto end2 = uh2.get_constraints().end();
for (; it1 != end1 && it2 != end2; ++it1, ++it2) {
if (unsigned cmp = expr_pair_quick_cmp()(*it1, *it2)) return cmp;
}
return 0;
}
}
/* Environment extension */
static std::string * g_key = nullptr;
struct unification_hint_state {
unification_hints m_hints;
name_map<unsigned> m_decl_names_to_prio; // Note: redundant but convenient
void validate_type(expr const & decl_type) {
expr type = decl_type;
while (is_pi(type)) type = binding_body(type);
if (!is_app_of(type, get_unification_hint_name(), 0)) {
throw exception("invalid unification hint, must return element of type `unification hint`");
}
}
void register_hint(name const & decl_name, expr const & value, unsigned priority) {
m_decl_names_to_prio.insert(decl_name, priority);
expr e_hint = value;
unsigned num_vars = 0;
while (is_lambda(e_hint)) {
e_hint = binding_body(e_hint);
num_vars++;
}
if (!is_app_of(e_hint, get_unification_hint_mk_name(), 2)) {
throw exception("invalid unification hint, body must be application of 'unification_hint.mk' to two arguments");
}
// e_hint := unification_hint.mk pattern constraints
expr e_pattern = app_arg(app_fn(e_hint));
expr e_constraints = app_arg(e_hint);
// pattern := unification_constraint.mk _ lhs rhs
expr e_pattern_lhs = app_arg(app_fn(e_pattern));
expr e_pattern_rhs = app_arg(e_pattern);
expr e_pattern_lhs_fn = get_app_fn(e_pattern_lhs);
expr e_pattern_rhs_fn = get_app_fn(e_pattern_rhs);
if (!is_constant(e_pattern_lhs_fn) || !is_constant(e_pattern_rhs_fn)) {
throw exception("invalid unification hint, the heads of both sides of pattern must be constants");
}
name_pair key = mk_pair(const_name(e_pattern_lhs_fn), const_name(e_pattern_rhs_fn));
buffer<expr_pair> constraints;
while (is_app_of(e_constraints, get_list_cons_name(), 3)) {
// e_constraints := cons _ constraint rest
expr e_constraint = app_arg(app_fn(e_constraints));
expr e_constraint_lhs = app_arg(app_fn(e_constraint));
expr e_constraint_rhs = app_arg(e_constraint);
constraints.push_back(mk_pair(e_constraint_lhs, e_constraint_rhs));
e_constraints = app_arg(e_constraints);
}
if (!is_app_of(e_constraints, get_list_nil_name(), 1)) {
throw exception("invalid unification hint, must provide list of constraints explicitly");
}
unification_hint hint(e_pattern_lhs, e_pattern_rhs, to_list(constraints), num_vars);
unification_hint_queue q;
if (auto const & q_ptr = m_hints.find(key)) q = *q_ptr;
q.insert(hint, priority);
m_hints.insert(key, q);
}
};
struct unification_hint_entry {
name m_decl_name;
unsigned m_priority;
unification_hint_entry(name const & decl_name, unsigned priority):
m_decl_name(decl_name), m_priority(priority) {}
};
struct unification_hint_config {
typedef unification_hint_entry entry;
typedef unification_hint_state state;
static void add_entry(environment const & env, io_state const &, state & s, entry const & e) {
declaration decl = env.get(e.m_decl_name);
s.validate_type(decl.get_type());
// Note: only definitions should be tagged as [unify], so if it is not a definition,
// there must have been an error when processing the definition. We return immediately
// so as not to hide the original error.
// TODO(dhs): the downside to this approach is that a [unify] tag on an actual axiom will be silently ignored.
if (decl.is_definition()) s.register_hint(e.m_decl_name, decl.get_value(), e.m_priority);
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_decl_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
name decl_name; unsigned prio;
d >> decl_name >> prio;
return entry(decl_name, prio);
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(hash(e.m_decl_name.hash(), e.m_priority));
}
};
typedef scoped_ext<unification_hint_config> unification_hint_ext;
environment add_unification_hint(environment const & env, io_state const & ios, name const & decl_name, unsigned prio,
bool persistent) {
return unification_hint_ext::add_entry(env, ios, unification_hint_entry(decl_name, prio), persistent);
}
unification_hints get_unification_hints(environment const & env) {
return unification_hint_ext::get_state(env).m_hints;
}
void get_unification_hints(environment const & env, name const & n1, name const & n2, buffer<unification_hint> & uhints) {
unification_hints hints = unification_hint_ext::get_state(env).m_hints;
if (auto const & q_ptr = hints.find(mk_pair(n1, n2))) {
q_ptr->to_buffer(uhints);
}
if (auto const & q_ptr = hints.find(mk_pair(n2, n1))) {
q_ptr->to_buffer(uhints);
}
}
/* Pretty-printing */
// TODO(dhs): I may not be using all the formatting functions correctly.
format unification_hint::pp(unsigned prio, formatter const & fmt) const {
format r;
if (prio != LEAN_DEFAULT_PRIORITY)
r += paren(format(prio)) + space();
format r1 = fmt(get_lhs()) + space() + format("=?=") + pp_indent_expr(fmt, get_rhs());
r1 += space() + lcurly();
r += group(r1);
bool first = true;
for (expr_pair const & p : m_constraints) {
if (first) {
first = false;
} else {
r += comma() + space();
}
r += fmt(p.first) + space() + format("=?=") + space() + fmt(p.second);
}
r += rcurly();
return r;
}
format pp_unification_hints(unification_hints const & hints, formatter const & fmt) {
format r;
r += format("unification hints") + colon() + line();
hints.for_each([&](name_pair const & names, unification_hint_queue const & q) {
q.for_each([&](unification_hint const & hint) {
r += lp() + format(names.first) + comma() + space() + format(names.second) + rp() + space();
r += hint.pp(*q.get_prio(hint), fmt) + line();
});
});
return r;
}
void initialize_unification_hint() {
g_key = new std::string("UNIFICATION_HINT");
unification_hint_ext::initialize();
register_system_attribute(basic_attribute("unify", "unification hint", add_unification_hint));
}
void finalize_unification_hint() {
unification_hint_ext::finalize();
delete g_key;
}
}