diff --git a/src/library/relation_manager.cpp b/src/library/relation_manager.cpp index f5ef3e642c..9b13d15a57 100644 --- a/src/library/relation_manager.cpp +++ b/src/library/relation_manager.cpp @@ -8,8 +8,10 @@ Author: Leonardo de Moura #include "util/optional.h" #include "util/name.h" #include "util/rb_map.h" +#include "util/sstream.h" #include "library/constants.h" #include "library/scoped_ext.h" +#include "library/relation_manager.h" namespace lean { // Check whether e is of the form (f ...) where f is a constant. If it is return f. @@ -48,12 +50,49 @@ struct rel_state { typedef name_map> subst_table; typedef name_map> symm_table; typedef rb_map, name_pair_quick_cmp> trans_table; + typedef name_map rop_table; trans_table m_trans_table; refl_table m_refl_table; subst_table m_subst_table; symm_table m_symm_table; + rop_table m_rop_table; rel_state() {} + bool is_equivalence(name const & rop) const { + return m_trans_table.contains(mk_pair(rop, rop)) && m_refl_table.contains(rop) && m_symm_table.contains(rop); + } + + static void throw_invalid_relation(name const & rop) { + throw exception(sstream() << "invalid binary relation declaration, relation '" << rop + << "' must have two explicit parameters"); + } + + void register_rop(environment const & env, name const & rop) { + if (m_rop_table.contains(rop)) + return; + declaration const & d = env.get(rop); + optional lhs_pos; + optional rhs_pos; + unsigned i = 0; + expr type = d.get_type(); + while (is_pi(type)) { + if (is_explicit(binding_info(type))) { + if (!lhs_pos) + lhs_pos = i; + else if (!rhs_pos) + rhs_pos = i; + else + throw_invalid_relation(rop); + } + type = binding_body(type); + } + if (lhs_pos && rhs_pos) { + m_rop_table.insert(rop, relation_info(i, *lhs_pos, *rhs_pos)); + } else { + throw_invalid_relation(rop); + } + } + void add_subst(environment const & env, name const & subst) { buffer arg_types; auto p = extract_arg_types_core(env, subst, arg_types); @@ -75,6 +114,7 @@ struct rel_state { if (nargs < 1) throw exception("invalid reflexivity rule, it must have at least 1 argument"); name const & rop = get_fn_const(r_type, "invalid reflexivity rule, result type must be an operator application"); + register_rop(env, rop); m_refl_table.insert(rop, std::make_tuple(refl, nargs, nunivs)); } @@ -87,6 +127,7 @@ struct rel_state { name const & rop = get_fn_const(r_type, "invalid transitivity rule, result type must be an operator application"); name const & op1 = get_fn_const(arg_types[nargs-2], "invalid transitivity rule, penultimate argument must be an operator application"); name const & op2 = get_fn_const(arg_types[nargs-1], "invalid transitivity rule, last argument must be an operator application"); + register_rop(env, rop); m_trans_table.insert(name_pair(op1, op2), std::make_tuple(trans, rop, nargs)); } @@ -99,6 +140,7 @@ struct rel_state { if (nargs < 1) throw exception("invalid symmetry rule, it must have at least 1 argument"); name const & rop = get_fn_const(r_type, "invalid symmetry rule, result type must be an operator application"); + register_rop(env, rop); m_symm_table.insert(rop, std::make_tuple(symm, nargs, nunivs)); } }; @@ -204,6 +246,14 @@ optional get_trans_info(environment const & env, name const & op) { return optional(); } +bool is_equivalence(environment const & env, name const & rop) { + return rel_ext::get_state(env).is_equivalence(rop); +} + +relation_info const * get_relation_info(environment const & env, name const & rop) { + return rel_ext::get_state(env).m_rop_table.find(rop); +} + void initialize_relation_manager() { g_rel_name = new name("rel"); g_key = new std::string("rel"); diff --git a/src/library/relation_manager.h b/src/library/relation_manager.h index d76243dca3..f1a6d14f16 100644 --- a/src/library/relation_manager.h +++ b/src/library/relation_manager.h @@ -8,6 +8,28 @@ Author: Leonardo de Moura #include "kernel/environment.h" namespace lean { +struct relation_info { + unsigned m_arity; + unsigned m_lhs_pos; + unsigned m_rhs_pos; +public: + relation_info() {} + relation_info(unsigned arity, unsigned lhs, unsigned rhs): + m_arity(arity), m_lhs_pos(lhs), m_rhs_pos(rhs) { + lean_assert(m_lhs_pos < m_arity); + lean_assert(m_rhs_pos < m_arity); + } + unsigned get_arity() const { return m_arity; } + unsigned get_lhs_pos() const { return m_lhs_pos; } + unsigned get_rhs_pos() const { return m_rhs_pos; } +}; + +/** \brief Return true if \c rop is a registered equivalence relation in the given manager */ +bool is_equivalence(environment const & env, name const & rop); + +/** \brief If \c rop is a registered relation, then return a non-null pointer to the associated information */ +relation_info const * get_relation_info(environment const & env, name const & rop); + environment add_subst(environment const & env, name const & n, bool persistent = true); environment add_refl(environment const & env, name const & n, bool persistent = true); environment add_symm(environment const & env, name const & n, bool persistent = true);