From eaa5d3498ca50c9ef5c2db3e74dd9ef73cdcbe67 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Tue, 29 Apr 2025 17:15:38 +0200 Subject: [PATCH] feat: implement a Selector for channels (#8150) This PR is a follow up to #8055 and implements a Selector for `Std.Channel` in order to allow multiplexing using channels. There is one subtlety to the implementation: Suppose we are in a situation where we run `select` in a loop on two channels. One of the channels is always quiet while the other has data available occasionally (however not always as this would trigger the `tryFn` fast path and hide the issue). In this situation the select receivers that are enqueued on the silent channel would usually just remain there indefinitely as nothing ever happens, causing a memleak. To avoid this we want to make a channel select clean up after itself, even if it fails. In an imperative programming language we could implement the receive queue as a doubly linked list and simply make each receive select maintain a pointer to its element in the queue and then remove itself in `O(1)` upon failure. As that is not possible in Lean trivially we decided to go for another approach for now: simply filter the queue for selects that have failed in `unregisterFn`. While this approach is `O(n)` we expect the amount of receivers enqueued on a channel to not be terribly large and thus this to be a reasonably fast operation compared to the remaining overhead. If it ever ends up becoming an issue, we could switch to an approach that uses a `TreeMap` with numbered receivers instead at a certain wait queue size and go to `O(log(n))`. --- src/Init/Data/Queue.lean | 17 +- src/Std/Internal/Async/Select.lean | 9 + src/Std/Sync/Channel.lean | 267 ++++++++++++++++++++--- tests/lean/run/async_select_channel.lean | 99 +++++++++ 4 files changed, 362 insertions(+), 30 deletions(-) create mode 100644 tests/lean/run/async_select_channel.lean diff --git a/src/Init/Data/Queue.lean b/src/Init/Data/Queue.lean index 70d752e9f9..bd209c6ccc 100644 --- a/src/Init/Data/Queue.lean +++ b/src/Init/Data/Queue.lean @@ -9,7 +9,7 @@ Note: this is only a temporary placeholder. module prelude -import Init.Data.Array.Basic +import Init.Data.List.Control namespace Std @@ -62,3 +62,18 @@ def dequeue? (q : Queue α) : Option (α × Queue α) := def toArray (q : Queue α) : Array α := q.dList.toArray ++ q.eList.toArray.reverse + +/-- +`O(n)`. Applies the monadic predicate `p` to every element in the queue, and returns the queue +of elements for which `p` returns `true`. Note that there are currently no guarantees for the order +that `p` is applied in. +-/ +@[specialize] +def filterM {m : Type → Type v} [Monad m] {α : Type} (p : α → m Bool) (q : Queue α) : + m (Queue α) := do + let dList ← q.dList.filterM p + let eList ← q.eList.filterM p + if dList.isEmpty then + return { dList := eList.reverse, eList := [] } + else + return { dList, eList } diff --git a/src/Std/Internal/Async/Select.lean b/src/Std/Internal/Async/Select.lean index dcfaad7188..5fc5484441 100644 --- a/src/Std/Internal/Async/Select.lean +++ b/src/Std/Internal/Async/Select.lean @@ -32,6 +32,7 @@ structure Waiter (α : Type) where Swap out the `IO.Promise` within the `Waiter`. Note that the part which determines whether the `Waiter` is finished is not swapped out. -/ +@[inline] def Waiter.withPromise (w : Waiter α) (p : IO.Promise (Except IO.Error β)) : Waiter β := Waiter.mk w.finished p @@ -49,6 +50,14 @@ def Waiter.race [Monad m] [MonadLiftT (ST IO.RealWorld) m] (w : Waiter α) else lose +/-- +Atomically checks whether the `Waiter` has already finished. Note that right after this function +call ends this might have already changed. +-/ +@[inline] +def Waiter.checkFinished [Monad m] [MonadLiftT (ST IO.RealWorld) m] (w : Waiter α) : m Bool := do + w.finished.get + /-- An event source that can be multiplexed using `Selectable.one`, see the documentation of `Selectable.one` for how the protocol of communicating with a `Selector` works. diff --git a/src/Std/Sync/Channel.lean b/src/Std/Sync/Channel.lean index f10a433a40..b34c8ef60d 100644 --- a/src/Std/Sync/Channel.lean +++ b/src/Std/Sync/Channel.lean @@ -7,6 +7,7 @@ prelude import Init.System.Promise import Init.Data.Queue import Std.Sync.Mutex +import Std.Internal.Async.Select /-! This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer @@ -44,6 +45,23 @@ instance : ToString Error where instance : MonadLift (EIO Error) IO where monadLift x := EIO.toIO (.userError <| toString ·) x +open Internal.IO.Async in +private inductive Consumer (α : Type) where + | normal (promise : IO.Promise (Option α)) + | select (finished : Waiter (Option α)) + +private def Consumer.resolve (c : Consumer α) (x : Option α) : BaseIO Bool := do + match c with + | .normal promise => + promise.resolve x + return true + | .select waiter => + let lose := return false + let win promise := do + promise.resolve (.ok x) + return true + waiter.race lose win + /-- The central state structure for an unbounded channel, maintains the following invariants: 1. `values = ∅ ∨ consumers = ∅` @@ -55,10 +73,10 @@ private structure Unbounded.State (α : Type) where -/ values : Std.Queue α /-- - Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be - resolved to `none` if the channel closes. + Consumers that are blocked on a producer providing them a value. They will be resolved to `none` + if the channel closes. -/ - consumers : Std.Queue (IO.Promise (Option α)) + consumers : Std.Queue (Consumer α) /-- Whether the channel is closed already. -/ @@ -85,12 +103,18 @@ private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do let st ← get if st.closed then return false - else if let some (consumer, consumers) := st.consumers.dequeue? then - consumer.resolve (some v) - set { st with consumers } - return true else - set { st with values := st.values.enqueue v } + while true do + let st ← get + if let some (consumer, consumers) := st.consumers.dequeue? then + let success ← consumer.resolve (some v) + set { st with consumers } + if success then + break + else + set { st with values := st.values.enqueue v } + break + return true private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do @@ -103,7 +127,8 @@ private def close (ch : Unbounded α) : EIO Error Unit := do ch.state.atomically do let st ← get if st.closed then throw .alreadyClosed - for consumer in st.consumers.toArray do consumer.resolve none + for consumer in st.consumers.toArray do + discard <| consumer.resolve none set { st with consumers := ∅, closed := true } return () @@ -111,7 +136,8 @@ private def isClosed (ch : Unbounded α) : BaseIO Bool := ch.state.atomically do return (← get).closed -private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do +private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] : + AtomicT (Unbounded.State α) m (Option α) := do let st ← get if let some (a, values) := st.values.dequeue? then set { st with values } @@ -131,9 +157,47 @@ private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do return .pure none else let promise ← IO.Promise.new - modify fun st => { st with consumers := st.consumers.enqueue promise } + modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) } return promise.result?.map (sync := true) (·.bind id) +@[inline] +private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] : + AtomicT (Unbounded.State α) m Bool := do + let st ← get + return !st.values.isEmpty || st.closed + +open Internal.IO.Async in +private def recvSelector (ch : Unbounded α) : Selector (Option α) where + tryFn := do + ch.state.atomically do + if ← recvReady' then + let val ← tryRecv' + return some val + else + return none + + registerFn waiter := do + ch.state.atomically do + -- We did drop the lock between `tryFn` and now so maybe ready? + if ← recvReady' then + let lose := return () + let win promise := do + -- We know we are ready so the value by this is fine + promise.resolve (.ok (← tryRecv')) + + waiter.race lose win + else + modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) } + + unregisterFn := do + ch.state.atomically do + let st ← get + let consumers ← st.consumers.filterM + fun + | .normal .. => return true + | .select waiter => return !(← waiter.checkFinished) + set { st with consumers } + end Unbounded /-- @@ -147,10 +211,10 @@ private structure Zero.State (α : Type) where -/ producers : Std.Queue (α × IO.Promise Bool) /-- - Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved - to `none` if the channel closes. + Consumers that are blocked on a producer providing them a value. They will be resolved to `none` + if the channel closes. -/ - consumers : Std.Queue (IO.Promise (Option α)) + consumers : Std.Queue (Consumer α) /-- Whether the channel is closed already. -/ @@ -174,13 +238,17 @@ private def new : BaseIO (Zero α) := do Precondition: The channel must not be closed. -/ private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do - let st ← get - if let some (consumer, consumers) := st.consumers.dequeue? then - consumer.resolve (some v) - set { st with consumers } - return true - else - return false + while true do + let st ← get + if let some (consumer, consumers) := st.consumers.dequeue? then + let success ← consumer.resolve (some v) + set { st with consumers } + if success then + break + else + return false + + return true private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do ch.state.atomically do @@ -207,7 +275,8 @@ private def close (ch : Zero α) : EIO Error Unit := do ch.state.atomically do let st ← get if st.closed then throw .alreadyClosed - for consumer in st.consumers.toArray do consumer.resolve none + for consumer in st.consumers.toArray do + discard <| consumer.resolve none set { st with consumers := ∅, closed := true } return () @@ -215,7 +284,8 @@ private def isClosed (ch : Zero α) : BaseIO Bool := ch.state.atomically do return (← get).closed -private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do +private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] : + AtomicT (Zero.State α) m (Option α) := do let st ← get if let some ((val, promise), producers) := st.producers.dequeue? then set { st with producers } @@ -235,13 +305,59 @@ private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do return .pure <| some val else if !st.closed then let promise ← IO.Promise.new - set { st with consumers := st.consumers.enqueue promise } + set { st with consumers := st.consumers.enqueue (.normal promise) } return promise.result?.map (sync := true) (·.bind id) else return .pure <| none +@[inline] +private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] : + AtomicT (Zero.State α) m Bool := do + let st ← get + return !st.producers.isEmpty || st.closed + +open Internal.IO.Async in +private def recvSelector (ch : Zero α) : Selector (Option α) where + tryFn := do + ch.state.atomically do + if ← recvReady' then + let val ← tryRecv' + return some val + else + return none + + registerFn waiter := do + ch.state.atomically do + -- We did drop the lock between `tryFn` and now so maybe ready? + if ← recvReady' then + let lose := return () + let win promise := do + -- We know we are ready so the value by this is fine + promise.resolve (.ok (← tryRecv')) + + waiter.race lose win + else + modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) } + + unregisterFn := do + ch.state.atomically do + let st ← get + let consumers ← st.consumers.filterM + fun + | .normal .. => pure true + | .select waiter => return !(← waiter.checkFinished) + set { st with consumers } + end Zero +open Internal.IO.Async in +private structure Bounded.Consumer (α : Type) where + promise : IO.Promise Bool + waiter : Option (Waiter (Option α)) + +private def Bounded.Consumer.resolve (c : Bounded.Consumer α) (b : Bool) : BaseIO Unit := + c.promise.resolve b + /-- The central state structure for a bounded channel, maintains the following invariants: 1. `0 < capacity` @@ -265,10 +381,9 @@ private structure Bounded.State (α : Type) where producers : Std.Queue (IO.Promise Bool) /-- Consumers that are blocked on a producer providing them a value, as there was no value - enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel - closes. + enqueued when they tried to dequeue. They will be resolved to `false` if the channel closes. -/ - consumers : Std.Queue (IO.Promise Bool) + consumers : Std.Queue (Bounded.Consumer α) /-- The capacity of the buffer space. -/ @@ -390,7 +505,8 @@ private def isClosed (ch : Bounded α) : BaseIO Bool := ch.state.atomically do return (← get).closed -private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do +private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] : + AtomicT (Bounded.State α) m (Option α) := do let mut st ← get if st.bufCount == 0 then return none @@ -423,13 +539,71 @@ private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do return .pure none else let promise ← IO.Promise.new - modify fun st => { st with consumers := st.consumers.enqueue promise } + modify fun st => { st with consumers := st.consumers.enqueue ⟨promise, none⟩ } BaseIO.bindTask promise.result? fun res => do if res.getD false then Bounded.recv ch else return .pure none +@[inline] +private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] : + AtomicT (Bounded.State α) m Bool := do + let st ← get + return st.bufCount != 0 || st.closed + +open Internal.IO.Async in +private partial def recvSelector (ch : Bounded α) : Selector (Option α) where + tryFn := do + ch.state.atomically do + if ← recvReady' then + let val ← tryRecv' + return some val + else + return none + + registerFn := registerAux ch + + unregisterFn := do + ch.state.atomically do + let st ← get + let consumers ← st.consumers.filterM fun c => do + match c.waiter with + | some waiter => return !(← waiter.checkFinished) + | none => return true + + set { st with consumers } +where + registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : IO Unit := do + ch.state.atomically do + -- We did drop the lock between `tryFn` and now so maybe ready? + if ← recvReady' then + -- if we lose we must trigger the next promise (if available) to avoid deadlocking + let lose := do + let st ← get + if let some (consumer, consumers) := st.consumers.dequeue? then + consumer.resolve true + set { st with consumers } + let win promise := do + -- We know we are ready so the value by this is fine + promise.resolve (.ok (← tryRecv')) + + waiter.race lose win + else + let promise ← IO.Promise.new + modify fun st => { st with consumers := st.consumers.enqueue ⟨promise, some waiter⟩ } + + IO.chainTask promise.result? fun res? => do + match res? with + | none => return () + | some res => + if res then + registerAux ch waiter + else + let lose := return () + let win promise := promise.resolve (.ok none) + waiter.race lose win + end Bounded /-- @@ -551,6 +725,18 @@ def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) := | .zero ch => CloseableChannel.Zero.recv ch | .bounded ch => CloseableChannel.Bounded.recv ch +open Internal.IO.Async in +/-- +Creates a `Selector` that resolves once `ch` has data available and provides that that data. +In particular if `ch` is closed while waiting on this `Selector` and no data is available already +this will resolve to `none`. +-/ +def recvSelector (ch : CloseableChannel α) : Selector (Option α) := + match ch with + | .unbounded ch => CloseableChannel.Unbounded.recvSelector ch + | .zero ch => CloseableChannel.Zero.recvSelector ch + | .bounded ch => CloseableChannel.Bounded.recvSelector ch + /-- `ch.forAsync f` calls `f` for every message received on `ch`. @@ -674,6 +860,29 @@ def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do | some val => return .pure val | none => unreachable! +open Internal.IO.Async in +/-- +Creates a `Selector` that resolves once `ch` has data available and provides that that data. +-/ +def recvSelector [Inhabited α] (ch : Channel α) : Selector α := + let sel := CloseableChannel.recvSelector ch.inner + { + tryFn := ch.tryRecv + registerFn waiter := do + let original := waiter.promise + let intermediate ← IO.Promise.new + let waiter := waiter.withPromise intermediate + sel.registerFn waiter + IO.chainTask (sync := true) intermediate.result? + fun + | none => return () + | some res => + -- `res` can only be `.err` or `.ok some` as we are in a non closeable channel. + original.resolve (res.map Option.get!) + + unregisterFn := sel.unregisterFn + } + @[inherit_doc CloseableChannel.forAsync] partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α) (prio : Task.Priority := .default) : BaseIO (Task Unit) := do diff --git a/tests/lean/run/async_select_channel.lean b/tests/lean/run/async_select_channel.lean new file mode 100644 index 0000000000..2f6b641af9 --- /dev/null +++ b/tests/lean/run/async_select_channel.lean @@ -0,0 +1,99 @@ +import Std.Sync.Channel + +open Std Internal IO Async + +namespace A + +def testReceiver (ch1 ch2 : Std.Channel Nat) (count : Nat) : IO (AsyncTask Nat) := do + go ch1 ch2 count 0 +where + go (ch1 ch2 : Std.Channel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do + match count with + | 0 => return AsyncTask.pure acc + | count + 1 => + Selectable.one #[ + .case ch1.recvSelector fun data => go ch1 ch2 count (acc + data), + .case ch2.recvSelector fun data => go ch1 ch2 count (acc + data), + ] + +def testIt (capacity : Option Nat) : IO Bool := do + let amount := 1000 + let messages := Array.range amount + let ch1 ← Std.Channel.new capacity + let ch2 ← Std.Channel.new capacity + let recvTask ← testReceiver ch1 ch2 amount + + for msg in messages do + if (← IO.rand 0 1) = 0 then + ch1.sync.send msg + else + ch2.sync.send msg + + let acc ← recvTask.block + return acc == messages.sum + +/-- info: true -/ +#guard_msgs in +#eval testIt none + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 0) + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 1) + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 128) + +end A + +namespace B + +def testReceiver (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) : IO (AsyncTask Nat) := do + go ch1 ch2 count 0 +where + go (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do + match count with + | 0 => return AsyncTask.pure acc + | count + 1 => + Selectable.one #[ + .case ch1.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0), + .case ch2.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0), + ] + +def testIt (capacity : Option Nat) : IO Bool := do + let amount := 1000 + let messages := Array.range amount + let ch1 ← Std.CloseableChannel.new capacity + let ch2 ← Std.CloseableChannel.new capacity + let recvTask ← testReceiver ch1 ch2 amount + + for msg in messages do + if (← IO.rand 0 1) = 0 then + ch1.sync.send msg + else + ch2.sync.send msg + + let acc ← recvTask.block + return acc == messages.sum + +/-- info: true -/ +#guard_msgs in +#eval testIt none + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 0) + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 1) + +/-- info: true -/ +#guard_msgs in +#eval testIt (some 128) + +end B