From 2e11b8ac8879f48ad73ffa06386c017aa58fabe5 Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Mon, 9 Dec 2024 22:33:43 +0100 Subject: [PATCH] feat: add support for `Float32` to the Lean runtime (#6348) This PR adds support for `Float32` to the Lean runtime. We need an update stage0, and then uncomment `Float32.lean` file. --- src/Init/Data.lean | 1 + src/Init/Data/Float32.lean | 180 +++++++++++++++++++++++++++++ src/Lean/Compiler/IR/Basic.lean | 28 +++-- src/Lean/Compiler/IR/EmitC.lean | 38 +++--- src/Lean/Compiler/IR/EmitLLVM.lean | 49 ++++---- src/Lean/Compiler/IR/Format.lean | 1 + src/include/lean/lean.h | 63 ++++++++++ src/library/compiler/ir.cpp | 2 + src/library/compiler/ir.h | 3 +- src/library/compiler/llnf.cpp | 6 +- src/library/compiler/llnf.h | 1 + src/library/compiler/util.cpp | 7 +- src/library/constants.cpp | 5 + src/library/constants.h | 1 + src/library/constants.txt | 1 + src/runtime/object.cpp | 52 +++++++++ 16 files changed, 382 insertions(+), 56 deletions(-) create mode 100644 src/Init/Data/Float32.lean diff --git a/src/Init/Data.lean b/src/Init/Data.lean index 3eaf3e9d5f..4d73951246 100644 --- a/src/Init/Data.lean +++ b/src/Init/Data.lean @@ -21,6 +21,7 @@ import Init.Data.Fin import Init.Data.UInt import Init.Data.SInt import Init.Data.Float +import Init.Data.Float32 import Init.Data.Option import Init.Data.Ord import Init.Data.Random diff --git a/src/Init/Data/Float32.lean b/src/Init/Data/Float32.lean new file mode 100644 index 0000000000..c5e08a43db --- /dev/null +++ b/src/Init/Data/Float32.lean @@ -0,0 +1,180 @@ +/- +Copyright (c) 2023 Amazon.com, Inc. or its affiliates. All Rights Reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Leonardo de Moura +-/ +prelude +import Init.Core +import Init.Data.Int.Basic +import Init.Data.ToString.Basic +import Init.Data.Float + +/- +#exit -- TODO: Remove after update stage0 + +-- Just show FloatSpec is inhabited. +opaque float32Spec : FloatSpec := { + float := Unit, + val := (), + lt := fun _ _ => True, + le := fun _ _ => True, + decLt := fun _ _ => inferInstanceAs (Decidable True), + decLe := fun _ _ => inferInstanceAs (Decidable True) +} + +/-- Native floating point type, corresponding to the IEEE 754 *binary32* format +(`float` in C or `f32` in Rust). -/ +structure Float32 where + val : float32Spec.float + +instance : Nonempty Float32 := ⟨{ val := float32Spec.val }⟩ + +@[extern "lean_float32_add"] opaque Float32.add : Float32 → Float32 → Float32 +@[extern "lean_float32_sub"] opaque Float32.sub : Float32 → Float32 → Float32 +@[extern "lean_float32_mul"] opaque Float32.mul : Float32 → Float32 → Float32 +@[extern "lean_float32_div"] opaque Float32.div : Float32 → Float32 → Float32 +@[extern "lean_float32_negate"] opaque Float32.neg : Float32 → Float32 + +set_option bootstrap.genMatcherCode false +def Float32.lt : Float32 → Float32 → Prop := fun a b => + match a, b with + | ⟨a⟩, ⟨b⟩ => float32Spec.lt a b + +def Float32.le : Float32 → Float32 → Prop := fun a b => + float32Spec.le a.val b.val + +/-- +Raw transmutation from `UInt32`. + +Float32s and UInts have the same endianness on all supported platforms. +IEEE 754 very precisely specifies the bit layout of floats. +-/ +@[extern "lean_float32_of_bits"] opaque Float32.ofBits : UInt32 → Float32 + +/-- +Raw transmutation to `UInt32`. + +Float32s and UInts have the same endianness on all supported platforms. +IEEE 754 very precisely specifies the bit layout of floats. + +Note that this function is distinct from `Float32.toUInt32`, which attempts +to preserve the numeric value, and not the bitwise value. +-/ +@[extern "lean_float32_to_bits"] opaque Float32.toBits : Float32 → UInt32 + +instance : Add Float32 := ⟨Float32.add⟩ +instance : Sub Float32 := ⟨Float32.sub⟩ +instance : Mul Float32 := ⟨Float32.mul⟩ +instance : Div Float32 := ⟨Float32.div⟩ +instance : Neg Float32 := ⟨Float32.neg⟩ +instance : LT Float32 := ⟨Float32.lt⟩ +instance : LE Float32 := ⟨Float32.le⟩ + +/-- Note: this is not reflexive since `NaN != NaN`.-/ +@[extern "lean_float32_beq"] opaque Float32.beq (a b : Float32) : Bool + +instance : BEq Float32 := ⟨Float32.beq⟩ + +@[extern "lean_float32_decLt"] opaque Float32.decLt (a b : Float32) : Decidable (a < b) := + match a, b with + | ⟨a⟩, ⟨b⟩ => float32Spec.decLt a b + +@[extern "lean_float32_decLe"] opaque Float32.decLe (a b : Float32) : Decidable (a ≤ b) := + match a, b with + | ⟨a⟩, ⟨b⟩ => float32Spec.decLe a b + +instance float32DecLt (a b : Float32) : Decidable (a < b) := Float32.decLt a b +instance float32DecLe (a b : Float32) : Decidable (a ≤ b) := Float32.decLe a b + +@[extern "lean_float32_to_string"] opaque Float32.toString : Float32 → String +/-- If the given float is non-negative, truncates the value to the nearest non-negative integer. +If negative or NaN, returns `0`. +If larger than the maximum value for `UInt8` (including Inf), returns the maximum value of `UInt8` +(i.e. `UInt8.size - 1`). +-/ +@[extern "lean_float32_to_uint8"] opaque Float32.toUInt8 : Float32 → UInt8 +/-- If the given float is non-negative, truncates the value to the nearest non-negative integer. +If negative or NaN, returns `0`. +If larger than the maximum value for `UInt16` (including Inf), returns the maximum value of `UInt16` +(i.e. `UInt16.size - 1`). +-/ +@[extern "lean_float32_to_uint16"] opaque Float32.toUInt16 : Float32 → UInt16 +/-- If the given float is non-negative, truncates the value to the nearest non-negative integer. +If negative or NaN, returns `0`. +If larger than the maximum value for `UInt32` (including Inf), returns the maximum value of `UInt32` +(i.e. `UInt32.size - 1`). +-/ +@[extern "lean_float32_to_uint32"] opaque Float32.toUInt32 : Float32 → UInt32 +/-- If the given float is non-negative, truncates the value to the nearest non-negative integer. +If negative or NaN, returns `0`. +If larger than the maximum value for `UInt64` (including Inf), returns the maximum value of `UInt64` +(i.e. `UInt64.size - 1`). +-/ +@[extern "lean_float32_to_uint64"] opaque Float32.toUInt64 : Float32 → UInt64 +/-- If the given float is non-negative, truncates the value to the nearest non-negative integer. +If negative or NaN, returns `0`. +If larger than the maximum value for `USize` (including Inf), returns the maximum value of `USize` +(i.e. `USize.size - 1`). This value is platform dependent). +-/ +@[extern "lean_float32_to_usize"] opaque Float32.toUSize : Float32 → USize + +@[extern "lean_float32_isnan"] opaque Float32.isNaN : Float32 → Bool +@[extern "lean_float32_isfinite"] opaque Float32.isFinite : Float32 → Bool +@[extern "lean_float32_isinf"] opaque Float32.isInf : Float32 → Bool +/-- Splits the given float `x` into a significand/exponent pair `(s, i)` +such that `x = s * 2^i` where `s ∈ (-1;-0.5] ∪ [0.5; 1)`. +Returns an undefined value if `x` is not finite. +-/ +@[extern "lean_float32_frexp"] opaque Float32.frExp : Float32 → Float32 × Int + +instance : ToString Float32 where + toString := Float32.toString + +@[extern "lean_uint64_to_float"] opaque UInt64.toFloat32 (n : UInt64) : Float32 + +instance : Inhabited Float32 where + default := UInt64.toFloat32 0 + +instance : Repr Float32 where + reprPrec n prec := if n < UInt64.toFloat32 0 then Repr.addAppParen (toString n) prec else toString n + +instance : ReprAtom Float32 := ⟨⟩ + +@[extern "sinf"] opaque Float32.sin : Float32 → Float32 +@[extern "cosf"] opaque Float32.cos : Float32 → Float32 +@[extern "tanf"] opaque Float32.tan : Float32 → Float32 +@[extern "asinf"] opaque Float32.asin : Float32 → Float32 +@[extern "acosf"] opaque Float32.acos : Float32 → Float32 +@[extern "atanf"] opaque Float32.atan : Float32 → Float32 +@[extern "atan2f"] opaque Float32.atan2 : Float32 → Float32 → Float32 +@[extern "sinhf"] opaque Float32.sinh : Float32 → Float32 +@[extern "coshf"] opaque Float32.cosh : Float32 → Float32 +@[extern "tanhf"] opaque Float32.tanh : Float32 → Float32 +@[extern "asinhf"] opaque Float32.asinh : Float32 → Float32 +@[extern "acoshf"] opaque Float32.acosh : Float32 → Float32 +@[extern "atanhf"] opaque Float32.atanh : Float32 → Float32 +@[extern "expf"] opaque Float32.exp : Float32 → Float32 +@[extern "exp2f"] opaque Float32.exp2 : Float32 → Float32 +@[extern "logf"] opaque Float32.log : Float32 → Float32 +@[extern "log2f"] opaque Float32.log2 : Float32 → Float32 +@[extern "log10f"] opaque Float32.log10 : Float32 → Float32 +@[extern "powf"] opaque Float32.pow : Float32 → Float32 → Float32 +@[extern "sqrtf"] opaque Float32.sqrt : Float32 → Float32 +@[extern "cbrtf"] opaque Float32.cbrt : Float32 → Float32 +@[extern "ceilf"] opaque Float32.ceil : Float32 → Float32 +@[extern "floorf"] opaque Float32.floor : Float32 → Float32 +@[extern "roundf"] opaque Float32.round : Float32 → Float32 +@[extern "fabsf"] opaque Float32.abs : Float32 → Float32 + +instance : HomogeneousPow Float32 := ⟨Float32.pow⟩ + +instance : Min Float32 := minOfLe + +instance : Max Float32 := maxOfLe + +/-- +Efficiently computes `x * 2^i`. +-/ +@[extern "lean_float32_scaleb"] +opaque Float32.scaleB (x : Float32) (i : @& Int) : Float32 +-/ diff --git a/src/Lean/Compiler/IR/Basic.lean b/src/Lean/Compiler/IR/Basic.lean index b543ffffa8..4b5059ca6d 100644 --- a/src/Lean/Compiler/IR/Basic.lean +++ b/src/Lean/Compiler/IR/Basic.lean @@ -83,12 +83,14 @@ inductive IRType where | irrelevant | object | tobject | struct (leanTypeName : Option Name) (types : Array IRType) : IRType | union (leanTypeName : Name) (types : Array IRType) : IRType + | float32 deriving Inhabited, Repr namespace IRType partial def beq : IRType → IRType → Bool | float, float => true + | float32, float32 => true | uint8, uint8 => true | uint16, uint16 => true | uint32, uint32 => true @@ -104,13 +106,14 @@ partial def beq : IRType → IRType → Bool instance : BEq IRType := ⟨beq⟩ def isScalar : IRType → Bool - | float => true - | uint8 => true - | uint16 => true - | uint32 => true - | uint64 => true - | usize => true - | _ => false + | float => true + | float32 => true + | uint8 => true + | uint16 => true + | uint32 => true + | uint64 => true + | usize => true + | _ => false def isObj : IRType → Bool | object => true @@ -611,10 +614,11 @@ def mkIf (x : VarId) (t e : FnBody) : FnBody := def getUnboxOpName (t : IRType) : String := match t with - | IRType.usize => "lean_unbox_usize" - | IRType.uint32 => "lean_unbox_uint32" - | IRType.uint64 => "lean_unbox_uint64" - | IRType.float => "lean_unbox_float" - | _ => "lean_unbox" + | IRType.usize => "lean_unbox_usize" + | IRType.uint32 => "lean_unbox_uint32" + | IRType.uint64 => "lean_unbox_uint64" + | IRType.float => "lean_unbox_float" + | IRType.float32 => "lean_unbox_float32" + | _ => "lean_unbox" end Lean.IR diff --git a/src/Lean/Compiler/IR/EmitC.lean b/src/Lean/Compiler/IR/EmitC.lean index abc0eb777c..ce5eecbdaa 100644 --- a/src/Lean/Compiler/IR/EmitC.lean +++ b/src/Lean/Compiler/IR/EmitC.lean @@ -55,6 +55,7 @@ def emitArg (x : Arg) : M Unit := def toCType : IRType → String | IRType.float => "double" + | IRType.float32 => "float" | IRType.uint8 => "uint8_t" | IRType.uint16 => "uint16_t" | IRType.uint32 => "uint32_t" @@ -311,12 +312,13 @@ def emitUSet (x : VarId) (n : Nat) (y : VarId) : M Unit := do def emitSSet (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M Unit := do match t with - | IRType.float => emit "lean_ctor_set_float" - | IRType.uint8 => emit "lean_ctor_set_uint8" - | IRType.uint16 => emit "lean_ctor_set_uint16" - | IRType.uint32 => emit "lean_ctor_set_uint32" - | IRType.uint64 => emit "lean_ctor_set_uint64" - | _ => throw "invalid instruction"; + | IRType.float => emit "lean_ctor_set_float" + | IRType.float32 => emit "lean_ctor_set_float32" + | IRType.uint8 => emit "lean_ctor_set_uint8" + | IRType.uint16 => emit "lean_ctor_set_uint16" + | IRType.uint32 => emit "lean_ctor_set_uint32" + | IRType.uint64 => emit "lean_ctor_set_uint64" + | _ => throw "invalid instruction"; emit "("; emit x; emit ", "; emitOffset n offset; emit ", "; emit y; emitLn ");" def emitJmp (j : JoinPointId) (xs : Array Arg) : M Unit := do @@ -386,12 +388,13 @@ def emitUProj (z : VarId) (i : Nat) (x : VarId) : M Unit := do def emitSProj (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M Unit := do emitLhs z; match t with - | IRType.float => emit "lean_ctor_get_float" - | IRType.uint8 => emit "lean_ctor_get_uint8" - | IRType.uint16 => emit "lean_ctor_get_uint16" - | IRType.uint32 => emit "lean_ctor_get_uint32" - | IRType.uint64 => emit "lean_ctor_get_uint64" - | _ => throw "invalid instruction" + | IRType.float => emit "lean_ctor_get_float" + | IRType.float32 => emit "lean_ctor_get_float32" + | IRType.uint8 => emit "lean_ctor_get_uint8" + | IRType.uint16 => emit "lean_ctor_get_uint16" + | IRType.uint32 => emit "lean_ctor_get_uint32" + | IRType.uint64 => emit "lean_ctor_get_uint64" + | _ => throw "invalid instruction" emit "("; emit x; emit ", "; emitOffset n offset; emitLn ");" def toStringArgs (ys : Array Arg) : List String := @@ -446,11 +449,12 @@ def emitApp (z : VarId) (f : VarId) (ys : Array Arg) : M Unit := def emitBoxFn (xType : IRType) : M Unit := match xType with - | IRType.usize => emit "lean_box_usize" - | IRType.uint32 => emit "lean_box_uint32" - | IRType.uint64 => emit "lean_box_uint64" - | IRType.float => emit "lean_box_float" - | _ => emit "lean_box" + | IRType.usize => emit "lean_box_usize" + | IRType.uint32 => emit "lean_box_uint32" + | IRType.uint64 => emit "lean_box_uint64" + | IRType.float => emit "lean_box_float" + | IRType.float32 => emit "lean_box_float32" + | _ => emit "lean_box" def emitBox (z : VarId) (x : VarId) (xType : IRType) : M Unit := do emitLhs z; emitBoxFn xType; emit "("; emit x; emitLn ");" diff --git a/src/Lean/Compiler/IR/EmitLLVM.lean b/src/Lean/Compiler/IR/EmitLLVM.lean index a93e26d164..34d6ff2886 100644 --- a/src/Lean/Compiler/IR/EmitLLVM.lean +++ b/src/Lean/Compiler/IR/EmitLLVM.lean @@ -315,6 +315,7 @@ def callLeanCtorSetTag (builder : LLVM.Builder llvmctx) def toLLVMType (t : IRType) : M llvmctx (LLVM.LLVMType llvmctx) := do match t with | IRType.float => LLVM.doubleTypeInContext llvmctx + | IRType.float32 => LLVM.floatTypeInContext llvmctx | IRType.uint8 => LLVM.intTypeInContext llvmctx 8 | IRType.uint16 => LLVM.intTypeInContext llvmctx 16 | IRType.uint32 => LLVM.intTypeInContext llvmctx 32 @@ -817,12 +818,13 @@ def emitSProj (builder : LLVM.Builder llvmctx) (z : VarId) (t : IRType) (n offset : Nat) (x : VarId) : M llvmctx Unit := do let (fnName, retty) ← match t with - | IRType.float => pure ("lean_ctor_get_float", ← LLVM.doubleTypeInContext llvmctx) - | IRType.uint8 => pure ("lean_ctor_get_uint8", ← LLVM.i8Type llvmctx) - | IRType.uint16 => pure ("lean_ctor_get_uint16", ← LLVM.i16Type llvmctx) - | IRType.uint32 => pure ("lean_ctor_get_uint32", ← LLVM.i32Type llvmctx) - | IRType.uint64 => pure ("lean_ctor_get_uint64", ← LLVM.i64Type llvmctx) - | _ => throw s!"Invalid type for lean_ctor_get: '{t}'" + | IRType.float => pure ("lean_ctor_get_float", ← LLVM.doubleTypeInContext llvmctx) + | IRType.float32 => pure ("lean_ctor_get_float32", ← LLVM.floatTypeInContext llvmctx) + | IRType.uint8 => pure ("lean_ctor_get_uint8", ← LLVM.i8Type llvmctx) + | IRType.uint16 => pure ("lean_ctor_get_uint16", ← LLVM.i16Type llvmctx) + | IRType.uint32 => pure ("lean_ctor_get_uint32", ← LLVM.i32Type llvmctx) + | IRType.uint64 => pure ("lean_ctor_get_uint64", ← LLVM.i64Type llvmctx) + | _ => throw s!"Invalid type for lean_ctor_get: '{t}'" let argtys := #[ ← LLVM.voidPtrType llvmctx, ← LLVM.unsignedType llvmctx] let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys let xval ← emitLhsVal builder x @@ -862,11 +864,12 @@ def emitBox (builder : LLVM.Builder llvmctx) (z : VarId) (x : VarId) (xType : IR let xv ← emitLhsVal builder x let (fnName, argTy, xv) ← match xType with - | IRType.usize => pure ("lean_box_usize", ← LLVM.size_tType llvmctx, xv) - | IRType.uint32 => pure ("lean_box_uint32", ← LLVM.i32Type llvmctx, xv) - | IRType.uint64 => pure ("lean_box_uint64", ← LLVM.size_tType llvmctx, xv) - | IRType.float => pure ("lean_box_float", ← LLVM.doubleTypeInContext llvmctx, xv) - | _ => do + | IRType.usize => pure ("lean_box_usize", ← LLVM.size_tType llvmctx, xv) + | IRType.uint32 => pure ("lean_box_uint32", ← LLVM.i32Type llvmctx, xv) + | IRType.uint64 => pure ("lean_box_uint64", ← LLVM.size_tType llvmctx, xv) + | IRType.float => pure ("lean_box_float", ← LLVM.doubleTypeInContext llvmctx, xv) + | IRType.float32 => pure ("lean_box_float32", ← LLVM.floatTypeInContext llvmctx, xv) + | _ => -- sign extend smaller values into i64 let xv ← LLVM.buildSext builder xv (← LLVM.size_tType llvmctx) pure ("lean_box", ← LLVM.size_tType llvmctx, xv) @@ -892,11 +895,12 @@ def callUnboxForType (builder : LLVM.Builder llvmctx) (retName : String := "") : M llvmctx (LLVM.Value llvmctx) := do let (fnName, retty) ← match t with - | IRType.usize => pure ("lean_unbox_usize", ← toLLVMType t) - | IRType.uint32 => pure ("lean_unbox_uint32", ← toLLVMType t) - | IRType.uint64 => pure ("lean_unbox_uint64", ← toLLVMType t) - | IRType.float => pure ("lean_unbox_float", ← toLLVMType t) - | _ => pure ("lean_unbox", ← LLVM.size_tType llvmctx) + | IRType.usize => pure ("lean_unbox_usize", ← toLLVMType t) + | IRType.uint32 => pure ("lean_unbox_uint32", ← toLLVMType t) + | IRType.uint64 => pure ("lean_unbox_uint64", ← toLLVMType t) + | IRType.float => pure ("lean_unbox_float", ← toLLVMType t) + | IRType.float32 => pure ("lean_unbox_float32", ← toLLVMType t) + | _ => pure ("lean_unbox", ← LLVM.size_tType llvmctx) let argtys := #[← LLVM.voidPtrType llvmctx ] let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys let fnty ← LLVM.functionType retty argtys @@ -1041,12 +1045,13 @@ def emitJmp (builder : LLVM.Builder llvmctx) (jp : JoinPointId) (xs : Array Arg) def emitSSet (builder : LLVM.Builder llvmctx) (x : VarId) (n : Nat) (offset : Nat) (y : VarId) (t : IRType) : M llvmctx Unit := do let (fnName, setty) ← match t with - | IRType.float => pure ("lean_ctor_set_float", ← LLVM.doubleTypeInContext llvmctx) - | IRType.uint8 => pure ("lean_ctor_set_uint8", ← LLVM.i8Type llvmctx) - | IRType.uint16 => pure ("lean_ctor_set_uint16", ← LLVM.i16Type llvmctx) - | IRType.uint32 => pure ("lean_ctor_set_uint32", ← LLVM.i32Type llvmctx) - | IRType.uint64 => pure ("lean_ctor_set_uint64", ← LLVM.i64Type llvmctx) - | _ => throw s!"invalid type for 'lean_ctor_set': '{t}'" + | IRType.float => pure ("lean_ctor_set_float", ← LLVM.doubleTypeInContext llvmctx) + | IRType.float32 => pure ("lean_ctor_set_float32", ← LLVM.floatTypeInContext llvmctx) + | IRType.uint8 => pure ("lean_ctor_set_uint8", ← LLVM.i8Type llvmctx) + | IRType.uint16 => pure ("lean_ctor_set_uint16", ← LLVM.i16Type llvmctx) + | IRType.uint32 => pure ("lean_ctor_set_uint32", ← LLVM.i32Type llvmctx) + | IRType.uint64 => pure ("lean_ctor_set_uint64", ← LLVM.i64Type llvmctx) + | _ => throw s!"invalid type for 'lean_ctor_set': '{t}'" let argtys := #[ ← LLVM.voidPtrType llvmctx, ← LLVM.unsignedType llvmctx, setty] let retty ← LLVM.voidType llvmctx let fn ← getOrCreateFunctionPrototype (← getLLVMModule) retty fnName argtys diff --git a/src/Lean/Compiler/IR/Format.lean b/src/Lean/Compiler/IR/Format.lean index 80a44c145f..12720d830c 100644 --- a/src/Lean/Compiler/IR/Format.lean +++ b/src/Lean/Compiler/IR/Format.lean @@ -55,6 +55,7 @@ instance : ToString Expr := ⟨fun e => Format.pretty (format e)⟩ private partial def formatIRType : IRType → Format | IRType.float => "float" + | IRType.float32 => "float32" | IRType.uint8 => "u8" | IRType.uint16 => "u16" | IRType.uint32 => "u32" diff --git a/src/include/lean/lean.h b/src/include/lean/lean.h index 88ad3af320..4b88e8b1a3 100644 --- a/src/include/lean/lean.h +++ b/src/include/lean/lean.h @@ -614,6 +614,11 @@ static inline double lean_ctor_get_float(b_lean_obj_arg o, unsigned offset) { return *((double*)((uint8_t*)(lean_ctor_obj_cptr(o)) + offset)); } +static inline float lean_ctor_get_float32(b_lean_obj_arg o, unsigned offset) { + assert(offset >= lean_ctor_num_objs(o) * sizeof(void*)); + return *((float*)((uint8_t*)(lean_ctor_obj_cptr(o)) + offset)); +} + static inline void lean_ctor_set_usize(b_lean_obj_arg o, unsigned i, size_t v) { assert(i >= lean_ctor_num_objs(o)); *((size_t*)(lean_ctor_obj_cptr(o) + i)) = v; @@ -644,6 +649,11 @@ static inline void lean_ctor_set_float(b_lean_obj_arg o, unsigned offset, double *((double*)((uint8_t*)(lean_ctor_obj_cptr(o)) + offset)) = v; } +static inline void lean_ctor_set_float32(b_lean_obj_arg o, unsigned offset, float v) { + assert(offset >= lean_ctor_num_objs(o) * sizeof(void*)); + *((float*)((uint8_t*)(lean_ctor_obj_cptr(o)) + offset)) = v; +} + /* Closures */ static inline void * lean_closure_fun(lean_object * o) { return lean_to_closure(o)->m_fun; } @@ -2561,6 +2571,15 @@ LEAN_EXPORT uint8_t lean_float_isfinite(double a); LEAN_EXPORT uint8_t lean_float_isinf(double a); LEAN_EXPORT lean_obj_res lean_float_frexp(double a); +/* Float32 */ + +LEAN_EXPORT lean_obj_res lean_float32_to_string(float a); +LEAN_EXPORT float lean_float32_scaleb(float a, b_lean_obj_arg b); +LEAN_EXPORT uint8_t lean_float32_isnan(float a); +LEAN_EXPORT uint8_t lean_float32_isfinite(float a); +LEAN_EXPORT uint8_t lean_float32_isinf(float a); +LEAN_EXPORT lean_obj_res lean_float32_frexp(float a); + /* Boxing primitives */ static inline lean_obj_res lean_box_uint32(uint32_t v) { @@ -2615,6 +2634,16 @@ static inline double lean_unbox_float(b_lean_obj_arg o) { return lean_ctor_get_float(o, 0); } +static inline lean_obj_res lean_box_float32(float v) { + lean_obj_res r = lean_alloc_ctor(0, 0, sizeof(float)); // NOLINT + lean_ctor_set_float32(r, 0, v); + return r; +} + +static inline float lean_unbox_float32(b_lean_obj_arg o) { + return lean_ctor_get_float32(o, 0); +} + /* Debugging helper functions */ LEAN_EXPORT lean_object * lean_dbg_trace(lean_obj_arg s, lean_obj_arg fn); @@ -2729,6 +2758,40 @@ static inline uint8_t lean_float_decLe(double a, double b) { return a <= b; } static inline uint8_t lean_float_decLt(double a, double b) { return a < b; } static inline double lean_uint64_to_float(uint64_t a) { return (double) a; } +/* float32 primitives */ +static inline uint8_t lean_float32_to_uint8(float a) { + return 0. <= a ? (a < 256. ? (uint8_t)a : UINT8_MAX) : 0; +} +static inline uint16_t lean_float32_to_uint16(float a) { + return 0. <= a ? (a < 65536. ? (uint16_t)a : UINT16_MAX) : 0; +} +static inline uint32_t lean_float32_to_uint32(float a) { + return 0. <= a ? (a < 4294967296. ? (uint32_t)a : UINT32_MAX) : 0; +} +static inline uint64_t lean_float32_to_uint64(float a) { + return 0. <= a ? (a < 18446744073709551616. ? (uint64_t)a : UINT64_MAX) : 0; +} +static inline size_t lean_float32_to_usize(float a) { + if (sizeof(size_t) == sizeof(uint64_t)) // NOLINT + return (size_t) lean_float32_to_uint64(a); // NOLINT + else + return (size_t) lean_float32_to_uint32(a); // NOLINT +} +LEAN_EXPORT float lean_float32_of_bits(uint32_t u); +LEAN_EXPORT uint32_t lean_float32_to_bits(float d); +static inline float lean_float32_add(float a, float b) { return a + b; } +static inline float lean_float32_sub(float a, float b) { return a - b; } +static inline float lean_float32_mul(float a, float b) { return a * b; } +static inline float lean_float32_div(float a, float b) { return a / b; } +static inline float lean_float32_negate(float a) { return -a; } +static inline uint8_t lean_float32_beq(float a, float b) { return a == b; } +static inline uint8_t lean_float32_decLe(float a, float b) { return a <= b; } +static inline uint8_t lean_float32_decLt(float a, float b) { return a < b; } +static inline float lean_uint64_to_float32(uint64_t a) { return (float) a; } + +static inline float lean_float_to_float32(double a) { return (float)a; } +static inline double lean_float32_to_float(float a) { return (double)a; } + /* Efficient C implementations of defns used by the compiler */ static inline size_t lean_hashmap_mk_idx(lean_obj_arg sz, uint64_t hash) { return (size_t)(hash & (lean_unbox(sz) - 1)); diff --git a/src/library/compiler/ir.cpp b/src/library/compiler/ir.cpp index 55c333ea6d..aa4fdbe425 100644 --- a/src/library/compiler/ir.cpp +++ b/src/library/compiler/ir.cpp @@ -123,6 +123,8 @@ static ir::type to_ir_type(expr const & e) { return ir::type::USize; } else if (const_name(e) == get_float_name()) { return ir::type::Float; + } else if (const_name(e) == get_float32_name()) { + return ir::type::Float32; } } else if (is_pi(e)) { return ir::type::Object; diff --git a/src/library/compiler/ir.h b/src/library/compiler/ir.h index 07b395acca..3a49fee0ff 100644 --- a/src/library/compiler/ir.h +++ b/src/library/compiler/ir.h @@ -16,10 +16,11 @@ inductive IRType | irrelevant | object | tobject | struct (leanTypeName : Option Name) (types : Array IRType) : IRType | union (leanTypeName : Name) (types : Array IRType) : IRType +| float32 Remark: we don't create struct/union types from C++. */ -enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject }; +enum class type { Float, UInt8, UInt16, UInt32, UInt64, USize, Irrelevant, Object, TObject, Float32 }; typedef nat var_id; typedef nat jp_id; diff --git a/src/library/compiler/llnf.cpp b/src/library/compiler/llnf.cpp index 0c8b453ce9..d86a3e5e98 100644 --- a/src/library/compiler/llnf.cpp +++ b/src/library/compiler/llnf.cpp @@ -586,7 +586,7 @@ class to_lambda_pure_fn { fields.push_back(mk_let_decl(info.get_type(), mk_uproj(major, info.m_idx))); break; case field_info::Scalar: - if (info.is_float()) { + if (info.is_float() || info.is_float32()) { fields.push_back(mk_let_decl(info.get_type(), mk_fproj(major, info.m_idx, info.m_offset))); } else { fields.push_back(mk_let_decl(info.get_type(), mk_sproj(major, info.m_size, info.m_idx, info.m_offset))); @@ -684,7 +684,7 @@ class to_lambda_pure_fn { if (first) { r = mk_let_decl(mk_enf_object_type(), r); } - if (info.is_float()) { + if (info.is_float() || info.is_float32()) { r = mk_let_decl(mk_enf_object_type(), mk_fset(r, info.m_idx, info.m_offset, args[j])); } else { r = mk_let_decl(mk_enf_object_type(), mk_sset(r, info.m_size, info.m_idx, info.m_offset, args[j])); @@ -731,7 +731,7 @@ class to_lambda_pure_fn { break; case field_info::Scalar: if (proj_idx(e) == i) { - if (info.is_float()) { + if (info.is_float() || info.is_float32()) { return mk_fproj(visit(proj_expr(e)), info.m_idx, info.m_offset); } else { return mk_sproj(visit(proj_expr(e)), info.m_size, info.m_idx, info.m_offset); diff --git a/src/library/compiler/llnf.h b/src/library/compiler/llnf.h index eddd346830..5dc6f98f57 100644 --- a/src/library/compiler/llnf.h +++ b/src/library/compiler/llnf.h @@ -66,6 +66,7 @@ struct field_info { m_kind(Scalar), m_size(sz), m_idx(num), m_offset(offset), m_type(type) {} expr get_type() const { return m_type; } bool is_float() const { return is_constant(m_type, get_float_name()); } + bool is_float32() const { return is_constant(m_type, get_float32_name()); } static field_info mk_irrelevant() { return field_info(); } static field_info mk_object(unsigned idx) { return field_info(idx); } static field_info mk_usize() { return field_info(0, true); } diff --git a/src/library/compiler/util.cpp b/src/library/compiler/util.cpp index 2d9e742cd0..96f8ed4c4d 100644 --- a/src/library/compiler/util.cpp +++ b/src/library/compiler/util.cpp @@ -386,6 +386,7 @@ bool is_runtime_builtin_type(name const & n) { n == get_uint64_name() || n == get_usize_name() || n == get_float_name() || + n == get_float32_name() || n == get_thunk_name() || n == get_task_name() || n == get_array_name() || @@ -403,7 +404,8 @@ bool is_runtime_scalar_type(name const & n) { n == get_uint32_name() || n == get_uint64_name() || n == get_usize_name() || - n == get_float_name(); + n == get_float_name() || + n == get_float32_name(); } bool is_llnf_unboxed_type(expr const & type) { @@ -493,6 +495,8 @@ expr mk_runtime_type(type_checker::state & st, local_ctx const & lctx, expr e) { return e; } else if (c == get_float_name()) { return e; + } else if (c == get_float32_name()) { + return e; } else if (optional nbytes = is_enum_type(st.env(), c)) { return *to_uint_type(*nbytes); } @@ -807,6 +811,7 @@ void initialize_compiler_util() { g_builtin_scalar_size->emplace_back(get_uint32_name(), 4); g_builtin_scalar_size->emplace_back(get_uint64_name(), 8); g_builtin_scalar_size->emplace_back(get_float_name(), 8); + g_builtin_scalar_size->emplace_back(get_float32_name(), 4); } void finalize_compiler_util() { diff --git a/src/library/constants.cpp b/src/library/constants.cpp index 63d027ab67..e172177055 100644 --- a/src/library/constants.cpp +++ b/src/library/constants.cpp @@ -44,6 +44,7 @@ name const * g_eq_subst = nullptr; name const * g_eq_symm = nullptr; name const * g_eq_trans = nullptr; name const * g_float = nullptr; +name const * g_float32 = nullptr; name const * g_float_array = nullptr; name const * g_float_array_data = nullptr; name const * g_false = nullptr; @@ -191,6 +192,8 @@ void initialize_constants() { mark_persistent(g_eq_trans->raw()); g_float = new name{"Float"}; mark_persistent(g_float->raw()); + g_float32 = new name{"Float32"}; + mark_persistent(g_float32->raw()); g_float_array = new name{"FloatArray"}; mark_persistent(g_float_array->raw()); g_float_array_data = new name{"FloatArray", "data"}; @@ -362,6 +365,7 @@ void finalize_constants() { delete g_eq_symm; delete g_eq_trans; delete g_float; + delete g_float32; delete g_float_array; delete g_float_array_data; delete g_false; @@ -468,6 +472,7 @@ name const & get_eq_subst_name() { return *g_eq_subst; } name const & get_eq_symm_name() { return *g_eq_symm; } name const & get_eq_trans_name() { return *g_eq_trans; } name const & get_float_name() { return *g_float; } +name const & get_float32_name() { return *g_float32; } name const & get_float_array_name() { return *g_float_array; } name const & get_float_array_data_name() { return *g_float_array_data; } name const & get_false_name() { return *g_false; } diff --git a/src/library/constants.h b/src/library/constants.h index 7693928801..4d2e50d6ff 100644 --- a/src/library/constants.h +++ b/src/library/constants.h @@ -46,6 +46,7 @@ name const & get_eq_subst_name(); name const & get_eq_symm_name(); name const & get_eq_trans_name(); name const & get_float_name(); +name const & get_float32_name(); name const & get_float_array_name(); name const & get_float_array_data_name(); name const & get_false_name(); diff --git a/src/library/constants.txt b/src/library/constants.txt index 973c231dab..1f6bf38a88 100644 --- a/src/library/constants.txt +++ b/src/library/constants.txt @@ -39,6 +39,7 @@ Eq.subst Eq.symm Eq.trans Float +Float32 FloatArray FloatArray.data False diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 24444786de..654e1b59f2 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1661,6 +1661,58 @@ extern "C" LEAN_EXPORT uint64_t lean_float_to_bits(double d) return ret; } +// ======================================= +// Float32 + +extern "C" LEAN_EXPORT lean_obj_res lean_float32_to_string(float a) { + if (isnan(a)) + // override NaN because we don't want NaNs to be distinguishable + // because the sign bit / payload bits can be architecture-dependent + return mk_ascii_string_unchecked("NaN"); + else + return mk_ascii_string_unchecked(std::to_string(a)); +} + +extern "C" LEAN_EXPORT float lean_float32_scaleb(float a, b_lean_obj_arg b) { + if (lean_is_scalar(b)) { + return scalbn(a, lean_scalar_to_int(b)); + } else if (a == 0 || mpz_value(b).is_neg()) { + return 0; + } else { + return a * (1.0 / 0.0); + } +} + +extern "C" LEAN_EXPORT uint8_t lean_float32_isnan(float a) { return (bool) isnan(a); } +extern "C" LEAN_EXPORT uint8_t lean_float32_isfinite(float a) { return (bool) isfinite(a); } +extern "C" LEAN_EXPORT uint8_t lean_float32_isinf(float a) { return (bool) isinf(a); } +extern "C" LEAN_EXPORT obj_res lean_float32_frexp(float a) { + object* r = lean_alloc_ctor(0, 2, 0); + int exp; + lean_ctor_set(r, 0, lean_box_float32(frexp(a, &exp))); + lean_ctor_set(r, 1, isfinite(a) ? lean_int_to_int(exp) : lean_box(0)); + return r; +} + +extern "C" LEAN_EXPORT float lean_float32_of_bits(uint32_t u) +{ + static_assert(sizeof(float) == sizeof(u), "`float` unexpected size."); + float ret; + std::memcpy(&ret, &u, sizeof(float)); + if (isnan(ret)) + ret = std::numeric_limits::quiet_NaN(); + return ret; +} + +extern "C" LEAN_EXPORT uint32_t lean_float32_to_bits(float d) +{ + uint32_t ret; + if (isnan(d)) + d = std::numeric_limits::quiet_NaN(); + std::memcpy(&ret, &d, sizeof(float)); + return ret; +} + // ======================================= // Strings