lean4-htt/tests/lean/run/async_tcp_server_client.lean
Sofia Rodrigues b15cfadde8
feat: monadic interface for asynchronous operations in Std (#8003)
This PR adds a new monadic interface for `Async` operations.

This is the design for the `Async` monad that I liked the most. The idea
was refined with the help of @tydeu. Before that, I had some
prerequisites in mind:

1. Good performance
2. Explicit `yield` points, so we could avoid using `bindTask` for every
lifted IO operation
3. A way to avoid creating an infinite chain of `Task`s during recursion

The 2 and 3 points are not covered in this PR, I wish I had a good
solution but right now only a few sketches of this.

### Explicit `yield` points

I thought this would be easy at first, but it actually turned out kinda
tricky. I ended up creating the `suspend` syntax, which is just a small
modification of the lift method (`<- ...`) syntax. It desugars to
`Suspend.suspend task fun _ => ...`. So something like:

```lean
do
  IO.println "a"
  IO.println "b"
  let result := suspend (client.recv? 1024)
  IO.println "c"
  IO.println "d"
```

Would become:

```lean
Bind.bind (IO.println "a") fun _ =>
Bind.bind (IO.println "b") fun _ =>
Suspend.suspend (client.recv? 1024) fun message =>
  Bind.bind (IO.println "c") fun _ =>
  IO.println "d"
```

This makes things a bit more efficient. When using `bind`, we would try
to avoid creating a `Task` chain, and the `suspend` would be the only
place we use `Task.bind`. But there's a problem if we use `bind` with
something that needs `suspend`, it’ll block the whole task. Blocking is
the only way to prevent task accumulation when using plain `bind` inside
a structure like that:

```
inductive AsyncResult (ε σ α : Type u) where
    | ok    : α → σ → AsyncResult ε σ α
    | error : ε → σ → AsyncResult ε σ α
    | ofTask  : Task (EStateM.Result ε σ α) → σ →AsyncResult ε σ α
```

Because we simply need to remove the `ofTask` and transform it into an
`ok`.

### Infinite chain of Tasks

If you create an infinite recursive function using `Task` (which is
super common in servers like HTTP ones), it can lead to a lot of memory
usage. Because those tasks get chained forever and won't be freed until
the function returns.

To get around that, I used CPS and instead of just calling `Task.bind`,
I’d spawn a new task and return an "empty" one like:

```lean
fun k => Task.bind (...) fun value => do k value; pure emptyTask
```

This works great with a CPS-style monad, but it generates a huge IR by
itself.

Just doing CPS alone was too much, though, because every lifted
operation created a new continuation and a `Task.bind`. So, I used it
with `suspend` and got a better performance, but the usage is not good
with `suspend`.

### The current monad

Right now, the monad I’m using is super simple. It doesn't solve the
earlier problems, but the API is clean, and the generated IR is small
enough. An example of how we should use it is:

```lean
-- A loop that repeatedly sends a message and waits for a reply.
partial def writeLoop (client : Socket.Client) (message : String) : Async (AsyncTask Unit) := async do
  IO.println s!"sending: {message}"
  await (← client.send (String.toUTF8 message))

  if let some mes ← await (← client.recv? 1024) then
    IO.println s!"received: {String.fromUTF8! mes}"
    -- use parallel to avoid building up an infinite task chain
    parallel (writeLoop client message)
  else
    IO.println "client disconnected from receiving"

-- Server’s main accept loop, keeps accepting and echoing for new clients.
partial def acceptLoop (server : Socket.Server) (promise : IO.Promise Unit) : Async (AsyncTask Unit) := async do
  let client ← await (← server.accept)
  await (← client.send (String.toUTF8 "tutturu "))

  -- allow multiple clients to connect at the same time
  parallel (writeLoop client "hi!!")

  -- and keep accepting more clients, parallel again to avoid building up an infinite task chain
  parallel (acceptLoop server promise)

-- A simple client that connects and sends a message.
def echoClient (addr : SocketAddress) (message : String) : Async (AsyncTask Unit) := async do
  let socket ← Client.mk
  await (← socket.connect addr)
  parallel (writeLoop socket message)

-- TCP setup: bind, listen, serve, and run a sample client.
partial def mainTCP : Async Unit := do
  let addr := SocketAddressV4.mk (.ofParts 127 0 0 1) 8080

  let server ← Server.mk
  server.bind addr
  server.listen 128

  -- promise exists since the server is (probably) never going to stop
  let promise ← IO.Promise.new
  let acceptAction ← acceptLoop server promise

  await (← echoClient addr "hi!")
  await acceptAction
  await promise

-- Entry point
def main : IO Unit := mainTCP.wait
```

---------

Co-authored-by: Henrik Böving <hargonix@gmail.com>
Co-authored-by: Mac Malone <tydeu@hatpress.net>
2025-06-26 02:51:26 +00:00

83 lines
2.3 KiB
Text
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Std.Internal.Async
import Std.Internal.UV
import Std.Net.Addr
open Std.Internal.IO Async
open Std.Net
-- Using this function to create IO Error. For some reason the assert! is not pausing the execution.
def assertBEq [BEq α] [ToString α] (actual expected : α) : IO Unit := do
unless actual == expected do
throw <| IO.userError <|
s!"expected '{expected}', got '{actual}'"
--------------------------------------------------------------
/-- Mike is another client. -/
def runMike (client: TCP.Socket.Client) : Async Unit := do
let mes ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hi mike!! :)")
await (← client.send (String.toUTF8 "hello robert!!"))
await (← client.shutdown)
/-- Joe is another client. -/
def runJoe (client: TCP.Socket.Client) : Async Unit := do
let mes ← await (← client.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hi joe! :)")
await (← client.send (String.toUTF8 "hello robert!"))
await (← client.shutdown)
/-- Robert is the server. -/
def runRobert (server: TCP.Socket.Server) : Async Unit := do
let joe ← await (← server.accept)
let mike ← await (← server.accept)
await (← joe.send (String.toUTF8 "hi joe! :)"))
let mes ← await (← joe.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hello robert!")
await (← mike.send (String.toUTF8 "hi mike!! :)"))
let mes ← await (← mike.recv? 1024)
assertBEq (String.fromUTF8? =<< mes) (some "hello robert!!")
pure ()
def clientServer (addr : SocketAddress) : IO Unit := do
let server ← TCP.Socket.Server.mk
server.bind addr
server.listen 128
let serverTask := runRobert server
let serverTask ← serverTask.toIO
assertBEq (← server.getSockName).port addr.port
let joe ← TCP.Socket.Client.mk
let task ← joe.connect addr
task.block
assertBEq (← joe.getPeerName).port addr.port
joe.noDelay
let mike ← TCP.Socket.Client.mk
let task ← mike.connect addr
task.block
assertBEq (← mike.getPeerName).port addr.port
mike.noDelay
let joeTask := runJoe joe
let mikeTask := runMike mike
let joeTask ← joeTask.toIO
let mikeTask ← mikeTask.toIO
serverTask.block
joeTask.block
mikeTask.block
#eval clientServer (SocketAddressV4.mk (.ofParts 127 0 0 1) 8084)
#eval clientServer (SocketAddressV6.mk (.ofParts 0 0 0 0 0 0 0 1) 9000)