diff --git a/src/interval/interval.h b/src/interval/interval.h index b2f4bf7d38..dd5f327ecb 100644 --- a/src/interval/interval.h +++ b/src/interval/interval.h @@ -139,14 +139,24 @@ public: void neg(); friend interval neg(interval o) { o.neg(); return o; } - interval & operator+=(interval const & o); - interval & operator-=(interval const & o); - interval & operator*=(interval const & o); - interval & operator/=(interval const & o); + interval & operator+=(interval const & o); + interval & operator-=(interval const & o); + interval & operator*=(interval const & o); + interval & operator/=(interval const & o); + + interval & operator+=(T const & o); + interval & operator-=(T const & o); + interval & operator*=(T const & o); + interval & operator/=(T const & o); void inv(); friend interval inv(interval o) { o.inv(); return o; } + void fmod(interval y); + void fmod(T y); + friend interval inv(interval o, interval y) { o.fmod(y); return o; } + friend interval inv(interval o, T y) { o.fmod(y); return o; } + void power(unsigned n); void exp (); void exp2 (); @@ -193,10 +203,21 @@ public: friend interval acosh(interval o) { o.acosh(); return o; } friend interval atanh(interval o) { o.atanh(); return o; } - friend interval operator+(interval a, interval const & b) { return a += b; } - friend interval operator-(interval a, interval const & b) { return a -= b; } - friend interval operator*(interval a, interval const & b) { return a *= b; } - friend interval operator/(interval a, interval const & b) { return a /= b; } + friend interval operator+(interval a, interval const & b) { return a += b; } + friend interval operator-(interval a, interval const & b) { return a -= b; } + friend interval operator*(interval a, interval const & b) { return a *= b; } + friend interval operator/(interval a, interval const & b) { return a /= b; } + + friend interval operator+(interval a, T const & b) { return a += b; } + friend interval operator-(interval a, T const & b) { return a -= b; } + friend interval operator*(interval a, T const & b) { return a *= b; } + friend interval operator/(interval a, T const & b) { return a /= b; } + + friend interval operator+(T const & a, interval b) { return b += a; } + friend interval operator-(T const & a, interval b) { return b += -a; } + friend interval operator*(T const & a, interval b) { return b *= a; } + friend interval operator/(T const & a, interval b) { b = b / a; return b; } + bool check_invariant() const; diff --git a/src/interval/interval_def.h b/src/interval/interval_def.h index 332e30744b..047a6e4489 100644 --- a/src/interval/interval_def.h +++ b/src/interval/interval_def.h @@ -560,6 +560,75 @@ interval & interval::operator/=(interval const & o) { return *this; } +template +interval & interval::operator+=(T const & o) { + xnumeral_kind new_l_kind, new_u_kind; + round_to_minus_inf(); + add(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL); + round_to_plus_inf(); + add(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL); + m_lower_inf = new_l_kind == XN_MINUS_INFINITY; + m_upper_inf = new_u_kind == XN_PLUS_INFINITY; + lean_assert(check_invariant()); + return *this; +} + +template +interval & interval::operator-=(T const & o) { + xnumeral_kind new_l_kind, new_u_kind; + round_to_minus_inf(); + sub(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL); + round_to_plus_inf(); + sub(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL); + m_lower_inf = new_l_kind == XN_MINUS_INFINITY; + m_upper_inf = new_u_kind == XN_PLUS_INFINITY; + lean_assert(check_invariant()); + return *this; +} + +template +interval & interval::operator*=(T const & o) { + xnumeral_kind new_l_kind, new_u_kind; + static thread_local T tmp1; + if (this->is_zero()) { + return *this; + } + if (numeric_traits::is_zero(o)) { + numeric_traits::reset(m_lower); + numeric_traits::reset(m_upper); + m_lower_open = m_upper_open = false; + m_lower_inf = m_upper_inf = false; + return *this; + } + + if(numeric_traits::is_pos(o)) { + // [a, b] * c = [a*c, b*c] when c > 0 + round_to_minus_inf(); + mul(m_lower, new_l_kind, m_lower, lower_kind(), o, XN_NUMERAL); + round_to_plus_inf(); + mul(m_upper, new_u_kind, m_upper, upper_kind(), o, XN_NUMERAL); + m_lower_inf = new_l_kind == XN_MINUS_INFINITY; + m_upper_inf = new_u_kind == XN_PLUS_INFINITY; + } + else { + // [a, b] * c = [b*c, a*c] when c < 0 + round_to_minus_inf(); + mul(tmp1, new_l_kind, m_upper, upper_kind(), o, XN_NUMERAL); + round_to_plus_inf(); + mul(m_upper, new_u_kind, m_lower, lower_kind(), o, XN_NUMERAL); + m_lower = tmp1; + m_lower_inf = new_l_kind == XN_MINUS_INFINITY; + m_upper_inf = new_u_kind == XN_PLUS_INFINITY; + } + return *this; +} + +template +interval & interval::operator/=(T const & o) { + return *this; +} + + template void interval::inv() { // If the interval [l,u] does not contain 0, then 1/[l,u] = [1/u, 1/l] @@ -756,7 +825,15 @@ void interval::display(std::ostream & out) const { out << (m_upper_open ? ")" : "]"); } -template void interval::exp () { + +template void interval::fmod(interval y) { +} + +template void interval::fmod(T y) { + +} + +template void interval::exp() { if(is_empty()) return; if(m_lower_inf) { @@ -774,7 +851,7 @@ template void interval::exp () { lean_assert(check_invariant()); return; } -template void interval::exp2 () { +template void interval::exp2() { if(is_empty()) return; if(m_lower_inf) { @@ -810,7 +887,7 @@ template void interval::exp10() { lean_assert(check_invariant()); return; } -template void interval::log () { +template void interval::log() { if(is_empty()) return; if(is_N0()) { @@ -833,7 +910,7 @@ template void interval::log () { lean_assert(check_invariant()); return; } -template void interval::log2 () { +template void interval::log2() { if(is_empty()) return; if(is_N0()) { @@ -879,7 +956,7 @@ template void interval::log10() { lean_assert(check_invariant()); return; } -template void interval::sin () { +template void interval::sin() { *this -= numeric_traits::pi_half_lower(); cos(); }