feat: monadic interface for asynchronous operations in Std (#8003)

This PR adds a new monadic interface for `Async` operations.

This is the design for the `Async` monad that I liked the most. The idea
was refined with the help of @tydeu. Before that, I had some
prerequisites in mind:

1. Good performance
2. Explicit `yield` points, so we could avoid using `bindTask` for every
lifted IO operation
3. A way to avoid creating an infinite chain of `Task`s during recursion

The 2 and 3 points are not covered in this PR, I wish I had a good
solution but right now only a few sketches of this.

### Explicit `yield` points

I thought this would be easy at first, but it actually turned out kinda
tricky. I ended up creating the `suspend` syntax, which is just a small
modification of the lift method (`<- ...`) syntax. It desugars to
`Suspend.suspend task fun _ => ...`. So something like:

```lean
do
  IO.println "a"
  IO.println "b"
  let result := suspend (client.recv? 1024)
  IO.println "c"
  IO.println "d"
```

Would become:

```lean
Bind.bind (IO.println "a") fun _ =>
Bind.bind (IO.println "b") fun _ =>
Suspend.suspend (client.recv? 1024) fun message =>
  Bind.bind (IO.println "c") fun _ =>
  IO.println "d"
```

This makes things a bit more efficient. When using `bind`, we would try
to avoid creating a `Task` chain, and the `suspend` would be the only
place we use `Task.bind`. But there's a problem if we use `bind` with
something that needs `suspend`, it’ll block the whole task. Blocking is
the only way to prevent task accumulation when using plain `bind` inside
a structure like that:

```
inductive AsyncResult (ε σ α : Type u) where
    | ok    : α → σ → AsyncResult ε σ α
    | error : ε → σ → AsyncResult ε σ α
    | ofTask  : Task (EStateM.Result ε σ α) → σ →AsyncResult ε σ α
```

Because we simply need to remove the `ofTask` and transform it into an
`ok`.

### Infinite chain of Tasks

If you create an infinite recursive function using `Task` (which is
super common in servers like HTTP ones), it can lead to a lot of memory
usage. Because those tasks get chained forever and won't be freed until
the function returns.

To get around that, I used CPS and instead of just calling `Task.bind`,
I’d spawn a new task and return an "empty" one like:

```lean
fun k => Task.bind (...) fun value => do k value; pure emptyTask
```

This works great with a CPS-style monad, but it generates a huge IR by
itself.

Just doing CPS alone was too much, though, because every lifted
operation created a new continuation and a `Task.bind`. So, I used it
with `suspend` and got a better performance, but the usage is not good
with `suspend`.

### The current monad

Right now, the monad I’m using is super simple. It doesn't solve the
earlier problems, but the API is clean, and the generated IR is small
enough. An example of how we should use it is:

```lean
-- A loop that repeatedly sends a message and waits for a reply.
partial def writeLoop (client : Socket.Client) (message : String) : Async (AsyncTask Unit) := async do
  IO.println s!"sending: {message}"
  await (← client.send (String.toUTF8 message))

  if let some mes ← await (← client.recv? 1024) then
    IO.println s!"received: {String.fromUTF8! mes}"
    -- use parallel to avoid building up an infinite task chain
    parallel (writeLoop client message)
  else
    IO.println "client disconnected from receiving"

-- Server’s main accept loop, keeps accepting and echoing for new clients.
partial def acceptLoop (server : Socket.Server) (promise : IO.Promise Unit) : Async (AsyncTask Unit) := async do
  let client ← await (← server.accept)
  await (← client.send (String.toUTF8 "tutturu "))

  -- allow multiple clients to connect at the same time
  parallel (writeLoop client "hi!!")

  -- and keep accepting more clients, parallel again to avoid building up an infinite task chain
  parallel (acceptLoop server promise)

-- A simple client that connects and sends a message.
def echoClient (addr : SocketAddress) (message : String) : Async (AsyncTask Unit) := async do
  let socket ← Client.mk
  await (← socket.connect addr)
  parallel (writeLoop socket message)

-- TCP setup: bind, listen, serve, and run a sample client.
partial def mainTCP : Async Unit := do
  let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080

  let server ← Server.mk
  server.bind addr
  server.listen 128

  -- promise exists since the server is (probably) never going to stop
  let promise ← IO.Promise.new
  let acceptAction ← acceptLoop server promise

  await (← echoClient addr "hi!")
  await acceptAction
  await promise

-- Entry point
def main : IO Unit := mainTCP.wait
```

---------

Co-authored-by: Henrik Böving <hargonix@gmail.com>
Co-authored-by: Mac Malone <tydeu@hatpress.net>
This commit is contained in:
Sofia Rodrigues 2025-06-25 23:51:26 -03:00 committed by GitHub
parent 1e135f2187
commit b15cfadde8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 831 additions and 145 deletions

View file

@ -1,7 +1,7 @@
/-
Copyright (c) 2024 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Henrik Böving
Authors: Henrik Böving, Sofia Rodrigues, Mac Malone
-/
prelude
import Init.Core
@ -13,21 +13,230 @@ namespace Internal
namespace IO
namespace Async
/--
A `Task` that may resolve to a value or an `IO.Error`.
/-!
# Asynchronous Programming Primitives
This module provides a layered approach to asynchronous programming, combining monadic types,
type classes, and concrete task types that work together in a cohesive system.
- **Monadic Types**: These types provide a good way to to chain and manipulate context. These
can contain a `Task`, enabling manipulation of both asynchronous and synchronous code.
- **Concrete Task Types**: Concrete units of work that can be executed within these contexts.
## Monadic Types
These types provide a good way to to chain and manipulate context. These can contain a `Task`,
enabling manipulation of both asynchronous and synchronous code.
- `BaseAsync`: A monadic type for infallible asynchronous computations
- `EAsync`: A monadic type for asynchronous computations that may fail with an error of type
`ε`
- `Async`: A monadic type for IO-based asynchronous computations that may fail with `IO.Error`
(alias for `EAsync IO.Error`)
## Concurrent Units of Work
These are the concrete computational units that exist within the monadic contexts. These types
should not be created directly.
- `Task`: A computation that will resolve to a value of type `α`,
- `ETask`: A task that may fail with an error of type `ε`.
- `AsyncTask`: A task that may fail with an `IO.Error` (alias for `ETask IO.Error`).
## Relation
These types are related by two functions in the type classes `MonadAsync` and `MonadAwait`: `async`
and `await`. The `async` function extracts a concrete asynchronous task from a computation within the
monadic context. In effect, it runs the computation in the background and returns a task handle that
can be awaited later. On the other hand, the `await` function takes a task and re-inserts it into the
monadic context, allowing its result to be composed using monadic bind and also pausing to wait for that result.
This relationship between `async` and `await` enables precise control over when a computation begins
and when its result is used. You can spawn multiple asynchronous tasks using `async`, perform other
operations in the meantime, and later rejoin the computation flow by awaiting their results.
These functions should not be used directly. Instead, prefer higher-level combinators such as
`race`, `raceAll`, `concurrently`, `background` and `concurrentlyAll`. The best way to think about
how to write your async code, it to avoid using the concurrent units of work, and only use it when integrating
with non async code that uses them.
-/
def AsyncTask (α : Type u) : Type u := Task (Except IO.Error α)
/--
Typeclass for monads that can "await" a computation of type `t α` in a monad `m` until the result is
available.
-/
class MonadAwait (t : Type → Type) (m : Type → Type) extends Monad m where
/--
Awaits the result of `t α` and returns it inside the `m` monad.
-/
await : t α → m α
/--
Represents monads that can launch computations asynchronously of type `t` in a monad `m`.
-/
class MonadAsync (t : Type → Type) (m : Type → Type) extends Monad m where
/--
Starts an asynchronous computation in another monad.
-/
async (x : m α) (prio := Task.Priority.default) : m (t α)
/-
These instances have the default_instance attribute so that other default instances
can function correctly within monad transformers.
-/
@[default_instance]
instance [Monad m] [MonadAwait t m] : MonadAwait t (StateT n m) where
await := liftM (m := m) ∘ MonadAwait.await
@[default_instance]
instance [Monad m] [MonadAwait t m] : MonadAwait t (ExceptT n m) where
await := liftM (m := m) ∘ MonadAwait.await
@[default_instance]
instance [Monad m] [MonadAwait t m] : MonadAwait t (ReaderT n m) where
await := liftM (m := m) ∘ MonadAwait.await
@[default_instance]
instance [MonadAwait t m] : MonadAwait t (StateRefT' s n m) where
await := liftM (m := m) ∘ MonadAwait.await
@[default_instance]
instance [MonadAwait t m] : MonadAwait t (StateT s m) where
await := liftM (m := m) ∘ MonadAwait.await
@[default_instance]
instance [MonadAsync t m] : MonadAsync t (ReaderT n m) where
async p prio := MonadAsync.async (prio := prio) ∘ p
@[default_instance]
instance [MonadAsync t m] : MonadAsync t (StateRefT' s n m) where
async p prio := MonadAsync.async (prio := prio) ∘ p
@[default_instance]
instance [Functor t] [inst : MonadAsync t m] : MonadAsync t (StateT s m) where
async p prio := fun s => do
let t ← inst.async (prio := prio) (p s)
pure (t <&> Prod.fst, s)
/--
A `Task` that may resolve to either a value of type `α` or an error value of type `ε`.
-/
abbrev ETask (ε : Type) (α : Type) : Type := ExceptT ε Task α
namespace ETask
/--
Construct an `ETask` that is already resolved with value `x`.
-/
@[inline]
protected def pure (x : α) : ETask ε α :=
Task.pure <| .ok x
/--
Creates a new `ETask` that will run after `x` has finished. If `x`:
- errors, return an `ETask` that resolves to the error.
- succeeds, return an `ETask` that resolves to `f x`.
-/
@[inline]
protected def map (f : α → β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : ETask ε β :=
Task.map (x := x) (f <$> ·) prio sync
/--
Creates a new `ETask` that will run after `x` has completed. If `x`:
- errors, return an `ETask` that resolves to the error.
- succeeds, run `f` on the result of `x` and return the `ETask` produced by `f`.
-/
@[inline]
protected def bind (x : ETask ε α) (f : α → ETask ε β) (prio := Task.Priority.default) (sync := false) : ETask ε β :=
Task.bind x (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => Task.pure <| .error e
/--
Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned
`ETask` resolves to that error.
-/
@[inline]
protected def bindEIO (x : ETask ε α) (f : α → EIO ε (ETask ε β)) (prio := Task.Priority.default) (sync := false) : EIO ε (ETask ε β) :=
EIO.bindTask x (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => .error e
/--
Similar to `bind`, however `f` has access to the `EIO` monad. If `f` throws an error, the returned
`ETask` resolves to that error.
-/
@[inline]
protected def mapEIO (f : α → EIO ε β) (x : ETask ε α) (prio := Task.Priority.default) (sync := false) : BaseIO (ETask ε β) :=
EIO.mapTask (t := x) (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => .error e
/--
Block until the `ETask` in `x` finishes and returns its value. Propagates any error encountered
during execution.
-/
@[inline]
def block (x : ETask ε α) : EIO ε α := do
match x.get with
| .ok a => return a
| .error e => .error e
/--
Create an `ETask` that resolves to the value of the promise `x`.
-/
@[inline]
def ofPromise (x : IO.Promise (Except ε α)) : ETask ε α :=
x.result!
/--
Create an `ETask` that resolves to the pure value of the promise `x`.
-/
@[inline]
def ofPurePromise (x : IO.Promise α) : ETask ε α :=
x.result!.map pure (sync := true)
/--
Obtain the `IO.TaskState` of `x`.
-/
@[inline]
def getState (x : ETask ε α) : BaseIO IO.TaskState :=
IO.getTaskState x
instance : Functor (ETask ε) where
map := ETask.map
instance : Monad (ETask ε) where
pure := ETask.pure
bind := ETask.bind
end ETask
/--
A `Task` that may resolve to a value or an error value of type `IO.Error`. Alias for `ETask IO.Error`.
-/
abbrev AsyncTask := ETask IO.Error
namespace AsyncTask
/--
Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned
`AsyncTask` resolves to that error.
-/
@[inline]
protected def mapIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) :=
EIO.mapTask (t := x) (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => .error e
/--
Construct an `AsyncTask` that is already resolved with value `x`.
-/
@[inline]
protected def pure (x : α) : AsyncTask α := Task.pure <| .ok x
instance : Pure AsyncTask where
pure := AsyncTask.pure
protected def pure (x : α) : AsyncTask α :=
Task.pure <| .ok x
/--
Create a new `AsyncTask` that will run after `x` has finished.
@ -36,9 +245,8 @@ If `x`:
- succeeds, run `f` on the result of `x` and return the `AsyncTask` produced by `f`.
-/
@[inline]
protected def bind (x : AsyncTask α) (f : α → AsyncTask β) : AsyncTask β :=
Task.bind x fun r =>
match r with
protected def bind (x : AsyncTask α) (f : α → AsyncTask β) (prio := Task.Priority.default) (sync := false) : AsyncTask β :=
Task.bind x (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => Task.pure <| .error e
@ -49,40 +257,34 @@ If `x`:
- succeeds, return an `AsyncTask` that resolves to `f x`.
-/
@[inline]
def map (f : α → β) (x : AsyncTask α) : AsyncTask β :=
Task.map (x := x) fun r =>
match r with
| .ok a => .ok (f a)
| .error e => .error e
def map (f : α → β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : AsyncTask β :=
Task.map (x := x) (f <$> ·) prio sync
/--
Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned
`AsyncTask` resolves to that error.
-/
@[inline]
def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) : BaseIO (AsyncTask β) :=
IO.bindTask x fun r =>
match r with
def bindIO (x : AsyncTask α) (f : α → IO (AsyncTask β)) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) :=
IO.bindTask x (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => .error e
/--
Similar to `bind`, however `f` has access to the `IO` monad. If `f` throws an error, the returned
Similar to `map`, however `f` has access to the `IO` monad. If `f` throws an error, the returned
`AsyncTask` resolves to that error.
-/
@[inline]
def mapIO (f : α → IO β) (x : AsyncTask α) : BaseIO (AsyncTask β) :=
IO.mapTask (t := x) fun r =>
match r with
def mapTaskIO (f : α → IO β) (x : AsyncTask α) (prio := Task.Priority.default) (sync := false) : BaseIO (AsyncTask β) :=
IO.mapTask (t := x) (prio := prio) (sync := sync) fun
| .ok a => f a
| .error e => .error e
/--
Block until the `AsyncTask` in `x` finishes.
-/
def block (x : AsyncTask α) : IO α := do
let res := x.get
match res with
def block (x : AsyncTask α) : IO α :=
match x.get with
| .ok a => return a
| .error e => .error e
@ -98,7 +300,7 @@ Create an `AsyncTask` that resolves to the value of `x`.
-/
@[inline]
def ofPurePromise (x : IO.Promise α) : AsyncTask α :=
x.result!.map pure
x.result!.map pure (sync := true)
/--
Obtain the `IO.TaskState` of `x`.
@ -109,6 +311,491 @@ def getState (x : AsyncTask α) : BaseIO IO.TaskState :=
end AsyncTask
/--
A `MaybeTask α` represents a computation that either:
- Is immediately available as an `α` value, or
- Is an asynchronous computation that will eventually produce an `α` value.
-/
inductive MaybeTask (α : Type)
| pure : α → MaybeTask α
| ofTask : Task α → MaybeTask α
namespace MaybeTask
/--
Constructs an `Task` from a `MaybeTask`.
-/
@[inline]
def toTask : MaybeTask α → Task α
| .pure a => .pure a
| .ofTask t => t
/--
Gets the value of the `MaybeTask` by blocking.
-/
@[inline]
def get {α : Type} : MaybeTask αα
| .pure a => a
| .ofTask t => t.get
/--
Maps a function over a `MaybeTask`.
-/
@[inline]
def map (f : α → β) (prio := Task.Priority.default) (sync := false) : MaybeTask α → MaybeTask β
| .pure a => .pure <| f a
| .ofTask t => .ofTask <| t.map f prio sync
/--
Sequences two computations, allowing the second to depend on the value computed by the first.
-/
@[inline]
protected def bind (t : MaybeTask α) (f : α → MaybeTask β) (prio := Task.Priority.default) (sync := false) : MaybeTask β :=
match t with
| .pure a => f a
| .ofTask t => .ofTask <| t.bind (f · |>.toTask) prio sync
/--
Join the `MaybeTask` to an `Task`.
-/
@[inline]
def joinTask (t : Task (MaybeTask α)) : Task α :=
t.bind (sync := true) fun
| .pure a => .pure a
| .ofTask t => t
instance : Functor (MaybeTask) where
map := MaybeTask.map
instance : Monad (MaybeTask) where
pure := MaybeTask.pure
bind := MaybeTask.bind
end MaybeTask
/--
An asynchronous computation that never fails.
-/
def BaseAsync (α : Type) := BaseIO (MaybeTask α)
namespace BaseAsync
/--
Converts a `BaseIO` into a `BaseAsync`
-/
@[inline]
def mk (x : BaseIO (MaybeTask α)) : BaseAsync α :=
x
/--
Converts a `BaseAsync` into a `BaseIO`
-/
@[inline]
def toRawBaseIO (x : BaseAsync α) : BaseIO (MaybeTask α) :=
x
/--
Converts a `BaseAsync` to a `BaseIO Task`.
-/
@[inline]
protected def toBaseIO (x : BaseAsync α) : BaseIO (Task α) :=
MaybeTask.toTask <$> x.toRawBaseIO
/--
Creates a new `BaseAsync` out of a `Task`.
-/
@[inline]
protected def ofTask (x : Task α) : BaseAsync α :=
.mk <| pure <| MaybeTask.ofTask x
/--
Creates a `BaseAsync` computation that immediately returns the given value.
-/
@[inline]
protected def pure (a : α) : BaseAsync α :=
.mk <| pure <| .pure a
/--
Maps the result of a `BaseAsync` computation with a function.
-/
@[inline]
protected def map (f : α → β) (self : BaseAsync α) (prio := Task.Priority.default) (sync := false) : BaseAsync β :=
mk <| (·.map f prio sync) <$> self.toRawBaseIO
/--
Sequences two computations, allowing the second to depend on the value computed by the first.
-/
@[inline]
protected def bind (self : BaseAsync α) (f : α → BaseAsync β) (prio := Task.Priority.default) (sync := false) : BaseAsync β :=
mk <| self.toRawBaseIO >>= (bindAsyncTask · f |>.toRawBaseIO)
where
bindAsyncTask (t : MaybeTask α) (f : α → BaseAsync β) : BaseAsync β := .mk <|
match t with
| .pure a => (f a) |>.toRawBaseIO
| .ofTask t => .ofTask <$> BaseIO.bindTask t (fun a => MaybeTask.toTask <$> (f a |>.toRawBaseIO)) prio sync
/--
Lifts a `BaseIO` action into a `BaseAsync` computation.
-/
@[inline]
protected def lift (x : BaseIO α) : BaseAsync α :=
.mk <| (.pure ∘ .pure) =<< x
/--
Waits for the result of the `BaseAsync` computation, blocking if necessary.
-/
@[inline]
protected def wait (self : BaseAsync α) : BaseIO α :=
pure ∘ Task.get =<< self.toBaseIO
/--
Lifts a `BaseAsync` computation into a `Task` that can be awaited and joined.
-/
@[inline]
protected def asTask (x : BaseAsync α) (prio := Task.Priority.default) : BaseIO (Task α) := do
let res ← BaseIO.asTask (prio := prio) x.toRawBaseIO
return MaybeTask.joinTask res
/--
Creates a `BaseAsync` that awaits the completion of the given `Task α`.
-/
@[inline]
def await (t : Task α) : BaseAsync α :=
.mk <| pure <| MaybeTask.ofTask t
/--
Returns the `BaseAsync` computation inside a `Task α`, so it can be awaited.
-/
@[inline]
def async (self : BaseAsync α) (prio := Task.Priority.default) : BaseAsync (Task α) :=
BaseAsync.lift <| self.asTask (prio := prio)
instance : Functor BaseAsync where
map := BaseAsync.map
instance : Monad BaseAsync where
pure := BaseAsync.pure
bind := BaseAsync.bind
instance : MonadLift BaseIO BaseAsync where
monadLift := BaseAsync.lift
instance : MonadAwait Task BaseAsync where
await := BaseAsync.await
instance : MonadAsync Task BaseAsync where
async t prio := BaseAsync.async t prio
instance [Inhabited α] : Inhabited (BaseAsync α) where
default := .mk <| pure (MaybeTask.pure default)
end BaseAsync
/--
An asynchronous computation that may produce an error of type `ε`.
-/
def EAsync (ε : Type) (α : Type) := BaseAsync (Except ε α)
namespace EAsync
/--
Converts a `EAsync` to a `ETask`.
-/
@[inline]
protected def toBaseIO (x : EAsync ε α) : BaseIO (ETask ε α) :=
MaybeTask.toTask <$> x.toRawBaseIO
/--
Creates a new `EAsync` out of a `RTask`.
-/
@[inline]
protected def ofTask (x : ETask ε α) : EAsync ε α :=
.mk <| pure <| MaybeTask.ofTask x
/--
Converts a `BaseAsync` to a `EIO ETask`.
-/
@[inline]
protected def toEIO (x : EAsync ε α) : EIO ε (ETask ε α) :=
MaybeTask.toTask <$> x.toRawBaseIO
/--
Creates a new `EAsync` out of a `ETask`.
-/
@[inline]
protected def ofETask (x : ETask ε α) : EAsync ε α :=
.mk <| BaseAsync.ofTask x
/--
Creates an `EAsync` computation that immediately returns the given value.
-/
@[inline]
protected def pure (a : α) : EAsync ε α :=
.mk <| BaseAsync.pure <| .ok a
/--
Maps the result of an `EAsync` computation with a pure function.
-/
@[inline]
protected def map (f : α → β) (self : EAsync ε α) : EAsync ε β :=
.mk <| BaseAsync.map (.map f) self
/--
Sequences two computations, allowing the second to depend on the value computed by the first.
-/
@[inline]
protected def bind (self : EAsync ε α) (f : α → EAsync ε β) : EAsync ε β :=
.mk <| BaseAsync.bind self fun
| .ok a => f a
| .error e => BaseAsync.pure (.error e)
/--
Lifts an `EIO` action into an `EAsync` computation.
-/
@[inline]
protected def lift (x : EIO ε α) : EAsync ε α :=
.mk <| BaseAsync.lift x.toBaseIO
/--
Waits for the result of the `EAsync` computation, blocking if necessary.
-/
@[inline]
protected def wait (self : EAsync ε α) : EIO ε α := do
let result ← self |> BaseAsync.wait
match result with
| .ok a => return a
| .error e => .error e
/--
Lifts an `EAsync` computation into an `ETask` that can be awaited and joined.
-/
@[inline]
protected def asTask (x : EAsync ε α) (prio := Task.Priority.default) : EIO ε (ETask ε α) :=
x |> BaseAsync.asTask (prio := prio)
/--
Raises an error of type `ε` within the `EAsync` monad.
-/
@[inline]
protected def throw (e : ε) : EAsync ε α :=
.mk <| BaseAsync.pure (.error e)
/--
Handles errors in an `EAsync` computation by running a handler if one occurs.
-/
@[inline]
protected def tryCatch (x : EAsync ε α) (f : ε → EAsync ε α) (prio := Task.Priority.default) (sync := false) : EAsync ε α :=
.mk <| BaseAsync.bind (sync := sync) (prio := prio) x fun
| .ok a => BaseAsync.pure (.ok a)
| .error e => (f e)
/--
Runs an action, ensuring that some other action always happens afterward.
-/
protected def tryFinally'
(x : EAsync ε α) (f : Option α → EAsync ε β)
(prio := Task.Priority.default) (sync := false) :
EAsync ε (α × β) :=
.mk <| BaseAsync.bind x (prio := prio) (sync := sync) fun
| .ok a => do
match ← (f (some a)) with
| .ok b => BaseAsync.pure (.ok (a, b))
| .error e => BaseAsync.pure (.error e)
| .error e => do
match ← (f none) with
| .ok _ => BaseAsync.pure (.error e)
| .error e' => BaseAsync.pure (.error e')
/--
Creates an `EAsync` computation that awaits the completion of the given `ETask ε α`.
-/
@[inline]
def await (x : ETask ε α) : EAsync ε α :=
.mk (BaseAsync.ofTask x)
/--
Returns the `EAsync` computation inside an `ETask ε α`, so it can be awaited.
-/
@[inline]
def async (self : EAsync ε α) (prio := Task.Priority.default) : EAsync ε (ETask ε α) :=
EAsync.lift <| self.asTask prio
instance : Functor (EAsync ε) where
map := EAsync.map
instance : Monad (EAsync ε) where
pure := EAsync.pure
bind := EAsync.bind
instance : MonadLift (EIO ε) (EAsync ε) where
monadLift := EAsync.lift
instance : MonadExcept ε (EAsync ε) where
throw := EAsync.throw
tryCatch := EAsync.tryCatch
instance : MonadExceptOf ε (EAsync ε) where
throw := EAsync.throw
tryCatch := EAsync.tryCatch
instance : MonadFinally (EAsync ε) where
tryFinally' := EAsync.tryFinally'
instance : OrElse (EAsync ε α) where
orElse := MonadExcept.orElse
instance [Inhabited ε] : Inhabited (EAsync ε α) where
default := .mk <| BaseAsync.pure default
instance : MonadAwait (ETask ε) (EAsync ε) where
await t := .mk <| BaseAsync.ofTask t
instance : MonadAwait Task (EAsync ε) where
await t := .mk <| BaseAsync.ofTask (t.map (.ok))
instance : MonadAwait AsyncTask (EAsync IO.Error) where
await t := .mk <| BaseAsync.ofTask t
instance : MonadAwait IO.Promise (EAsync ε) where
await t := .mk <| BaseAsync.ofTask (t.result!.map (.ok))
instance : MonadAsync (ETask ε) (EAsync ε) where
async t prio := EAsync.lift <| t.asTask (prio := prio)
instance : MonadAsync AsyncTask (EAsync IO.Error) where
async t prio := EAsync.lift <| t.asTask (prio := prio)
instance : MonadLift BaseIO (EAsync ε) where
monadLift x := .mk <| (pure ∘ .ok) <$> x
instance : MonadLift (EIO ε) (EAsync ε) where
monadLift x := .mk <| pure <$> x.toBaseIO
instance : MonadLift BaseAsync (EAsync ε) where
monadLift x := .mk <| x.map (.ok)
@[inline]
protected partial def forIn
{β : Type} [i : Inhabited ε] (init : β)
(f : Unit → β → EAsync ε (ForInStep β))
(prio := Task.Priority.default) :
EAsync ε β := do
let promise ← IO.Promise.new
let rec @[specialize] loop (b : β) : EAsync ε (ETask ε Unit) := async (prio := prio) do
match ← f () b with
| ForInStep.done b => promise.resolve (.ok b)
| ForInStep.yield b => discard <| (loop b)
discard <| loop init
.mk <| BaseAsync.ofTask promise.result!
instance [Inhabited ε] : ForIn (EAsync ε) Lean.Loop Unit where
forIn _ := EAsync.forIn
end EAsync
/--
An asynchronous computation that may produce an error of type `IO.Error`..
-/
abbrev Async (α : Type) := EAsync IO.Error α
namespace Async
/--
Converts a `Async` to a `AsyncTask`.
-/
@[inline]
protected def toIO (x : Async α) : IO (AsyncTask α) :=
MaybeTask.toTask <$> x.toRawBaseIO
@[default_instance]
instance : MonadAsync AsyncTask Async :=
inferInstanceAs (MonadAsync (ETask IO.Error) (EAsync IO.Error))
instance : MonadAwait AsyncTask Async :=
inferInstanceAs (MonadAwait AsyncTask (EAsync IO.Error))
instance : MonadAwait IO.Promise Async :=
inferInstanceAs (MonadAwait IO.Promise (EAsync IO.Error))
end Async
export MonadAsync (async)
export MonadAwait (await)
/--
This function transforms the operation inside the monad `m` into a task and let it run in the background.
-/
@[inline, specialize]
def background [Monad m] [MonadAsync t m] (prio := Task.Priority.default) : m α → m Unit :=
discard ∘ async (t := t) (prio := prio)
/--
Runs two computations concurrently and returns both results as a pair.
-/
@[inline, specialize]
def concurrently
[Monad m] [MonadAwait t m] [MonadAsync t m]
(x : m α) (y : m β)
(prio := Task.Priority.default) :
m (α × β) := do
let taskX : t α ← async x (prio := prio)
let taskY : t β ← async y (prio := prio)
let resultX ← await taskX
let resultY ← await taskY
return (resultX, resultY)
/--
Runs two computations concurrently and returns the result of the one that finishes first.
The other result is lost and the other task is not cancelled, so the task will continue the execution
until the end.
-/
@[inline, specialize]
def race
[MonadLiftT BaseIO m] [MonadAwait Task m] [MonadAsync t m] [MonadAwait t m]
[Monad m] [Inhabited α] (x : m α) (y : m α)
(prio := Task.Priority.default) :
m α := do
let promise ← IO.Promise.new
discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve))
discard (async (t := t) (prio := prio) <| Bind.bind y (liftM ∘ promise.resolve))
await promise.result!
/--
Runs all computations in an `Array` concurrently and returns all results as an array.
-/
@[inline, specialize]
def concurrentlyAll
[Monad m] [MonadAwait t m] [MonadAsync t m] (xs : Array (m α))
(prio := Task.Priority.default) :
m (Array α) := do
let tasks : Array (t α) ← xs.mapM (async (prio := prio))
tasks.mapM await
/--
Runs all computations concurrently and returns the result of the first one to finish.
All other results are lost, and the tasks are not cancelled, so they'll continue their executing
until the end.
-/
@[inline, specialize]
def raceAll
[ForM m c (m α)] [MonadLiftT BaseIO m] [MonadAwait Task m]
[MonadAsync t m] [MonadAwait t m] [Monad m] [Inhabited α]
(xs : c)
(prio := Task.Priority.default) :
m α := do
let promise ← IO.Promise.new
ForM.forM xs fun x =>
discard (async (t := t) (prio := prio) <| Bind.bind x (liftM ∘ promise.resolve))
await promise.result!
end Async
end IO
end Internal

31
tests/lean/run/async.lean Normal file
View file

@ -0,0 +1,31 @@
import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO.Async.UDP
open Std.Internal.IO.Async
open Std.Net
def t : IO (Async Nat) := do
IO.println "a"
return do
IO.println "b"
return 2
-- Usage example of the ForIn instance
def loop : Async Nat := do
let mut res := 0
while res < 10 do
res := res + 1
discard <| t
pure res
/--
info: 10
-/
#guard_msgs in
#eval IO.println =<< ETask.block =<< loop.asTask

View file

@ -0,0 +1,54 @@
import Std.Internal.Async
import Std.Sync.Mutex
open Std
open Std.Internal.IO.Async
def wait (ms : Nat) (ref : Std.Mutex Nat) (val : Nat) : Async Unit := do
ref.atomically (·.modify (· * val))
IO.sleep ms
ref.atomically (·.modify (· + val))
-- Tests
def sequential : Async Unit := do
let ref ← Std.Mutex.new 0
wait 200 ref 1
wait 400 ref 2
ref.atomically (·.modify (· * 10))
assert! (← ref.atomically (·.get)) == 40
#eval do (← sequential.toEIO).block
def conc : Async Unit := do
let ref ← Std.Mutex.new 0
discard <| concurrently (wait 200 ref 1) (wait 400 ref 2)
ref.atomically (·.modify (· * 10))
assert! (← ref.atomically (·.get)) == 30
#eval do (← conc.toEIO).block
def racer : Async Unit := do
let ref ← Std.Mutex.new 0
race (wait 200 ref 1) (wait 400 ref 2)
ref.atomically (·.modify (· * 10))
assert! (← ref.atomically (·.get)) == 10
#eval do (← racer.toEIO).block
def concAll : Async Unit := do
let ref ← Std.Mutex.new 0
discard <| concurrentlyAll #[(wait 200 ref 1), (wait 400 ref 2)]
ref.atomically (·.modify (· * 10))
assert! (← ref.atomically (·.get)) == 30
#eval do (← concAll.toEIO).block
def racerAll : Async Unit := do
let ref ← Std.Mutex.new 0
raceAll #[(wait 200 ref 1), (wait 400 ref 2)]
ref.atomically (·.modify (· * 10))
assert! (← ref.atomically (·.get)) == 10
#eval do (← racerAll.toEIO).block

View file

@ -2,7 +2,7 @@ import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO.Async
open Std.Internal.IO Async
open Std.Net
def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
@ -10,43 +10,20 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
throw <| IO.userError <|
s!"expected '{expected}', got '{actual}'"
-- Define the Async monad
structure Async (α : Type) where
run : IO (AsyncTask α)
namespace Async
-- Monad instance for Async
instance : Monad Async where
pure x := Async.mk (pure (AsyncTask.pure x))
bind ma f := Async.mk do
let task ← ma.run
task.bindIO fun a => (f a).run
-- Await function to simplify AsyncTask handling
def await (task : IO (AsyncTask α)) : Async α :=
Async.mk task
instance : MonadLift IO Async where
monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure))
--------------------------------------------------------------
/-- Mike is another client. -/
def runMike (client: TCP.Socket.Client) : Async Unit := do
let message ← await (client.recv? 1024)
let message ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< message) none
/-- Joe is another client. -/
def runJoe (client: TCP.Socket.Client) : Async Unit := do
let message ← await (client.recv? 1024)
let message ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< message) none
/-- Robert is the server. -/
def runRobert (server: TCP.Socket.Server) : Async Unit := do
discard <| await server.accept
discard <| await server.accept
discard <| await (← server.accept)
discard <| await (← server.accept)
def clientServer : IO Unit := do
let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8083
@ -73,16 +50,14 @@ def clientServer : IO Unit := do
mike.noDelay
let serverTask ← (runRobert server).run
let serverTask ← (runRobert server).toIO
let joeTask ← (runJoe joe).run
let mikeTask ← (runMike mike).run
let joeTask ← (runJoe joe).toIO
let mikeTask ← (runMike mike).toIO
serverTask.block
joeTask.block
mikeTask.block
end Async
#eval Async.clientServer
#eval clientServer

View file

@ -2,7 +2,7 @@ import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO.Async
open Std.Internal.IO Async
open Std.Net
-- Using this function to create IO Error. For some reason the assert! is not pausing the execution.
@ -11,33 +11,13 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
throw <| IO.userError <|
s!"expected '{expected}', got '{actual}'"
-- Define the Async monad
structure Async (α : Type) where
run : IO (AsyncTask α)
namespace Async
-- Monad instance for Async
instance : Monad Async where
pure x := Async.mk (pure (AsyncTask.pure x))
bind ma f := Async.mk do
let task ← ma.run
task.bindIO fun a => (f a).run
-- Await function to simplify AsyncTask handling
def await (task : IO (AsyncTask α)) : Async α :=
Async.mk task
instance : MonadLift IO Async where
monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure))
/-- Joe is another client. -/
def runJoe (addr: SocketAddress) : Async Unit := do
let client ← TCP.Socket.Client.mk
await (client.connect addr)
await (client.send (String.toUTF8 "hello robert!"))
await client.shutdown
await (← client.connect addr)
await (← client.send (String.toUTF8 "hello robert!"))
await (← client.shutdown)
def listenClose : IO Unit := do
let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080
@ -53,7 +33,7 @@ def acceptClose : IO Unit := do
server.bind addr
server.listen 128
let joeTask ← (runJoe addr).run
let joeTask ← (runJoe addr).toIO
let task ← server.accept
let client ← task.block
@ -71,5 +51,4 @@ def acceptClose : IO Unit := do
-- Waiting to avoid errors from escaping.
joeTask.block
#eval acceptClose
#eval listenClose

