diff --git a/Lake/Async.lean b/Lake/Async.lean index a64165026a..3706f05c34 100644 --- a/Lake/Async.lean +++ b/Lake/Async.lean @@ -6,126 +6,176 @@ Authors: Mac Malone namespace Lake --- # Async / Await +-------------------------------------------------------------------------------- +-- # Async / Await Abstraction +-------------------------------------------------------------------------------- class Async (m : Type u → Type v) (n : outParam $ Type u → Type u) where + /- Run the monadic action as an asynchronous task. -/ async : m α → m (n α) export Async (async) -class Await (m : Type u → Type v) (n : outParam $ Type u → Type u) where +class Await (n : Type u → Type u) (m : Type u → Type v) where + /- Wait for an asynchronous task to finish. -/ await : n α → m α export Await (await) -class ApplicativeAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends Async m n, Await m n where - seqLeftAsync {α β : Type u} : n α → m β → m (n α) -- := fun x y => async (await x <* y) - seqRightAsync {α β : Type u} : n α → m β → m (n β) -- := fun x y => async (await x *> y) +-- ## Monadic Specializations -export ApplicativeAsync (seqLeftAsync seqRightAsync) +class MapAsync (m : Type u → Type v) (n : Type u → Type u) where + mapAsync {α β : Type u} : (α → m β) → n α → m (n β) + -- := fun f x => async (await x >>= f) -class MonadAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) extends ApplicativeAsync m n where - mapAsync {α β : Type u} : (α → m β) → n α → m (n β) -- := fun f x => async (await x >>= f) - bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) -- := fun x f => async (await x >>= f >>= await) +export MapAsync (mapAsync) -export MonadAsync (mapAsync bindAsync) +class SeqLeftAsync (m : Type u → Type v) (n : Type u → Type u) where + seqLeftAsync {α β : Type u} : n α → m β → m (n α) + -- := fun x y => async (await x <* y) + +export SeqLeftAsync (seqLeftAsync) + +class SeqRightAsync (m : Type u → Type v) (n : Type u → Type u) where + seqRightAsync {α β : Type u} : n α → m β → m (n β) + -- := fun x y => async (await x *> y) + +export SeqRightAsync (seqRightAsync) + +class BindAsync (m : Type u → Type v) (n : outParam $ Type u → Type u) where + bindAsync {α β : Type u} : n α → (α → m (n β)) → m (n β) + -- := fun x f => async (await x >>= f >>= await) + +export BindAsync (bindAsync) + +-------------------------------------------------------------------------------- +-- # List/Array Utilities +-------------------------------------------------------------------------------- + +-- ## Sequencing Lists/Arrays of Asynchronous Tasks section -variable [MonadAsync m n] +variable [BindAsync m n] --- ## List Utilities +-- ### List -/-- `MonadAsync` version of `IO.mapTasks` -/ -def mapListAsync (f : List α → m β) (ts : List (n α)) : m (n β) := - go ts [] +/-- Spawn the asynchronous task `last` after `tasks` finish. -/ +def afterListAsync (last : m (n β)) : (tasks : List (n α)) → m (n β) + | [] => last + | t::ts => bindAsync t fun _ => afterListAsync last ts + +/-- Join all asynchronous tasks in a List into a single task beginning with `init`. -/ +def joinTaskList1 [Pure m] (init : (n α)) : (tasks : List (n α)) → m (n α) + | [] => pure init + | t::ts => bindAsync init fun _ => joinTaskList1 t ts + +/-- Join all asynchronous tasks in a List into a single task. -/ +def joinTaskList [Pure m] [Pure n] : (tasks : List (n PUnit)) → m (n PUnit) + | [] => pure (pure ()) + | t::ts => joinTaskList1 t ts + +/-- Asynchronously after completing `tasks`, perform `act`. -/ +def afterTaskList [Async m n] (tasks : List (n α)) (act : m β) : m (n β) := + afterListAsync (async act) tasks + +-- ### Array + +/-- Efficient implementation of `afterArrayAsync`. Assumes Arrays are at max `USize`. -/ +@[inline] unsafe def afterArrayAsyncUnsafe (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) := + let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) := + if i == stop then + last + else + bindAsync (tasks.uget i lcProof) fun _ => fold (i+1) stop + if start < stop then + if stop ≤ tasks.size then + fold (USize.ofNat start) (USize.ofNat stop) + else + last + else + last + +/-- Spawn the asynchronous task `last` after `tasks` finish. -/ +@[implementedBy afterArrayAsyncUnsafe] +def afterArrayAsync (last : m (n β)) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) := + let fold (stop : Nat) (h : stop ≤ tasks.size) := + let rec loop (i : Nat) (j : Nat) : m (n β) := + if hlt : j < stop then + match i with + | Nat.zero => last + | Nat.succ i' => + let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩ + bindAsync t fun a => loop i' (j+1) + else + last + loop (stop - start) start + if h : stop ≤ tasks.size then + fold stop h + else + fold tasks.size (Nat.le_refl _) + +/-- Join all asynchronous tasks in a Array into a single task. -/ +def joinTaskArray [Pure m] [Pure n] (tasks : Array (n PUnit)) : m (n PUnit) := + if h : 0 < tasks.size then + afterArrayAsync (tasks.get ⟨tasks.size - 1, Nat.sub_lt h (by decide)⟩) tasks.pop + else + pure (pure ()) + +/-- Asynchronously after completing `tasks`, perform `act`. -/ +def afterTaskArray [Async m n] (tasks : Array (n α)) (act : m β) : m (n β) := + afterArrayAsync (async act) tasks + +end + +-- ## Mapping Lists/Arrays of Asynchronous Tasks + +section +variable [BindAsync m n] [Async m n] + +-- ### List + +/-- Abstract version of `IO.mapTasks`. -/ +def mapListAsync (f : List α → m β) (tasks : List (n α)) : m (n β) := + go tasks [] where go | [], as => async (f as.reverse) - | t::ts, as => bindAsync t fun a => go ts (a :: as) + | t::tasks, as => bindAsync t fun a => go tasks (a :: as) -def afterListAsync (task : m (n β)) : (ts : List (n α)) → m (n β) - | [] => task - | t::ts => bindAsync t fun _ => afterListAsync task ts +-- ### Array -def andThenListAsync [Pure m] (task : (n α)) : (ts : List (n α)) → m (n α) - | [] => pure task - | t::ts => bindAsync task fun _ => andThenListAsync t ts - -def seqListAsync [Pure m] [Pure n] : (ts : List (n PUnit)) → m (n PUnit) - | [] => pure (pure ()) - | t::ts => andThenListAsync t ts - --- ## Array Utilities --- These Follow the pattern of Array iterators established in the Lean core. - -@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) := +/-- Efficient implementation of `mapArrayAsync`. Assumes Arrays are at max `USize`. -/ +@[inline] unsafe def mapArrayAsyncUnsafe (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) := let rec @[specialize] fold (i : USize) (stop : USize) (as : Array α) : m (n β) := if i == stop then async (f as) else - bindAsync (ts.uget i lcProof) fun a => fold (i+1) stop (as.push a) + bindAsync (tasks.uget i lcProof) fun a => fold (i+1) stop (as.push a) if start < stop then - if stop ≤ ts.size then + if stop ≤ tasks.size then fold (USize.ofNat start) (USize.ofNat stop) (Array.mkEmpty (start - stop)) else async (f #[]) else async (f #[]) +/-- Abstract version of `IO.mapTasks` for Arrays. -/ @[implementedBy mapArrayAsyncUnsafe] -def mapArrayAsync (f : Array α → m β) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) := - let fold (stop : Nat) (h : stop ≤ ts.size) := +def mapArrayAsync (f : Array α → m β) (tasks : Array (n α)) (start := 0) (stop := tasks.size) : m (n β) := + let fold (stop : Nat) (h : stop ≤ tasks.size) := let rec loop (i : Nat) (j : Nat) (as : Array α) : m (n β) := if hlt : j < stop then match i with | Nat.zero => async (f as) | Nat.succ i' => - let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩ + let t := tasks.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩ bindAsync t fun a => loop i' (j+1) (as.push a) else async (f as) loop (stop - start) start (Array.mkEmpty (stop - start)) - if h : stop ≤ ts.size then + if h : stop ≤ tasks.size then fold stop h else - fold ts.size (Nat.le_refl _) - -@[inline] unsafe def afterArrayAsyncUnsafe (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) := - let rec @[specialize] fold (i : USize) (stop : USize) : m (n β) := - if i == stop then - task - else - bindAsync (ts.uget i lcProof) fun _ => fold (i+1) stop - if start < stop then - if stop ≤ ts.size then - fold (USize.ofNat start) (USize.ofNat stop) - else - task - else - task - -@[implementedBy afterArrayAsyncUnsafe] -def afterArrayAsync (task : m (n β)) (ts : Array (n α)) (start := 0) (stop := ts.size) : m (n β) := - let fold (stop : Nat) (h : stop ≤ ts.size) := - let rec loop (i : Nat) (j : Nat) : m (n β) := - if hlt : j < stop then - match i with - | Nat.zero => task - | Nat.succ i' => - let t := ts.get ⟨j, Nat.lt_of_lt_of_le hlt h⟩ - bindAsync t fun a => loop i' (j+1) - else - task - loop (stop - start) start - if h : stop ≤ ts.size then - fold stop h - else - fold ts.size (Nat.le_refl _) - -def seqArrayAsync [Pure m] [Pure n] (ts : Array (n PUnit)) : m (n PUnit) := - if h : 0 < ts.size then - afterArrayAsync (ts.get ⟨ts.size - 1, Nat.sub_lt h (by decide)⟩) ts.pop - else - pure (pure ()) + fold tasks.size (Nat.le_refl _) end diff --git a/Lake/Target.lean b/Lake/Target.lean index 4da4c6c803..b70fe63100 100644 --- a/Lake/Target.lean +++ b/Lake/Target.lean @@ -76,18 +76,18 @@ protected def pure [Pure m] (artifact : a) (trace : t) : ActiveTarget t m a := def nil [Pure m] [Inhabited t] : ActiveTarget t m PUnit := ActiveTarget.pure () Inhabited.default -def materialize [Await m n] (self : ActiveTarget t n α) : m PUnit := +def materialize [Await n m] (self : ActiveTarget t n α) : m PUnit := await self.task -def andThen [MonadAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) := +def andThen [SeqRightAsync m n] (target : ActiveTarget t n a) (act : m PUnit) : m (n PUnit) := seqRightAsync target.task act -instance [MonadAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) := +instance [SeqRightAsync m n] : HAndThen (ActiveTarget t n a) (m PUnit) (m (n PUnit)) := ⟨andThen⟩ -def all [Monad m] [Pure n] [MonadAsync m n] [NilTrace t] [MixTrace t] +def all [Monad m] [Pure n] [BindAsync m n] [Async m n] [NilTrace t] [MixTrace t] (targets : List (ActiveTarget t n a)) : m (ActiveTarget t n PUnit) := do - let task ← seqListAsync <| targets.map (·.task) + let task ← joinTaskList <| targets.map (·.task) let trace := mixTraceList <| targets.map (·.trace) return ActiveTarget.mk () trace task @@ -96,7 +96,7 @@ end ActiveTarget -- ## Combinators section -variable [Monad m] [MonadAsync m n] +variable [Monad m] [BindAsync m n] [Async m n] def afterActiveList (targets : List (ActiveTarget t n a)) (act : m PUnit) : m (n PUnit) := afterTaskList (targets.map (·.task)) act @@ -143,18 +143,18 @@ def materialize (self : Target t m a) : m PUnit := self.task section -variable [Monad m] [Pure n] [MonadAsync m n] +variable [Monad m] [Pure n] [BindAsync m n] [Async m n] def materializeListAsync (targets : List (Target t m a)) : m (n PUnit) := do - seqListAsync (← targets.mapM (·.materializeAsync)) + joinTaskList (← targets.mapM (·.materializeAsync)) -def materializeList (targets : List (Target t m a)) : m PUnit := do +def materializeList [Await n m] (targets : List (Target t m a)) : m PUnit := do await <| ← materializeListAsync targets -def materializeArrayAsync (targets : Array (Target t m a)) : m (n PUnit) := do - seqArrayAsync (← targets.mapM (·.materializeAsync)) +def materializeArrayAsync [Await n m] (targets : Array (Target t m a)) : m (n PUnit) := do + joinTaskArray (← targets.mapM (·.materializeAsync)) -def materializeArray (targets : Array (Target t m a)) : m PUnit := do +def materializeArray [Await n m] (targets : Array (Target t m a)) : m PUnit := do await <| ← materializeArrayAsync targets end diff --git a/Lake/Task.lean b/Lake/Task.lean index 562c32ea70..fa273c165a 100644 --- a/Lake/Task.lean +++ b/Lake/Task.lean @@ -26,43 +26,35 @@ namespace IOTask def spawn (act : IO α) (prio := Task.Priority.dedicated) : IO (IOTask α) := IO.asTask act prio +instance : Async IO IOTask := ⟨spawn⟩ + def await (self : IOTask α) : IO α := do IO.ofExcept (← IO.wait self) +instance : Await IOTask IO := ⟨await⟩ + def mapAsync (f : α → IO β) (self : IOTask α) (prio := Task.Priority.dedicated) : IO (IOTask β) := IO.mapTask (fun x => do let x ← IO.ofExcept x; f x) self prio -def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) := - IO.mapTask (fun x => IO.ofExcept x <* act) self prio - -def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) := - IO.mapTask (fun x => IO.ofExcept x *> act) self prio - -instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩ +instance : MapAsync IO IOTask := ⟨mapAsync⟩ def bindAsync (self : IOTask α) (f : α → IO (IOTask β)) (prio := Task.Priority.dedicated) : IO (IOTask β) := IO.bindTask self (fun x => do let x ← IO.ofExcept x; f x) prio -instance : MonadAsync IO IOTask where - async := IOTask.spawn - await := IOTask.await - mapAsync := IOTask.mapAsync - bindAsync := IOTask.bindAsync - seqLeftAsync := IOTask.seqLeftAsync - seqRightAsync := IOTask.seqRightAsync +instance : BindAsync IO IOTask := ⟨bindAsync⟩ + +def seqLeftAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask α) := + IO.mapTask (fun x => IO.ofExcept x <* act) self prio + +instance : SeqLeftAsync IO IOTask := ⟨seqLeftAsync⟩ + +def seqRightAsync (self : IOTask α) (act : IO β) (prio := Task.Priority.dedicated) : IO (IOTask β) := + IO.mapTask (fun x => IO.ofExcept x *> act) self prio + +instance : SeqRightAsync IO IOTask := ⟨seqRightAsync⟩ +instance : HAndThen (IOTask α) (IO β) (IO (IOTask β)) := ⟨seqRightAsync⟩ end IOTask -section -variable [Monad m] [MonadAsync m n] - -def afterTaskList (tasks : List (n α)) (act : m β) : m (n β) := - afterListAsync (async act) tasks - -def afterTaskArray (tasks : Array (n α)) (act : m β) : m (n β) := - afterArrayAsync (async act) tasks - -end - instance : HAndThen (List (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskList⟩ instance : HAndThen (Array (IOTask α)) (IO β) (IO (IOTask β)) := ⟨afterTaskArray⟩