feat: add support for foldlM, foldl, ForIn instances for byte/float arrays
This commit is contained in:
parent
d03aaec944
commit
2fd024c26f
4 changed files with 161 additions and 8 deletions
|
|
@ -139,13 +139,6 @@ def modifyOp [Inhabited α] (self : Array α) (idx : Nat) (f : α → α) : Arra
|
|||
pure b
|
||||
loop 0 b
|
||||
|
||||
-- Move?
|
||||
private theorem zero_lt_of_lt : {a b : Nat} → a < b → 0 < b
|
||||
| 0, _, h => h
|
||||
| a+1, b, h =>
|
||||
have : a < b := Nat.lt_trans (Nat.lt_succ_self _) h
|
||||
zero_lt_of_lt this
|
||||
|
||||
/- Reference implementation for `forIn` -/
|
||||
@[implementedBy Array.forInUnsafe]
|
||||
protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) : m β :=
|
||||
|
|
@ -154,7 +147,7 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m
|
|||
| 0, _ => pure b
|
||||
| i+1, h =>
|
||||
have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h
|
||||
have : as.size - 1 < as.size := Nat.sub_lt (zero_lt_of_lt h') (by decide)
|
||||
have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide)
|
||||
have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this
|
||||
match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with
|
||||
| ForInStep.done b => pure b
|
||||
|
|
|
|||
|
|
@ -92,6 +92,83 @@ partial def toList (bs : ByteArray) : List UInt8 :=
|
|||
none
|
||||
loop start
|
||||
|
||||
/-
|
||||
We claim this unsafe implementation is correct because an array cannot have more than `usizeSz` elements in our runtime.
|
||||
This is similar to the `Array` version.
|
||||
|
||||
TODO: avoid code duplication in the future after we improve the compiler.
|
||||
-/
|
||||
@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : ByteArray) (b : β) (f : UInt8 → β → m (ForInStep β)) : m β :=
|
||||
let sz := USize.ofNat as.size
|
||||
let rec @[specialize] loop (i : USize) (b : β) : m β := do
|
||||
if i < sz then
|
||||
let a := as.uget i lcProof
|
||||
match (← f a b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop (i+1) b
|
||||
else
|
||||
pure b
|
||||
loop 0 b
|
||||
|
||||
/- Reference implementation for `forIn` -/
|
||||
@[implementedBy ByteArray.forInUnsafe]
|
||||
protected def forIn {β : Type v} {m : Type v → Type w} [Monad m] (as : ByteArray) (b : β) (f : UInt8 → β → m (ForInStep β)) : m β :=
|
||||
let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do
|
||||
match i, h with
|
||||
| 0, _ => pure b
|
||||
| i+1, h =>
|
||||
have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h
|
||||
have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide)
|
||||
have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this
|
||||
match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop i (Nat.le_of_lt h') b
|
||||
loop as.size (Nat.le_refl _) b
|
||||
|
||||
instance : ForIn m ByteArray UInt8 where
|
||||
forIn := ByteArray.forIn
|
||||
|
||||
/-
|
||||
See comment at forInUnsafe
|
||||
TODO: avoid code duplication.
|
||||
-/
|
||||
@[inline]
|
||||
unsafe def foldlMUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (f : β → UInt8 → m β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : m β :=
|
||||
let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do
|
||||
if i == stop then
|
||||
pure b
|
||||
else
|
||||
fold (i+1) stop (← f b (as.uget i lcProof))
|
||||
if start < stop then
|
||||
if stop ≤ as.size then
|
||||
fold (USize.ofNat start) (USize.ofNat stop) init
|
||||
else
|
||||
pure init
|
||||
else
|
||||
pure init
|
||||
|
||||
/- Reference implementation for `foldlM` -/
|
||||
@[implementedBy foldlMUnsafe]
|
||||
def foldlM {β : Type v} {m : Type v → Type w} [Monad m] (f : β → UInt8 → m β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : m β :=
|
||||
let fold (stop : Nat) (h : stop ≤ as.size) :=
|
||||
let rec loop (i : Nat) (j : Nat) (b : β) : m β := do
|
||||
if hlt : j < stop then
|
||||
match i with
|
||||
| 0 => pure b
|
||||
| i'+1 =>
|
||||
loop i' (j+1) (← f b (as.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩))
|
||||
else
|
||||
pure b
|
||||
loop (stop - start) start init
|
||||
if h : stop ≤ as.size then
|
||||
fold stop h
|
||||
else
|
||||
fold as.size (Nat.le_refl _)
|
||||
|
||||
@[inline]
|
||||
def foldl {β : Type v} (f : β → UInt8 → β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : β :=
|
||||
Id.run <| as.foldlM f init start stop
|
||||
|
||||
end ByteArray
|
||||
|
||||
def List.toByteArray (bs : List UInt8) : ByteArray :=
|
||||
|
|
|
|||
|
|
@ -74,6 +74,83 @@ partial def toList (ds : FloatArray) : List Float :=
|
|||
r.reverse
|
||||
loop 0 []
|
||||
|
||||
/-
|
||||
We claim this unsafe implementation is correct because an array cannot have more than `usizeSz` elements in our runtime.
|
||||
This is similar to the `Array` version.
|
||||
|
||||
TODO: avoid code duplication in the future after we improve the compiler.
|
||||
-/
|
||||
@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : FloatArray) (b : β) (f : Float → β → m (ForInStep β)) : m β :=
|
||||
let sz := USize.ofNat as.size
|
||||
let rec @[specialize] loop (i : USize) (b : β) : m β := do
|
||||
if i < sz then
|
||||
let a := as.uget i lcProof
|
||||
match (← f a b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop (i+1) b
|
||||
else
|
||||
pure b
|
||||
loop 0 b
|
||||
|
||||
/- Reference implementation for `forIn` -/
|
||||
@[implementedBy FloatArray.forInUnsafe]
|
||||
protected def forIn {β : Type v} {m : Type v → Type w} [Monad m] (as : FloatArray) (b : β) (f : Float → β → m (ForInStep β)) : m β :=
|
||||
let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do
|
||||
match i, h with
|
||||
| 0, _ => pure b
|
||||
| i+1, h =>
|
||||
have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h
|
||||
have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide)
|
||||
have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this
|
||||
match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with
|
||||
| ForInStep.done b => pure b
|
||||
| ForInStep.yield b => loop i (Nat.le_of_lt h') b
|
||||
loop as.size (Nat.le_refl _) b
|
||||
|
||||
instance : ForIn m FloatArray Float where
|
||||
forIn := FloatArray.forIn
|
||||
|
||||
/-
|
||||
See comment at forInUnsafe
|
||||
TODO: avoid code duplication.
|
||||
-/
|
||||
@[inline]
|
||||
unsafe def foldlMUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (f : β → Float → m β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : m β :=
|
||||
let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do
|
||||
if i == stop then
|
||||
pure b
|
||||
else
|
||||
fold (i+1) stop (← f b (as.uget i lcProof))
|
||||
if start < stop then
|
||||
if stop ≤ as.size then
|
||||
fold (USize.ofNat start) (USize.ofNat stop) init
|
||||
else
|
||||
pure init
|
||||
else
|
||||
pure init
|
||||
|
||||
/- Reference implementation for `foldlM` -/
|
||||
@[implementedBy foldlMUnsafe]
|
||||
def foldlM {β : Type v} {m : Type v → Type w} [Monad m] (f : β → Float → m β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : m β :=
|
||||
let fold (stop : Nat) (h : stop ≤ as.size) :=
|
||||
let rec loop (i : Nat) (j : Nat) (b : β) : m β := do
|
||||
if hlt : j < stop then
|
||||
match i with
|
||||
| 0 => pure b
|
||||
| i'+1 =>
|
||||
loop i' (j+1) (← f b (as.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩))
|
||||
else
|
||||
pure b
|
||||
loop (stop - start) start init
|
||||
if h : stop ≤ as.size then
|
||||
fold stop h
|
||||
else
|
||||
fold as.size (Nat.le_refl _)
|
||||
|
||||
@[inline]
|
||||
def foldl {β : Type v} (f : β → Float → β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : β :=
|
||||
Id.run <| as.foldlM f init start stop
|
||||
|
||||
end FloatArray
|
||||
|
||||
def List.toFloatArray (ds : List Float) : FloatArray :=
|
||||
|
|
|
|||
|
|
@ -261,6 +261,12 @@ theorem lt_of_succ_le {n m : Nat} (h : succ n ≤ m) : n < m :=
|
|||
theorem succ_le_of_lt {n m : Nat} (h : n < m) : succ n ≤ m :=
|
||||
h
|
||||
|
||||
theorem zero_lt_of_lt : {a b : Nat} → a < b → 0 < b
|
||||
| 0, _, h => h
|
||||
| a+1, b, h =>
|
||||
have : a < b := Nat.lt_trans (Nat.lt_succ_self _) h
|
||||
zero_lt_of_lt this
|
||||
|
||||
theorem lt_or_eq_or_le_succ {m n : Nat} (h : m ≤ succ n) : m ≤ n ∨ m = succ n :=
|
||||
Decidable.byCases
|
||||
(fun (h' : m = succ n) => Or.inr h')
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue