feat: implement a Selector for async TCP (#8078)

This PR is a follow up to #8055 and implements a `Selector` for async
TCP in order to allow IO multiplexing using TCP sockets.

As we must not commit to actually fetching data from the socket buffer
this cannot be implemented by just racing on `recv?`. Instead we perform
a call to `uv_read_start` and pass an `alloc_cb` that allocates no
memory at all. According to the docs of
[`uv_alloc_cb`](https://docs.libuv.org/en/v1.x/handle.html#c.uv_alloc_cb)
this is guaranteed to give us a `UV_ENOBUFS` in the relevant callback.
Thus we can first run this "zero read" and then go into one of three
cases:
1. We get cancelled before the zero read completes, in this case just
cancel the zero read and give up.
2. The zero read completes and we loose the race for completing the
`select`, in this case just don't do anything anymore
3. The zero read completes and we win the race for completing the
`select`, in this case we perform the actual read on the socket. As we
know that data is available already (since the read callback of the zero
read is only triggered if data actually is available) we know that the
subsequent actual read should complete right away.

In this way we avoid any data loss if we loose the race.
This commit is contained in:
Henrik Böving 2025-04-24 18:05:35 +02:00 committed by GitHub
parent bc032eec8d
commit 406bda8807
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 206 additions and 3 deletions

View file

@ -5,8 +5,8 @@ Authors: Sofia Rodrigues
-/
prelude
import Std.Time
import Std.Internal.UV
import Std.Internal.Async.Basic
import Std.Internal.UV.TCP
import Std.Internal.Async.Select
import Std.Net.Addr
namespace Std
@ -125,11 +125,46 @@ def send (s : Client) (data : ByteArray) : IO (AsyncTask Unit) :=
Receives data from the client socket. If data is received, its wrapped in .some. If EOF is reached,
the result is .none, indicating no more data is available. Receiving data in parallel on the same
socket is not supported. Instead, we recommend binding multiple sockets to the same address.
Furthermore calling this function in parallel with `recvSelector` is not supported.
-/
@[inline]
def recv? (s : Client) (size : UInt64) : IO (AsyncTask (Option ByteArray)) :=
AsyncTask.ofPromise <$> s.native.recv? size
/--
Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes,
and provides that data. Calling this function starts the data wait, so it must not be called
in parallel with `recv?`.
-/
def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option ByteArray)) := do
let readableWaiter ← s.native.waitReadable
return {
tryFn := do
if ← readableWaiter.isResolved then
-- We know that this read should not block
let res ← (← s.recv? size).block
return some res
else
return none
registerFn waiter := do
-- If we get cancelled the promise will be dropped so prepare for that
discard <| IO.mapTask (t := readableWaiter.result?) fun res => do
match res with
| none => return ()
| some res =>
let lose := return ()
let win promise := do
try
discard <| IO.ofExcept res
-- We know that this read should not block
let res ← (← s.recv? size).block
promise.resolve (.ok res)
catch e =>
promise.resolve (.error e)
waiter.race lose win
unregisterFn := s.native.cancelRecv
}
/--
Shuts down the write side of the client socket.
-/

View file

@ -50,10 +50,30 @@ Receives data from a TCP socket with a maximum size of size bytes. The promise r
available or an error occurs. If data is received, its wrapped in .some. If EOF is reached, the
result is .none, indicating no more data is available. Receiving data in parallel on the same
socket is not supported. Instead, we recommend binding multiple sockets to the same address.
Furthermore calling this function in parallel with `waitReadable` is not supported.
-/
@[extern "lean_uv_tcp_recv"]
opaque recv? (socket : @& Socket) (size : UInt64) : IO (IO.Promise (Except IO.Error (Option ByteArray)))
/--
Returns an `IO.Promise` that resolves to `true` once `socket` has data available for reading,
or to `false` if `socket` is closed before that. Calling this function twice on the same `Socket`
or in parallel with `recv?` is not supported.
-/
@[extern "lean_uv_tcp_wait_readable"]
opaque waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool))
/--
Cancels a receive operation in the form of `recv?` or `waitReadable` if there is currently one
pending. This resolves their returned `IO.Promise` to `none`. This function is considered dangerous,
as improper use can cause data loss, and is therefore not exposed to the top-level API.
Note that this function is idempotent and as such can be called multiple times on the same socket
without causing errors, in particular also without a receive running in the first place.
-/
@[extern "lean_uv_tcp_cancel_recv"]
opaque cancelRecv (socket : @& Socket) : IO Unit
/--
Binds a TCP socket to a specific address.
-/

View file

