feat(library/tactic/defeq_simplifier): reimplement defeq simp lemma cache

This commit is contained in:
Leonardo de Moura 2016-09-02 09:10:09 -07:00
parent 02316c39b8
commit 0afef31be6
4 changed files with 162 additions and 109 deletions

View file

@ -479,7 +479,7 @@ static void print_unification_hints(parser & p) {
static void print_defeq_lemmas(parser & p) {
type_checker tc(p.env());
auto out = regular(p.env(), p.ios(), tc);
out << pp_defeq_simp_lemmas(get_defeq_simp_lemmas(p.env()), out.get_formatter());
out << pp_defeq_simp_lemmas(*get_defeq_simp_lemmas(p.env()), out.get_formatter());
}
static void print_simp_rules(parser & p) {

View file

@ -1,14 +1,16 @@
/*
Copyright (c) 2015 Daniel Selsam. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
Author: Daniel Selsam, Leonardo de Moura
*/
#include <vector>
#include <string>
#include "util/sexpr/format.h"
#include "kernel/expr.h"
#include "kernel/error_msgs.h"
#include "kernel/instantiate.h"
#include "library/attribute_manager.h"
#include "library/trace.h"
#include "library/constants.h"
#include "library/util.h"
#include "library/scoped_ext.h"
@ -16,9 +18,6 @@ Author: Daniel Selsam
#include "library/tactic/defeq_simplifier/defeq_simp_lemmas.h"
namespace lean {
/* defeq simp lemmas */
defeq_simp_lemma::defeq_simp_lemma(name const & id, levels const & umetas, list<expr> const & emetas,
list<bool> const & instances, expr const & lhs, expr const & rhs, unsigned priority):
m_id(id), m_umetas(umetas), m_emetas(emetas), m_instances(instances), m_lhs(lhs), m_rhs(rhs), m_priority(priority) {}
@ -27,98 +26,150 @@ bool operator==(defeq_simp_lemma const & sl1, defeq_simp_lemma const & sl2) {
return sl1.get_lhs() == sl2.get_lhs() && sl1.get_rhs() == sl2.get_rhs();
}
/* Environment extension */
static std::string * g_key = nullptr;
struct defeq_simp_lemmas_state {
defeq_simp_lemmas m_defeq_simp_lemmas;
void register_defeq_simp_lemma(type_context & tctx, name const & decl_name, unsigned priority) {
declaration const & d = tctx.env().get(decl_name);
// TODO(dhs): once we refactor to register this attribute as "definitions-only", this can be an assert
if (!d.is_definition()) {
throw exception("invalid [defeq] simp lemma: must be a definition so that definitional equality can be verified");
}
buffer<level> us;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
us.push_back(tctx.mk_tmp_univ_mvar());
}
levels ls = to_list(us);
expr type = instantiate_type_univ_params(d, ls);
expr proof = instantiate_value_univ_params(d, ls);
return register_defeq_simp_lemma_core(tctx, decl_name, ls, type, proof, priority);
}
void register_defeq_simp_lemma_core(type_context & tctx, name const & decl_name, levels const & umetas,
expr const & type, expr const & proof, unsigned priority) {
expr rule = type;
expr pf = proof;
buffer<expr> emetas;
buffer<bool> instances;
while (is_pi(rule)) {
expr mvar = tctx.mk_tmp_mvar(binding_domain(rule));
emetas.push_back(mvar);
instances.push_back(binding_info(rule).is_inst_implicit());
rule = instantiate(binding_body(rule), mvar);
pf = binding_body(pf);
}
expr lhs, rhs;
if (!is_eq(rule, lhs, rhs))
throw exception("invalid [defeq] simp lemma: body must be an equality");
if (!is_app_of(pf, get_eq_refl_name(), 2) && !is_app_of(pf, get_rfl_name(), 2))
throw exception("invalid [defeq] simp lemma: proof must be an application of either 'eq.refl' or 'rfl' to two arguments");
if (lhs == rhs)
throw exception("invalid [defeq] simp lemma: the two sides of the equality cannot be structurally equal");
defeq_simp_lemma lemma(decl_name, umetas, to_list(emetas), to_list(instances), lhs, rhs, priority);
m_defeq_simp_lemmas.insert(lhs, lemma);
}
};
struct defeq_simp_lemmas_entry {
name m_decl_name;
unsigned m_priority;
defeq_simp_lemmas_entry(name const & decl_name, unsigned priority):
m_decl_name(decl_name), m_priority(priority) {}
};
struct defeq_simp_lemmas_config {
typedef defeq_simp_lemmas_entry entry;
typedef defeq_simp_lemmas_state state;
static void add_entry(environment const & env, io_state const & ios, state & s, entry const & e) {
type_context tctx(env, ios.get_options());
type_context::tmp_mode_scope scope(tctx);
s.register_defeq_simp_lemma(tctx, e.m_decl_name, e.m_priority);
}
static std::string const & get_serialization_key() {
return *g_key;
}
static void write_entry(serializer & s, entry const & e) {
s << e.m_decl_name << e.m_priority;
}
static entry read_entry(deserializer & d) {
name decl_name; unsigned prio;
d >> decl_name >> prio;
return entry(decl_name, prio);
}
static optional<unsigned> get_fingerprint(entry const & e) {
return some(hash(e.m_decl_name.hash(), e.m_priority));
}
};
typedef scoped_ext<defeq_simp_lemmas_config> defeq_simp_lemmas_ext;
environment add_defeq_simp_lemma(environment const & env, io_state const & ios, name const & decl_name, unsigned prio, bool persistent) {
return defeq_simp_lemmas_ext::add_entry(env, ios, defeq_simp_lemmas_entry(decl_name, prio), persistent);
static void throw_non_rfl_proof() {
throw exception("invalid [defeq] simp lemma: proof must be an application of either 'eq.refl' or 'rfl'");
}
defeq_simp_lemmas get_defeq_simp_lemmas(environment const & env) {
return defeq_simp_lemmas_ext::get_state(env).m_defeq_simp_lemmas;
static void on_add_defeq_simp_lemma(environment const & env, name const & decl_name, bool) {
type_context ctx(env);
declaration const & d = env.get(decl_name);
if (!d.is_definition()) {
throw exception("invalid [defeq] simp lemma: must be a definition so that definitional equality can be verified");
}
expr type = d.get_type();
expr pf = d.get_value();
while (is_pi(type)) {
if (!is_lambda(pf))
throw_non_rfl_proof();
pf = binding_body(pf);
type = binding_body(type);
}
expr lhs, rhs;
if (!is_eq(type, lhs, rhs))
throw exception("invalid [defeq] simp lemma: body must be an equality");
if (!is_app_of(pf, get_eq_refl_name(), 2) && !is_app_of(pf, get_rfl_name(), 2))
throw_non_rfl_proof();
if (lhs == rhs)
throw exception("invalid [defeq] simp lemma: the two sides of the equality cannot be structurally equal");
}
/* Pretty printing */
static void add_lemma_core(tmp_type_context & ctx, name const & decl_name, unsigned priority,
defeq_simp_lemmas_ptr & result) {
environment const & env = ctx.env();
declaration const & d = env.get(decl_name);
lean_assert(d.is_definition());
buffer<level> umetas;
unsigned num_univs = d.get_num_univ_params();
for (unsigned i = 0; i < num_univs; i++) {
umetas.push_back(ctx.mk_tmp_univ_mvar());
}
levels ls = to_list(umetas);
expr type = instantiate_type_univ_params(d, ls);
buffer<expr> emetas;
buffer<bool> instances;
while (is_pi(type)) {
expr mvar = ctx.mk_tmp_mvar(binding_domain(type));
emetas.push_back(mvar);
instances.push_back(binding_info(type).is_inst_implicit());
type = instantiate(binding_body(type), mvar);
}
expr lhs, rhs;
lean_verify(is_eq(type, lhs, rhs));
defeq_simp_lemma lemma(decl_name, to_list(umetas), to_list(emetas), to_list(instances), lhs, rhs, priority);
result->insert(lhs, lemma);
}
static defeq_simp_lemmas_ptr get_defeq_simp_lemmas_from_attribute(type_context & ctx, name const & attr_name) {
environment const & env = ctx.env();
auto const & attr = get_attribute(env, attr_name);
buffer<name> lemmas;
attr.get_instances(env, lemmas);
defeq_simp_lemmas_ptr result = std::make_shared<defeq_simp_lemmas>();
unsigned i = lemmas.size();
while (i > 0) {
i--;
name const & id = lemmas[i];
unsigned prio = attr.get_prio(env, id);
tmp_type_context tmp_ctx(ctx);
add_lemma_core(tmp_ctx, id, prio, result);
}
return result;
}
static defeq_simp_lemmas_ptr get_defeq_simp_lemmas_from_attribute(environment const & env, name const & attr_name) {
type_context ctx(env);
return get_defeq_simp_lemmas_from_attribute(ctx, attr_name);
}
static std::vector<name> * g_defeq_simp_attributes = nullptr;
static defeq_lemmas_token g_default_token;
defeq_lemmas_token register_defeq_simp_attribute(name const & attr_name) {
unsigned r = g_defeq_simp_attributes->size();
g_defeq_simp_attributes->push_back(attr_name);
register_system_attribute(basic_attribute::with_check(attr_name, "[defeq] simplification lemma",
on_add_defeq_simp_lemma));
return r;
}
class defeq_simp_lemmas_cache {
struct entry {
environment m_env;
name m_attr_name;
unsigned m_attr_fingerprint;
defeq_simp_lemmas_ptr m_lemmas;
entry(environment const & env, name const & attr_name):
m_env(env), m_attr_name(attr_name), m_attr_fingerprint(0) {}
};
std::vector<entry> m_entries;
public:
void expand(environment const & env, unsigned new_sz) {
for (unsigned i = m_entries.size(); i < new_sz; i++) {
m_entries.emplace_back(env, (*g_defeq_simp_attributes)[i]);
}
}
defeq_simp_lemmas_ptr mk_lemmas(environment const & env, entry & C) {
lean_trace("defeq_simp_lemmas_cache", tout() << "make defeq simp lemmas [" << C.m_attr_name << "]\n";);
C.m_env = env;
C.m_lemmas = get_defeq_simp_lemmas_from_attribute(env, C.m_attr_name);
C.m_attr_fingerprint = get_attribute_fingerprint(env, C.m_attr_name);
return C.m_lemmas;
}
defeq_simp_lemmas_ptr lemmas_of(entry const & C) {
lean_trace("defeq_simp_lemmas_cache", tout() << "reusing cached defeq simp lemmas [" << C.m_attr_name << "]\n";);
return C.m_lemmas;
}
defeq_simp_lemmas_ptr get(environment const & env, defeq_lemmas_token token) {
lean_assert(token < g_defeq_simp_attributes->size());
if (token >= m_entries.size()) expand(env, token+1);
lean_assert(token < m_entries.size());
entry & C = m_entries[token];
if (!C.m_lemmas) return mk_lemmas(env, C);
if (is_eqp(env, C.m_env)) return lemmas_of(C);
if (!env.is_descendant(C.m_env) ||
get_attribute_fingerprint(env, C.m_attr_name) != C.m_attr_fingerprint) {
lean_trace("defeq_simp_lemmas_cache",
bool c = (get_attribute_fingerprint(env, C.m_attr_name) == C.m_attr_fingerprint);
tout() << "creating new cache, is_descendant: " << env.is_descendant(C.m_env)
<< ", attribute fingerprint compatibility: " << c << "\n";);
return mk_lemmas(env, C);
}
return lemmas_of(C);
}
};
MK_THREAD_LOCAL_GET_DEF(defeq_simp_lemmas_cache, get_cache);
defeq_simp_lemmas_ptr get_defeq_simp_lemmas(environment const & env) {
return get_cache().get(env, g_default_token);
}
defeq_simp_lemmas_ptr get_defeq_simp_lemmas(environment const & env, defeq_lemmas_token token) {
return get_cache().get(env, token);
}
format defeq_simp_lemma::pp(formatter const & fmt) const {
format r;
@ -140,19 +191,13 @@ format pp_defeq_simp_lemmas(defeq_simp_lemmas const & lemmas, formatter const &
return r;
}
/* Setup and teardown */
void initialize_defeq_simp_lemmas() {
g_key = new std::string("DEFEQ_SIMP_LEMMAS");
defeq_simp_lemmas_ext::initialize();
register_system_attribute(basic_attribute("defeq", "defeq simp lemma", add_defeq_simp_lemma));
register_trace_class("defeq_simp_lemmas_cache");
g_defeq_simp_attributes = new std::vector<name>();
g_default_token = register_defeq_simp_attribute("defeq");
}
void finalize_defeq_simp_lemmas() {
defeq_simp_lemmas_ext::finalize();
delete g_key;
delete g_defeq_simp_attributes;
}
}

View file

@ -4,8 +4,8 @@ Released under Apache 2.0 license as described in the file LICENSE.
Author: Daniel Selsam
*/
#pragma once
#include <memory>
#include "kernel/environment.h"
#include "library/io_state.h"
#include "library/head_map.h"
namespace lean {
@ -48,7 +48,15 @@ inline bool operator!=(defeq_simp_lemma const & sl1, defeq_simp_lemma const & sl
struct defeq_simp_lemma_prio_fn { unsigned operator()(defeq_simp_lemma const & sl) const { return sl.get_priority(); } };
typedef head_map_prio<defeq_simp_lemma, defeq_simp_lemma_prio_fn> defeq_simp_lemmas;
defeq_simp_lemmas get_defeq_simp_lemmas(environment const & env);
typedef unsigned defeq_lemmas_token;
/* Register a system level defeq attribute. This method must only be invoked during initialization.
It returns an internal "token" for retrieving the lemmas */
defeq_lemmas_token register_defeq_simp_attribute(name const & attr_name);
typedef std::shared_ptr<defeq_simp_lemmas> defeq_simp_lemmas_ptr;
defeq_simp_lemmas_ptr get_defeq_simp_lemmas(environment const & env);
defeq_simp_lemmas_ptr get_defeq_simp_lemmas(environment const & env, defeq_lemmas_token token);
format pp_defeq_simp_lemmas(defeq_simp_lemmas const & lemmas, formatter const & fmt);

View file

@ -354,15 +354,15 @@ vm_obj tactic_defeq_simp(vm_obj const & m, vm_obj const & e, vm_obj const & s0)
type_context ctx = mk_type_context_for(s0, m);
tactic_state const & s = to_tactic_state(s0);
LEAN_TACTIC_TRY;
defeq_simp_lemmas lemmas = get_defeq_simp_lemmas(s.env());
expr new_e = defeq_simplify(ctx, lemmas, to_expr(e));
defeq_simp_lemmas_ptr lemmas = get_defeq_simp_lemmas(s.env());
expr new_e = defeq_simplify(ctx, *lemmas, to_expr(e));
return mk_tactic_success(to_obj(new_e), s);
LEAN_TACTIC_CATCH(s);
}
expr defeq_simplify(type_context & ctx, expr const & e) {
defeq_simp_lemmas lemmas = get_defeq_simp_lemmas(ctx.env());
return defeq_simplify(ctx, lemmas, e);
defeq_simp_lemmas_ptr lemmas = get_defeq_simp_lemmas(ctx.env());
return defeq_simplify(ctx, *lemmas, e);
}
/* Setup and teardown */