refactor: reorganize Async.lean
This commit is contained in:
parent
f9d6f57725
commit
0bfebc1975
3 changed files with 155 additions and 113 deletions
202
Lake/Async.lean
202
Lake/Async.lean
|
|
@ -6,126 +6,176 @@ Authors: Mac Malone
|
|||
|
||||
namespace Lake
|
||||
|
||||
-- # Async / Await
|
||||
--------------------------------------------------------------------------------
|
||||
-- # Async / Await Abstraction
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||||
/- Run the monadic action as an asynchronous task. -/
|
||||
async : m α → m (n α)
|
||||
|
||||
export Async (async)
|
||||
|
||||
class Await (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||||
class Await (n : Type u → Type u) (m : Type u → Type v) where
|
||||
/- Wait for an asynchronous task to finish. -/
|
||||
await : n α → m α
|
||||
|
||||
export Await (await)
|
||||
|
||||
class ApplicativeAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends Async m n, Await m n where
|
||||
seqLeftAsync {α β : Type u} : n α → m β → m (n α) -- := fun x y => async (await x <* y)
|
||||
seqRightAsync {α β : Type u} : n α → m β → m (n β) -- := fun x y => async (await x *> y)
|
||||
-- ## Monadic Specializations
|
||||
|
||||
export ApplicativeAsync (seqLeftAsync seqRightAsync)
|
||||
class MapAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||||
mapAsync {α β : Type u} : (α → m β) → n α → m (n β)
|
||||
-- := fun f x => async (await x >>= f)
|
||||
|
||||
class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends ApplicativeAsync m n where
|
||||
mapAsync {α β : Type u} : (α → m β) → n α → m (n β) -- := fun f x => async (await x >>= f)
|
||||
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) -- := fun x f => async (await x >>= f >>= await)
|
||||
export MapAsync (mapAsync)
|
||||
|
||||
export MonadAsync (mapAsync bindAsync)
|
||||
class SeqLeftAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||||
seqLeftAsync {α β : Type u} : n α → m β → m (n α)
|
||||
-- := fun x y => async (await x <* y)
|
||||
|
||||
export SeqLeftAsync (seqLeftAsync)
|
||||
|
||||
class SeqRightAsync (m : Type u → Type v) (n : Type u → Type u) where
|
||||
seqRightAsync {α β : Type u} : n α → m β → m (n β)
|
||||
-- := fun x y => async (await x *> y)
|
||||
|
||||
export SeqRightAsync (seqRightAsync)
|
||||
|
||||
class BindAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) where
|
||||
bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β)
|
||||
-- := fun x f => async (await x >>= f >>= await)
|
||||
|
||||
export BindAsync (bindAsync)
|
||||
|
||||
--------------------------------------------------------------------------------
|
||||
-- # List/Array Utilities
|
||||
--------------------------------------------------------------------------------
|
||||
|
||||
-- ## Sequencing Lists/Arrays of Asynchronous Tasks
|
||||
|
||||
section
|
||||
variable [MonadAsync m n]
|
||||
variable [BindAsync m n]
|
||||
|
||||
-- ## List Utilities
|
||||
-- ### List
|
||||
|
||||
/-- `MonadAsync` version of `IO.mapTasks` -/
|
||||
def mapListAsync (f : List α → m β) (ts : List (n α)) : m (n β) :=
|
||||
go ts []
|
||||
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
|
||||
def afterListAsync (last : m (n β)) : (tasks : List (n α)) → m (n β)
|
||||
| [] => last
|
||||
| t::ts => bindAsync t fun _ => afterListAsync last ts
|
||||
|
||||
/-- Join all asynchronous tasks in a List into a single task beginning with `init`. -/
|
||||
def joinTaskList1 [Pure m] (init : (n α)) : (tasks : List (n α)) → m (n α)
|
||||
| [] => pure init
|
||||
| t::ts => bindAsync init fun _ => joinTaskList1 t ts
|
||||
|
||||
/-- Join all asynchronous tasks in a List into a single task. -/
|
||||
def joinTaskList [Pure m] [Pure n] : (tasks : List (n PUnit)) → m (n PUnit)
|
||||
| [] => pure (pure ())
|
||||
| t::ts => joinTaskList1 t ts
|
||||
|
||||
/-- Asynchronously after completing `tasks`, perform `act`. -/
|
||||
def afterTaskList [Async m n] (tasks : List (n α)) (act : m β) : m (n β) :=
|
||||
afterListAsync (async act) tasks
|
||||
|
||||
-- ### Array
|
||||
|
||||
/-- Efficient implementation of `afterArrayAsync`. Assumes Arrays are at max `USize`. -/
|
||||
@[inline] unsafe def afterArrayAsyncUnsafe (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||||
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
|
||||
if i == stop then
|
||||
last
|
||||
else
|
||||
bindAsync (tasks.uget i lcProof) fun _ => fold (i+1) stop
|
||||
if start < stop then
|
||||
if stop ≤ tasks.size then
|
||||
fold (USize.ofNat start) (USize.ofNat stop)
|
||||
else
|
||||
last
|
||||
else
|
||||
last
|
||||
|
||||
/-- Spawn the asynchronous task `last` after `tasks` finish. -/
|
||||
@[implementedBy afterArrayAsyncUnsafe]
|
||||
def afterArrayAsync (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||||
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
|
||||
let rec loop (i : Nat) (j : Nat) : m (n β) :=
|
||||
if hlt : j < stop then
|
||||
match i with
|
||||
| Nat.zero => last
|
||||
| Nat.succ i' =>
|
||||
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||||
bindAsync t fun a => loop i' (j+1)
|
||||
else
|
||||
last
|
||||
loop (stop - start) start
|
||||
if h : stop ≤ tasks.size then
|
||||
fold stop h
|
||||
else
|
||||
fold tasks.size (Nat.le_refl _)
|
||||
|
||||
/-- Join all asynchronous tasks in a Array into a single task. -/
|
||||
def joinTaskArray [Pure m] [Pure n] (tasks : Array (n PUnit)) : m (n PUnit) :=
|
||||
if h : 0 < tasks.size then
|
||||
afterArrayAsync (tasks.get ⟨tasks.size - 1, Nat.sub_lt h (by decide)⟩) tasks.pop
|
||||
else
|
||||
pure (pure ())
|
||||
|
||||
/-- Asynchronously after completing `tasks`, perform `act`. -/
|
||||
def afterTaskArray [Async m n] (tasks : Array (n α)) (act : m β) : m (n β) :=
|
||||
afterArrayAsync (async act) tasks
|
||||
|
||||
end
|
||||
|
||||
-- ## Mapping Lists/Arrays of Asynchronous Tasks
|
||||
|
||||
section
|
||||
variable [BindAsync m n] [Async m n]
|
||||
|
||||
-- ### List
|
||||
|
||||
/-- Abstract version of `IO.mapTasks`. -/
|
||||
def mapListAsync (f : List α → m β) (tasks : List (n α)) : m (n β) :=
|
||||
go tasks []
|
||||
where
|
||||
go
|
||||
| [], as => async (f as.reverse)
|
||||
| t::ts, as => bindAsync t fun a => go ts (a :: as)
|
||||
| t::tasks, as => bindAsync t fun a => go tasks (a :: as)
|
||||
|
||||
def afterListAsync (task : m (n β)) : (ts : List (n α)) → m (n β)
|
||||
| [] => task
|
||||
| t::ts => bindAsync t fun _ => afterListAsync task ts
|
||||
-- ### Array
|
||||
|
||||
def andThenListAsync [Pure m] (task : (n α)) : (ts : List (n α)) → m (n α)
|
||||
| [] => pure task
|
||||
| t::ts => bindAsync task fun _ => andThenListAsync t ts
|
||||
|
||||
def seqListAsync [Pure m] [Pure n] : (ts : List (n PUnit)) → m (n PUnit)
|
||||
| [] => pure (pure ())
|
||||
| t::ts => andThenListAsync t ts
|
||||
|
||||
-- ## Array Utilities
|
||||
-- These Follow the pattern of Array iterators established in the Lean core.
|
||||
|
||||
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||||
/-- Efficient implementation of `mapArrayAsync`. Assumes Arrays are at max `USize`. -/
|
||||
@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||||
let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) :=
|
||||
if i == stop then
|
||||
async (f as)
|
||||
else
|
||||
bindAsync (ts.uget i lcProof) fun a => fold (i+1) stop (as.push a)
|
||||
bindAsync (tasks.uget i lcProof) fun a => fold (i+1) stop (as.push a)
|
||||
if start < stop then
|
||||
if stop ≤ ts.size then
|
||||
if stop ≤ tasks.size then
|
||||
fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop))
|
||||
else
|
||||
async (f #[])
|
||||
else
|
||||
async (f #[])
|
||||
|
||||
/-- Abstract version of `IO.mapTasks` for Arrays. -/
|
||||
@[implementedBy mapArrayAsyncUnsafe]
|
||||
def mapArrayAsync (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||||
let fold (stop : Nat) (h : stop ≤ ts.size) :=
|
||||
def mapArrayAsync (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) :=
|
||||
let fold (stop : Nat) (h : stop ≤ tasks.size) :=
|
||||
let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) :=
|
||||
if hlt : j < stop then
|
||||
match i with
|
||||
| Nat.zero => async (f as)
|
||||
| Nat.succ i' =>
|
||||
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||||
let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||||
bindAsync t fun a => loop i' (j+1) (as.push a)
|
||||
else
|
||||
async (f as)
|
||||
loop (stop - start) start (Array.mkEmpty (stop - start))
|
||||
if h : stop ≤ ts.size then
|
||||
if h : stop ≤ tasks.size then
|
||||
fold stop h
|
||||
else
|
||||
fold ts.size (Nat.le_refl _)
|
||||
|
||||
@[inline] unsafe def afterArrayAsyncUnsafe (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||||
let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) :=
|
||||
if i == stop then
|
||||
task
|
||||
else
|
||||
bindAsync (ts.uget i lcProof) fun _ => fold (i+1) stop
|
||||
if start < stop then
|
||||
if stop ≤ ts.size then
|
||||
fold (USize.ofNat start) (USize.ofNat stop)
|
||||
else
|
||||
task
|
||||
else
|
||||
task
|
||||
|
||||
@[implementedBy afterArrayAsyncUnsafe]
|
||||
def afterArrayAsync (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) :=
|
||||
let fold (stop : Nat) (h : stop ≤ ts.size) :=
|
||||
let rec loop (i : Nat) (j : Nat) : m (n β) :=
|
||||
if hlt : j < stop then
|
||||
match i with
|
||||
| Nat.zero => task
|
||||
| Nat.succ i' =>
|
||||
let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩
|
||||
bindAsync t fun a => loop i' (j+1)
|
||||
else
|
||||
task
|
||||
loop (stop - start) start
|
||||
if h : stop ≤ ts.size then
|
||||
fold stop h
|
||||
else
|
||||
fold ts.size (Nat.le_refl _)
|
||||
|
||||
def seqArrayAsync [Pure m] [Pure n] (ts : Array (n PUnit)) : m (n PUnit) :=
|
||||
if h : 0 < ts.size then
|
||||
afterArrayAsync (ts.get ⟨ts.size - 1, Nat.sub_lt h (by decide)⟩) ts.pop
|
||||
else
|
||||
pure (pure ())
|
||||
fold tasks.size (Nat.le_refl _)
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -76,18 +76,18 @@ protected def pure [Pure m] (artifact : a) (trace : t) : ActiveTarget t m a :=
|
|||
def nil [Pure m] [Inhabited t] : ActiveTarget t m PUnit :=
|
||||
ActiveTarget.pure () Inhabited.default
|
||||
|
||||
def materialize [Await m n] (self : ActiveTarget t n α) : m PUnit :=
|
||||
def materialize [Await n m] (self : ActiveTarget t n α) : m PUnit :=
|
||||
await self.task
|
||||
|
||||
def andThen [MonadAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) :=
|
||||
def andThen [SeqRightAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) :=
|
||||
seqRightAsync target.task act
|
||||
|
||||
instance [MonadAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) :=
|
||||
instance [SeqRightAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) :=
|
||||
⟨andThen⟩
|
||||
|
||||
def all [Monad m] [Pure n] [MonadAsync m n] [NilTrace t] [MixTrace t]
|
||||
def all [Monad m] [Pure n] [BindAsync m n] [Async m n] [NilTrace t] [MixTrace t]
|
||||
(targets : List (ActiveTarget t n a)) : m (ActiveTarget t n PUnit) := do
|
||||
let task ← seqListAsync <| targets.map (·.task)
|
||||
let task ← joinTaskList <| targets.map (·.task)
|
||||
let trace := mixTraceList <| targets.map (·.trace)
|
||||
return ActiveTarget.mk () trace task
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ end ActiveTarget
|
|||
-- ## Combinators
|
||||
|
||||
section
|
||||
variable [Monad m] [MonadAsync m n]
|
||||
variable [Monad m] [BindAsync m n] [Async m n]
|
||||
|
||||
def afterActiveList (targets : List (ActiveTarget t n a)) (act : m PUnit) : m (n PUnit) :=
|
||||
afterTaskList (targets.map (·.task)) act
|
||||
|
|
@ -143,18 +143,18 @@ def materialize (self : Target t m a) : m PUnit :=
|
|||
self.task
|
||||
|
||||
section
|
||||
variable [Monad m] [Pure n] [MonadAsync m n]
|
||||
variable [Monad m] [Pure n] [BindAsync m n] [Async m n]
|
||||
|
||||
def materializeListAsync (targets : List (Target t m a)) : m (n PUnit) := do
|
||||
seqListAsync (← targets.mapM (·.materializeAsync))
|
||||
joinTaskList (← targets.mapM (·.materializeAsync))
|
||||
|
||||
def materializeList (targets : List (Target t m a)) : m PUnit := do
|
||||
def materializeList [Await n m] (targets : List (Target t m a)) : m PUnit := do
|
||||
await <| ← materializeListAsync targets
|
||||
|
||||
def materializeArrayAsync (targets : Array (Target t m a)) : m (n PUnit) := do
|
||||
seqArrayAsync (← targets.mapM (·.materializeAsync))
|
||||
def materializeArrayAsync [Await n m] (targets : Array (Target t m a)) : m (n PUnit) := do
|
||||
joinTaskArray (← targets.mapM (·.materializeAsync))
|
||||
|
||||
def materializeArray (targets : Array (Target t m a)) : m PUnit := do
|
||||
def materializeArray [Await n m] (targets : Array (Target t m a)) : m PUnit := do
|
||||
await <| ← materializeArrayAsync targets
|
||||
|
||||
end
|
||||
|
|
|
|||
|
|
@ -26,43 +26,35 @@ namespace IOTask
|
|||
def spawn (act : IO α) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
|
||||
IO.asTask act prio
|
||||
|
||||
instance : Async IO IOTask := ⟨spawn⟩
|
||||
|
||||
def await (self : IOTask α) : IO α := do
|
||||
IO.ofExcept (← IO.wait self)
|
||||
|
||||
instance : Await IOTask IO := ⟨await⟩
|
||||
|
||||
def mapAsync (f : α → IO β) (self : IOTask α) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
|
||||
IO.mapTask (fun x => do let x ← IO.ofExcept x; f x) self prio
|
||||
|
||||
def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
|
||||
IO.mapTask (fun x => IO.ofExcept x <* act) self prio
|
||||
|
||||
def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
|
||||
IO.mapTask (fun x => IO.ofExcept x *> act) self prio
|
||||
|
||||
instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩
|
||||
instance : MapAsync IO IOTask := ⟨mapAsync⟩
|
||||
|
||||
def bindAsync (self : IOTask α) (f : α → IO (IOTask β)) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
|
||||
IO.bindTask self (fun x => do let x ← IO.ofExcept x; f x) prio
|
||||
|
||||
instance : MonadAsync IO IOTask where
|
||||
async := IOTask.spawn
|
||||
await := IOTask.await
|
||||
mapAsync := IOTask.mapAsync
|
||||
bindAsync := IOTask.bindAsync
|
||||
seqLeftAsync := IOTask.seqLeftAsync
|
||||
seqRightAsync := IOTask.seqRightAsync
|
||||
instance : BindAsync IO IOTask := ⟨bindAsync⟩
|
||||
|
||||
def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) :=
|
||||
IO.mapTask (fun x => IO.ofExcept x <* act) self prio
|
||||
|
||||
instance : SeqLeftAsync IO IOTask := ⟨seqLeftAsync⟩
|
||||
|
||||
def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) :=
|
||||
IO.mapTask (fun x => IO.ofExcept x *> act) self prio
|
||||
|
||||
instance : SeqRightAsync IO IOTask := ⟨seqRightAsync⟩
|
||||
instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩
|
||||
|
||||
end IOTask
|
||||
|
||||
section
|
||||
variable [Monad m] [MonadAsync m n]
|
||||
|
||||
def afterTaskList (tasks : List (n α)) (act : m β) : m (n β) :=
|
||||
afterListAsync (async act) tasks
|
||||
|
||||
def afterTaskArray (tasks : Array (n α)) (act : m β) : m (n β) :=
|
||||
afterArrayAsync (async act) tasks
|
||||
|
||||
end
|
||||
|
||||
instance : HAndThen (List (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskList⟩
|
||||
instance : HAndThen (Array (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskArray⟩
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue