/* Copyright (c) 2017 Microsoft Corporation. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. Author: Leonardo de Moura */ #include "util/sstream.h" #include "library/util.h" #include "library/constants.h" #include "library/arith_instance.h" #include "library/num.h" namespace lean { // TODO(Leo): pre compute arith_instance_info for nat, int and real arith_instance_info_ptr mk_arith_instance_info(expr const & type, level const & lvl) { return std::make_shared(type, lvl); } arith_instance::arith_instance(type_context & ctx, expr const & type, level const & level): m_ctx(&ctx), m_info(mk_arith_instance_info(type, level)) {} arith_instance::arith_instance(type_context & ctx, expr const & type): m_ctx(&ctx) { set_type(type); } void arith_instance::set_type(expr const & type) { if (optional lvl = dec_level(get_level(*m_ctx, type))) m_info = mk_arith_instance_info(type, *lvl); else throw exception("failed to infer universe level"); } expr arith_instance::mk_op(name const & op, name const & s, optional & r) { if (r) return *r; if (m_ctx) { expr inst_type = mk_app(mk_constant(s, m_info->m_levels), m_info->m_type); if (auto inst = m_ctx->mk_class_instance(inst_type)) { r = mk_app(mk_constant(op, m_info->m_levels), m_info->m_type, *inst); return *r; } } throw exception(sstream() << "failed to synthesize '" << s << "'"); } expr arith_instance::mk_structure(name const & s, optional & r) { if (r) return *r; if (m_ctx) { expr inst_type = mk_app(mk_constant(s, m_info->m_levels), m_info->m_type); if (auto inst = m_ctx->mk_class_instance(inst_type)) { r = *inst; return *r; } } throw exception(sstream() << "failed to synthesize '" << s << "'"); } expr arith_instance::mk_bit1() { if (!m_info->m_bit1) m_info->m_bit1 = mk_app(mk_constant(get_bit1_name(), m_info->m_levels), m_info->m_type, mk_has_one(), mk_has_add()); return *m_info->m_bit1; } expr arith_instance::mk_zero() { return mk_op(get_has_zero_zero_name(), get_has_zero_name(), m_info->m_zero); } expr arith_instance::mk_one() { return mk_op(get_has_one_one_name(), get_has_one_name(), m_info->m_one); } expr arith_instance::mk_add() { return mk_op(get_has_add_add_name(), get_has_add_name(), m_info->m_add); } expr arith_instance::mk_sub() { return mk_op(get_has_sub_sub_name(), get_has_sub_name(), m_info->m_sub); } expr arith_instance::mk_neg() { return mk_op(get_has_neg_neg_name(), get_has_neg_name(), m_info->m_neg); } expr arith_instance::mk_mul() { return mk_op(get_has_mul_mul_name(), get_has_mul_name(), m_info->m_mul); } expr arith_instance::mk_div() { return mk_op(get_has_div_div_name(), get_has_div_name(), m_info->m_div); } expr arith_instance::mk_inv() { return mk_op(get_has_inv_inv_name(), get_has_inv_name(), m_info->m_inv); } expr arith_instance::mk_lt() { return mk_op(get_has_lt_lt_name(), get_has_lt_name(), m_info->m_lt); } expr arith_instance::mk_le() { return mk_op(get_has_le_le_name(), get_has_le_name(), m_info->m_le); } expr arith_instance::mk_bit0() { return mk_op(get_bit0_name(), get_has_add_name(), m_info->m_bit0); } expr arith_instance::mk_partial_order() { return mk_structure(get_partial_order_name(), m_info->m_partial_order); } expr arith_instance::mk_add_comm_semigroup() { return mk_structure(get_add_comm_semigroup_name(), m_info->m_add_comm_semigroup); } expr arith_instance::mk_monoid() { return mk_structure(get_monoid_name(), m_info->m_monoid); } expr arith_instance::mk_add_monoid() { return mk_structure(get_add_monoid_name(), m_info->m_add_monoid); } expr arith_instance::mk_add_group() { return mk_structure(get_add_group_name(), m_info->m_add_group); } expr arith_instance::mk_add_comm_group() { return mk_structure(get_add_comm_group_name(), m_info->m_add_comm_group); } expr arith_instance::mk_distrib() { return mk_structure(get_distrib_name(), m_info->m_distrib); } expr arith_instance::mk_mul_zero_class() { return mk_structure(get_mul_zero_class_name(), m_info->m_mul_zero_class); } expr arith_instance::mk_semiring() { return mk_structure(get_semiring_name(), m_info->m_semiring); } expr arith_instance::mk_linear_ordered_semiring() { return mk_structure(get_linear_ordered_semiring_name(), m_info->m_linear_ordered_semiring); } expr arith_instance::mk_ring() { return mk_structure(get_ring_name(), m_info->m_ring); } expr arith_instance::mk_linear_ordered_ring() { return mk_structure(get_linear_ordered_ring_name(), m_info->m_linear_ordered_ring); } expr arith_instance::mk_field() { return mk_structure(get_field_name(), m_info->m_field); } expr arith_instance::mk_pos_num(mpz const & n) { lean_assert(n > 0); if (n == 1) return mk_one(); else if (n % mpz(2) == 1) return mk_app(mk_bit1(), mk_pos_num(n/2)); else return mk_app(mk_bit0(), mk_pos_num(n/2)); } expr arith_instance::mk_num(mpz const & n) { if (n < 0) { return mk_app(mk_neg(), mk_pos_num(0 - n)); } else if (n == 0) { return mk_zero(); } else { return mk_pos_num(n); } } expr arith_instance::mk_num(mpq const & q) { mpz numer = q.get_numerator(); mpz denom = q.get_denominator(); lean_assert(denom >= 0); if (denom == 1 || numer == 0) { return mk_num(numer); } else if (numer > 0) { return mk_app(mk_div(), mk_num(numer), mk_num(denom)); } else { return mk_app(mk_neg(), mk_app(mk_div(), mk_num(neg(numer)), mk_num(denom))); } } bool arith_instance::is_nat() { return is_constant(m_info->m_type, get_nat_name()); } optional arith_instance::eval(expr const & e) { buffer args; expr f = get_app_args(e, args); if (!is_constant(f)) { throw exception("cannot find num of nonconstant"); } else if (const_name(f) == get_has_add_add_name() && args.size() == 4) { if (auto r1 = eval(args[2])) if (auto r2 = eval(args[3])) return optional(*r1 + *r2); } else if (const_name(f) == get_has_mul_mul_name() && args.size() == 4) { if (auto r1 = eval(args[2])) if (auto r2 = eval(args[3])) return optional(*r1 * *r2); } else if (const_name(f) == get_has_sub_sub_name() && args.size() == 4) { if (auto r1 = eval(args[2])) if (auto r2 = eval(args[3])) { if (is_nat() && *r2 > *r1) return optional(0); else return optional(*r1 - *r2); } } else if (const_name(f) == get_has_div_div_name() && args.size() == 4) { if (auto r1 = eval(args[2])) if (auto r2 = eval(args[3])) { if (is_nat()) return optional(); // not supported yet else if (*r2 == 0) return optional(); // division by zero, add support for x/0 = 0 else return optional(*r1 / *r2); } } else if (const_name(f) == get_has_neg_neg_name() && args.size() == 3) { if (auto r1 = eval(args[2])) return optional(neg(*r1)); } else if (auto r = to_num(e)) { return optional(*r); } return optional(); } }