feat: define Int8 (#5790)

This commit is contained in:
Henrik Böving 2024-10-25 08:06:40 +02:00 committed by GitHub
parent 19ce2040a2
commit 193b6f2bec
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 520 additions and 38 deletions

View file

@ -155,6 +155,10 @@ endif ()
# We want explicit stack probes in huge Lean stack frames for robust stack overflow detection
string(APPEND LEANC_EXTRA_FLAGS " -fstack-clash-protection")
# This makes signed integer overflow guaranteed to match 2's complement.
string(APPEND CMAKE_CXX_FLAGS " -fwrapv")
string(APPEND LEANC_EXTRA_FLAGS " -fwrapv")
if(NOT MULTI_THREAD)
message(STATUS "Disabled multi-thread support, it will not be safe to run multiple threads in parallel")
set(AUTO_THREAD_FINALIZATION OFF)

View file

@ -19,6 +19,7 @@ import Init.Data.ByteArray
import Init.Data.FloatArray
import Init.Data.Fin
import Init.Data.UInt
import Init.Data.SInt
import Init.Data.Float
import Init.Data.Option
import Init.Data.Ord

11
src/Init/Data/SInt.lean Normal file
View file

@ -0,0 +1,11 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
prelude
import Init.Data.SInt.Basic
/-!
This module contains the definitions and basic theory about signed fixed width integer types.
-/

View file

@ -0,0 +1,116 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
-/
prelude
import Init.Data.UInt.Basic
/-!
This module contains the definition of signed fixed width integer types as well as basic arithmetic
and bitwise operations on top of it.
-/
/--
The type of signed 8-bit integers. This type has special support in the
compiler to make it actually 8 bits rather than wrapping a `Nat`.
-/
structure Int8 where
/--
Obtain the `UInt8` that is 2's complement equivalent to the `Int8`.
-/
toUInt8 : UInt8
/-- The size of type `Int8`, that is, `2^8 = 256`. -/
abbrev Int8.size : Nat := 256
/--
Obtain the `BitVec` that contains the 2's complement representation of the `Int8`.
-/
@[inline] def Int8.toBitVec (x : Int8) : BitVec 8 := x.toUInt8.toBitVec
@[extern "lean_int8_of_int"]
def Int8.ofInt (i : @& Int) : Int8 := ⟨⟨BitVec.ofInt 8 i⟩⟩
@[extern "lean_int8_of_int"]
def Int8.ofNat (n : @& Nat) : Int8 := ⟨⟨BitVec.ofNat 8 n⟩⟩
abbrev Int.toInt8 := Int8.ofInt
abbrev Nat.toInt8 := Int8.ofNat
@[extern "lean_int8_to_int"]
def Int8.toInt (i : Int8) : Int := i.toBitVec.toInt
@[inline] def Int8.toNat (i : Int8) : Nat := i.toInt.toNat
@[extern "lean_int8_neg"]
def Int8.neg (i : Int8) : Int8 := ⟨⟨-i.toBitVec⟩⟩
instance : ToString Int8 where
toString i := toString i.toInt
instance : OfNat Int8 n := ⟨Int8.ofNat n⟩
instance : Neg Int8 where
neg := Int8.neg
@[extern "lean_int8_add"]
def Int8.add (a b : Int8) : Int8 := ⟨⟨a.toBitVec + b.toBitVec⟩⟩
@[extern "lean_int8_sub"]
def Int8.sub (a b : Int8) : Int8 := ⟨⟨a.toBitVec - b.toBitVec⟩⟩
@[extern "lean_int8_mul"]
def Int8.mul (a b : Int8) : Int8 := ⟨⟨a.toBitVec * b.toBitVec⟩⟩
@[extern "lean_int8_div"]
def Int8.div (a b : Int8) : Int8 := ⟨⟨BitVec.sdiv a.toBitVec b.toBitVec⟩⟩
@[extern "lean_int8_mod"]
def Int8.mod (a b : Int8) : Int8 := ⟨⟨BitVec.smod a.toBitVec b.toBitVec⟩⟩
@[extern "lean_int8_land"]
def Int8.land (a b : Int8) : Int8 := ⟨⟨a.toBitVec &&& b.toBitVec⟩⟩
@[extern "lean_int8_lor"]
def Int8.lor (a b : Int8) : Int8 := ⟨⟨a.toBitVec ||| b.toBitVec⟩⟩
@[extern "lean_int8_xor"]
def Int8.xor (a b : Int8) : Int8 := ⟨⟨a.toBitVec ^^^ b.toBitVec⟩⟩
@[extern "lean_int8_shift_left"]
def Int8.shiftLeft (a b : Int8) : Int8 := ⟨⟨a.toBitVec <<< (mod b 8).toBitVec⟩⟩
@[extern "lean_int8_shift_right"]
def Int8.shiftRight (a b : Int8) : Int8 := ⟨⟨BitVec.sshiftRight' a.toBitVec (mod b 8).toBitVec⟩⟩
@[extern "lean_int8_complement"]
def Int8.complement (a : Int8) : Int8 := ⟨⟨~~~a.toBitVec⟩⟩
@[extern "lean_int8_dec_eq"]
def Int8.decEq (a b : Int8) : Decidable (a = b) :=
match a, b with
| ⟨n⟩, ⟨m⟩ =>
if h : n = m then
isTrue <| h ▸ rfl
else
isFalse (fun h' => Int8.noConfusion h' (fun h' => absurd h' h))
def Int8.lt (a b : Int8) : Prop := a.toBitVec.slt b.toBitVec
def Int8.le (a b : Int8) : Prop := a.toBitVec.sle b.toBitVec
instance : Inhabited Int8 where
default := 0
instance : Add Int8 := ⟨Int8.add⟩
instance : Sub Int8 := ⟨Int8.sub⟩
instance : Mul Int8 := ⟨Int8.mul⟩
instance : Mod Int8 := ⟨Int8.mod⟩
instance : Div Int8 := ⟨Int8.div⟩
instance : LT Int8 := ⟨Int8.lt⟩
instance : LE Int8 := ⟨Int8.le⟩
instance : Complement Int8 := ⟨Int8.complement⟩
instance : AndOp Int8 := ⟨Int8.land⟩
instance : OrOp Int8 := ⟨Int8.lor⟩
instance : Xor Int8 := ⟨Int8.xor⟩
instance : ShiftLeft Int8 := ⟨Int8.shiftLeft⟩
instance : ShiftRight Int8 := ⟨Int8.shiftRight⟩
instance : DecidableEq Int8 := Int8.decEq
@[extern "lean_int8_dec_lt"]
def Int8.decLt (a b : Int8) : Decidable (a < b) :=
inferInstanceAs (Decidable (a.toBitVec.slt b.toBitVec))
@[extern "lean_int8_dec_le"]
def Int8.decLe (a b : Int8) : Decidable (a ≤ b) :=
inferInstanceAs (Decidable (a.toBitVec.sle b.toBitVec))
instance (a b : Int8) : Decidable (a < b) := Int8.decLt a b
instance (a b : Int8) : Decidable (a ≤ b) := Int8.decLe a b
instance : Max Int8 := maxOfLe
instance : Min Int8 := minOfLe

View file

@ -1847,6 +1847,140 @@ static inline uint8_t lean_usize_dec_le(size_t a1, size_t a2) { return a1 <= a2;
static inline uint32_t lean_usize_to_uint32(size_t a) { return ((uint32_t)a); }
static inline uint64_t lean_usize_to_uint64(size_t a) { return ((uint64_t)a); }
/*
* Note that we compile all files with -frwapv so in the following section all potential UB that
* may arise from signed overflow is forced to match 2's complement behavior.
*
* We furthermore rely on the implementation defined behavior of gcc/clang to apply reduction mod
* 2^N when casting to an integer type of size N:
* https://gcc.gnu.org/onlinedocs/gcc/Integers-implementation.html#Integers-implementation
* Unfortunately LLVM does not yet document their implementation defined behavior but it is
* most likely fine to rely on the fact that GCC and LLVM match on this:
* https://github.com/llvm/llvm-project/issues/11644
*/
/* Int8 */
LEAN_EXPORT int8_t lean_int8_of_big_int(b_lean_obj_arg a);
static inline uint8_t lean_int8_of_int(b_lean_obj_arg a) {
int8_t res;
if (lean_is_scalar(a)) {
res = (int8_t)lean_scalar_to_int64(a);
} else {
res = lean_int8_of_big_int(a);
}
return (uint8_t)res;
}
static inline lean_obj_res lean_int8_to_int(uint8_t a) {
int8_t arg = (int8_t)a;
return lean_int64_to_int((int64_t)arg);
}
static inline uint8_t lean_int8_neg(uint8_t a) {
int8_t arg = (int8_t)a;
return (uint8_t)(-arg);
}
static inline uint8_t lean_int8_add(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs + rhs);
}
static inline uint8_t lean_int8_sub(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs - rhs);
}
static inline uint8_t lean_int8_mul(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs * rhs);
}
static inline uint8_t lean_int8_div(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(rhs == 0 ? 0 : lhs / rhs);
}
static inline uint8_t lean_int8_mod(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(rhs == 0 ? 0 : lhs % rhs);
}
static inline uint8_t lean_int8_land(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs & rhs);
}
static inline uint8_t lean_int8_lor(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs | rhs);
}
static inline uint8_t lean_int8_xor(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs ^ rhs);
}
static inline uint8_t lean_int8_shift_right(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs >> (rhs % 8));
}
static inline uint8_t lean_int8_shift_left(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return (uint8_t)(lhs << (rhs % 8));
}
static inline uint8_t lean_int8_complement(uint8_t a) {
int8_t arg = (int8_t)a;
return (uint8_t)(~arg);
}
static inline uint8_t lean_int8_dec_eq(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return lhs == rhs;
}
static inline uint8_t lean_int8_dec_lt(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return lhs < rhs;
}
static inline uint8_t lean_int8_dec_le(uint8_t a1, uint8_t a2) {
int8_t lhs = (int8_t) a1;
int8_t rhs = (int8_t) a2;
return lhs <= rhs;
}
/* Float */
LEAN_EXPORT lean_obj_res lean_float_to_string(double a);

