refactor: reorganize Async.lean

This commit is contained in:
tydeu 2021-08-18 21:30:41 -04:00
parent f9d6f57725
commit 0bfebc1975
3 changed files with 155 additions and 113 deletions

View file

@ -6,126 +6,176 @@ Authors: Mac Malone
namespace Lake
-- # Async / Await
--------------------------------------------------------------------------------
-- # Async / Await Abstraction
--------------------------------------------------------------------------------
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
/- Run the monadic action as an asynchronous task. -/
async : m α → m (n α)
export Async (async)
class Await (m : Type u → Type v) (n : outParam $ Type u → Type u) where
class Await (n : Type u → Type u) (m : Type u → Type v) where
/- Wait for an asynchronous task to finish. -/
await : n α → m α
export Await (await)
class ApplicativeAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends Async m n, Await m n where
seqLeftAsync {α β : Type u} : n α → m β → m (n α) -- := fun x y => async (await x <* y)
seqRightAsync {α β : Type u} : n α → m β → m (n β) -- := fun x y => async (await x *> y)
-- ## Monadic Specializations
export ApplicativeAsync (seqLeftAsync seqRightAsync)
class MapAsync (m : Type u → Type v) (n : Type u → Type u) where
mapAsync {α β : Type u} : (α → m β) → n α → m (n β)
-- := fun f x => async (await x >>= f)
class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends ApplicativeAsync m n where
mapAsync {α β : Type u} : (α → m β) → n α → m (n β) -- := fun f x => async (await x >>= f)
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) -- := fun x f => async (await x >>= f >>= await)
export MapAsync (mapAsync)
export MonadAsync (mapAsync bindAsync)
class SeqLeftAsync (m : Type u → Type v) (n : Type u → Type u) where
seqLeftAsync {α β : Type u} : n α → m β → m (n α)
-- := fun x y => async (await x <* y)
export SeqLeftAsync (seqLeftAsync)
class SeqRightAsync (m : Type u → Type v) (n : Type u → Type u) where
seqRightAsync {α β : Type u} : n α → m β → m (n β)
-- := fun x y => async (await x *> y)
export SeqRightAsync (seqRightAsync)
class BindAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) where
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β)
-- := fun x f => async (await x >>= f >>= await)
export BindAsync (bindAsync)
--------------------------------------------------------------------------------
-- # List/Array Utilities
--------------------------------------------------------------------------------
-- ## Sequencing Lists/Arrays of Asynchronous Tasks
section
variable [MonadAsync m n]
variable [BindAsync m n]
-- ## List Utilities
-- ### List
/-- `MonadAsync` version of `IO.mapTasks` -/
def mapListAsync (f : List α → m β) (ts : List (n α)) : m (n β) :=
go ts []
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
def afterListAsync (last : m (n β)) : (tasks : List (n α)) → m (n β)
| [] => last
| t::ts => bindAsync t fun _ => afterListAsync last ts
/-- Join all asynchronous tasks in a List into a single task beginning with `init`. -/
def joinTaskList1 [Pure m] (init : (n α)) : (tasks : List (n α)) → m (n α)
| [] => pure init
| t::ts => bindAsync init fun _ => joinTaskList1 t ts
/-- Join all asynchronous tasks in a List into a single task. -/
def joinTaskList [Pure m] [Pure n] : (tasks : List (n PUnit)) → m (n PUnit)
| [] => pure (pure ())
| t::ts => joinTaskList1 t ts
/-- Asynchronously after completing `tasks`, perform `act`. -/
def afterTaskList [Async m n] (tasks : List (n α)) (act : m β) : m (n β) :=
afterListAsync (async act) tasks
-- ### Array
/-- Efficient implementation of `afterArrayAsync`. Assumes Arrays are at max `USize`. -/
@[inline] unsafe def afterArrayAsyncUnsafe (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
if i == stop then
last
else
bindAsync (tasks.uget i lcProof) fun _ => fold (i+1) stop
if start < stop then
if stop ≤ tasks.size then
fold (USize.ofNat start) (USize.ofNat stop)
else
last
else
last
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
@[implementedBy afterArrayAsyncUnsafe]
def afterArrayAsync (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
let rec loop (i : Nat) (j : Nat) : m (n β) :=
if hlt : j < stop then
match i with
| Nat.zero => last
| Nat.succ i' =>
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
bindAsync t fun a => loop i' (j+1)
else
last
loop (stop - start) start
if h : stop ≤ tasks.size then
fold stop h
else
fold tasks.size (Nat.le_refl _)
/-- Join all asynchronous tasks in a Array into a single task. -/
def joinTaskArray [Pure m] [Pure n] (tasks : Array (n PUnit)) : m (n PUnit) :=
if h : 0 < tasks.size then
afterArrayAsync (tasks.get ⟨tasks.size - 1, Nat.sub_lt h (by decide)⟩) tasks.pop
else
pure (pure ())
/-- Asynchronously after completing `tasks`, perform `act`. -/
def afterTaskArray [Async m n] (tasks : Array (n α)) (act : m β) : m (n β) :=
afterArrayAsync (async act) tasks
end
-- ## Mapping Lists/Arrays of Asynchronous Tasks
section
variable [BindAsync m n] [Async m n]
-- ### List
/-- Abstract version of `IO.mapTasks`. -/
def mapListAsync (f : List α → m β) (tasks : List (n α)) : m (n β) :=
go tasks []
where
go
| [], as => async (f as.reverse)
| t::ts, as => bindAsync t fun a => go ts (a :: as)
| t::tasks, as => bindAsync t fun a => go tasks (a :: as)
def afterListAsync (task : m (n β)) : (ts : List (n α)) → m (n β)
| [] => task
| t::ts => bindAsync t fun _ => afterListAsync task ts
-- ### Array
def andThenListAsync [Pure m] (task : (n α)) : (ts : List (n α)) → m (n α)
| [] => pure task
| t::ts => bindAsync task fun _ => andThenListAsync t ts
def seqListAsync [Pure m] [Pure n] : (ts : List (n PUnit)) → m (n PUnit)
| [] => pure (pure ())
| t::ts => andThenListAsync t ts
-- ## Array Utilities
-- These Follow the pattern of Array iterators established in the Lean core.
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
/-- Efficient implementation of `mapArrayAsync`. Assumes Arrays are at max `USize`. -/
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) :=
if i == stop then
async (f as)
else
bindAsync (ts.uget i lcProof) fun a => fold (i+1) stop (as.push a)
bindAsync (tasks.uget i lcProof) fun a => fold (i+1) stop (as.push a)
if start < stop then
if stop ≤ ts.size then
if stop ≤ tasks.size then
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
else
async (f #[])
else
async (f #[])
/-- Abstract version of `IO.mapTasks` for Arrays. -/
@[implementedBy mapArrayAsyncUnsafe]
def mapArrayAsync (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ ts.size) :=
def mapArrayAsync (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) :=
if hlt : j < stop then
match i with
| Nat.zero => async (f as)
| Nat.succ i' =>
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
bindAsync t fun a => loop i' (j+1) (as.push a)
else
async (f as)
loop (stop - start) start (Array.mkEmpty (stop - start))
if h : stop ≤ ts.size then
if h : stop ≤ tasks.size then
fold stop h
else
fold ts.size (Nat.le_refl _)
@[inline] unsafe def afterArrayAsyncUnsafe (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
if i == stop then
task
else
bindAsync (ts.uget i lcProof) fun _ => fold (i+1) stop
if start < stop then
if stop ≤ ts.size then
fold (USize.ofNat start) (USize.ofNat stop)
else
task
else
task
@[implementedBy afterArrayAsyncUnsafe]
def afterArrayAsync (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
let fold (stop : Nat) (h : stop ≤ ts.size) :=
let rec loop (i : Nat) (j : Nat) : m (n β) :=
if hlt : j < stop then
match i with
| Nat.zero => task
| Nat.succ i' =>
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
bindAsync t fun a => loop i' (j+1)
else
task
loop (stop - start) start
if h : stop ≤ ts.size then
fold stop h
else
fold ts.size (Nat.le_refl _)
def seqArrayAsync [Pure m] [Pure n] (ts : Array (n PUnit)) : m (n PUnit) :=
if h : 0 < ts.size then
afterArrayAsync (ts.get ⟨ts.size - 1, Nat.sub_lt h (by decide)⟩) ts.pop
else
pure (pure ())
fold tasks.size (Nat.le_refl _)
end

View file

@ -76,18 +76,18 @@ protected def pure [Pure m] (artifact : a) (trace : t) : ActiveTarget t m a :=
def nil [Pure m] [Inhabited t] : ActiveTarget t m PUnit :=
ActiveTarget.pure () Inhabited.default
def materialize [Await m n] (self : ActiveTarget t n α) : m PUnit :=
def materialize [Await n m] (self : ActiveTarget t n α) : m PUnit :=
await self.task
def andThen [MonadAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) :=
def andThen [SeqRightAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) :=
seqRightAsync target.task act
instance [MonadAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) :=
instance [SeqRightAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) :=
⟨andThen⟩
def all [Monad m] [Pure n] [MonadAsync m n] [NilTrace t] [MixTrace t]
def all [Monad m] [Pure n] [BindAsync m n] [Async m n] [NilTrace t] [MixTrace t]
(targets : List (ActiveTarget t n a)) : m (ActiveTarget t n PUnit) := do
let task ← seqListAsync <| targets.map (·.task)
let task ← joinTaskList <| targets.map (·.task)
let trace := mixTraceList <| targets.map (·.trace)
return ActiveTarget.mk () trace task
@ -96,7 +96,7 @@ end ActiveTarget
-- ## Combinators
section
variable [Monad m] [MonadAsync m n]
variable [Monad m] [BindAsync m n] [Async m n]
def afterActiveList (targets : List (ActiveTarget t n a)) (act : m PUnit) : m (n PUnit) :=
afterTaskList (targets.map (·.task)) act
@ -143,18 +143,18 @@ def materialize (self : Target t m a) : m PUnit :=
self.task
section
variable [Monad m] [Pure n] [MonadAsync m n]
variable [Monad m] [Pure n] [BindAsync m n] [Async m n]
def materializeListAsync (targets : List (Target t m a)) : m (n PUnit) := do
seqListAsync (← targets.mapM (·.materializeAsync))
joinTaskList (← targets.mapM (·.materializeAsync))
def materializeList (targets : List (Target t m a)) : m PUnit := do
def materializeList [Await n m] (targets : List (Target t m a)) : m PUnit := do
await <| ← materializeListAsync targets
def materializeArrayAsync (targets : Array (Target t m a)) : m (n PUnit) := do
seqArrayAsync (← targets.mapM (·.materializeAsync))
def materializeArrayAsync [Await n m] (targets : Array (Target t m a)) : m (n PUnit) := do
joinTaskArray (← targets.mapM (·.materializeAsync))
def materializeArray (targets : Array (Target t m a)) : m PUnit := do
def materializeArray [Await n m] (targets : Array (Target t m a)) : m PUnit := do
await <| ← materializeArrayAsync targets
end

View file

@ -26,43 +26,35 @@ namespace IOTask
def spawn (act : IO α) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
IO.asTask act prio
instance : Async IO IOTask := ⟨spawn⟩
def await (self : IOTask α) : IO α := do
IO.ofExcept (← IO.wait self)
instance : Await IOTask IO := ⟨await⟩
def mapAsync (f : α → IO β) (self : IOTask α) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.mapTask (fun x => do let x ← IO.ofExcept x; f x) self prio
def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
IO.mapTask (fun x => IO.ofExcept x <* act) self prio
def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.mapTask (fun x => IO.ofExcept x *> act) self prio
instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩
instance : MapAsync IO IOTask := ⟨mapAsync⟩
def bindAsync (self : IOTask α) (f : α → IO (IOTask β)) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.bindTask self (fun x => do let x ← IO.ofExcept x; f x) prio
instance : MonadAsync IO IOTask where
async := IOTask.spawn
await := IOTask.await
mapAsync := IOTask.mapAsync
bindAsync := IOTask.bindAsync
seqLeftAsync := IOTask.seqLeftAsync
seqRightAsync := IOTask.seqRightAsync
instance : BindAsync IO IOTask := ⟨bindAsync⟩
def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
IO.mapTask (fun x => IO.ofExcept x <* act) self prio
instance : SeqLeftAsync IO IOTask := ⟨seqLeftAsync⟩
def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
IO.mapTask (fun x => IO.ofExcept x *> act) self prio
instance : SeqRightAsync IO IOTask := ⟨seqRightAsync⟩
instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩
end IOTask
section
variable [Monad m] [MonadAsync m n]
def afterTaskList (tasks : List (n α)) (act : m β) : m (n β) :=
afterListAsync (async act) tasks
def afterTaskArray (tasks : Array (n α)) (act : m β) : m (n β) :=
afterArrayAsync (async act) tasks
end
instance : HAndThen (List (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskList⟩
instance : HAndThen (Array (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskArray⟩