lean4-htt/Lake/Async.lean
2021-08-18 21:30:41 -04:00

181 lines
6 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

/-
Copyright (c) 2021 Mac Malone. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Mac Malone
-/
namespace Lake
--------------------------------------------------------------------------------
-- # Async / Await Abstraction
--------------------------------------------------------------------------------
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
/- Run the monadic action as an asynchronous task. -/
async : m α → m (n α)
export Async (async)
class Await (n : Type u → Type u) (m : Type u → Type v) where
/- Wait for an asynchronous task to finish. -/
await : n α → m α
export Await (await)
-- ## Monadic Specializations
class MapAsync (m : Type u → Type v) (n : Type u → Type u) where
mapAsync {α β : Type u} : (α → m β) → n α → m (n β)
-- := fun f x => async (await x >>= f)
export MapAsync (mapAsync)
class SeqLeftAsync (m : Type u → Type v) (n : Type u → Type u) where
seqLeftAsync {α β : Type u} : n α → m β → m (n α)
-- := fun x y => async (await x <* y)
export SeqLeftAsync (seqLeftAsync)
class SeqRightAsync (m : Type u → Type v) (n : Type u → Type u) where
seqRightAsync {α β : Type u} : n α → m β → m (n β)
-- := fun x y => async (await x *> y)
export SeqRightAsync (seqRightAsync)
class BindAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) where
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β)
-- := fun x f => async (await x >>= f >>= await)
export BindAsync (bindAsync)
--------------------------------------------------------------------------------
-- # List/Array Utilities
--------------------------------------------------------------------------------
-- ## Sequencing Lists/Arrays of Asynchronous Tasks
section
variable [BindAsync m n]
-- ### List
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
def afterListAsync (last : m (n β)) : (tasks : List (n α)) → m (n β)
| [] => last
| t::ts => bindAsync t fun _ => afterListAsync last ts
/-- Join all asynchronous tasks in a List into a single task beginning with `init`. -/
def joinTaskList1 [Pure m] (init : (n α)) : (tasks : List (n α)) → m (n α)
| [] => pure init
| t::ts => bindAsync init fun _ => joinTaskList1 t ts
/-- Join all asynchronous tasks in a List into a single task. -/
def joinTaskList [Pure m] [Pure n] : (tasks : List (n PUnit)) → m (n PUnit)
| [] => pure (pure ())
| t::ts => joinTaskList1 t ts
/-- Asynchronously after completing `tasks`, perform `act`. -/
def afterTaskList [Async m n] (tasks : List (n α)) (act : m β) : m (n β) :=
afterListAsync (async act) tasks
-- ### Array
/-- Efficient implementation of `afterArrayAsync`. Assumes Arrays are at max `USize`. -/
@[inline] unsafe def afterArrayAsyncUnsafe (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
if i == stop then
last
else
bindAsync (tasks.uget i lcProof) fun _ => fold (i+1) stop
if start < stop then
if stop ≤ tasks.size then
fold (USize.ofNat start) (USize.ofNat stop)
else
last
else
last
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
@[implementedBy afterArrayAsyncUnsafe]
def afterArrayAsync (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
let rec loop (i : Nat) (j : Nat) : m (n β) :=
if hlt : j < stop then
match i with
| Nat.zero => last
| Nat.succ i' =>
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
bindAsync t fun a => loop i' (j+1)
else
last
loop (stop - start) start
if h : stop ≤ tasks.size then
fold stop h
else
fold tasks.size (Nat.le_refl _)
/-- Join all asynchronous tasks in a Array into a single task. -/
def joinTaskArray [Pure m] [Pure n] (tasks : Array (n PUnit)) : m (n PUnit) :=
if h : 0 < tasks.size then
afterArrayAsync (tasks.get ⟨tasks.size - 1, Nat.sub_lt h (by decide)⟩) tasks.pop
else
pure (pure ())
/-- Asynchronously after completing `tasks`, perform `act`. -/
def afterTaskArray [Async m n] (tasks : Array (n α)) (act : m β) : m (n β) :=
afterArrayAsync (async act) tasks
end
-- ## Mapping Lists/Arrays of Asynchronous Tasks
section
variable [BindAsync m n] [Async m n]
-- ### List
/-- Abstract version of `IO.mapTasks`. -/
def mapListAsync (f : List α → m β) (tasks : List (n α)) : m (n β) :=
go tasks []
where
go
| [], as => async (f as.reverse)
| t::tasks, as => bindAsync t fun a => go tasks (a :: as)
-- ### Array
/-- Efficient implementation of `mapArrayAsync`. Assumes Arrays are at max `USize`. -/
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) :=
if i == stop then
async (f as)
else
bindAsync (tasks.uget i lcProof) fun a => fold (i+1) stop (as.push a)
if start < stop then
if stop ≤ tasks.size then
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
else
async (f #[])
else
async (f #[])
/-- Abstract version of `IO.mapTasks` for Arrays. -/
@[implementedBy mapArrayAsyncUnsafe]
def mapArrayAsync (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) :=
if hlt : j < stop then
match i with
| Nat.zero => async (f as)
| Nat.succ i' =>
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
bindAsync t fun a => loop i' (j+1) (as.push a)
else
async (f as)
loop (stop - start) start (Array.mkEmpty (stop - start))
if h : stop ≤ tasks.size then
fold stop h
else
fold tasks.size (Nat.le_refl _)
end