feat: usize for array types (#4801)

Add efficient `usize` functions for `Array`, `ByteArray`, `FloatArray`.

This is part 1 of 2 since there is a need to update stage0 between the
two parts. (See discussion below.)

Closes #4654
This commit is contained in:
François G. Dorais 2024-07-21 06:23:49 -04:00 committed by GitHub
parent 08acf5a136
commit 8f0631ab1f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 21 additions and 4 deletions

View file

@ -50,6 +50,13 @@ instance : Inhabited (Array α) where
def singleton (v : α) : Array α :=
mkArray 1 v
/-- Low-level version of `size` that directly queries the C array object cached size.
While this is not provable, `usize` always returns the exact size of the array since
the implementation only supports arrays of size less than `USize.size`.
-/
@[extern "lean_array_size", simp]
def usize (a : @& Array α) : USize := a.size.toUSize
/-- Low-level version of `fget` which is as fast as a C array read.
`Fin` values are represented as tag pointers in the Lean runtime. Thus,
`fget` may be slightly slower than `uget`. -/
@ -174,7 +181,7 @@ def modifyOp (self : Array α) (idx : Nat) (f : αα) : Array α :=
This kind of low level trick can be removed with a little bit of compiler support. For example, if the compiler simplifies `as.size < usizeSz` to true. -/
@[inline] unsafe def forInUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) : m β :=
let sz := USize.ofNat as.size
let sz := as.size.toUSize -- TODO: use usize
let rec @[specialize] loop (i : USize) (b : β) : m β := do
if i < sz then
let a := as.uget i lcProof
@ -280,7 +287,7 @@ def foldrM {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α
/-- See comment at `forInUnsafe` -/
@[inline]
unsafe def mapMUnsafe {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (f : α → m β) (as : Array α) : m (Array β) :=
let sz := USize.ofNat as.size
let sz := as.size.toUSize -- TODO: use usize
let rec @[specialize] map (i : USize) (r : Array NonScalar) : m (Array PNonScalar.{v}) := do
if i < sz then
let v := r.uget i lcProof

View file

@ -37,6 +37,10 @@ def push : ByteArray → UInt8 → ByteArray
def size : (@& ByteArray) → Nat
| ⟨bs⟩ => bs.size
@[extern "lean_sarray_size", simp]
def usize (a : @& ByteArray) : USize :=
a.size.toUSize
@[extern "lean_byte_array_uget"]
def uget : (a : @& ByteArray) → (i : USize) → i.toNat < a.size → UInt8
| ⟨bs⟩, i, h => bs[i]
@ -119,7 +123,7 @@ def toList (bs : ByteArray) : List UInt8 :=
TODO: avoid code duplication in the future after we improve the compiler.
-/
@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : ByteArray) (b : β) (f : UInt8 → β → m (ForInStep β)) : m β :=
let sz := USize.ofNat as.size
let sz := as.size.toUSize -- TODO: use usize
let rec @[specialize] loop (i : USize) (b : β) : m β := do
if i < sz then
let a := as.uget i lcProof

View file

@ -37,6 +37,10 @@ def push : FloatArray → Float → FloatArray
def size : (@& FloatArray) → Nat
| ⟨ds⟩ => ds.size
@[extern "lean_sarray_size", simp]
def usize (a : @& FloatArray) : USize :=
a.size.toUSize
@[extern "lean_float_array_uget"]
def uget : (a : @& FloatArray) → (i : USize) → i.toNat < a.size → Float
| ⟨ds⟩, i, h => ds[i]
@ -90,7 +94,7 @@ partial def toList (ds : FloatArray) : List Float :=
-/
-- TODO: avoid code duplication in the future after we improve the compiler.
@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : FloatArray) (b : β) (f : Float → β → m (ForInStep β)) : m β :=
let sz := USize.ofNat as.size
let sz := as.size.toUSize -- TODO: use usize
let rec @[specialize] loop (i : USize) (b : β) : m β := do
if i < sz then
let a := as.uget i lcProof

View file

@ -18,3 +18,5 @@ options get_default_options() {
return opts;
}
}
// please update stage0