diff --git a/src/Init/Data/Array/Basic.lean b/src/Init/Data/Array/Basic.lean index 7ca753ad51..86765aa384 100644 --- a/src/Init/Data/Array/Basic.lean +++ b/src/Init/Data/Array/Basic.lean @@ -139,13 +139,6 @@ def modifyOp [Inhabited α] (self : Array α) (idx : Nat) (f : α → α) : Arra pure b loop 0 b --- Move? -private theorem zero_lt_of_lt : {a b : Nat} → a < b → 0 < b - | 0, _, h => h - | a+1, b, h => - have : a < b := Nat.lt_trans (Nat.lt_succ_self _) h - zero_lt_of_lt this - /- Reference implementation for `forIn` -/ @[implementedBy Array.forInUnsafe] protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m] (as : Array α) (b : β) (f : α → β → m (ForInStep β)) : m β := @@ -154,7 +147,7 @@ protected def forIn {α : Type u} {β : Type v} {m : Type v → Type w} [Monad m | 0, _ => pure b | i+1, h => have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h - have : as.size - 1 < as.size := Nat.sub_lt (zero_lt_of_lt h') (by decide) + have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with | ForInStep.done b => pure b diff --git a/src/Init/Data/ByteArray/Basic.lean b/src/Init/Data/ByteArray/Basic.lean index 783ba5c619..1e73aa7de3 100644 --- a/src/Init/Data/ByteArray/Basic.lean +++ b/src/Init/Data/ByteArray/Basic.lean @@ -92,6 +92,83 @@ partial def toList (bs : ByteArray) : List UInt8 := none loop start +/- + We claim this unsafe implementation is correct because an array cannot have more than `usizeSz` elements in our runtime. + This is similar to the `Array` version. + + TODO: avoid code duplication in the future after we improve the compiler. +-/ +@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : ByteArray) (b : β) (f : UInt8 → β → m (ForInStep β)) : m β := + let sz := USize.ofNat as.size + let rec @[specialize] loop (i : USize) (b : β) : m β := do + if i < sz then + let a := as.uget i lcProof + match (← f a b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop (i+1) b + else + pure b + loop 0 b + +/- Reference implementation for `forIn` -/ +@[implementedBy ByteArray.forInUnsafe] +protected def forIn {β : Type v} {m : Type v → Type w} [Monad m] (as : ByteArray) (b : β) (f : UInt8 → β → m (ForInStep β)) : m β := + let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do + match i, h with + | 0, _ => pure b + | i+1, h => + have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h + have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) + have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this + match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop i (Nat.le_of_lt h') b + loop as.size (Nat.le_refl _) b + +instance : ForIn m ByteArray UInt8 where + forIn := ByteArray.forIn + +/- + See comment at forInUnsafe + TODO: avoid code duplication. +-/ +@[inline] +unsafe def foldlMUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (f : β → UInt8 → m β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : m β := + let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do + if i == stop then + pure b + else + fold (i+1) stop (← f b (as.uget i lcProof)) + if start < stop then + if stop ≤ as.size then + fold (USize.ofNat start) (USize.ofNat stop) init + else + pure init + else + pure init + +/- Reference implementation for `foldlM` -/ +@[implementedBy foldlMUnsafe] +def foldlM {β : Type v} {m : Type v → Type w} [Monad m] (f : β → UInt8 → m β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : m β := + let fold (stop : Nat) (h : stop ≤ as.size) := + let rec loop (i : Nat) (j : Nat) (b : β) : m β := do + if hlt : j < stop then + match i with + | 0 => pure b + | i'+1 => + loop i' (j+1) (← f b (as.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩)) + else + pure b + loop (stop - start) start init + if h : stop ≤ as.size then + fold stop h + else + fold as.size (Nat.le_refl _) + +@[inline] +def foldl {β : Type v} (f : β → UInt8 → β) (init : β) (as : ByteArray) (start := 0) (stop := as.size) : β := + Id.run <| as.foldlM f init start stop + end ByteArray def List.toByteArray (bs : List UInt8) : ByteArray := diff --git a/src/Init/Data/FloatArray/Basic.lean b/src/Init/Data/FloatArray/Basic.lean index f890e1c078..a609712081 100644 --- a/src/Init/Data/FloatArray/Basic.lean +++ b/src/Init/Data/FloatArray/Basic.lean @@ -74,6 +74,83 @@ partial def toList (ds : FloatArray) : List Float := r.reverse loop 0 [] +/- + We claim this unsafe implementation is correct because an array cannot have more than `usizeSz` elements in our runtime. + This is similar to the `Array` version. + + TODO: avoid code duplication in the future after we improve the compiler. +-/ +@[inline] unsafe def forInUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (as : FloatArray) (b : β) (f : Float → β → m (ForInStep β)) : m β := + let sz := USize.ofNat as.size + let rec @[specialize] loop (i : USize) (b : β) : m β := do + if i < sz then + let a := as.uget i lcProof + match (← f a b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop (i+1) b + else + pure b + loop 0 b + +/- Reference implementation for `forIn` -/ +@[implementedBy FloatArray.forInUnsafe] +protected def forIn {β : Type v} {m : Type v → Type w} [Monad m] (as : FloatArray) (b : β) (f : Float → β → m (ForInStep β)) : m β := + let rec loop (i : Nat) (h : i ≤ as.size) (b : β) : m β := do + match i, h with + | 0, _ => pure b + | i+1, h => + have h' : i < as.size := Nat.lt_of_lt_of_le (Nat.lt_succ_self i) h + have : as.size - 1 < as.size := Nat.sub_lt (Nat.zero_lt_of_lt h') (by decide) + have : as.size - 1 - i < as.size := Nat.lt_of_le_of_lt (Nat.sub_le (as.size - 1) i) this + match (← f (as.get ⟨as.size - 1 - i, this⟩) b) with + | ForInStep.done b => pure b + | ForInStep.yield b => loop i (Nat.le_of_lt h') b + loop as.size (Nat.le_refl _) b + +instance : ForIn m FloatArray Float where + forIn := FloatArray.forIn + +/- + See comment at forInUnsafe + TODO: avoid code duplication. +-/ +@[inline] +unsafe def foldlMUnsafe {β : Type v} {m : Type v → Type w} [Monad m] (f : β → Float → m β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : m β := + let rec @[specialize] fold (i : USize) (stop : USize) (b : β) : m β := do + if i == stop then + pure b + else + fold (i+1) stop (← f b (as.uget i lcProof)) + if start < stop then + if stop ≤ as.size then + fold (USize.ofNat start) (USize.ofNat stop) init + else + pure init + else + pure init + +/- Reference implementation for `foldlM` -/ +@[implementedBy foldlMUnsafe] +def foldlM {β : Type v} {m : Type v → Type w} [Monad m] (f : β → Float → m β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : m β := + let fold (stop : Nat) (h : stop ≤ as.size) := + let rec loop (i : Nat) (j : Nat) (b : β) : m β := do + if hlt : j < stop then + match i with + | 0 => pure b + | i'+1 => + loop i' (j+1) (← f b (as.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩)) + else + pure b + loop (stop - start) start init + if h : stop ≤ as.size then + fold stop h + else + fold as.size (Nat.le_refl _) + +@[inline] +def foldl {β : Type v} (f : β → Float → β) (init : β) (as : FloatArray) (start := 0) (stop := as.size) : β := + Id.run <| as.foldlM f init start stop + end FloatArray def List.toFloatArray (ds : List Float) : FloatArray := diff --git a/src/Init/Data/Nat/Basic.lean b/src/Init/Data/Nat/Basic.lean index e49a40a613..27582e1e5c 100644 --- a/src/Init/Data/Nat/Basic.lean +++ b/src/Init/Data/Nat/Basic.lean @@ -261,6 +261,12 @@ theorem lt_of_succ_le {n m : Nat} (h : succ n ≤ m) : n < m := theorem succ_le_of_lt {n m : Nat} (h : n < m) : succ n ≤ m := h +theorem zero_lt_of_lt : {a b : Nat} → a < b → 0 < b + | 0, _, h => h + | a+1, b, h => + have : a < b := Nat.lt_trans (Nat.lt_succ_self _) h + zero_lt_of_lt this + theorem lt_or_eq_or_le_succ {m n : Nat} (h : m ≤ succ n) : m ≤ n ∨ m = succ n := Decidable.byCases (fun (h' : m = succ n) => Or.inr h')