View file

@ -6,7 +6,7 @@ Author: Leonardo de Moura
*/
#pragma once
#include "runtime/debug.h"
#include "runtime/int64.h"
#include "runtime/int.h"
namespace lean {

33
src/runtime/int.h Normal file
View file

@ -0,0 +1,33 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <stdint.h>
#include <cstddef>
namespace lean {
typedef int8_t int8;
typedef uint8_t uint8;
static_assert(sizeof(int8) == 1, "unexpected int8 size"); // NOLINT
static_assert(sizeof(uint8) == 1, "unexpected uint8 size"); // NOLINT
//
typedef int16_t int16;
typedef uint16_t uint16;
static_assert(sizeof(int16) == 2, "unexpected int16 size"); // NOLINT
static_assert(sizeof(uint16) == 2, "unexpected uint16 size"); // NOLINT
//
typedef int32_t int32;
typedef uint32_t uint32;
static_assert(sizeof(int32) == 4, "unexpected int32 size"); // NOLINT
static_assert(sizeof(uint32) == 4, "unexpected uint32 size"); // NOLINT
typedef int64_t int64;
typedef uint64_t uint64;
static_assert(sizeof(int64) == 8, "unexpected int64 size"); // NOLINT
static_assert(sizeof(uint64) == 8, "unexpected uint64 size"); // NOLINT
//
typedef size_t usize;
}

View file

@ -1,13 +0,0 @@
/*
Copyright (c) 2013 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Author: Leonardo de Moura
*/
#include <stdint.h>
namespace lean {
typedef int64_t int64;
typedef uint64_t uint64;
static_assert(sizeof(int64) == 8, "unexpected int64 size"); // NOLINT
static_assert(sizeof(uint64) == 8, "unexpected uint64 size"); // NOLINT
}

View file

@ -239,22 +239,22 @@ void div2k(mpz & a, mpz const & b, unsigned k) {
mpz_tdiv_q_2exp(a.m_val, b.m_val, k);
}
unsigned mpz::mod8() const {
uint8 mpz::mod8() const {
mpz a;
mpz_tdiv_r_2exp(a.m_val, m_val, 8);
return a.get_unsigned_int();
return static_cast<uint8>(a.get_unsigned_int());
}
unsigned mpz::mod16() const {
uint16 mpz::mod16() const {
mpz a;
mpz_tdiv_r_2exp(a.m_val, m_val, 16);
return a.get_unsigned_int();
return static_cast<uint16>(a.get_unsigned_int());
}
unsigned mpz::mod32() const {
uint32 mpz::mod32() const {
mpz a;
mpz_tdiv_r_2exp(a.m_val, m_val, 32);
return a.get_unsigned_int();
return static_cast<uint32>(a.get_unsigned_int());
}
uint64 mpz::mod64() const {
@ -267,6 +267,12 @@ uint64 mpz::mod64() const {
return (static_cast<uint64>(h.get_unsigned_int()) << 32) + static_cast<uint64>(l.get_unsigned_int());
}
int8 mpz::smod8() const {
mpz a;
mpz_tdiv_r_2exp(a.m_val, m_val, 8);
return static_cast<int8>(a.get_int());
}
void power(mpz & a, mpz const & b, unsigned k) {
mpz_pow_ui(a.m_val, b.m_val, k);
}
@ -945,16 +951,16 @@ void div2k(mpz & a, mpz const & b, unsigned k) {
a.set(new_sz, ds.begin());
}
unsigned mpz::mod8() const {
return m_digits[0] & 0xFFu;
uint8 mpz::mod8() const {
return static_cast<uint8>(m_digits[0] & 0xFFu);
}
unsigned mpz::mod16() const {
return m_digits[0] & 0xFFFFu;
uint16 mpz::mod16() const {
return static_cast<uint16>(m_digits[0] & 0xFFFFu);
}
unsigned mpz::mod32() const {
return m_digits[0];
uint32 mpz::mod32() const {
return static_cast<uint32>(m_digits[0]);
}
uint64 mpz::mod64() const {
@ -964,6 +970,14 @@ uint64 mpz::mod64() const {
return m_digits[0] + (static_cast<uint64>(m_digits[1]) << 8*sizeof(mpn_digit));
}
int8 mpz::smod8() const {
int8_t val = mod8();
if (m_sign) {
val = -val;
}
return val;
}
void power(mpz & a, mpz const & b, unsigned k) {
a = b;
a.pow(k);

View file

@ -15,7 +15,7 @@ Author: Leonardo de Moura
#include <iostream>
#include <limits>
#include <lean/lean.h>
#include "runtime/int64.h"
#include "runtime/int.h"
#include "runtime/debug.h"
namespace lean {
@ -266,11 +266,13 @@ public:
// a <- b / 2^k
friend void div2k(mpz & a, mpz const & b, unsigned k);
unsigned mod8() const;
unsigned mod16() const;
unsigned mod32() const;
uint8 mod8() const;
uint16 mod16() const;
uint32 mod32() const;
uint64 mod64() const;
int8 smod8() const;
/**
\brief Return the position of the most significant bit.
Return 0 if the number is negative

View file

@ -1536,11 +1536,11 @@ extern "C" LEAN_EXPORT bool lean_int_big_nonneg(object * a) {
// UInt
extern "C" LEAN_EXPORT uint8 lean_uint8_of_big_nat(b_obj_arg a) {
return static_cast<uint8>(mpz_value(a).mod8());
return mpz_value(a).mod8();
}
extern "C" LEAN_EXPORT uint16 lean_uint16_of_big_nat(b_obj_arg a) {
return static_cast<uint16>(mpz_value(a).mod16());
return mpz_value(a).mod16();
}
extern "C" LEAN_EXPORT uint32 lean_uint32_of_big_nat(b_obj_arg a) {
@ -1574,6 +1574,13 @@ extern "C" LEAN_EXPORT usize lean_usize_big_modn(usize a1, b_lean_obj_arg) {
return a1;
}
// =======================================
// IntX
extern "C" LEAN_EXPORT int8 lean_int8_of_big_int(b_obj_arg a) {
return mpz_value(a).smod8();
}
// =======================================
// Float

View file

@ -10,11 +10,6 @@ Author: Leonardo de Moura
#include "runtime/mpz.h"
namespace lean {
typedef uint8_t uint8;
typedef uint16_t uint16;
typedef uint32_t uint32;
typedef uint64_t uint64;
typedef size_t usize;
typedef lean_object object;
typedef object * obj_arg;

View file

@ -7,7 +7,7 @@ Author: Gabriel Ebner
#pragma once
#include <utility>
#include <string>
#include "runtime/int64.h"
#include "runtime/int.h"
#include "util/name.h"
namespace lean {

View file

@ -0,0 +1,95 @@
#check Int8
#eval Int8.ofInt 20
#eval Int8.ofInt (-20)
#eval Int8.ofInt (-20) = -20
#eval Int8.ofInt (-130) = 126
#eval (10 : Int8) ≠ (11 : Int8)
#eval (-10 : Int8) ≠ (10 : Int8)
#eval Int8.ofNat 10 = 10
#eval Int8.ofNat 130 = 130
#eval Int8.ofNat 120 = 120
#eval Int8.ofInt (-20) = -20
#eval (Int8.ofInt (-2)).toInt = -2
#eval (Int8.ofInt (-2)).toNat = 0
#eval (Int8.ofInt (10)).toNat = 10
#eval (Int8.ofInt (10)).toInt = 10
#eval Int8.ofNat (2^64) == 0
#eval Int8.ofInt (-2^64) == 0
#eval Int8.neg 10 = -10
#eval (20 : Int8) + 20 = 40
#eval (127 : Int8) + 1 = -128
#eval (-10 : Int8) + 10 = 0
#eval (1 : Int8) - 2 = -1
#eval (-128 : Int8) - 1 = 127
#eval (1 : Int8) * 120 = 120
#eval (2 : Int8) * 10 = 20
#eval (2 : Int8) * 128 = 0
#eval (-1 : Int8) * (-1) = 1
#eval (1 : Int8) * (-1) = -1
#eval (2 : Int8) * (-10) = -20
#eval (-5 : Int8) * (-5) = 25
#eval (10 : Int8) / 2 = 5
#eval (-10 : Int8) / 2 = -5
#eval (-10 : Int8) / -2 = 5
#eval (10 : Int8) / -2 = -5
#eval (10 : Int8) / 0 = 0
#eval (10 : Int8) % 1 = 0
#eval (10 : Int8) % 0 = 0
#eval (10 : Int8) % 3 = 1
#eval (-10 : Int8) % 3 = -1
#eval (-10 : Int8) % -3 = -1
#eval (10 : Int8) % -3 = 1
#eval (10 : Int8) &&& 10 = 10
#eval (-1 : Int8) &&& 1 = 1
#eval (-1 : Int8) ^^^ 123 = ~~~123
#eval (10 : Int8) ||| 10 = 10
#eval (10 : Int8) ||| 0 = 10
#eval (10 : Int8) ||| -1 = -1
#eval (16 : Int8) >>> 1 = 8
#eval (16 : Int8) >>> 16 = 16
#eval (16 : Int8) >>> 9 = 8
#eval (-16 : Int8) >>> 1 = -8
#eval (16 : Int8) <<< 1 = 32
#eval (16 : Int8) <<< 9 = 32
#eval (-16 : Int8) <<< 1 = -32
#eval (-16 : Int8) <<< 9 = -32
#eval (-16 : Int8) >>> 1 <<< 1 = -16
#eval (0 : Int8) < 1
#eval (0 : Int8) < 120
#eval (120 : Int8) > 0
#eval -1 < (0 : Int8)
#eval -120 < (0 : Int8)
#eval ¬((0 : Int8) < (0 : Int8))
#eval (0 : Int8) ≤ 1
#eval (0 : Int8) ≤ 120
#eval -1 ≤ (0 : Int8)
#eval -120 ≤ (0 : Int8)
#eval (0 : Int8) ≤ (0 : Int8)
#eval (-10 : Int8) ≥ (-10 : Int8)
#eval max (10 : Int8) 20 = 20
#eval max (10 : Int8) (-1) = 10
#eval min (10 : Int8) 20 = 10
#eval min (10 : Int8) (-1) = -1
def test : Option Int := Id.run do
let doTest (base : Int) (i : Int) : Bool :=
let t := base + i
let a := ⟨⟨BitVec.ofInt 8 t⟩⟩
let b := Int8.ofInt t
a == b
let range := 2^9
for i in [0:2*range] do
let i := Int.ofNat i - range
if !(doTest (2^256) i) then
return i
if !(doTest (-2^256) i) then
return i
return none
#eval test.isNone
-- runtime representation
set_option trace.compiler.ir.result true in
def myId (x : Int8) : Int8 := x

View file

@ -0,0 +1,83 @@
Int8 : Type
20
-20
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
true
[result]
def myId (x_1 : u8) : u8 :=
ret x_1
def myId._boxed (x_1 : obj) : obj :=
let x_2 : u8 := unbox x_1;
dec x_1;
let x_3 : u8 := myId x_2;
let x_4 : obj := box x_3;
ret x_4