fix: use saturating casts in lean_float_to_uint8 to avoid UB

This commit is contained in:
Mario Carneiro 2022-08-11 13:20:39 -04:00 committed by Leonardo de Moura
parent 9ac4cf927d
commit d8c6c827fe
3 changed files with 77 additions and 46 deletions

View file

@ -1826,20 +1826,33 @@ static inline uint64_t lean_name_hash(b_lean_obj_arg n) {
}
/* float primitives */
static inline uint8_t lean_float_to_uint8(double a) { return (uint8_t)a; }
static inline uint16_t lean_float_to_uint16(double a) { return (uint16_t)a; }
static inline uint32_t lean_float_to_uint32(double a) { return (uint32_t)a; }
static inline uint64_t lean_float_to_uint64(double a) { return (uint64_t)a; }
static inline size_t lean_float_to_usize(double a) { return (size_t)a; }
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }
static inline uint8_t lean_float_to_uint8(double a) {
return 0. <= a ? (a < 256. ? (uint8_t)a : UINT8_MAX) : 0;
}
static inline uint16_t lean_float_to_uint16(double a) {
return 0. <= a ? (a < 65536. ? (uint16_t)a : UINT16_MAX) : 0;
}
static inline uint32_t lean_float_to_uint32(double a) {
return 0. <= a ? (a < 4294967296. ? (uint32_t)a : UINT32_MAX) : 0;
}
static inline uint64_t lean_float_to_uint64(double a) {
return 0. <= a ? (a < 18446744073709551616. ? (uint64_t)a : UINT64_MAX) : 0;
}
static inline size_t lean_float_to_usize(double a) {
if (sizeof(size_t) == sizeof(uint64_t)) // NOLINT
return (size_t) lean_float_to_uint64(a); // NOLINT
else
return (size_t) lean_float_to_uint32(a); // NOLINT
}
static inline double lean_float_add(double a, double b) { return a + b; }
static inline double lean_float_sub(double a, double b) { return a - b; }
static inline double lean_float_mul(double a, double b) { return a * b; }
static inline double lean_float_div(double a, double b) { return a / b; }
static inline double lean_float_negate(double a) { return -a; }
static inline uint8_t lean_float_beq(double a, double b) { return a == b; }
static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; }
static inline uint8_t lean_float_decLt(double a, double b) { return a < b; }
static inline double lean_uint64_to_float(uint64_t a) { return (double) a; }
#ifdef __cplusplus
}

View file

@ -1,45 +1,53 @@
def tst1 : IO Unit := do
IO.println (1 : Float);
IO.println ((1 : Float) + 2);
IO.println ((2 : Float) - 3);
IO.println ((3 : Float) * 2);
IO.println ((3 : Float) / 2);
IO.println (decide ((3 : Float) < 2));
IO.println (decide ((3 : Float) < 4));
IO.println ((3 : Float) == 2);
IO.println ((2 : Float) == 2);
IO.println (decide ((3 : Float) ≤ 2));
IO.println (decide ((3 : Float) ≤ 3));
IO.println (decide ((3 : Float) ≤ 4));
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
pure ()
IO.println (1 : Float)
IO.println ((1 : Float) + 2)
IO.println ((2 : Float) - 3)
IO.println ((3 : Float) * 2)
IO.println ((3 : Float) / 2)
IO.println (decide ((3 : Float) < 2))
IO.println (decide ((3 : Float) < 4))
IO.println ((3 : Float) == 2)
IO.println ((2 : Float) == 2)
IO.println (decide ((3 : Float) ≤ 2))
IO.println (decide ((3 : Float) ≤ 3))
IO.println (decide ((3 : Float) ≤ 4))
IO.println (Float.ofInt 0)
IO.println (Float.ofInt 42)
IO.println (Float.ofInt (-42))
IO.println (0 / 0 : Float).toUInt8
IO.println (0 / 0 : Float).toUInt16
IO.println (0 / 0 : Float).toUInt32
IO.println (0 / 0 : Float).toUInt64
IO.println (-1 : Float).toUInt8
IO.println (256 : Float).toUInt8
IO.println (1 / 0 : Float).toUInt8
IO.println (-1 : Float).toUInt64
IO.println (2^64 : Float).toUInt64
IO.println (1 / 0 : Float).toUInt64
structure Foo :=
(x : Nat)
(w : UInt64)
(y : Float)
(z : Float)
structure Foo where
x : Nat
w : UInt64
y : Float
z : Float
@[noinline] def mkFoo (x : Nat) : Foo :=
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }
{ x := x, w := x.toUInt64, y := x.toFloat / 3, z := x.toFloat / 2 }
def tst2 (x : Nat) : IO Unit := do
let foo := mkFoo x;
IO.println foo.y;
IO.println foo.z
let foo := mkFoo x
IO.println foo.y
IO.println foo.z
@[noinline] def fMap (f : Float → Float) (xs : List Float) :=
xs.map f
xs.map f
def tst3 (xs : List Float) (y : Float) : IO Unit :=
IO.println (fMap (fun x => x / y) xs)
IO.println (fMap (fun x => x / y) xs)
def main : IO Unit := do
tst1;
IO.println "-----";
tst2 7;
tst3 [3, 4, 7, 8, 9, 11] 2;
pure ()
tst1
IO.println "-----"
tst2 7
tst3 [3, 4, 7, 8, 9, 11] 2

View file

@ -13,6 +13,16 @@ true
0.000000
42.000000
-42.000000
0
0
0
0
0
255
255
0
18446744073709551615
18446744073709551615
-----
2.333333
3.500000