feat: implement a Selector for channels (#8150)

This PR is a follow up to #8055 and implements a Selector for
`Std.Channel` in order to allow
 multiplexing using channels.

There is one subtlety to the implementation: Suppose we are in a
situation where we run `select` in a loop on two channels. One of the
channels is always quiet while the other has data available occasionally
(however not always as this would trigger the `tryFn` fast path and hide
the issue). In this situation the select receivers that are enqueued on
the silent channel would usually just remain there indefinitely as
nothing ever happens, causing a memleak. To avoid this we want to make a
channel select clean up after itself, even if it fails.

In an imperative programming language we could implement the receive
queue as a doubly linked list and simply make each receive select
maintain a pointer to its element in the queue and then remove itself in
`O(1)` upon failure. As that is not possible in Lean trivially we
decided to go for another approach for now: simply filter the queue for
selects that have failed in `unregisterFn`. While this approach is
`O(n)` we expect the amount of receivers enqueued on a channel to not be
terribly large and thus this to be a reasonably fast operation compared
to the remaining overhead. If it ever ends up becoming an issue, we
could switch to an approach that uses a `TreeMap` with numbered
receivers instead at a certain wait queue size and go to `O(log(n))`.
This commit is contained in:
Henrik Böving 2025-04-29 17:15:38 +02:00 committed by GitHub
parent db35bbb1a0
commit eaa5d3498c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 362 additions and 30 deletions

View file

@ -9,7 +9,7 @@ Note: this is only a temporary placeholder.
module
prelude
import Init.Data.Array.Basic
import Init.Data.List.Control
namespace Std
@ -62,3 +62,18 @@ def dequeue? (q : Queue α) : Option (α × Queue α) :=
def toArray (q : Queue α) : Array α :=
q.dList.toArray ++ q.eList.toArray.reverse
/--
`O(n)`. Applies the monadic predicate `p` to every element in the queue, and returns the queue
of elements for which `p` returns `true`. Note that there are currently no guarantees for the order
that `p` is applied in.
-/
@[specialize]
def filterM {m : Type → Type v} [Monad m] {α : Type} (p : α → m Bool) (q : Queue α) :
m (Queue α) := do
let dList ← q.dList.filterM p
let eList ← q.eList.filterM p
if dList.isEmpty then
return { dList := eList.reverse, eList := [] }
else
return { dList, eList }

View file

@ -32,6 +32,7 @@ structure Waiter (α : Type) where
Swap out the `IO.Promise` within the `Waiter`. Note that the part which determines whether the
`Waiter` is finished is not swapped out.
-/
@[inline]
def Waiter.withPromise (w : Waiter α) (p : IO.Promise (Except IO.Error β)) : Waiter β :=
Waiter.mk w.finished p
@ -49,6 +50,14 @@ def Waiter.race [Monad m] [MonadLiftT (ST IO.RealWorld) m] (w : Waiter α)
else
lose
/--
Atomically checks whether the `Waiter` has already finished. Note that right after this function
call ends this might have already changed.
-/
@[inline]
def Waiter.checkFinished [Monad m] [MonadLiftT (ST IO.RealWorld) m] (w : Waiter α) : m Bool := do
w.finished.get
/--
An event source that can be multiplexed using `Selectable.one`, see the documentation of
`Selectable.one` for how the protocol of communicating with a `Selector` works.

View file

@ -7,6 +7,7 @@ prelude
import Init.System.Promise
import Init.Data.Queue
import Std.Sync.Mutex
import Std.Internal.Async.Select
/-!
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
@ -44,6 +45,23 @@ instance : ToString Error where
instance : MonadLift (EIO Error) IO where
monadLift x := EIO.toIO (.userError <| toString ·) x
open Internal.IO.Async in
private inductive Consumer (α : Type) where
| normal (promise : IO.Promise (Option α))
| select (finished : Waiter (Option α))
private def Consumer.resolve (c : Consumer α) (x : Option α) : BaseIO Bool := do
match c with
| .normal promise =>
promise.resolve x
return true
| .select waiter =>
let lose := return false
let win promise := do
promise.resolve (.ok x)
return true
waiter.race lose win
/--
The central state structure for an unbounded channel, maintains the following invariants:
1. `values = ∅ consumers = ∅`
@ -55,10 +73,10 @@ private structure Unbounded.State (α : Type) where
-/
values : Std.Queue α
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
resolved to `none` if the channel closes.
Consumers that are blocked on a producer providing them a value. They will be resolved to `none`
if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
consumers : Std.Queue (Consumer α)
/--
Whether the channel is closed already.
-/
@ -85,12 +103,18 @@ private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
let st ← get
if st.closed then
return false
else if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
set { st with values := st.values.enqueue v }
while true do
let st ← get
if let some (consumer, consumers) := st.consumers.dequeue? then
let success ← consumer.resolve (some v)
set { st with consumers }
if success then
break
else
set { st with values := st.values.enqueue v }
break
return true
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
@ -103,7 +127,8 @@ private def close (ch : Unbounded α) : EIO Error Unit := do
ch.state.atomically do
let st ← get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
for consumer in st.consumers.toArray do
discard <| consumer.resolve none
set { st with consumers := ∅, closed := true }
return ()
@ -111,7 +136,8 @@ private def isClosed (ch : Unbounded α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
AtomicT (Unbounded.State α) m (Option α) := do
let st ← get
if let some (a, values) := st.values.dequeue? then
set { st with values }
@ -131,9 +157,47 @@ private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
return .pure none
else
let promise ← IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) }
return promise.result?.map (sync := true) (·.bind id)
@[inline]
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
AtomicT (Unbounded.State α) m Bool := do
let st ← get
return !st.values.isEmpty || st.closed
open Internal.IO.Async in
private def recvSelector (ch : Unbounded α) : Selector (Option α) where
tryFn := do
ch.state.atomically do
if ← recvReady' then
let val ← tryRecv'
return some val
else
return none
registerFn waiter := do
ch.state.atomically do
-- We did drop the lock between `tryFn` and now so maybe ready?
if ← recvReady' then
let lose := return ()
let win promise := do
-- We know we are ready so the value by this is fine
promise.resolve (.ok (← tryRecv'))
waiter.race lose win
else
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
unregisterFn := do
ch.state.atomically do
let st ← get
let consumers ← st.consumers.filterM
fun
| .normal .. => return true
| .select waiter => return !(← waiter.checkFinished)
set { st with consumers }
end Unbounded
/--
@ -147,10 +211,10 @@ private structure Zero.State (α : Type) where
-/
producers : Std.Queue (α × IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
to `none` if the channel closes.
Consumers that are blocked on a producer providing them a value. They will be resolved to `none`
if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
consumers : Std.Queue (Consumer α)
/--
Whether the channel is closed already.
-/
@ -174,13 +238,17 @@ private def new : BaseIO (Zero α) := do
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
let st ← get
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
return false
while true do
let st ← get
if let some (consumer, consumers) := st.consumers.dequeue? then
let success ← consumer.resolve (some v)
set { st with consumers }
if success then
break
else
return false
return true
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
ch.state.atomically do
@ -207,7 +275,8 @@ private def close (ch : Zero α) : EIO Error Unit := do
ch.state.atomically do
let st ← get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
for consumer in st.consumers.toArray do
discard <| consumer.resolve none
set { st with consumers := ∅, closed := true }
return ()
@ -215,7 +284,8 @@ private def isClosed (ch : Zero α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] :
AtomicT (Zero.State α) m (Option α) := do
let st ← get
if let some ((val, promise), producers) := st.producers.dequeue? then
set { st with producers }
@ -235,13 +305,59 @@ private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
return .pure <| some val
else if !st.closed then
let promise ← IO.Promise.new
set { st with consumers := st.consumers.enqueue promise }
set { st with consumers := st.consumers.enqueue (.normal promise) }
return promise.result?.map (sync := true) (·.bind id)
else
return .pure <| none
@[inline]
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
AtomicT (Zero.State α) m Bool := do
let st ← get
return !st.producers.isEmpty || st.closed
open Internal.IO.Async in
private def recvSelector (ch : Zero α) : Selector (Option α) where
tryFn := do
ch.state.atomically do
if ← recvReady' then
let val ← tryRecv'
return some val
else
return none
registerFn waiter := do
ch.state.atomically do
-- We did drop the lock between `tryFn` and now so maybe ready?
if ← recvReady' then
let lose := return ()
let win promise := do
-- We know we are ready so the value by this is fine
promise.resolve (.ok (← tryRecv'))
waiter.race lose win
else
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
unregisterFn := do
ch.state.atomically do
let st ← get
let consumers ← st.consumers.filterM
fun
| .normal .. => pure true
| .select waiter => return !(← waiter.checkFinished)
set { st with consumers }
end Zero
open Internal.IO.Async in
private structure Bounded.Consumer (α : Type) where
promise : IO.Promise Bool
waiter : Option (Waiter (Option α))
private def Bounded.Consumer.resolve (c : Bounded.Consumer α) (b : Bool) : BaseIO Unit :=
c.promise.resolve b
/--
The central state structure for a bounded channel, maintains the following invariants:
1. `0 < capacity`
@ -265,10 +381,9 @@ private structure Bounded.State (α : Type) where
producers : Std.Queue (IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value, as there was no value
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
closes.
enqueued when they tried to dequeue. They will be resolved to `false` if the channel closes.
-/
consumers : Std.Queue (IO.Promise Bool)
consumers : Std.Queue (Bounded.Consumer α)
/--
The capacity of the buffer space.
-/
@ -390,7 +505,8 @@ private def isClosed (ch : Bounded α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
private def tryRecv' [Monad m] [MonadLiftT (ST IO.RealWorld) m] [MonadLiftT BaseIO m] :
AtomicT (Bounded.State α) m (Option α) := do
let mut st ← get
if st.bufCount == 0 then
return none
@ -423,13 +539,71 @@ private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
return .pure none
else
let promise ← IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
modify fun st => { st with consumers := st.consumers.enqueue promise, none⟩ }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.recv ch
else
return .pure none
@[inline]
private def recvReady' [Monad m] [MonadLiftT (ST IO.RealWorld) m] :
AtomicT (Bounded.State α) m Bool := do
let st ← get
return st.bufCount != 0 || st.closed
open Internal.IO.Async in
private partial def recvSelector (ch : Bounded α) : Selector (Option α) where
tryFn := do
ch.state.atomically do
if ← recvReady' then
let val ← tryRecv'
return some val
else
return none
registerFn := registerAux ch
unregisterFn := do
ch.state.atomically do
let st ← get
let consumers ← st.consumers.filterM fun c => do
match c.waiter with
| some waiter => return !(← waiter.checkFinished)
| none => return true
set { st with consumers }
where
registerAux (ch : Bounded α) (waiter : Waiter (Option α)) : IO Unit := do
ch.state.atomically do
-- We did drop the lock between `tryFn` and now so maybe ready?
if ← recvReady' then
-- if we lose we must trigger the next promise (if available) to avoid deadlocking
let lose := do
let st ← get
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve true
set { st with consumers }
let win promise := do
-- We know we are ready so the value by this is fine
promise.resolve (.ok (← tryRecv'))
waiter.race lose win
else
let promise ← IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue ⟨promise, some waiter⟩ }
IO.chainTask promise.result? fun res? => do
match res? with
| none => return ()
| some res =>
if res then
registerAux ch waiter
else
let lose := return ()
let win promise := promise.resolve (.ok none)
waiter.race lose win
end Bounded
/--
@ -551,6 +725,18 @@ def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
| .zero ch => CloseableChannel.Zero.recv ch
| .bounded ch => CloseableChannel.Bounded.recv ch
open Internal.IO.Async in
/--
Creates a `Selector` that resolves once `ch` has data available and provides that that data.
In particular if `ch` is closed while waiting on this `Selector` and no data is available already
this will resolve to `none`.
-/
def recvSelector (ch : CloseableChannel α) : Selector (Option α) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.recvSelector ch
| .zero ch => CloseableChannel.Zero.recvSelector ch
| .bounded ch => CloseableChannel.Bounded.recvSelector ch
/--
`ch.forAsync f` calls `f` for every message received on `ch`.
@ -674,6 +860,29 @@ def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
| some val => return .pure val
| none => unreachable!
open Internal.IO.Async in
/--
Creates a `Selector` that resolves once `ch` has data available and provides that that data.
-/
def recvSelector [Inhabited α] (ch : Channel α) : Selector α :=
let sel := CloseableChannel.recvSelector ch.inner
{
tryFn := ch.tryRecv
registerFn waiter := do
let original := waiter.promise
let intermediate ← IO.Promise.new
let waiter := waiter.withPromise intermediate
sel.registerFn waiter
IO.chainTask (sync := true) intermediate.result?
fun
| none => return ()
| some res =>
-- `res` can only be `.err` or `.ok some` as we are in a non closeable channel.
original.resolve (res.map Option.get!)
unregisterFn := sel.unregisterFn
}
@[inherit_doc CloseableChannel.forAsync]
partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do

View file

@ -0,0 +1,99 @@
import Std.Sync.Channel
open Std Internal IO Async
namespace A
def testReceiver (ch1 ch2 : Std.Channel Nat) (count : Nat) : IO (AsyncTask Nat) := do
go ch1 ch2 count 0
where
go (ch1 ch2 : Std.Channel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do
match count with
| 0 => return AsyncTask.pure acc
| count + 1 =>
Selectable.one #[
.case ch1.recvSelector fun data => go ch1 ch2 count (acc + data),
.case ch2.recvSelector fun data => go ch1 ch2 count (acc + data),
]
def testIt (capacity : Option Nat) : IO Bool := do
let amount := 1000
let messages := Array.range amount
let ch1 ← Std.Channel.new capacity
let ch2 ← Std.Channel.new capacity
let recvTask ← testReceiver ch1 ch2 amount
for msg in messages do
if (← IO.rand 0 1) = 0 then
ch1.sync.send msg
else
ch2.sync.send msg
let acc ← recvTask.block
return acc == messages.sum
/-- info: true -/
#guard_msgs in
#eval testIt none
/-- info: true -/
#guard_msgs in
#eval testIt (some 0)
/-- info: true -/
#guard_msgs in
#eval testIt (some 1)
/-- info: true -/
#guard_msgs in
#eval testIt (some 128)
end A
namespace B
def testReceiver (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) : IO (AsyncTask Nat) := do
go ch1 ch2 count 0
where
go (ch1 ch2 : Std.CloseableChannel Nat) (count : Nat) (acc : Nat) : IO (AsyncTask Nat) := do
match count with
| 0 => return AsyncTask.pure acc
| count + 1 =>
Selectable.one #[
.case ch1.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0),
.case ch2.recvSelector fun data => go ch1 ch2 count (acc + data.getD 0),
]
def testIt (capacity : Option Nat) : IO Bool := do
let amount := 1000
let messages := Array.range amount
let ch1 ← Std.CloseableChannel.new capacity
let ch2 ← Std.CloseableChannel.new capacity
let recvTask ← testReceiver ch1 ch2 amount
for msg in messages do
if (← IO.rand 0 1) = 0 then
ch1.sync.send msg
else
ch2.sync.send msg
let acc ← recvTask.block
return acc == messages.sum
/-- info: true -/
#guard_msgs in
#eval testIt none
/-- info: true -/
#guard_msgs in
#eval testIt (some 0)
/-- info: true -/
#guard_msgs in
#eval testIt (some 1)
/-- info: true -/
#guard_msgs in
#eval testIt (some 128)
end B