@ -227,7 +227,7 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t
// Locking early prevents potential parallelism issues setting the byte_array.
event_loop_lock(&global_ev);
if (tcp_socket->m_byte_array != nullptr) {
if (tcp_socket->m_promise_read != nullptr) {
event_loop_unlock(&global_ev);
return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr));
}
@ -295,6 +295,102 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t
return lean_io_result_mk_ok(promise);
}
/* Std.Internal.UV.TCP.Socket.waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool)) */
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */) {
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);
event_loop_lock(&global_ev);
if (tcp_socket->m_promise_read != nullptr) {
event_loop_unlock(&global_ev);
return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr));
}
lean_object* promise = lean_promise_new();
mark_mt(promise);
tcp_socket->m_promise_read = promise;
// The event loop owns the socket.
lean_inc(socket);
lean_inc(promise);
int result = uv_read_start((uv_stream_t*)tcp_socket->m_uv_tcp, [](uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) {
// According to libuv documentation if we do this we do not lose data and a UV_ENOBUFS will
// be triggered in the read cb.
buf->base = NULL;
buf->len = 0;
}, [](uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) {
uv_read_stop(stream);
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket((lean_object*)stream->data);
lean_object* promise = tcp_socket->m_promise_read;
tcp_socket->m_promise_read = nullptr;
if (nread == UV_ENOBUFS) {
lean_promise_resolve(mk_except_ok(lean_box(1)), promise);
} else if (nread == UV_EOF) {
lean_promise_resolve(mk_except_ok(lean_box(0)), promise);
} else if (nread < 0) {
lean_promise_resolve(mk_except_err(lean_decode_uv_error(nread, nullptr)), promise);
} else {
// This branch should be dead, we cannot receive a value >= 0 according to docs.
lean_always_assert(false);
}
lean_dec(promise);
// The event loop does not own the object anymore.
lean_dec((lean_object*)stream->data);
});
if (result < 0) {
tcp_socket->m_promise_read = nullptr;
event_loop_unlock(&global_ev);
lean_dec(promise); // The structure does not own it.
lean_dec(promise); // We are not going to return it.
lean_dec(socket);
return lean_io_result_mk_error(lean_decode_uv_error(result, nullptr));
}
event_loop_unlock(&global_ev);
return lean_io_result_mk_ok(promise);
}
/* Std.Internal.UV.TCP.Socket.cancelRecv (socket : @& Socket) : IO Unit */
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */) {
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);
event_loop_lock(&global_ev);
if (tcp_socket->m_promise_read == nullptr) {
event_loop_unlock(&global_ev);
return lean_io_result_mk_ok(lean_box(0));
}
uv_read_stop((uv_stream_t*)tcp_socket->m_uv_tcp);
lean_object* promise = tcp_socket->m_promise_read;
lean_dec(promise);
tcp_socket->m_promise_read = nullptr;
lean_object* byte_array = tcp_socket->m_byte_array;
if (byte_array != nullptr) {
lean_dec(byte_array);
tcp_socket->m_byte_array = nullptr;
}
lean_dec((lean_object*)tcp_socket);
event_loop_unlock(&global_ev);
return lean_io_result_mk_ok(lean_box(0));
}
/* Std.Internal.UV.TCP.Socket.bind (socket : @& Socket) (addr : @& SocketAddress) : IO Unit */
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */) {
lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket);

View file

@ -42,6 +42,8 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_new(obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_connect(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_send(b_obj_arg socket, obj_arg data, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t buffer_size, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_listen(b_obj_arg socket, int32_t backlog, obj_arg /* w */);
extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_accept(b_obj_arg socket, obj_arg /* w */);

View file

@ -0,0 +1,50 @@
import Std.Internal.Async.Timer
import Std.Internal.Async.TCP
open Std Internal IO Async
def testClient (addr : Net.SocketAddress) : IO (AsyncTask String) := do
let client ← TCP.Socket.Client.mk
(← client.connect addr).bindIO fun _ => do
Selectable.one #[
.case (← Selector.sleep 1000) fun _ => return AsyncTask.pure "Timeout",
.case (← client.recvSelector 4096) fun data? => do
if let some data := data? then
return AsyncTask.pure <| String.fromUTF8! data
else
return AsyncTask.pure "Closed"
]
def test (serverFn : TCP.Socket.Server → IO (AsyncTask Unit)) (addr : Net.SocketAddress) :
IO Unit := do
let server ← TCP.Socket.Server.mk
server.bind addr
server.listen 1
let serverTask ← serverFn server
let clientTask ← testClient addr
serverTask.block
IO.println (← clientTask.block)
def testServerSend (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
(← server.accept).bindIO fun client => do
client.send (String.toUTF8 "Success")
def testServerTimeout (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
(← server.accept).bindIO fun client => do
(← Async.sleep 1500).bindIO fun _ => do
client.shutdown
def testServerClose (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do
(← server.accept).bindIO fun client => client.shutdown
/-- info: Success -/
#guard_msgs in
#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070)
/-- info: Closed -/
#guard_msgs in
#eval test testServerClose (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7071)
/-- info: Timeout -/
#guard_msgs in
#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072)