From b15cfadde8a5263319c88d1262955224a074f5f3 Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Wed, 25 Jun 2025 23:51:26 -0300 Subject: [PATCH] feat: monadic interface for asynchronous operations in Std (#8003) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds a new monadic interface for `Async` operations. This is the design for the `Async` monad that I liked the most. The idea was refined with the help of @tydeu. Before that, I had some prerequisites in mind: 1. Good performance 2. Explicit `yield` points, so we could avoid using `bindTask` for every lifted IO operation 3. A way to avoid creating an infinite chain of `Task`s during recursion The 2 and 3 points are not covered in this PR, I wish I had a good solution but right now only a few sketches of this. ### Explicit `yield` points I thought this would be easy at first, but it actually turned out kinda tricky. I ended up creating the `suspend` syntax, which is just a small modification of the lift method (`<- ...`) syntax. It desugars to `Suspend.suspend task fun _ => ...`. So something like: ```lean do IO.println "a" IO.println "b" let result := suspend (client.recv? 1024) IO.println "c" IO.println "d" ``` Would become: ```lean Bind.bind (IO.println "a") fun _ => Bind.bind (IO.println "b") fun _ => Suspend.suspend (client.recv? 1024) fun message => Bind.bind (IO.println "c") fun _ => IO.println "d" ``` This makes things a bit more efficient. When using `bind`, we would try to avoid creating a `Task` chain, and the `suspend` would be the only place we use `Task.bind`. But there's a problem if we use `bind` with something that needs `suspend`, it’ll block the whole task. Blocking is the only way to prevent task accumulation when using plain `bind` inside a structure like that: ``` inductive AsyncResult (ε σ α : Type u) where | ok : α → σ → AsyncResult ε σ α | error : ε → σ → AsyncResult ε σ α | ofTask : Task (EStateM.Result ε σ α) → σ →AsyncResult ε σ α ``` Because we simply need to remove the `ofTask` and transform it into an `ok`. ### Infinite chain of Tasks If you create an infinite recursive function using `Task` (which is super common in servers like HTTP ones), it can lead to a lot of memory usage. Because those tasks get chained forever and won't be freed until the function returns. To get around that, I used CPS and instead of just calling `Task.bind`, I’d spawn a new task and return an "empty" one like: ```lean fun k => Task.bind (...) fun value => do k value; pure emptyTask ``` This works great with a CPS-style monad, but it generates a huge IR by itself. Just doing CPS alone was too much, though, because every lifted operation created a new continuation and a `Task.bind`. So, I used it with `suspend` and got a better performance, but the usage is not good with `suspend`. ### The current monad Right now, the monad I’m using is super simple. It doesn't solve the earlier problems, but the API is clean, and the generated IR is small enough. An example of how we should use it is: ```lean -- A loop that repeatedly sends a message and waits for a reply. partial def writeLoop (client : Socket.Client) (message : String) : Async (AsyncTask Unit) := async do IO.println s!"sending: {message}" await (← client.send (String.toUTF8 message)) if let some mes ← await (← client.recv? 1024) then IO.println s!"received: {String.fromUTF8! mes}" -- use parallel to avoid building up an infinite task chain parallel (writeLoop client message) else IO.println "client disconnected from receiving" -- Server’s main accept loop, keeps accepting and echoing for new clients. partial def acceptLoop (server : Socket.Server) (promise : IO.Promise Unit) : Async (AsyncTask Unit) := async do let client ← await (← server.accept) await (← client.send (String.toUTF8 "tutturu ")) -- allow multiple clients to connect at the same time parallel (writeLoop client "hi!!") -- and keep accepting more clients, parallel again to avoid building up an infinite task chain parallel (acceptLoop server promise) -- A simple client that connects and sends a message. def echoClient (addr : SocketAddress) (message : String) : Async (AsyncTask Unit) := async do let socket ← Client.mk await (← socket.connect addr) parallel (writeLoop socket message) -- TCP setup: bind, listen, serve, and run a sample client. partial def mainTCP : Async Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080 let server ← Server.mk server.bind addr server.listen 128 -- promise exists since the server is (probably) never going to stop let promise ← IO.Promise.new let acceptAction ← acceptLoop server promise await (← echoClient addr "hi!") await acceptAction await promise -- Entry point def main : IO Unit := mainTCP.wait ``` --------- Co-authored-by: Henrik Böving Co-authored-by: Mac Malone --- src/Std/Internal/Async/Basic.lean | 741 +++++++++++++++++++- tests/lean/run/async.lean | 31 + tests/lean/run/async_base_functions.lean | 54 ++ tests/lean/run/async_tcp_fname_errors.lean | 43 +- tests/lean/run/async_tcp_half.lean | 31 +- tests/lean/run/async_tcp_server_client.lean | 52 +- tests/lean/run/async_udp_sockets.lean | 24 +- 7 files changed, 831 insertions(+), 145 deletions(-) create mode 100644 tests/lean/run/async.lean create mode 100644 tests/lean/run/async_base_functions.lean diff --git a/src/Std/Internal/Async/Basic.lean b/src/Std/Internal/Async/Basic.lean index 2ea9fdf26d..a38e09d3c8 100644 --- a/src/Std/Internal/Async/Basic.lean +++ b/src/Std/Internal/Async/Basic.lean @@ -1,7 +1,7 @@ /- Copyright (c) 2024 Lean FRO, LLC. All rights reserved. Released under Apache 2.0 license as described in the file LICENSE. -Authors: Henrik Böving +Authors: Henrik Böving, Sofia Rodrigues, Mac Malone -/ prelude import Init.Core @@ -13,21 +13,230 @@ namespace Internal namespace IO namespace Async -/-- -A `Task` that may resolve to a value or an `IO.Error`. +/-! + +# Asynchronous Programming Primitives + +This module provides a layered approach to asynchronous programming, combining monadic types, +type classes, and concrete task types that work together in a cohesive system. + +- **Monadic Types**: These types provide a good way to to chain and manipulate context. These + can contain a `Task`, enabling manipulation of both asynchronous and synchronous code. +- **Concrete Task Types**: Concrete units of work that can be executed within these contexts. + +## Monadic Types + +These types provide a good way to to chain and manipulate context. These can contain a `Task`, +enabling manipulation of both asynchronous and synchronous code. + +- `BaseAsync`: A monadic type for infallible asynchronous computations +- `EAsync`: A monadic type for asynchronous computations that may fail with an error of type + `ε` +- `Async`: A monadic type for IO-based asynchronous computations that may fail with `IO.Error` + (alias for `EAsync IO.Error`) + +## Concurrent Units of Work + +These are the concrete computational units that exist within the monadic contexts. These types +should not be created directly. + +- `Task`: A computation that will resolve to a value of type `α`, +- `ETask`: A task that may fail with an error of type `ε`. +- `AsyncTask`: A task that may fail with an `IO.Error` (alias for `ETask IO.Error`). + +## Relation + +These types are related by two functions in the type classes `MonadAsync` and `MonadAwait`: `async` +and `await`. The `async` function extracts a concrete asynchronous task from a computation within the +monadic context. In effect, it runs the computation in the background and returns a task handle that +can be awaited later. On the other hand, the `await` function takes a task and re-inserts it into the +monadic context, allowing its result to be composed using monadic bind and also pausing to wait for that result. +This relationship between `async` and `await` enables precise control over when a computation begins +and when its result is used. You can spawn multiple asynchronous tasks using `async`, perform other +operations in the meantime, and later rejoin the computation flow by awaiting their results. + +These functions should not be used directly. Instead, prefer higher-level combinators such as +`race`, `raceAll`, `concurrently`, `background` and `concurrentlyAll`. The best way to think about +how to write your async code, it to avoid using the concurrent units of work, and only use it when integrating +with non async code that uses them. + -/ -def AsyncTask (α : Type u) : Type u := Task (Except IO.Error α) + +/-- +Typeclass for monads that can "await" a computation of type `t α` in a monad `m` until the result is +available. +-/ +class MonadAwait (t : Type → Type) (m : Type → Type) extends Monad m where + /-- + Awaits the result of `t α` and returns it inside the `m` monad. + -/ + await : t α → m α + +/-- +Represents monads that can launch computations asynchronously of type `t` in a monad `m`. +-/ +class MonadAsync (t : Type → Type) (m : Type → Type) extends Monad m where + /-- + Starts an asynchronous computation in another monad. + -/ + async (x : m α) (prio := Task.Priority.default) : m (t α) + +/- +These instances have the default_instance attribute so that other default instances +can function correctly within monad transformers. +-/ + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (StateT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (ExceptT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [Monad m] [MonadAwait t m] : MonadAwait t (ReaderT n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAwait t m] : MonadAwait t (StateRefT' s n m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAwait t m] : MonadAwait t (StateT s m) where + await := liftM (m := m) ∘ MonadAwait.await + +@[default_instance] +instance [MonadAsync t m] : MonadAsync t (ReaderT n m) where + async p prio := MonadAsync.async (prio := prio) ∘ p + +@[default_instance] +instance [MonadAsync t m] : MonadAsync t (StateRefT' s n m) where + async p prio := MonadAsync.async (prio := prio) ∘ p + +@[default_instance] +instance [Functor t] [inst : MonadAsync t m] : MonadAsync t (StateT s m) where + async p prio := fun s => do + let t ← inst.async (prio := prio) (p s) + pure (t <&> Prod.fst, s) + +/-- +A `Task` that may resolve to either a value of type `α` or an error value of type `ε`. +-/ +abbrev ETask (ε : Type) (α : Type) : Type := ExceptT ε Task α + +namespace ETask + +/-- +Construct an `ETask` that is already resolved with value `x`. +-/ +@[inline] +protected def pure (x : α) : ETask ε α := + Task.pure <| .ok x + +/-- +Creates a new `ETask` that will run after `x` has finished. If `x`: +- errors, return an `ETask` that resolves to the error. +- succeeds, return an `ETask` that resolves to `f x`. +-/ +@[inline] +protected def map (f : α → β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : ETask ε β := + Task.map (x := x) (f <$> ·) prio sync + +/-- +Creates a new `ETask` that will run after `x` has completed. If `x`: +- errors, return an `ETask` that resolves to the error. +- succeeds, run `f` on the result of `x` and return the `ETask` produced by `f`. +-/ +@[inline] +protected def bind (x : ETask ε α) (f : α → ETask ε β) (prio := Task.Priority.default) (sync := false) : ETask ε β := + Task.bind x (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => Task.pure <| .error e + +/-- +Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned +`ETask` resolves to that error. +-/ +@[inline] +protected def bindEIO (x : ETask ε α) (f : α → EIO ε (ETask ε β)) (prio := Task.Priority.default) (sync := false) : EIO ε (ETask ε β) := + EIO.bindTask x (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + +/-- +Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned +`ETask` resolves to that error. +-/ +@[inline] +protected def mapEIO (f : α → EIO ε β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : BaseIO (ETask ε β) := + EIO.mapTask (t := x) (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + +/-- +Block until the `ETask` in `x` finishes and returns its value. Propagates any error encountered +during execution. +-/ +@[inline] +def block (x : ETask ε α) : EIO ε α := do + match x.get with + | .ok a => return a + | .error e => .error e + +/-- +Create an `ETask` that resolves to the value of the promise `x`. +-/ +@[inline] +def ofPromise (x : IO.Promise (Except ε α)) : ETask ε α := + x.result! + +/-- +Create an `ETask` that resolves to the pure value of the promise `x`. +-/ +@[inline] +def ofPurePromise (x : IO.Promise α) : ETask ε α := + x.result!.map pure (sync := true) + +/-- +Obtain the `IO.TaskState` of `x`. +-/ +@[inline] +def getState (x : ETask ε α) : BaseIO IO.TaskState := + IO.getTaskState x + +instance : Functor (ETask ε) where + map := ETask.map + +instance : Monad (ETask ε) where + pure := ETask.pure + bind := ETask.bind + +end ETask + +/-- +A `Task` that may resolve to a value or an error value of type `IO.Error`. Alias for `ETask IO.Error`. +-/ +abbrev AsyncTask := ETask IO.Error namespace AsyncTask +/-- +Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned +`AsyncTask` resolves to that error. +-/ +@[inline] +protected def mapIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + EIO.mapTask (t := x) (prio := prio) (sync := sync) fun + | .ok a => f a + | .error e => .error e + /-- Construct an `AsyncTask` that is already resolved with value `x`. -/ @[inline] -protected def pure (x : α) : AsyncTask α := Task.pure <| .ok x - -instance : Pure AsyncTask where - pure := AsyncTask.pure +protected def pure (x : α) : AsyncTask α := + Task.pure <| .ok x /-- Create a new `AsyncTask` that will run after `x` has finished. @@ -36,9 +245,8 @@ If `x`: - succeeds, run `f` on the result of `x` and return the `AsyncTask` produced by `f`. -/ @[inline] -protected def bind (x : AsyncTask α) (f : α → AsyncTask β) : AsyncTask β := - Task.bind x fun r => - match r with +protected def bind (x : AsyncTask α) (f : α → AsyncTask β) (prio := Task.Priority.default) (sync := false) : AsyncTask β := + Task.bind x (prio := prio) (sync := sync) fun | .ok a => f a | .error e => Task.pure <| .error e @@ -49,40 +257,34 @@ If `x`: - succeeds, return an `AsyncTask` that resolves to `f x`. -/ @[inline] -def map (f : α → β) (x : AsyncTask α) : AsyncTask β := - Task.map (x := x) fun r => - match r with - | .ok a => .ok (f a) - | .error e => .error e +def map (f : α → β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : AsyncTask β := + Task.map (x := x) (f <$> ·) prio sync /-- Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned `AsyncTask` resolves to that error. -/ @[inline] -def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) : BaseIO (AsyncTask β) := - IO.bindTask x fun r => - match r with +def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + IO.bindTask x (prio := prio) (sync := sync) fun | .ok a => f a | .error e => .error e /-- -Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned +Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned `AsyncTask` resolves to that error. -/ @[inline] -def mapIO (f : α → IO β) (x : AsyncTask α) : BaseIO (AsyncTask β) := - IO.mapTask (t := x) fun r => - match r with +def mapTaskIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) := + IO.mapTask (t := x) (prio := prio) (sync := sync) fun | .ok a => f a | .error e => .error e /-- Block until the `AsyncTask` in `x` finishes. -/ -def block (x : AsyncTask α) : IO α := do - let res := x.get - match res with +def block (x : AsyncTask α) : IO α := + match x.get with | .ok a => return a | .error e => .error e @@ -98,7 +300,7 @@ Create an `AsyncTask` that resolves to the value of `x`. -/ @[inline] def ofPurePromise (x : IO.Promise α) : AsyncTask α := - x.result!.map pure + x.result!.map pure (sync := true) /-- Obtain the `IO.TaskState` of `x`. @@ -109,6 +311,491 @@ def getState (x : AsyncTask α) : BaseIO IO.TaskState := end AsyncTask +/-- +A `MaybeTask α` represents a computation that either: + +- Is immediately available as an `α` value, or +- Is an asynchronous computation that will eventually produce an `α` value. +-/ +inductive MaybeTask (α : Type) + | pure : α → MaybeTask α + | ofTask : Task α → MaybeTask α + +namespace MaybeTask + +/-- +Constructs an `Task` from a `MaybeTask`. +-/ +@[inline] +def toTask : MaybeTask α → Task α + | .pure a => .pure a + | .ofTask t => t + +/-- +Gets the value of the `MaybeTask` by blocking. +-/ +@[inline] +def get {α : Type} : MaybeTask α → α + | .pure a => a + | .ofTask t => t.get + +/-- +Maps a function over a `MaybeTask`. +-/ +@[inline] +def map (f : α → β) (prio := Task.Priority.default) (sync := false) : MaybeTask α → MaybeTask β + | .pure a => .pure <| f a + | .ofTask t => .ofTask <| t.map f prio sync + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (t : MaybeTask α) (f : α → MaybeTask β) (prio := Task.Priority.default) (sync := false) : MaybeTask β := + match t with + | .pure a => f a + | .ofTask t => .ofTask <| t.bind (f · |>.toTask) prio sync + +/-- +Join the `MaybeTask` to an `Task`. +-/ +@[inline] +def joinTask (t : Task (MaybeTask α)) : Task α := + t.bind (sync := true) fun + | .pure a => .pure a + | .ofTask t => t + +instance : Functor (MaybeTask) where + map := MaybeTask.map + +instance : Monad (MaybeTask) where + pure := MaybeTask.pure + bind := MaybeTask.bind + +end MaybeTask + +/-- +An asynchronous computation that never fails. +-/ +def BaseAsync (α : Type) := BaseIO (MaybeTask α) + +namespace BaseAsync + +/-- +Converts a `BaseIO` into a `BaseAsync` +-/ +@[inline] +def mk (x : BaseIO (MaybeTask α)) : BaseAsync α := + x + +/-- +Converts a `BaseAsync` into a `BaseIO` +-/ +@[inline] +def toRawBaseIO (x : BaseAsync α) : BaseIO (MaybeTask α) := + x + +/-- +Converts a `BaseAsync` to a `BaseIO Task`. +-/ +@[inline] +protected def toBaseIO (x : BaseAsync α) : BaseIO (Task α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `BaseAsync` out of a `Task`. +-/ +@[inline] +protected def ofTask (x : Task α) : BaseAsync α := + .mk <| pure <| MaybeTask.ofTask x + +/-- +Creates a `BaseAsync` computation that immediately returns the given value. +-/ +@[inline] +protected def pure (a : α) : BaseAsync α := + .mk <| pure <| .pure a + +/-- +Maps the result of a `BaseAsync` computation with a function. +-/ +@[inline] +protected def map (f : α → β) (self : BaseAsync α) (prio := Task.Priority.default) (sync := false) : BaseAsync β := + mk <| (·.map f prio sync) <$> self.toRawBaseIO + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (self : BaseAsync α) (f : α → BaseAsync β) (prio := Task.Priority.default) (sync := false) : BaseAsync β := + mk <| self.toRawBaseIO >>= (bindAsyncTask · f |>.toRawBaseIO) +where + bindAsyncTask (t : MaybeTask α) (f : α → BaseAsync β) : BaseAsync β := .mk <| + match t with + | .pure a => (f a) |>.toRawBaseIO + | .ofTask t => .ofTask <$> BaseIO.bindTask t (fun a => MaybeTask.toTask <$> (f a |>.toRawBaseIO)) prio sync + +/-- +Lifts a `BaseIO` action into a `BaseAsync` computation. +-/ +@[inline] +protected def lift (x : BaseIO α) : BaseAsync α := + .mk <| (.pure ∘ .pure) =<< x + +/-- +Waits for the result of the `BaseAsync` computation, blocking if necessary. +-/ +@[inline] +protected def wait (self : BaseAsync α) : BaseIO α := + pure ∘ Task.get =<< self.toBaseIO + +/-- +Lifts a `BaseAsync` computation into a `Task` that can be awaited and joined. +-/ +@[inline] +protected def asTask (x : BaseAsync α) (prio := Task.Priority.default) : BaseIO (Task α) := do + let res ← BaseIO.asTask (prio := prio) x.toRawBaseIO + return MaybeTask.joinTask res + +/-- +Creates a `BaseAsync` that awaits the completion of the given `Task α`. +-/ +@[inline] +def await (t : Task α) : BaseAsync α := + .mk <| pure <| MaybeTask.ofTask t + +/-- +Returns the `BaseAsync` computation inside a `Task α`, so it can be awaited. +-/ +@[inline] +def async (self : BaseAsync α) (prio := Task.Priority.default) : BaseAsync (Task α) := + BaseAsync.lift <| self.asTask (prio := prio) + +instance : Functor BaseAsync where + map := BaseAsync.map + +instance : Monad BaseAsync where + pure := BaseAsync.pure + bind := BaseAsync.bind + +instance : MonadLift BaseIO BaseAsync where + monadLift := BaseAsync.lift + +instance : MonadAwait Task BaseAsync where + await := BaseAsync.await + +instance : MonadAsync Task BaseAsync where + async t prio := BaseAsync.async t prio + +instance [Inhabited α] : Inhabited (BaseAsync α) where + default := .mk <| pure (MaybeTask.pure default) + +end BaseAsync + +/-- +An asynchronous computation that may produce an error of type `ε`. +-/ +def EAsync (ε : Type) (α : Type) := BaseAsync (Except ε α) + +namespace EAsync + +/-- +Converts a `EAsync` to a `ETask`. +-/ +@[inline] +protected def toBaseIO (x : EAsync ε α) : BaseIO (ETask ε α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `EAsync` out of a `RTask`. +-/ +@[inline] +protected def ofTask (x : ETask ε α) : EAsync ε α := + .mk <| pure <| MaybeTask.ofTask x + +/-- +Converts a `BaseAsync` to a `EIO ETask`. +-/ +@[inline] +protected def toEIO (x : EAsync ε α) : EIO ε (ETask ε α) := + MaybeTask.toTask <$> x.toRawBaseIO + +/-- +Creates a new `EAsync` out of a `ETask`. +-/ +@[inline] +protected def ofETask (x : ETask ε α) : EAsync ε α := + .mk <| BaseAsync.ofTask x + +/-- +Creates an `EAsync` computation that immediately returns the given value. +-/ +@[inline] +protected def pure (a : α) : EAsync ε α := + .mk <| BaseAsync.pure <| .ok a + +/-- +Maps the result of an `EAsync` computation with a pure function. +-/ +@[inline] +protected def map (f : α → β) (self : EAsync ε α) : EAsync ε β := + .mk <| BaseAsync.map (.map f) self + +/-- +Sequences two computations, allowing the second to depend on the value computed by the first. +-/ +@[inline] +protected def bind (self : EAsync ε α) (f : α → EAsync ε β) : EAsync ε β := + .mk <| BaseAsync.bind self fun + | .ok a => f a + | .error e => BaseAsync.pure (.error e) + +/-- +Lifts an `EIO` action into an `EAsync` computation. +-/ +@[inline] +protected def lift (x : EIO ε α) : EAsync ε α := + .mk <| BaseAsync.lift x.toBaseIO + +/-- +Waits for the result of the `EAsync` computation, blocking if necessary. +-/ +@[inline] +protected def wait (self : EAsync ε α) : EIO ε α := do + let result ← self |> BaseAsync.wait + match result with + | .ok a => return a + | .error e => .error e + +/-- +Lifts an `EAsync` computation into an `ETask` that can be awaited and joined. +-/ +@[inline] +protected def asTask (x : EAsync ε α) (prio := Task.Priority.default) : EIO ε (ETask ε α) := + x |> BaseAsync.asTask (prio := prio) + +/-- +Raises an error of type `ε` within the `EAsync` monad. +-/ +@[inline] +protected def throw (e : ε) : EAsync ε α := + .mk <| BaseAsync.pure (.error e) + +/-- +Handles errors in an `EAsync` computation by running a handler if one occurs. +-/ +@[inline] +protected def tryCatch (x : EAsync ε α) (f : ε → EAsync ε α) (prio := Task.Priority.default) (sync := false) : EAsync ε α := + .mk <| BaseAsync.bind (sync := sync) (prio := prio) x fun + | .ok a => BaseAsync.pure (.ok a) + | .error e => (f e) + +/-- +Runs an action, ensuring that some other action always happens afterward. +-/ +protected def tryFinally' + (x : EAsync ε α) (f : Option α → EAsync ε β) + (prio := Task.Priority.default) (sync := false) : + EAsync ε (α × β) := + .mk <| BaseAsync.bind x (prio := prio) (sync := sync) fun + | .ok a => do + match ← (f (some a)) with + | .ok b => BaseAsync.pure (.ok (a, b)) + | .error e => BaseAsync.pure (.error e) + | .error e => do + match ← (f none) with + | .ok _ => BaseAsync.pure (.error e) + | .error e' => BaseAsync.pure (.error e') + +/-- +Creates an `EAsync` computation that awaits the completion of the given `ETask ε α`. +-/ +@[inline] +def await (x : ETask ε α) : EAsync ε α := + .mk (BaseAsync.ofTask x) + +/-- +Returns the `EAsync` computation inside an `ETask ε α`, so it can be awaited. +-/ +@[inline] +def async (self : EAsync ε α) (prio := Task.Priority.default) : EAsync ε (ETask ε α) := + EAsync.lift <| self.asTask prio + +instance : Functor (EAsync ε) where + map := EAsync.map + +instance : Monad (EAsync ε) where + pure := EAsync.pure + bind := EAsync.bind + +instance : MonadLift (EIO ε) (EAsync ε) where + monadLift := EAsync.lift + +instance : MonadExcept ε (EAsync ε) where + throw := EAsync.throw + tryCatch := EAsync.tryCatch + +instance : MonadExceptOf ε (EAsync ε) where + throw := EAsync.throw + tryCatch := EAsync.tryCatch + +instance : MonadFinally (EAsync ε) where + tryFinally' := EAsync.tryFinally' + +instance : OrElse (EAsync ε α) where + orElse := MonadExcept.orElse + +instance [Inhabited ε] : Inhabited (EAsync ε α) where + default := .mk <| BaseAsync.pure default + +instance : MonadAwait (ETask ε) (EAsync ε) where + await t := .mk <| BaseAsync.ofTask t + +instance : MonadAwait Task (EAsync ε) where + await t := .mk <| BaseAsync.ofTask (t.map (.ok)) + +instance : MonadAwait AsyncTask (EAsync IO.Error) where + await t := .mk <| BaseAsync.ofTask t + +instance : MonadAwait IO.Promise (EAsync ε) where + await t := .mk <| BaseAsync.ofTask (t.result!.map (.ok)) + +instance : MonadAsync (ETask ε) (EAsync ε) where + async t prio := EAsync.lift <| t.asTask (prio := prio) + +instance : MonadAsync AsyncTask (EAsync IO.Error) where + async t prio := EAsync.lift <| t.asTask (prio := prio) + +instance : MonadLift BaseIO (EAsync ε) where + monadLift x := .mk <| (pure ∘ .ok) <$> x + +instance : MonadLift (EIO ε) (EAsync ε) where + monadLift x := .mk <| pure <$> x.toBaseIO + +instance : MonadLift BaseAsync (EAsync ε) where + monadLift x := .mk <| x.map (.ok) + +@[inline] +protected partial def forIn + {β : Type} [i : Inhabited ε] (init : β) + (f : Unit → β → EAsync ε (ForInStep β)) + (prio := Task.Priority.default) : + EAsync ε β := do + let promise ← IO.Promise.new + + let rec @[specialize] loop (b : β) : EAsync ε (ETask ε Unit) := async (prio := prio) do + match ← f () b with + | ForInStep.done b => promise.resolve (.ok b) + | ForInStep.yield b => discard <| (loop b) + + discard <| loop init + + .mk <| BaseAsync.ofTask promise.result! + +instance [Inhabited ε] : ForIn (EAsync ε) Lean.Loop Unit where + forIn _ := EAsync.forIn + +end EAsync + +/-- +An asynchronous computation that may produce an error of type `IO.Error`.. +-/ +abbrev Async (α : Type) := EAsync IO.Error α + +namespace Async + +/-- +Converts a `Async` to a `AsyncTask`. +-/ +@[inline] +protected def toIO (x : Async α) : IO (AsyncTask α) := + MaybeTask.toTask <$> x.toRawBaseIO + +@[default_instance] +instance : MonadAsync AsyncTask Async := + inferInstanceAs (MonadAsync (ETask IO.Error) (EAsync IO.Error)) + +instance : MonadAwait AsyncTask Async := + inferInstanceAs (MonadAwait AsyncTask (EAsync IO.Error)) + +instance : MonadAwait IO.Promise Async := + inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error)) + +end Async + +export MonadAsync (async) +export MonadAwait (await) + +/-- +This function transforms the operation inside the monad `m` into a task and let it run in the background. +-/ +@[inline, specialize] +def background [Monad m] [MonadAsync t m] (prio := Task.Priority.default) : m α → m Unit := + discard ∘ async (t := t) (prio := prio) + +/-- +Runs two computations concurrently and returns both results as a pair. +-/ +@[inline, specialize] +def concurrently + [Monad m] [MonadAwait t m] [MonadAsync t m] + (x : m α) (y : m β) + (prio := Task.Priority.default) : + m (α × β) := do + let taskX : t α ← async x (prio := prio) + let taskY : t β ← async y (prio := prio) + let resultX ← await taskX + let resultY ← await taskY + return (resultX, resultY) + +/-- +Runs two computations concurrently and returns the result of the one that finishes first. +The other result is lost and the other task is not cancelled, so the task will continue the execution +until the end. +-/ +@[inline, specialize] +def race + [MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m] + [Monad m] [Inhabited α] (x : m α) (y : m α) + (prio := Task.Priority.default) : + m α := do + let promise ← IO.Promise.new + + discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) + discard (async (t := t) (prio := prio) <| Bind.bind y (liftM ∘ promise.resolve)) + + await promise.result! + +/-- +Runs all computations in an `Array` concurrently and returns all results as an array. +-/ +@[inline, specialize] +def concurrentlyAll + [Monad m] [MonadAwait t m] [MonadAsync t m] (xs : Array (m α)) + (prio := Task.Priority.default) : + m (Array α) := do + let tasks : Array (t α) ← xs.mapM (async (prio := prio)) + tasks.mapM await + +/-- +Runs all computations concurrently and returns the result of the first one to finish. +All other results are lost, and the tasks are not cancelled, so they'll continue their executing +until the end. +-/ +@[inline, specialize] +def raceAll + [ForM m c (m α)] [MonadLiftT BaseIO m] [MonadAwait Task m] + [MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α] + (xs : c) + (prio := Task.Priority.default) : + m α := do + let promise ← IO.Promise.new + + ForM.forM xs fun x => + discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve)) + + await promise.result! + end Async end IO end Internal diff --git a/tests/lean/run/async.lean b/tests/lean/run/async.lean new file mode 100644 index 0000000000..331a866286 --- /dev/null +++ b/tests/lean/run/async.lean @@ -0,0 +1,31 @@ +import Std.Internal.Async +import Std.Internal.UV +import Std.Net.Addr + +open Std.Internal.IO.Async.UDP +open Std.Internal.IO.Async +open Std.Net + +def t : IO (Async Nat) := do + IO.println "a" + return do + IO.println "b" + return 2 + +-- Usage example of the ForIn instance + +def loop : Async Nat := do + let mut res := 0 + + while res < 10 do + res := res + 1 + + discard <| t + + pure res + +/-- +info: 10 +-/ +#guard_msgs in +#eval IO.println =<< ETask.block =<< loop.asTask diff --git a/tests/lean/run/async_base_functions.lean b/tests/lean/run/async_base_functions.lean new file mode 100644 index 0000000000..f3f6564000 --- /dev/null +++ b/tests/lean/run/async_base_functions.lean @@ -0,0 +1,54 @@ +import Std.Internal.Async +import Std.Sync.Mutex + +open Std + +open Std.Internal.IO.Async + +def wait (ms : Nat) (ref : Std.Mutex Nat) (val : Nat) : Async Unit := do + ref.atomically (·.modify (· * val)) + IO.sleep ms + ref.atomically (·.modify (· + val)) + +-- Tests + +def sequential : Async Unit := do + let ref ← Std.Mutex.new 0 + wait 200 ref 1 + wait 400 ref 2 + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 40 + +#eval do (← sequential.toEIO).block + +def conc : Async Unit := do + let ref ← Std.Mutex.new 0 + discard <| concurrently (wait 200 ref 1) (wait 400 ref 2) + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 30 + +#eval do (← conc.toEIO).block + +def racer : Async Unit := do + let ref ← Std.Mutex.new 0 + race (wait 200 ref 1) (wait 400 ref 2) + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 10 + +#eval do (← racer.toEIO).block + +def concAll : Async Unit := do + let ref ← Std.Mutex.new 0 + discard <| concurrentlyAll #[(wait 200 ref 1), (wait 400 ref 2)] + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 30 + +#eval do (← concAll.toEIO).block + +def racerAll : Async Unit := do + let ref ← Std.Mutex.new 0 + raceAll #[(wait 200 ref 1), (wait 400 ref 2)] + ref.atomically (·.modify (· * 10)) + assert! (← ref.atomically (·.get)) == 10 + +#eval do (← racerAll.toEIO).block diff --git a/tests/lean/run/async_tcp_fname_errors.lean b/tests/lean/run/async_tcp_fname_errors.lean index 9b345052a8..01c75b4f46 100644 --- a/tests/lean/run/async_tcp_fname_errors.lean +++ b/tests/lean/run/async_tcp_fname_errors.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do @@ -10,43 +10,20 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" - --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - --------------------------------------------------------------- - /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let message ← await (client.recv? 1024) + let message ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< message) none /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let message ← await (client.recv? 1024) + let message ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< message) none /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - discard <| await server.accept - discard <| await server.accept + discard <| await (← server.accept) + discard <| await (← server.accept) def clientServer : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8083 @@ -73,16 +50,14 @@ def clientServer : IO Unit := do mike.noDelay - let serverTask ← (runRobert server).run + let serverTask ← (runRobert server).toIO - let joeTask ← (runJoe joe).run - let mikeTask ← (runMike mike).run + let joeTask ← (runJoe joe).toIO + let mikeTask ← (runMike mike).toIO serverTask.block joeTask.block mikeTask.block -end Async - -#eval Async.clientServer +#eval clientServer diff --git a/tests/lean/run/async_tcp_half.lean b/tests/lean/run/async_tcp_half.lean index ee91e825f3..c2fdf655e1 100644 --- a/tests/lean/run/async_tcp_half.lean +++ b/tests/lean/run/async_tcp_half.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net -- Using this function to create IO Error. For some reason the assert! is not pausing the execution. @@ -11,33 +11,13 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - /-- Joe is another client. -/ def runJoe (addr: SocketAddress) : Async Unit := do let client ← TCP.Socket.Client.mk - await (client.connect addr) - await (client.send (String.toUTF8 "hello robert!")) - await client.shutdown + await (← client.connect addr) + await (← client.send (String.toUTF8 "hello robert!")) + await (← client.shutdown) def listenClose : IO Unit := do let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080 @@ -53,7 +33,7 @@ def acceptClose : IO Unit := do server.bind addr server.listen 128 - let joeTask ← (runJoe addr).run + let joeTask ← (runJoe addr).toIO let task ← server.accept let client ← task.block @@ -71,5 +51,4 @@ def acceptClose : IO Unit := do -- Waiting to avoid errors from escaping. joeTask.block -#eval acceptClose #eval listenClose diff --git a/tests/lean/run/async_tcp_server_client.lean b/tests/lean/run/async_tcp_server_client.lean index c80abfbac3..48ef4b7149 100644 --- a/tests/lean/run/async_tcp_server_client.lean +++ b/tests/lean/run/async_tcp_server_client.lean @@ -2,7 +2,7 @@ import Std.Internal.Async import Std.Internal.UV import Std.Net.Addr -open Std.Internal.IO.Async +open Std.Internal.IO Async open Std.Net -- Using this function to create IO Error. For some reason the assert! is not pausing the execution. @@ -11,53 +11,33 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - -------------------------------------------------------------- /-- Mike is another client. -/ def runMike (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (client.recv? 1024) + let mes ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hi mike!! :)") - await (client.send (String.toUTF8 "hello robert!!")) - await (client.shutdown) + await (← client.send (String.toUTF8 "hello robert!!")) + await (← client.shutdown) /-- Joe is another client. -/ def runJoe (client: TCP.Socket.Client) : Async Unit := do - let mes ← await (client.recv? 1024) + let mes ← await (← client.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hi joe! :)") - await (client.send (String.toUTF8 "hello robert!")) - await client.shutdown + await (← client.send (String.toUTF8 "hello robert!")) + await (← client.shutdown) /-- Robert is the server. -/ def runRobert (server: TCP.Socket.Server) : Async Unit := do - let joe ← await server.accept - let mike ← await server.accept + let joe ← await (← server.accept) + let mike ← await (← server.accept) - await (joe.send (String.toUTF8 "hi joe! :)")) - let mes ← await (joe.recv? 1024) + await (← joe.send (String.toUTF8 "hi joe! :)")) + let mes ← await (← joe.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hello robert!") - await (mike.send (String.toUTF8 "hi mike!! :)")) - let mes ← await (mike.recv? 1024) + await (← mike.send (String.toUTF8 "hi mike!! :)")) + let mes ← await (← mike.recv? 1024) assertBEq (String.fromUTF8? =<< mes) (some "hello robert!!") pure () @@ -69,7 +49,7 @@ def clientServer (addr : SocketAddress) : IO Unit := do let serverTask := runRobert server - let serverTask ← serverTask.run + let serverTask ← serverTask.toIO assertBEq (← server.getSockName).port addr.port @@ -92,8 +72,8 @@ def clientServer (addr : SocketAddress) : IO Unit := do let joeTask := runJoe joe let mikeTask := runMike mike - let joeTask ← joeTask.run - let mikeTask ← mikeTask.run + let joeTask ← joeTask.toIO + let mikeTask ← mikeTask.toIO serverTask.block joeTask.block diff --git a/tests/lean/run/async_udp_sockets.lean b/tests/lean/run/async_udp_sockets.lean index 0582f88feb..a7bd64b5a1 100644 --- a/tests/lean/run/async_udp_sockets.lean +++ b/tests/lean/run/async_udp_sockets.lean @@ -11,26 +11,6 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do throw <| IO.userError <| s!"expected '{expected}', got '{actual}'" --- Define the Async monad -structure Async (α : Type) where - run : IO (AsyncTask α) - -namespace Async - --- Monad instance for Async -instance : Monad Async where - pure x := Async.mk (pure (AsyncTask.pure x)) - bind ma f := Async.mk do - let task ← ma.run - task.bindIO fun a => (f a).run - --- Await function to simplify AsyncTask handling -def await (task : IO (AsyncTask α)) : Async α := - Async.mk task - -instance : MonadLift IO Async where - monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure)) - /-- Joe is another client. -/ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Unit := do let client ← UDP.Socket.mk @@ -38,7 +18,7 @@ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Uni client.bind (addr second) client.connect (addr first) - await (client.send (String.toUTF8 "hello robert!")) + await (← client.send (String.toUTF8 "hello robert!")) def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO Unit := do @@ -46,7 +26,7 @@ def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO U let server ← UDP.Socket.mk server.bind (addr first) - let res ← (runJoe addr first second).run + let res ← (runJoe addr first second).toIO res.block let res ← server.recv 1024