View file

@ -2,7 +2,7 @@ import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO.Async
open Std.Internal.IO Async
open Std.Net
-- Using this function to create IO Error. For some reason the assert! is not pausing the execution.
@ -11,53 +11,33 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
throw <| IO.userError <|
s!"expected '{expected}', got '{actual}'"
-- Define the Async monad
structure Async (α : Type) where
run : IO (AsyncTask α)
namespace Async
-- Monad instance for Async
instance : Monad Async where
pure x := Async.mk (pure (AsyncTask.pure x))
bind ma f := Async.mk do
let task ← ma.run
task.bindIO fun a => (f a).run
-- Await function to simplify AsyncTask handling
def await (task : IO (AsyncTask α)) : Async α :=
Async.mk task
instance : MonadLift IO Async where
monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure))
--------------------------------------------------------------
/-- Mike is another client. -/
def runMike (client: TCP.Socket.Client) : Async Unit := do
let mes ← await (client.recv? 1024)
let mes ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hi mike!! :)")
await (client.send (String.toUTF8 "hello robert!!"))
await (client.shutdown)
await (← client.send (String.toUTF8 "hello robert!!"))
await (← client.shutdown)
/-- Joe is another client. -/
def runJoe (client: TCP.Socket.Client) : Async Unit := do
let mes ← await (client.recv? 1024)
let mes ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hi joe! :)")
await (client.send (String.toUTF8 "hello robert!"))
await client.shutdown
await (← client.send (String.toUTF8 "hello robert!"))
await (← client.shutdown)
/-- Robert is the server. -/
def runRobert (server: TCP.Socket.Server) : Async Unit := do
let joe ← await server.accept
let mike ← await server.accept
let joe ← await (← server.accept)
let mike ← await (← server.accept)
await (joe.send (String.toUTF8 "hi joe! :)"))
let mes ← await (joe.recv? 1024)
await (joe.send (String.toUTF8 "hi joe! :)"))
let mes ← await (joe.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hello robert!")
await (mike.send (String.toUTF8 "hi mike!! :)"))
let mes ← await (mike.recv? 1024)
await (mike.send (String.toUTF8 "hi mike!! :)"))
let mes ← await (mike.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hello robert!!")
pure ()
@ -69,7 +49,7 @@ def clientServer (addr : SocketAddress) : IO Unit := do
let serverTask := runRobert server
let serverTask ← serverTask.run
let serverTask ← serverTask.toIO
assertBEq (← server.getSockName).port addr.port
@ -92,8 +72,8 @@ def clientServer (addr : SocketAddress) : IO Unit := do
let joeTask := runJoe joe
let mikeTask := runMike mike
let joeTask ← joeTask.run
let mikeTask ← mikeTask.run
let joeTask ← joeTask.toIO
let mikeTask ← mikeTask.toIO
serverTask.block
joeTask.block

View file

@ -11,26 +11,6 @@ def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
throw <| IO.userError <|
s!"expected '{expected}', got '{actual}'"
-- Define the Async monad
structure Async (α : Type) where
run : IO (AsyncTask α)
namespace Async
-- Monad instance for Async
instance : Monad Async where
pure x := Async.mk (pure (AsyncTask.pure x))
bind ma f := Async.mk do
let task ← ma.run
task.bindIO fun a => (f a).run
-- Await function to simplify AsyncTask handling
def await (task : IO (AsyncTask α)) : Async α :=
Async.mk task
instance : MonadLift IO Async where
monadLift io := Async.mk (io >>= (pure ∘ AsyncTask.pure))
/-- Joe is another client. -/
def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Unit := do
let client ← UDP.Socket.mk
@ -38,7 +18,7 @@ def runJoe (addr : UInt16 → SocketAddress) (first second : UInt16) : Async Uni
client.bind (addr second)
client.connect (addr first)
await (client.send (String.toUTF8 "hello robert!"))
await (client.send (String.toUTF8 "hello robert!"))
def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO Unit := do
@ -46,7 +26,7 @@ def acceptClose (addr : UInt16 → SocketAddress) (first second : UInt16) : IO U
let server ← UDP.Socket.mk
server.bind (addr first)
let res ← (runJoe addr first second).run
let res ← (runJoe addr first second).toIO
res.block
let res ← server.recv 1024