diff --git a/src/library/arith_instance.cpp b/src/library/arith_instance.cpp index ba88c324ee..2c4b55e9dd 100644 --- a/src/library/arith_instance.cpp +++ b/src/library/arith_instance.cpp @@ -21,7 +21,11 @@ arith_instance::arith_instance(type_context & ctx, expr const & type, level cons arith_instance::arith_instance(type_context & ctx, expr const & type): m_ctx(&ctx) { - if (optional lvl = dec_level(get_level(ctx, type))) + set_type(type); +} + +void arith_instance::set_type(expr const & type) { + if (optional lvl = dec_level(get_level(*m_ctx, type))) m_info = mk_arith_instance_info(type, *lvl); else throw exception("failed to infer universe level"); @@ -30,9 +34,9 @@ arith_instance::arith_instance(type_context & ctx, expr const & type): expr arith_instance::mk_op(name const & op, name const & s, optional & r) { if (r) return *r; if (m_ctx) { - expr inst_type = mk_app(mk_constant(s, {m_info->m_level}), m_info->m_type); + expr inst_type = mk_app(mk_constant(s, m_info->m_levels), m_info->m_type); if (auto inst = m_ctx->mk_class_instance(inst_type)) { - r = mk_app(mk_constant(op, {m_info->m_level}), m_info->m_type, *inst); + r = mk_app(mk_constant(op, m_info->m_levels), m_info->m_type, *inst); return *r; } } @@ -42,7 +46,7 @@ expr arith_instance::mk_op(name const & op, name const & s, optional & r) expr arith_instance::mk_structure(name const & s, optional & r) { if (r) return *r; if (m_ctx) { - expr inst_type = mk_app(mk_constant(s, {m_info->m_level}), m_info->m_type); + expr inst_type = mk_app(mk_constant(s, m_info->m_levels), m_info->m_type); if (auto inst = m_ctx->mk_class_instance(inst_type)) { r = *inst; return *r; @@ -53,7 +57,7 @@ expr arith_instance::mk_structure(name const & s, optional & r) { expr arith_instance::mk_bit1() { if (!m_info->m_bit1) - m_info->m_bit1 = mk_app(mk_constant(get_bit1_name(), {m_info->m_level}), m_info->m_type, mk_has_one(), mk_has_add()); + m_info->m_bit1 = mk_app(mk_constant(get_bit1_name(), m_info->m_levels), m_info->m_type, mk_has_one(), mk_has_add()); return *m_info->m_bit1; } @@ -83,4 +87,39 @@ expr arith_instance::mk_linear_ordered_semiring() { return mk_structure(get_line expr arith_instance::mk_ring() { return mk_structure(get_ring_name(), m_info->m_ring); } expr arith_instance::mk_linear_ordered_ring() { return mk_structure(get_linear_ordered_ring_name(), m_info->m_linear_ordered_ring); } expr arith_instance::mk_field() { return mk_structure(get_field_name(), m_info->m_field); } + +expr arith_instance::mk_pos_num(mpz const & n) { + lean_assert(n > 0); + if (n == 1) + return mk_one(); + else if (n % mpz(2) == 1) + return mk_app(mk_bit1(), mk_pos_num(n/2)); + else + return mk_app(mk_bit0(), mk_pos_num(n/2)); +} + +expr arith_instance::mk_num(mpz const & n) { + if (n < 0) { + return mk_app(mk_neg(), mk_pos_num(0 - n)); + } else if (n == 0) { + return mk_zero(); + } else { + return mk_pos_num(n); + } +} + +expr arith_instance::mk_num(mpq const & q) { + mpz numer = q.get_numerator(); + mpz denom = q.get_denominator(); + lean_assert(denom >= 0); + if (denom == 1 || numer == 0) { + return mk_num(numer); + } else if (numer > 0) { + return mk_app(mk_div(), mk_num(numer), mk_num(denom)); + } else { + return mk_app(mk_neg(), mk_app(mk_div(), mk_num(neg(numer)), mk_num(denom))); + } +} + + } diff --git a/src/library/arith_instance.h b/src/library/arith_instance.h index 95195f1055..a950a40a03 100644 --- a/src/library/arith_instance.h +++ b/src/library/arith_instance.h @@ -5,13 +5,14 @@ Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #pragma once +#include "util/numerics/mpq.h" #include "library/type_context.h" namespace lean { class arith_instance_info { friend class arith_instance; - expr m_type; - level m_level; + expr m_type; + levels m_levels; /* Partial applications */ optional m_zero, m_one; @@ -30,7 +31,7 @@ class arith_instance_info { optional m_ring, m_linear_ordered_ring; optional m_field; public: - arith_instance_info(expr const & type, level const & lvl):m_type(type), m_level(lvl) {} + arith_instance_info(expr const & type, level const & lvl):m_type(type), m_levels(lvl) {} }; typedef std::shared_ptr arith_instance_info_ptr; @@ -43,11 +44,20 @@ class arith_instance { expr mk_structure(name const & s, optional & r); expr mk_op(name const & op, name const & s, optional & r); + expr mk_pos_num(mpz const & n); + public: arith_instance(type_context & ctx, arith_instance_info_ptr const & info):m_ctx(&ctx), m_info(info) {} arith_instance(type_context & ctx, expr const & type, level const & level); arith_instance(type_context & ctx, expr const & type); arith_instance(arith_instance_info_ptr const & info):m_ctx(nullptr), m_info(info) {} + arith_instance(type_context & ctx):m_ctx(&ctx) {} + + void set_info(arith_instance_info_ptr const & info) { m_info = info; } + void set_type(expr const & type); + + expr const & get_type() const { return m_info->m_type; } + levels const & get_levels() const { return m_info->m_levels; } expr mk_zero(); expr mk_one(); @@ -87,5 +97,8 @@ public: expr mk_ring(); expr mk_linear_ordered_ring(); expr mk_field(); + + expr mk_num(mpz const & n); + expr mk_num(mpq const & n); }; }; diff --git a/src/library/mpq_macro.cpp b/src/library/mpq_macro.cpp index fd2dae7d04..8ec2edc9c3 100644 --- a/src/library/mpq_macro.cpp +++ b/src/library/mpq_macro.cpp @@ -16,51 +16,6 @@ Author: Daniel Selsam #include "library/arith_instance.h" namespace lean { - -struct mpq2expr_fn { - arith_instance & m_ainst; - - mpq2expr_fn(arith_instance & ainst): m_ainst(ainst) {} - - expr operator()(mpq const & q) { - mpz numer = q.get_numerator(); - if (numer.is_zero()) - return m_ainst.mk_zero(); - - mpz denom = q.get_denominator(); - lean_assert(denom > 0); - - bool flip_sign = false; - if (numer.is_neg()) { - numer.neg(); - flip_sign = true; - } - - expr e; - if (denom == 1) { - e = pos_mpz_to_expr(numer); - } else { - e = mk_app(m_ainst.mk_div(), pos_mpz_to_expr(numer), pos_mpz_to_expr(denom)); - } - - if (flip_sign) { - return mk_app(m_ainst.mk_neg(), e); - } else { - return e; - } - } - - expr pos_mpz_to_expr(mpz const & n) { - lean_assert(n > 0); - if (n == 1) - return m_ainst.mk_one(); - if (n % mpz(2) == 1) - return mk_app(m_ainst.mk_bit1(), pos_mpz_to_expr(n/2)); - else - return mk_app(m_ainst.mk_bit0(), pos_mpz_to_expr(n/2)); - } -}; - static name * g_mpq_macro_name = nullptr; static std::string * g_mpq_opcode = nullptr; @@ -93,7 +48,7 @@ public: throw exception(sstream() << "trying to expand invalid 'mpq' macro"); type_context ctx(actx.env()); arith_instance ainst(ctx, type); - return some_expr(mpq2expr_fn(ainst)(m_q)); + return some_expr(ainst.mk_num(m_q)); } virtual void write(serializer & s) const {