diff --git a/src/Std/Internal/Async/Basic.lean b/src/Std/Internal/Async/Basic.lean index 2ea9fdf26d..a38e09d3c8 100644 --- a/src/Std/Internal/Async/Basic.lean +++ b/src/Std/Internal/Async/Basic.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2024 Lean FRO, LLC. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Henrik Böving +Authors: Henrik Böving, Sofia Rodrigues, Mac Malone -/ prelude import Init.Core @@ -13,21 +13,230 @@ namespace Internal namespace IO namespace Async -/-- -A `Task` that may resolve to a value or an `IO.Error`. +/-! + +# Asynchronous Programming Primitives + +This module provides a layered approach to asynchronous programming, combining monadic types, +type classes, and concrete task types that work together in a cohesive system. + +- **Monadic Types**: These types provide a good way to to chain and manipulate context. These + can contain a `Task`, enabling manipulation of both asynchronous and synchronous code. +- **Concrete Task Types**: Concrete units of work that can be executed within these contexts. + +## Monadic Types + +These types provide a good way to to chain and manipulate context. These can contain a `Task`, +enabling manipulation of both asynchronous and synchronous code. + +- `BaseAsync`: A monadic type for infallible asynchronous computations +- `EAsync`: A monadic type for asynchronous computations that may fail with an error of type + `ε` +- `Async`: A monadic type for IO-based asynchronous computations that may fail with `IO.Error` + (alias for `EAsync IO.Error`) + +## Concurrent Units of Work + +These are the concrete computational units that exist within the monadic contexts. These types +should not be created directly. + +- `Task`: A computation that will resolve to a value of type `α`, +- `ETask`: A task that may fail with an error of type `ε`. +- `AsyncTask`: A task that may fail with an `IO.Error` (alias for `ETask IO.Error`). + +## Relation + +These types are related by two functions in the type classes `MonadAsync` and `MonadAwait`: `async` +and `await`. The `async` function extracts a concrete asynchronous task from a computation within the +monadic context. In effect, it runs the computation in the background and returns a task handle that +can be awaited later. On the other hand, the `await` function takes a task and re-inserts it into the +monadic context, allowing its result to be composed using monadic bind and also pausing to wait for that result. +This relationship between `async` and `await` enables precise control over when a computation begins +and when its result is used. You can spawn multiple asynchronous tasks using `async`, perform other +operations in the meantime, and later rejoin the computation flow by awaiting their results. + +These functions should not be used directly. Instead, prefer higher-level combinators such as +`race`, `raceAll`, `concurrently`, `background` and `concurrentlyAll`. The best way to think about +how to write your async code, it to avoid using the concurrent units of work, and only use it when integrating +with non async code that uses them. + -/ -def AsyncTask (α : Type u) : Type u := Task (Except IO.Error α) + +/-- +Typeclass for monads that can "await" a computation of type `t α` in a monad `m` until the result is +available. +-/ +class MonadAwait (t : Type → Type) (m : Type → Type) extends Monad m where + /-- + Awaits the result of `t α` and returns it inside the `m` monad. + -/ + await : t α → m α + +/-- +Represents monads that can launch computations asynchronously of type `t` in a monad `m`. +-/ +class MonadAsync (t : Type → Type) (m : Type → Type) extends Monad m where + /-- + Starts an asynchronous computation in another monad. + -/ + async (x : m α) (prio := Task.Priority.default) : m (t α) + +/- +These instances have the default_instance attribute so that other default instances +can function correctly within monad transformers. +-/ + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (StateT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (ExceptT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (ReaderT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAwait t m] : MonadAwait t (StateRefT' s n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAwait t m] : MonadAwait t (StateT s m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAsync t m] : MonadAsync t (ReaderT n m) where + async p prio := MonadAsync.async (prio := prio) ∘ p + +@[default_instance] +instance [MonadAsync t m] : MonadAsync t (StateRefT' s n m) where + async p prio := MonadAsync.async (prio := prio) ∘ p + +@[default_instance] +instance [Functor t] [inst : MonadAsync t m] : MonadAsync t (StateT s m) where + async p prio := fun s => do + let t ← inst.async (prio := prio) (p s) + pure (t <&> Prod.fst, s) + +/-- +A `Task` that may resolve to either a value of type `α` or an error value of type `ε`. +-/ +abbrev ETask (ε : Type) (α : Type) : Type := ExceptT ε Task α + +namespace ETask + +/-- +Construct an `ETask` that is already resolved with value `x`. +-/ +@[inline] +protected def pure (x : α) : ETask ε α := + Task.pure <| .ok x + +/-- +Creates a new `ETask` that will run after `x` has finished. If `x`: +- errors, return an `ETask` that resolves to the error. +- succeeds, return an `ETask` that resolves to `f x`. +-/ +@[inline] +protected def map (f : α → β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : ETask ε β := + Task.map (x := x) (f <$> ·) prio sync + +/-- +Creates a new `ETask` that will run after `x` has completed. If `x`: +- errors, return an `ETask` that resolves to the error. +- succeeds, run `f` on the result of `x` and return the `ETask` produced by `f`. +-/ +@[inline] +protected def bind (x : ETask ε α) (f : α → ETask ε β) (prio := Task.Priority.default) (sync := false) : ETask ε β := + Task.bind x (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => Task.pure <| .error e + +/-- +Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned +`ETask` resolves to that error. +-/ +@[inline] +protected def bindEIO (x : ETask ε α) (f : α → EIO ε (ETask ε β)) (prio := Task.Priority.default) (sync := false) : EIO ε (ETask ε β) := + EIO.bindTask x (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + +/-- +Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned +`ETask` resolves to that error. +-/ +@[inline] +protected def mapEIO (f : α → EIO ε β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : BaseIO (ETask ε β) := + EIO.mapTask (t := x) (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + +/-- +Block until the `ETask` in `x` finishes and returns its value. Propagates any error encountered +during execution. +-/ +@[inline] +def block (x : ETask ε α) : EIO ε α := do + match x.get with + | .ok a => return a + | .error e => .error e + +/-- +Create an `ETask` that resolves to the value of the promise `x`. +-/ +@[inline] +def ofPromise (x : IO.Promise (Except ε α)) : ETask ε α := + x.result! + +/-- +Create an `ETask` that resolves to the pure value of the promise `x`. +-/ +@[inline] +def ofPurePromise (x : IO.Promise α) : ETask ε α := + x.result!.map pure (sync := true) + +/-- +Obtain the `IO.TaskState` of `x`. +-/ +@[inline] +def getState (x : ETask ε α) : BaseIO IO.TaskState := + IO.getTaskState x + +instance : Functor (ETask ε) where + map := ETask.map + +instance : Monad (ETask ε) where + pure := ETask.pure + bind := ETask.bind + +end ETask + +/-- +A `Task` that may resolve to a value or an error value of type `IO.Error`. Alias for `ETask IO.Error`. +-/ +abbrev AsyncTask := ETask IO.Error namespace AsyncTask +/-- +Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned +`AsyncTask` resolves to that error. +-/ +@[inline] +protected def mapIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + EIO.mapTask (t := x) (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + /-- Construct an `AsyncTask` that is already resolved with value `x`. -/ @[inline] -protected def pure (x : α) : AsyncTask α := Task.pure <| .ok x - -instance : Pure AsyncTask where - pure := AsyncTask.pure +protected def pure (x : α) : AsyncTask α := + Task.pure <| .ok x /-- Create a new `AsyncTask` that will run after `x` has finished. @@ -36,9 +245,8 @@ If `x`: - succeeds, run `f` on the result of `x` and return the `AsyncTask` produced by `f`. -/ @[inline] -protected def bind (x : AsyncTask α) (f : α → AsyncTask β) : AsyncTask β := - Task.bind x fun r => - match r with +protected def bind (x : AsyncTask α) (f : α → AsyncTask β) (prio := Task.Priority.default) (sync := false) : AsyncTask β := + Task.bind x (prio := prio) (sync := sync) fun | .ok a => f a | .error e => Task.pure <| .error e @@ -49,40 +257,34 @@ If `x`: - succeeds, return an `AsyncTask` that resolves to `f x`. -/ @[inline] -def map (f : α → β) (x : AsyncTask α) : AsyncTask β := - Task.map (x := x) fun r => - match r with - | .ok a => .ok (f a) - | .error e => .error e +def map (f : α → β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : AsyncTask β := + Task.map (x := x) (f <$> ·) prio sync /-- Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned `AsyncTask` resolves to that error. -/ @[inline] -def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) : BaseIO (AsyncTask β) := - IO.bindTask x fun r => - match r with +def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + IO.bindTask x (prio := prio) (sync := sync) fun | .ok a => f a | .error e => .error e /-- -Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned +Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned `AsyncTask` resolves to that error. -/ @[inline] -def mapIO (f : α → IO β) (x : AsyncTask α) : BaseIO (AsyncTask β) := - IO.mapTask (t := x) fun r => - match r with +def mapTaskIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + IO.mapTask (t := x) (prio := prio) (sync := sync) fun | .ok a => f a | .error e => .error e /-- Block until the `AsyncTask` in `x` finishes. -/ -def block (x : AsyncTask α) : IO α := do - let res := x.get - match res with +def block (x : AsyncTask α) : IO α := + match x.get with | .ok a => return a | .error e => .error e @@ -98,7 +300,7 @@ Create an `AsyncTask` that resolves to the value of `x`. -/ @[inline] def ofPurePromise (x : IO.Promise α) : AsyncTask α := - x.result!.map pure + x.result!.map pure (sync := true) /-- Obtain the `IO.TaskState` of `x`. @@ -109,6 +311,491 @@ def getState (x : AsyncTask α) : BaseIO IO.TaskState := end AsyncTask +/-- +A `MaybeTask α` represents a computation that either: + +- Is immediately available as an `α` value, or +- Is an asynchronous computation that will eventually produce an `α` value. +-/ +inductive MaybeTask (α : Type) + | pure : α → MaybeTask α + | ofTask : Task α → MaybeTask α + +namespace MaybeTask + +/-- +Constructs an `Task` from a `MaybeTask`. +-/ +@[inline] +def toTask : MaybeTask α → Task α + | .pure a => .pure a + | .ofTask t => t + +/-- +Gets the value of the `MaybeTask` by blocking. +-/ +@[inline] +def get {α : Type} : MaybeTask α → α + | .pure a => a + | .ofTask t => t.get + +/-- +Maps a function over a `MaybeTask`. +-/ +@[inline] +def map (f : α → β) (prio := Task.Priority.default) (sync := false) : MaybeTask α → MaybeTask β + | .pure a => .pure <| f a + | .ofTask t => .ofTask <| t.map f prio sync + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (t : MaybeTask α) (f : α → MaybeTask β) (prio := Task.Priority.default) (sync := false) : MaybeTask β := + match t with + | .pure a => f a + | .ofTask t => .ofTask <| t.bind (f · |>.toTask) prio sync + +/-- +Join the `MaybeTask` to an `Task`. +-/ +@[inline] +def joinTask (t : Task (MaybeTask α)) : Task α := + t.bind (sync := true) fun + | .pure a => .pure a + | .ofTask t => t + +instance : Functor (MaybeTask) where + map := MaybeTask.map + +instance : Monad (MaybeTask) where + pure := MaybeTask.pure + bind := MaybeTask.bind + +end MaybeTask + +/-- +An asynchronous computation that never fails. +-/ +def BaseAsync (α : Type) := BaseIO (MaybeTask α) + +namespace BaseAsync + +/-- +Converts a `BaseIO` into a `BaseAsync` +-/ +@[inline] +def mk (x : BaseIO (MaybeTask α)) : BaseAsync α := + x + +/-- +Converts a `BaseAsync` into a `BaseIO` +-/ +@[inline] +def toRawBaseIO (x : BaseAsync α) : BaseIO (MaybeTask α) := + x + +/-- +Converts a `BaseAsync` to a `BaseIO Task`. +-/ +@[inline] +protected def toBaseIO (x : BaseAsync α) : BaseIO (Task α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `BaseAsync` out of a `Task`. +-/ +@[inline] +protected def ofTask (x : Task α) : BaseAsync α := + .mk <| pure <| MaybeTask.ofTask x + +/-- +Creates a `BaseAsync` computation that immediately returns the given value. +-/ +@[inline] +protected def pure (a : α) : BaseAsync α := + .mk <| pure <| .pure a + +/-- +Maps the result of a `BaseAsync` computation with a function. +-/ +@[inline] +protected def map (f : α → β) (self : BaseAsync α) (prio := Task.Priority.default) (sync := false) : BaseAsync β := + mk <| (·.map f prio sync) <$> self.toRawBaseIO + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (self : BaseAsync α) (f : α → BaseAsync β) (prio := Task.Priority.default) (sync := false) : BaseAsync β := + mk <| self.toRawBaseIO >>= (bindAsyncTask · f |>.toRawBaseIO) +where + bindAsyncTask (t : MaybeTask α) (f : α → BaseAsync β) : BaseAsync β := .mk <| + match t with + | .pure a => (f a) |>.toRawBaseIO + | .ofTask t => .ofTask <$> BaseIO.bindTask t (fun a => MaybeTask.toTask <$> (f a |>.toRawBaseIO)) prio sync + +/-- +Lifts a `BaseIO` action into a `BaseAsync` computation. +-/ +@[inline] +protected def lift (x : BaseIO α) : BaseAsync α := + .mk <| (.pure ∘ .pure) =<< x + +/-- +Waits for the result of the `BaseAsync` computation, blocking if necessary. +-/ +@[inline] +protected def wait (self : BaseAsync α) : BaseIO α := + pure ∘ Task.get =<< self.toBaseIO + +/-- +Lifts a `BaseAsync` computation into a `Task` that can be awaited and joined. +-/ +@[inline] +protected def asTask (x : BaseAsync α) (prio := Task.Priority.default) : BaseIO (Task α) := do + let res ← BaseIO.asTask (prio := prio) x.toRawBaseIO + return MaybeTask.joinTask res + +/-- +Creates a `BaseAsync` that awaits the completion of the given `Task α`. +-/ +@[inline] +def await (t : Task α) : BaseAsync α := + .mk <| pure <| MaybeTask.ofTask t + +/-- +Returns the `BaseAsync` computation inside a `Task α`, so it can be awaited. +-/ +@[inline] +def async (self : BaseAsync α) (prio := Task.Priority.default) : BaseAsync (Task α) := + BaseAsync.lift <| self.asTask (prio := prio) + +instance : Functor BaseAsync where + map := BaseAsync.map + +instance : Monad BaseAsync where + pure := BaseAsync.pure + bind := BaseAsync.bind + +instance : MonadLift BaseIO BaseAsync where + monadLift := BaseAsync.lift + +instance : MonadAwait Task BaseAsync where + await := BaseAsync.await + +instance : MonadAsync Task BaseAsync where + async t prio := BaseAsync.async t prio + +instance [Inhabited α] : Inhabited (BaseAsync α) where + default := .mk <| pure (MaybeTask.pure default) + +end BaseAsync + +/-- +An asynchronous computation that may produce an error of type `ε`. +-/ +def EAsync (ε : Type) (α : Type) := BaseAsync (Except ε α) + +namespace EAsync + +/-- +Converts a `EAsync` to a `ETask`. +-/ +@[inline] +protected def toBaseIO (x : EAsync ε α) : BaseIO (ETask ε α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `EAsync` out of a `RTask`. +-/ +@[inline] +protected def ofTask (x : ETask ε α) : EAsync ε α := + .mk <| pure <| MaybeTask.ofTask x + +/-- +Converts a `BaseAsync` to a `EIO ETask`. +-/ +@[inline] +protected def toEIO (x : EAsync ε α) : EIO ε (ETask ε α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `EAsync` out of a `ETask`. +-/ +@[inline] +protected def ofETask (x : ETask ε α) : EAsync ε α := + .mk <| BaseAsync.ofTask x + +/-- +Creates an `EAsync` computation that immediately returns the given value. +-/ +@[inline] +protected def pure (a : α) : EAsync ε α := + .mk <| BaseAsync.pure <| .ok a + +/-- +Maps the result of an `EAsync` computation with a pure function. +-/ +@[inline] +protected def map (f : α → β) (self : EAsync ε α) : EAsync ε β := + .mk <| BaseAsync.map (.map f) self + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (self : EAsync ε α) (f : α → EAsync ε β) : EAsync ε β := + .mk <| BaseAsync.bind self fun + | .ok a => f a + | .error e => BaseAsync.pure (.error e) + +/-- +Lifts an `EIO` action into an `EAsync` computation. +-/ +@[inline] +protected def lift (x : EIO ε α) : EAsync ε α := + .mk <| BaseAsync.lift x.toBaseIO + +/-- +Waits for the result of the `EAsync` computation, blocking if necessary. +-/ +@[inline] +protected def wait (self : EAsync ε α) : EIO ε α := do + let result ← self |> BaseAsync.wait + match result with + | .ok a => return a + | .error e => .error e + +/-- +Lifts an `EAsync` computation into an `ETask` that can be awaited and joined. +-/ +@[inline] +protected def asTask (x : EAsync ε α) (prio := Task.Priority.default) : EIO ε (ETask ε α) := + x |> BaseAsync.asTask (prio := prio) + +/-- +Raises an error of type `ε` within the `EAsync` monad. +-/ +@[inline] +protected def throw (e : ε) : EAsync ε α := + .mk <| BaseAsync.pure (.error e) + +/-- +Handles errors in an `EAsync` computation by running a handler if one occurs. +-/ +@[inline] +protected def tryCatch (x : EAsync ε α) (f : ε → EAsync ε α) (prio := Task.Priority.default) (sync := false) : EAsync ε α := + .mk <| BaseAsync.bind (sync := sync) (prio := prio) x fun + | .ok a => BaseAsync.pure (.ok a) + | .error e => (f e) + +/-- +Runs an action, ensuring that some other action always happens afterward. +-/ +protected def tryFinally' + (x : EAsync ε α) (f : Option α → EAsync ε β) + (prio := Task.Priority.default) (sync := false) : + EAsync ε (α × β) := + .mk <| BaseAsync.bind x (prio := prio) (sync := sync) fun + | .ok a => do + match ← (f (some a)) with + | .ok b => BaseAsync.pure (.ok (a, b)) + | .error e => BaseAsync.pure (.error e) + | .error e => do + match ← (f none) with + | .ok _ => BaseAsync.pure (.error e) + | .error e' => BaseAsync.pure (.error e') + +/-- +Creates an `EAsync` computation that awaits the completion of the given `ETask ε α`. +-/ +@[inline] +def await (x : ETask ε α) : EAsync ε α := + .mk (BaseAsync.ofTask x) + +/-- +Returns the `EAsync` computation inside an `ETask ε α`, so it can be awaited. +-/ +@[inline] +def async (self : EAsync ε α) (prio := Task.Priority.default) : EAsync ε (ETask ε α) := + EAsync.lift <| self.asTask prio + +instance : Functor (EAsync ε) where + map := EAsync.map + +instance : Monad (EAsync ε) where + pure := EAsync.pure + bind := EAsync.bind + +instance : MonadLift (EIO ε) (EAsync ε) where + monadLift := EAsync.lift + +instance : MonadExcept ε (EAsync ε) where + throw := EAsync.throw + tryCatch := EAsync.tryCatch + +instance : MonadExceptOf ε (EAsync ε) where + throw := EAsync.throw + tryCatch := EAsync.tryCatch + +instance : MonadFinally (EAsync ε) where + tryFinally' := EAsync.tryFinally' + +instance : OrElse (EAsync ε α) where + orElse := MonadExcept.orElse + +instance [Inhabited ε] : Inhabited (EAsync ε α) where + default := .mk <| BaseAsync.pure default + +instance : MonadAwait (ETask ε) (EAsync ε) where + await t := .mk <| BaseAsync.ofTask t + +instance : MonadAwait Task (EAsync ε) where + await t := .mk <| BaseAsync.ofTask (t.map (.ok)) + +instance : MonadAwait AsyncTask (EAsync IO.Error) where + await t := .mk <| BaseAsync.ofTask t + +instance : MonadAwait IO.Promise (EAsync ε) where + await t := .mk <| BaseAsync.ofTask (t.result!.map (.ok)) + +instance : MonadAsync (ETask ε) (EAsync ε) where + async t prio := EAsync.lift <| t.asTask (prio := prio) + +instance : MonadAsync AsyncTask (EAsync IO.Error) where + async t prio := EAsync.lift <| t.asTask (prio := prio) + +instance : MonadLift BaseIO (EAsync ε) where + monadLift x := .mk <| (pure ∘ .ok) <$> x + +instance : MonadLift (EIO ε) (EAsync ε) where + monadLift x := .mk <| pure <$> x.toBaseIO + +instance : MonadLift BaseAsync (EAsync ε) where + monadLift x := .mk <| x.map (.ok) + +@[inline] +protected partial def forIn + {β : Type} [i : Inhabited ε] (init : β) + (f : Unit → β → EAsync ε (ForInStep β)) + (prio := Task.Priority.default) : + EAsync ε β := do + let promise ← IO.Promise.new + + let rec @[specialize] loop (b : β) : EAsync ε (ETask ε Unit) := async (prio := prio) do + match ← f () b with + | ForInStep.done b => promise.resolve (.ok b) + | ForInStep.yield b => discard <| (loop b) + + discard <| loop init + + .mk <| BaseAsync.ofTask promise.result! + +instance [Inhabited ε] : ForIn (EAsync ε) Lean.Loop Unit where + forIn _ := EAsync.forIn + +end EAsync + +/-- +An asynchronous computation that may produce an error of type `IO.Error`.. +-/ +abbrev Async (α : Type) := EAsync IO.Error α + +namespace Async + +/-- +Converts a `Async` to a `AsyncTask`. +-/ +@[inline] +protected def toIO (x : Async α) : IO (AsyncTask α) := + MaybeTask.toTask <$> x.toRawBaseIO + +@[default_instance] +instance : MonadAsync AsyncTask Async := + inferInstanceAs (MonadAsync (ETask IO.Error) (EAsync IO.Error)) + +instance : MonadAwait AsyncTask Async := + inferInstanceAs (MonadAwait AsyncTask (EAsync IO.Error)) + +instance : MonadAwait IO.Promise Async := + inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error)) + +end Async + +export MonadAsync (async) +export MonadAwait (await) + +/-- +This function transforms the operation inside the monad `m` into a task and let it run in the background. +-/ +@[inline, specialize] +def background [Monad m] [MonadAsync t m] (prio := Task.Priority.default) : m α → m Unit := + discard ∘ async (t := t) (prio := prio) + +/-- +Runs two computations concurrently and returns both results as a pair. +-/ +@[inline, specialize] +def concurrently + [Monad m] [MonadAwait t m] [MonadAsync t m] + (x : m α) (y : m β) + (prio := Task.Priority.default) : + m (α × β) := do + let taskX : t α ← async x (prio := prio) + let taskY : t β ← async y (prio := prio) + let resultX ← await taskX + let resultY ← await taskY + return (resultX, resultY) + +/-- +Runs two computations concurrently and returns the result of the one that finishes first. +The other result is lost and the other task is not cancelled, so the task will continue the execution +until the end. +-/ +@[inline, specialize] +def race + [MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m] + [Monad m] [Inhabited α] (x : m α) (y : m α) + (prio := Task.Priority.default) : + m α := do + let promise ← IO.Promise.new + + discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) + discard (async (t := t) (prio := prio) <| Bind.bind y (liftM ∘ promise.resolve)) + + await promise.result! + +/-- +Runs all computations in an `Array` concurrently and returns all results as an array. +-/ +@[inline, specialize] +def concurrentlyAll + [Monad m] [MonadAwait t m] [MonadAsync t m] (xs : Array (m α)) + (prio := Task.Priority.default) : + m (Array α) := do + let tasks : Array (t α) ← xs.mapM (async (prio := prio)) + tasks.mapM await + +/-- +Runs all computations concurrently and returns the result of the first one to finish. +All other results are lost, and the tasks are not cancelled, so they'll continue their executing +until the end. +-/ +@[inline, specialize] +def raceAll + [ForM m c (m α)] [MonadLiftT BaseIO m] [MonadAwait Task m] + [MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α] + (xs : c) + (prio := Task.Priority.default) : + m α := do + let promise ← IO.Promise.new + + ForM.forM xs fun x => + discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) + + await promise.result! + end Async end IO end Internal diff --git a/tests/lean/run/async.lean b/tests/lean/run/async.lean new file mode 100644 index 0000000000..331a866286 --- /dev/null +++ b/tests/lean/run/async.lean @@ -0,0 +1,31 @@ +import Std.Internal.Async +import Std.Internal.UV +import Std.Net.Addr + +open Std.Internal.IO.Async.UDP +open Std.Internal.IO.Async +open Std.Net + +def t : IO (Async Nat) := do + IO.println "a" + return do + IO.println "b" + return 2 + +-- Usage example of the ForIn instance + +def loop : Async Nat := do + let mut res := 0 + + while res < 10 do + res := res + 1 + + discard <| t + + pure res + +/-- +info: 10 +-/ +#guard_msgs in +#eval IO.println =<< ETask.block =<< loop.asTask diff --git a/tests/lean/run/async_base_functions.lean b/tests/lean/run/async_base_functions.lean new file mode 100644 index 0000000000..f3f6564000 --- /dev/null +++ b/tests/lean/run/async_base_functions.lean @@ -0,0 +1,54 @@ +import Std.Internal.Async +import Std.Sync.Mutex + +open Std + +open Std.Internal.IO.Async + +def wait (ms : Nat) (ref : Std.Mutex Nat) (val : Nat) : Async Unit := do + ref.atomically (·.modify (· * val)) + IO.sleep ms + ref.atomically (·.modify (· + val)) + +-- Tests + +def sequential : Async Unit := do + let ref ← Std.Mutex.new 0 + wait 200 ref 1 + wait 400 ref 2 + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 40 + +#eval do (← sequential.toEIO).block + +def conc : Async Unit := do + let ref ← Std.Mutex.new 0 + discard <| concurrently (wait 200 ref 1) (wait 400 ref 2) + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 30 + +#eval do (← conc.toEIO).block + +def racer : Async Unit := do + let ref ← Std.Mutex.new 0 + race (wait 200 ref 1) (wait 400 ref 2) + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 10 + +#eval do (← racer.toEIO).block + +def concAll : Async Unit := do + let ref ← Std.Mutex.new 0 + discard <| concurrentlyAll #[(wait 200 ref 1), (wait 400 ref 2)] + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 30 + +#eval do (← concAll.toEIO).block + +def racerAll : Async Unit := do + let ref ← Std.Mutex.new 0 + raceAll #[(wait 200 ref 1), (wait 400 ref 2)] + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 10 + +#eval do (← racerAll.toEIO).block diff --git a/tests/lean/run/async_tcp_fname_errors.lean b/tests/lean/run/async_tcp_fname_errors.lean index 9b345052a8..01c75b4f46 100644 --- a/tests/lean/run/async_tcp_fname_errors.lean +++ b/tests/lean/run/async_tcp_fname_errors.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do @@ -10,43 +10,20 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" - --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - --------------------------------------------------------------- - /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let message ← await (client.recv? 1024) + let message ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< message) none /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let message ← await (client.recv? 1024) + let message ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< message) none /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - discard <| await server.accept - discard <| await server.accept + discard <| await (← server.accept) + discard <| await (← server.accept) def clientServer : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8083 @@ -73,16 +50,14 @@ def clientServer : IO Unit := do mike.noDelay - let serverTask ← (runRobert server).run + let serverTask ← (runRobert server).toIO - let joeTask ← (runJoe joe).run - let mikeTask ← (runMike mike).run + let joeTask ← (runJoe joe).toIO + let mikeTask ← (runMike mike).toIO serverTask.block joeTask.block mikeTask.block -end Async - -#eval Async.clientServer +#eval clientServer diff --git a/tests/lean/run/async_tcp_half.lean b/tests/lean/run/async_tcp_half.lean index ee91e825f3..c2fdf655e1 100644 --- a/tests/lean/run/async_tcp_half.lean +++ b/tests/lean/run/async_tcp_half.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net -- Using this function to create IO Error. For some reason the assert! is not pausing the execution. @@ -11,33 +11,13 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - /-- Joe is another client. -/ def runJoe (addr: SocketAddress) : Async Unit := do let client ← TCP.Socket.Client.mk - await (client.connect addr) - await (client.send (String.toUTF8 "hello robert!")) - await client.shutdown + await (← client.connect addr) + await (← client.send (String.toUTF8 "hello robert!")) + await (← client.shutdown) def listenClose : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080 @@ -53,7 +33,7 @@ def acceptClose : IO Unit := do server.bind addr server.listen 128 - let joeTask ← (runJoe addr).run + let joeTask ← (runJoe addr).toIO let task ← server.accept let client ← task.block @@ -71,5 +51,4 @@ def acceptClose : IO Unit := do -- Waiting to avoid errors from escaping. joeTask.block -#eval acceptClose #eval listenClose diff --git a/tests/lean/run/async_tcp_server_client.lean b/tests/lean/run/async_tcp_server_client.lean index c80abfbac3..48ef4b7149 100644 --- a/tests/lean/run/async_tcp_server_client.lean +++ b/tests/lean/run/async_tcp_server_client.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net -- Using this function to create IO Error. For some reason the assert! is not pausing the execution. @@ -11,53 +11,33 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - -------------------------------------------------------------- /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (client.recv? 1024) + let mes ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hi mike!! :)") - await (client.send (String.toUTF8 "hello robert!!")) - await (client.shutdown) + await (← client.send (String.toUTF8 "hello robert!!")) + await (← client.shutdown) /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (client.recv? 1024) + let mes ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hi joe! :)") - await (client.send (String.toUTF8 "hello robert!")) - await client.shutdown + await (← client.send (String.toUTF8 "hello robert!")) + await (← client.shutdown) /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - let joe ← await server.accept - let mike ← await server.accept + let joe ← await (← server.accept) + let mike ← await (← server.accept) - await (joe.send (String.toUTF8 "hi joe! :)")) - let mes ← await (joe.recv? 1024) + await (← joe.send (String.toUTF8 "hi joe! :)")) + let mes ← await (← joe.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hello robert!") - await (mike.send (String.toUTF8 "hi mike!! :)")) - let mes ← await (mike.recv? 1024) + await (← mike.send (String.toUTF8 "hi mike!! :)")) + let mes ← await (← mike.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hello robert!!") pure () @@ -69,7 +49,7 @@ def clientServer (addr : SocketAddress) : IO Unit := do let serverTask := runRobert server - let serverTask ← serverTask.run + let serverTask ← serverTask.toIO assertBEq (← server.getSockName).port addr.port @@ -92,8 +72,8 @@ def clientServer (addr : SocketAddress) : IO Unit := do let joeTask := runJoe joe let mikeTask := runMike mike - let joeTask ← joeTask.run - let mikeTask ← mikeTask.run + let joeTask ← joeTask.toIO + let mikeTask ← mikeTask.toIO serverTask.block joeTask.block diff --git a/tests/lean/run/async_udp_sockets.lean b/tests/lean/run/async_udp_sockets.lean index 0582f88feb..a7bd64b5a1 100644 --- a/tests/lean/run/async_udp_sockets.lean +++ b/tests/lean/run/async_udp_sockets.lean @@ -11,26 +11,6 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - /-- Joe is another client. -/ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Unit := do let client ← UDP.Socket.mk @@ -38,7 +18,7 @@ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Uni client.bind (addr second) client.connect (addr first) - await (client.send (String.toUTF8 "hello robert!")) + await (← client.send (String.toUTF8 "hello robert!")) def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO Unit := do @@ -46,7 +26,7 @@ def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO U let server ← UDP.Socket.mk server.bind (addr first) - let res ← (runJoe addr first second).run + let res ← (runJoe addr first second).toIO res.block let res ← server.recv 1024