181 lines
6 KiB
Text
181 lines
6 KiB
Text
/-
|
||
Copyright (c) 2021 Mac Malone. All rights reserved.
|
||
Released under Apache 2.0 license as described in the file LICENSE.
|
||
Authors: Mac Malone
|
||
-/
|
||
|
||
namespace Lake
|
||
|
||
--------------------------------------------------------------------------------
|
||
-- # Async / Await Abstraction
|
||
--------------------------------------------------------------------------------
|
||
|
||
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||
/- Run the monadic action as an asynchronous task. -/
|
||
async : m α → m (n α)
|
||
|
||
export Async (async)
|
||
|
||
class Await (n : Type u → Type u) (m : Type u → Type v) where
|
||
/- Wait for an asynchronous task to finish. -/
|
||
await : n α → m α
|
||
|
||
export Await (await)
|
||
|
||
-- ## Monadic Specializations
|
||
|
||
class MapAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||
mapAsync {α β : Type u} : (α → m β) → n α → m (n β)
|
||
-- := fun f x => async (await x >>= f)
|
||
|
||
export MapAsync (mapAsync)
|
||
|
||
class SeqLeftAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||
seqLeftAsync {α β : Type u} : n α → m β → m (n α)
|
||
-- := fun x y => async (await x <* y)
|
||
|
||
export SeqLeftAsync (seqLeftAsync)
|
||
|
||
class SeqRightAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||
seqRightAsync {α β : Type u} : n α → m β → m (n β)
|
||
-- := fun x y => async (await x *> y)
|
||
|
||
export SeqRightAsync (seqRightAsync)
|
||
|
||
class BindAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β)
|
||
-- := fun x f => async (await x >>= f >>= await)
|
||
|
||
export BindAsync (bindAsync)
|
||
|
||
--------------------------------------------------------------------------------
|
||
-- # List/Array Utilities
|
||
--------------------------------------------------------------------------------
|
||
|
||
-- ## Sequencing Lists/Arrays of Asynchronous Tasks
|
||
|
||
section
|
||
variable [BindAsync m n]
|
||
|
||
-- ### List
|
||
|
||
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
|
||
def afterListAsync (last : m (n β)) : (tasks : List (n α)) → m (n β)
|
||
| [] => last
|
||
| t::ts => bindAsync t fun _ => afterListAsync last ts
|
||
|
||
/-- Join all asynchronous tasks in a List into a single task beginning with `init`. -/
|
||
def joinTaskList1 [Pure m] (init : (n α)) : (tasks : List (n α)) → m (n α)
|
||
| [] => pure init
|
||
| t::ts => bindAsync init fun _ => joinTaskList1 t ts
|
||
|
||
/-- Join all asynchronous tasks in a List into a single task. -/
|
||
def joinTaskList [Pure m] [Pure n] : (tasks : List (n PUnit)) → m (n PUnit)
|
||
| [] => pure (pure ())
|
||
| t::ts => joinTaskList1 t ts
|
||
|
||
/-- Asynchronously after completing `tasks`, perform `act`. -/
|
||
def afterTaskList [Async m n] (tasks : List (n α)) (act : m β) : m (n β) :=
|
||
afterListAsync (async act) tasks
|
||
|
||
-- ### Array
|
||
|
||
/-- Efficient implementation of `afterArrayAsync`. Assumes Arrays are at max `USize`. -/
|
||
@[inline] unsafe def afterArrayAsyncUnsafe (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
|
||
if i == stop then
|
||
last
|
||
else
|
||
bindAsync (tasks.uget i lcProof) fun _ => fold (i+1) stop
|
||
if start < stop then
|
||
if stop ≤ tasks.size then
|
||
fold (USize.ofNat start) (USize.ofNat stop)
|
||
else
|
||
last
|
||
else
|
||
last
|
||
|
||
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
|
||
@[implementedBy afterArrayAsyncUnsafe]
|
||
def afterArrayAsync (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
|
||
let rec loop (i : Nat) (j : Nat) : m (n β) :=
|
||
if hlt : j < stop then
|
||
match i with
|
||
| Nat.zero => last
|
||
| Nat.succ i' =>
|
||
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||
bindAsync t fun a => loop i' (j+1)
|
||
else
|
||
last
|
||
loop (stop - start) start
|
||
if h : stop ≤ tasks.size then
|
||
fold stop h
|
||
else
|
||
fold tasks.size (Nat.le_refl _)
|
||
|
||
/-- Join all asynchronous tasks in a Array into a single task. -/
|
||
def joinTaskArray [Pure m] [Pure n] (tasks : Array (n PUnit)) : m (n PUnit) :=
|
||
if h : 0 < tasks.size then
|
||
afterArrayAsync (tasks.get ⟨tasks.size - 1, Nat.sub_lt h (by decide)⟩) tasks.pop
|
||
else
|
||
pure (pure ())
|
||
|
||
/-- Asynchronously after completing `tasks`, perform `act`. -/
|
||
def afterTaskArray [Async m n] (tasks : Array (n α)) (act : m β) : m (n β) :=
|
||
afterArrayAsync (async act) tasks
|
||
|
||
end
|
||
|
||
-- ## Mapping Lists/Arrays of Asynchronous Tasks
|
||
|
||
section
|
||
variable [BindAsync m n] [Async m n]
|
||
|
||
-- ### List
|
||
|
||
/-- Abstract version of `IO.mapTasks`. -/
|
||
def mapListAsync (f : List α → m β) (tasks : List (n α)) : m (n β) :=
|
||
go tasks []
|
||
where
|
||
go
|
||
| [], as => async (f as.reverse)
|
||
| t::tasks, as => bindAsync t fun a => go tasks (a :: as)
|
||
|
||
-- ### Array
|
||
|
||
/-- Efficient implementation of `mapArrayAsync`. Assumes Arrays are at max `USize`. -/
|
||
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) :=
|
||
if i == stop then
|
||
async (f as)
|
||
else
|
||
bindAsync (tasks.uget i lcProof) fun a => fold (i+1) stop (as.push a)
|
||
if start < stop then
|
||
if stop ≤ tasks.size then
|
||
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
|
||
else
|
||
async (f #[])
|
||
else
|
||
async (f #[])
|
||
|
||
/-- Abstract version of `IO.mapTasks` for Arrays. -/
|
||
@[implementedBy mapArrayAsyncUnsafe]
|
||
def mapArrayAsync (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
|
||
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) :=
|
||
if hlt : j < stop then
|
||
match i with
|
||
| Nat.zero => async (f as)
|
||
| Nat.succ i' =>
|
||
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||
bindAsync t fun a => loop i' (j+1) (as.push a)
|
||
else
|
||
async (f as)
|
||
loop (stop - start) start (Array.mkEmpty (stop - start))
|
||
if h : stop ≤ tasks.size then
|
||
fold stop h
|
||
else
|
||
fold tasks.size (Nat.le_refl _)
|
||
|
||
end
|