From 406bda8807e6015bfa48806bc5fae3e6ff6136a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Thu, 24 Apr 2025 18:05:35 +0200 Subject: [PATCH] feat: implement a Selector for async TCP (#8078) This PR is a follow up to #8055 and implements a `Selector` for async TCP in order to allow IO multiplexing using TCP sockets. As we must not commit to actually fetching data from the socket buffer this cannot be implemented by just racing on `recv?`. Instead we perform a call to `uv_read_start` and pass an `alloc_cb` that allocates no memory at all. According to the docs of [`uv_alloc_cb`](https://docs.libuv.org/en/v1.x/handle.html#c.uv_alloc_cb) this is guaranteed to give us a `UV_ENOBUFS` in the relevant callback. Thus we can first run this "zero read" and then go into one of three cases: 1. We get cancelled before the zero read completes, in this case just cancel the zero read and give up. 2. The zero read completes and we loose the race for completing the `select`, in this case just don't do anything anymore 3. The zero read completes and we win the race for completing the `select`, in this case we perform the actual read on the socket. As we know that data is available already (since the read callback of the zero read is only triggered if data actually is available) we know that the subsequent actual read should complete right away. In this way we avoid any data loss if we loose the race. --- src/Std/Internal/Async/TCP.lean | 39 +++++++++- src/Std/Internal/UV/TCP.lean | 20 +++++ src/runtime/uv/tcp.cpp | 98 ++++++++++++++++++++++++- src/runtime/uv/tcp.h | 2 + tests/lean/run/async_select_socket.lean | 50 +++++++++++++ 5 files changed, 206 insertions(+), 3 deletions(-) create mode 100644 tests/lean/run/async_select_socket.lean diff --git a/src/Std/Internal/Async/TCP.lean b/src/Std/Internal/Async/TCP.lean index bcf75bbf55..56116079d5 100644 --- a/src/Std/Internal/Async/TCP.lean +++ b/src/Std/Internal/Async/TCP.lean @@ -5,8 +5,8 @@ Authors: Sofia Rodrigues -/ prelude import Std.Time -import Std.Internal.UV -import Std.Internal.Async.Basic +import Std.Internal.UV.TCP +import Std.Internal.Async.Select import Std.Net.Addr namespace Std @@ -125,11 +125,46 @@ def send (s : Client) (data : ByteArray) : IO (AsyncTask Unit) := Receives data from the client socket. If data is received, it’s wrapped in .some. If EOF is reached, the result is .none, indicating no more data is available. Receiving data in parallel on the same socket is not supported. Instead, we recommend binding multiple sockets to the same address. +Furthermore calling this function in parallel with `recvSelector` is not supported. -/ @[inline] def recv? (s : Client) (size : UInt64) : IO (AsyncTask (Option ByteArray)) := AsyncTask.ofPromise <$> s.native.recv? size +/-- +Creates a `Selector` that resolves once `s` has data available, up to at most `size` bytes, +and provides that data. Calling this function starts the data wait, so it must not be called +in parallel with `recv?`. +-/ +def recvSelector (s : TCP.Socket.Client) (size : UInt64) : IO (Selector (Option ByteArray)) := do + let readableWaiter ← s.native.waitReadable + return { + tryFn := do + if ← readableWaiter.isResolved then + -- We know that this read should not block + let res ← (← s.recv? size).block + return some res + else + return none + registerFn waiter := do + -- If we get cancelled the promise will be dropped so prepare for that + discard <| IO.mapTask (t := readableWaiter.result?) fun res => do + match res with + | none => return () + | some res => + let lose := return () + let win promise := do + try + discard <| IO.ofExcept res + -- We know that this read should not block + let res ← (← s.recv? size).block + promise.resolve (.ok res) + catch e => + promise.resolve (.error e) + waiter.race lose win + unregisterFn := s.native.cancelRecv + } + /-- Shuts down the write side of the client socket. -/ diff --git a/src/Std/Internal/UV/TCP.lean b/src/Std/Internal/UV/TCP.lean index caa8dde8e1..9477ba3785 100644 --- a/src/Std/Internal/UV/TCP.lean +++ b/src/Std/Internal/UV/TCP.lean @@ -50,10 +50,30 @@ Receives data from a TCP socket with a maximum size of size bytes. The promise r available or an error occurs. If data is received, it’s wrapped in .some. If EOF is reached, the result is .none, indicating no more data is available. Receiving data in parallel on the same socket is not supported. Instead, we recommend binding multiple sockets to the same address. +Furthermore calling this function in parallel with `waitReadable` is not supported. -/ @[extern "lean_uv_tcp_recv"] opaque recv? (socket : @& Socket) (size : UInt64) : IO (IO.Promise (Except IO.Error (Option ByteArray))) +/-- +Returns an `IO.Promise` that resolves to `true` once `socket` has data available for reading, +or to `false` if `socket` is closed before that. Calling this function twice on the same `Socket` +or in parallel with `recv?` is not supported. +-/ +@[extern "lean_uv_tcp_wait_readable"] +opaque waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool)) + +/-- +Cancels a receive operation in the form of `recv?` or `waitReadable` if there is currently one +pending. This resolves their returned `IO.Promise` to `none`. This function is considered dangerous, +as improper use can cause data loss, and is therefore not exposed to the top-level API. + +Note that this function is idempotent and as such can be called multiple times on the same socket +without causing errors, in particular also without a receive running in the first place. +-/ +@[extern "lean_uv_tcp_cancel_recv"] +opaque cancelRecv (socket : @& Socket) : IO Unit + /-- Binds a TCP socket to a specific address. -/ diff --git a/src/runtime/uv/tcp.cpp b/src/runtime/uv/tcp.cpp index 459375da45..30f76609cb 100644 --- a/src/runtime/uv/tcp.cpp +++ b/src/runtime/uv/tcp.cpp @@ -227,7 +227,7 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t // Locking early prevents potential parallelism issues setting the byte_array. event_loop_lock(&global_ev); - if (tcp_socket->m_byte_array != nullptr) { + if (tcp_socket->m_promise_read != nullptr) { event_loop_unlock(&global_ev); return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr)); } @@ -295,6 +295,102 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t return lean_io_result_mk_ok(promise); } +/* Std.Internal.UV.TCP.Socket.waitReadable (socket : @& Socket) : IO (IO.Promise (Except IO.Error Bool)) */ +extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */) { + lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket); + + event_loop_lock(&global_ev); + + if (tcp_socket->m_promise_read != nullptr) { + event_loop_unlock(&global_ev); + return lean_io_result_mk_error(lean_decode_uv_error(UV_EALREADY, nullptr)); + } + + lean_object* promise = lean_promise_new(); + mark_mt(promise); + + tcp_socket->m_promise_read = promise; + + // The event loop owns the socket. + lean_inc(socket); + lean_inc(promise); + + int result = uv_read_start((uv_stream_t*)tcp_socket->m_uv_tcp, [](uv_handle_t* handle, size_t suggested_size, uv_buf_t* buf) { + // According to libuv documentation if we do this we do not lose data and a UV_ENOBUFS will + // be triggered in the read cb. + buf->base = NULL; + buf->len = 0; + }, [](uv_stream_t* stream, ssize_t nread, const uv_buf_t* buf) { + uv_read_stop(stream); + + lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket((lean_object*)stream->data); + lean_object* promise = tcp_socket->m_promise_read; + + tcp_socket->m_promise_read = nullptr; + + if (nread == UV_ENOBUFS) { + lean_promise_resolve(mk_except_ok(lean_box(1)), promise); + } else if (nread == UV_EOF) { + lean_promise_resolve(mk_except_ok(lean_box(0)), promise); + } else if (nread < 0) { + lean_promise_resolve(mk_except_err(lean_decode_uv_error(nread, nullptr)), promise); + } else { + // This branch should be dead, we cannot receive a value >= 0 according to docs. + lean_always_assert(false); + } + + lean_dec(promise); + + // The event loop does not own the object anymore. + lean_dec((lean_object*)stream->data); + }); + + if (result < 0) { + tcp_socket->m_promise_read = nullptr; + + event_loop_unlock(&global_ev); + + lean_dec(promise); // The structure does not own it. + lean_dec(promise); // We are not going to return it. + lean_dec(socket); + + return lean_io_result_mk_error(lean_decode_uv_error(result, nullptr)); + } + + event_loop_unlock(&global_ev); + + return lean_io_result_mk_ok(promise); +} + +/* Std.Internal.UV.TCP.Socket.cancelRecv (socket : @& Socket) : IO Unit */ +extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */) { + lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket); + + event_loop_lock(&global_ev); + + if (tcp_socket->m_promise_read == nullptr) { + event_loop_unlock(&global_ev); + return lean_io_result_mk_ok(lean_box(0)); + } + + uv_read_stop((uv_stream_t*)tcp_socket->m_uv_tcp); + + lean_object* promise = tcp_socket->m_promise_read; + lean_dec(promise); + tcp_socket->m_promise_read = nullptr; + + lean_object* byte_array = tcp_socket->m_byte_array; + if (byte_array != nullptr) { + lean_dec(byte_array); + tcp_socket->m_byte_array = nullptr; + } + + lean_dec((lean_object*)tcp_socket); + + event_loop_unlock(&global_ev); + return lean_io_result_mk_ok(lean_box(0)); +} + /* Std.Internal.UV.TCP.Socket.bind (socket : @& Socket) (addr : @& SocketAddress) : IO Unit */ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */) { lean_uv_tcp_socket_object* tcp_socket = lean_to_uv_tcp_socket(socket); diff --git a/src/runtime/uv/tcp.h b/src/runtime/uv/tcp.h index 1fe176d62d..cc2bcee2df 100644 --- a/src/runtime/uv/tcp.h +++ b/src/runtime/uv/tcp.h @@ -42,6 +42,8 @@ extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_new(obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_connect(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_send(b_obj_arg socket, obj_arg data, obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_recv(b_obj_arg socket, uint64_t buffer_size, obj_arg /* w */); +extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_wait_readable(b_obj_arg socket, obj_arg /* w */); +extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_cancel_recv(b_obj_arg socket, obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_bind(b_obj_arg socket, b_obj_arg addr, obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_listen(b_obj_arg socket, int32_t backlog, obj_arg /* w */); extern "C" LEAN_EXPORT lean_obj_res lean_uv_tcp_accept(b_obj_arg socket, obj_arg /* w */); diff --git a/tests/lean/run/async_select_socket.lean b/tests/lean/run/async_select_socket.lean new file mode 100644 index 0000000000..1695ff11d2 --- /dev/null +++ b/tests/lean/run/async_select_socket.lean @@ -0,0 +1,50 @@ +import Std.Internal.Async.Timer +import Std.Internal.Async.TCP + +open Std Internal IO Async + +def testClient (addr : Net.SocketAddress) : IO (AsyncTask String) := do + let client ← TCP.Socket.Client.mk + (← client.connect addr).bindIO fun _ => do + Selectable.one #[ + .case (← Selector.sleep 1000) fun _ => return AsyncTask.pure "Timeout", + .case (← client.recvSelector 4096) fun data? => do + if let some data := data? then + return AsyncTask.pure <| String.fromUTF8! data + else + return AsyncTask.pure "Closed" + ] + +def test (serverFn : TCP.Socket.Server → IO (AsyncTask Unit)) (addr : Net.SocketAddress) : + IO Unit := do + let server ← TCP.Socket.Server.mk + server.bind addr + server.listen 1 + let serverTask ← serverFn server + let clientTask ← testClient addr + serverTask.block + IO.println (← clientTask.block) + +def testServerSend (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do + (← server.accept).bindIO fun client => do + client.send (String.toUTF8 "Success") + +def testServerTimeout (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do + (← server.accept).bindIO fun client => do + (← Async.sleep 1500).bindIO fun _ => do + client.shutdown + +def testServerClose (server : TCP.Socket.Server) : IO (AsyncTask Unit) := do + (← server.accept).bindIO fun client => client.shutdown + +/-- info: Success -/ +#guard_msgs in +#eval test testServerSend (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7070) + +/-- info: Closed -/ +#guard_msgs in +#eval test testServerClose (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7071) + +/-- info: Timeout -/ +#guard_msgs in +#eval test testServerTimeout (Net.SocketAddressV4.mk (.ofParts 127 0 0 1) 7072)