refactor: more complete channel implementation for Std.Channel (#7819)

This PR extends `Std.Channel` to provide a full sync and async API, as
well as unbounded, zero sized and bounded channels.

A few notes on the implementation:
- the bounded channel is inspired by [Go channels on
steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
though currently doesn't do any of the lock-free optimizations
- @mhuisi convinced me that having a non-closable channel may be a good
idea as this alleviates the need for error handling which is very
annoying when working with `Task`. This does complicate the API a little
bit and I'm not quite sure whether this is a choice we want users to
give. An alternative to this would be to just write `send!` that panics
on sending to a closed channel (receiving from a closed channel is not
an error), this is for example the behavior that golang goes with.
This commit is contained in:
Henrik Böving 2025-04-12 23:02:24 +02:00 committed by GitHub
parent 85a0232e87
commit dd7ca772d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 956 additions and 246 deletions

View file

@ -34,7 +34,6 @@ import Init.Data.Stream
import Init.Data.Prod
import Init.Data.AC
import Init.Data.Queue
import Init.Data.Channel
import Init.Data.Sum
import Init.Data.BEq
import Init.Data.Subtype

View file

@ -1,149 +0,0 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
-/
prelude
import Init.Data.Queue
import Init.System.Promise
import Init.System.Mutex
set_option linter.deprecated false
namespace IO
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
@[deprecated "Use Std.Channel.State from Std.Sync.Channel instead" (since := "2024-12-02")]
structure Channel.State (α : Type) where
values : Std.Queue α := ∅
consumers : Std.Queue (Promise (Option α)) := ∅
closed := false
deriving Inhabited
/--
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
-/
@[deprecated "Use Std.Channel from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel (α : Type) : Type := Mutex (Channel.State α)
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
/-- Creates a new `Channel`. -/
@[deprecated "Use Std.Channel.new from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
/--
Sends a message on an `Channel`.
This function does not block.
-/
@[deprecated "Use Std.Channel.send from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
let st ← get
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
else
set { st with values := st.values.enqueue v }
/--
Closes an `Channel`.
-/
@[deprecated "Use Std.Channel.close from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
let st ← get
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := ∅ }
/--
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
let st ← get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
else if !st.closed then
let promise ← Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result
else
return .pure none
/--
`ch.forAsync f` calls `f` for every messages received on `ch`.
Note that if this function is called twice, each `forAsync` only gets half the messages.
-/
@[deprecated "Use Std.Channel.forAsync from Std.Sync.Channel instead" (since := "2024-12-02")]
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
-/
@[deprecated "Use Std.Channel.recvAllCurrent from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
@[deprecated "Use Std.Channel.Sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync := Channel
/--
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
-/
@[deprecated "Use Std.Channel.sync from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
/--
Synchronously receives a message from the channel.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
-/
@[deprecated "Use Std.Channel.Sync.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait (← Channel.recv? ch)
@[deprecated "Use Std.Channel.Sync.forIn from Std.Sync.Channel instead" (since := "2024-12-02")]
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
match ← ch.recv? with
| some a =>
match ← f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
forIn ch b f := ch.forIn f b

View file

@ -111,6 +111,7 @@ inductive Message where
| response (id : RequestID) (result : Json)
/-- A non-successful response. -/
| responseError (id : RequestID) (code : ErrorCode) (message : String) (data? : Option Json)
deriving Inhabited
def Batch := Array Message

View file

@ -207,7 +207,7 @@ This option can only be set on the command line, not in the lakefile or via `set
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|>.map (·.toDiagnostic)
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
ctx.chanOut.send notification
ctx.chanOut.sync.send notification
open Language in
/--
@ -239,7 +239,7 @@ This option can only be set on the command line, not in the lakefile or via `set
publishDiagnostics ctx doc
-- This will overwrite existing ilean info for the file, in case something
-- went wrong during the incremental updates.
ctx.chanOut.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
ctx.chanOut.sync.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
return ()
where
/--
@ -312,7 +312,7 @@ This option can only be set on the command line, not in the lakefile or via `set
if let some itree := node.element.infoTree? then
let mut newInfoTrees := (← get).newInfoTrees.push itree
if (← get).hasBlocked then
ctx.chanOut.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
ctx.chanOut.sync.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
newInfoTrees := #[]
modify fun st => { st with newInfoTrees, allInfoTrees := st.allInfoTrees.push itree }
@ -329,7 +329,7 @@ This option can only be set on the command line, not in the lakefile or via `set
| none => rs.push r
let ranges := ranges.map (·.toLspRange doc.meta.text)
let notifs := ranges.map ({ range := ·, kind := .processing })
ctx.chanOut.send <| mkFileProgressNotification doc.meta notifs
ctx.chanOut.sync.send <| mkFileProgressNotification doc.meta notifs
end Elab
@ -389,9 +389,9 @@ def setupImports
severity? := DiagnosticSeverity.information
message := stderrLine
}
chanOut.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
-- clear progress notifications in the end
chanOut.send <| mkPublishDiagnosticsNotification meta #[]
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[]
match fileSetupResult.kind with
| .importsOutOfDate =>
return .error {
@ -525,7 +525,7 @@ section ServerRequests
(freshRequestId, freshRequestId + 1)
let responseTask ← ctx.initPendingServerRequest responseType freshRequestId
let r : JsonRpc.Request paramType := ⟨freshRequestId, method, param⟩
ctx.chanOut.send r
ctx.chanOut.sync.send r
return responseTask
def sendUntypedServerRequest
@ -679,7 +679,7 @@ section MessageHandling
let availableImports ← ImportCompletion.collectAvailableImports
let lastRequestTimestampMs ← IO.monoMsNow
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.send <| .response id (toJson completions)
ctx.chanOut.sync.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
| some task => ServerTask.IO.mapTaskCostly (t := task) fun (result : Except Error AvailableImportsCache) => do
@ -689,7 +689,7 @@ section MessageHandling
availableImports ← ImportCompletion.collectAvailableImports
lastRequestTimestampMs := timestampNowMs
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
ctx.chanOut.send <| .response id (toJson completions)
ctx.chanOut.sync.send <| .response id (toJson completions)
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
@ -701,7 +701,7 @@ section MessageHandling
| "$/lean/rpc/connect" =>
let ps ← parseParams RpcConnectParams params
let resp ← handleRpcConnect ps
ctx.chanOut.send <| .response id (toJson resp)
ctx.chanOut.sync.send <| .response id (toJson resp)
return true
| "textDocument/completion" =>
let params ← parseParams CompletionParams params
@ -714,7 +714,7 @@ section MessageHandling
| _ =>
return false
catch e =>
ctx.chanOut.send <| .responseError id .internalError (toString e) none
ctx.chanOut.sync.send <| .responseError id .internalError (toString e) none
return true
open Widget RequestM Language in
@ -836,7 +836,7 @@ section MessageHandling
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
where
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
ctx.chanOut.send m
ctx.chanOut.sync.send m
let timestamp ← IO.monoMsNow
ctx.modifyPartialHandler method fun h => { h with
requestsInFlight := h.requestsInFlight - 1

View file

@ -1,137 +1,720 @@
/-
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Gabriel Ebner
Authors: Henrik Böving
-/
prelude
import Init.System.Promise
import Init.Data.Queue
import Std.Sync.Mutex
/-!
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
multi-consumer FIFO channel that offers both bounded and unbounded buffering as well as synchronous
and asynchronous APIs.
Additionally `Std.CloseableChannel` is provided in case closing the channel is of interest.
The two are distinct as the non closable `Std.Channel` can never throw errors which makes
for cleaner code.
-/
namespace Std
/--
Internal state of an `Channel`.
We maintain the invariant that at all times either `consumers` or `values` is empty.
-/
structure Channel.State (α : Type) where
values : Std.Queue α := ∅
consumers : Std.Queue (IO.Promise (Option α)) := ∅
closed := false
deriving Inhabited
namespace CloseableChannel
/--
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
A channel can be closed. Once it is closed, all `send`s are ignored, and
`recv?` returns `none` once the queue is empty.
Errors that may be thrown while interacting with the channel API.
-/
def Channel (α : Type) : Type := Mutex (Channel.State α)
inductive Error where
/--
Tried to send to a closed channel.
-/
| closed
/--
Tried to close an already closed channel.
-/
| alreadyClosed
deriving Repr, DecidableEq, Hashable
instance : Nonempty (Channel α) :=
inferInstanceAs (Nonempty (Mutex _))
instance : ToString Error where
toString
| .closed => "trying to send on an already closed channel"
| .alreadyClosed => "trying to close an already closed channel"
/-- Creates a new `Channel`. -/
def Channel.new : BaseIO (Channel α) :=
Mutex.new {}
instance : MonadLift (EIO Error) IO where
monadLift x := EIO.toIO (.userError <| toString ·) x
/--
Sends a message on an `Channel`.
This function does not block.
The central state structure for an unbounded channel, maintains the following invariants:
1. `values = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
-/
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
ch.atomically do
private structure Unbounded.State (α : Type) where
/--
Values pushed into the channel that are waiting to be consumed.
-/
values : Std.Queue α
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
resolved to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
deriving Nonempty
private structure Unbounded (α : Type) where
state : Mutex (Unbounded.State α)
deriving Nonempty
namespace Unbounded
private def new : BaseIO (Unbounded α) := do
return {
state := ← Mutex.new {
values := ∅
consumers := ∅
closed := false
}
}
private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
let st ← get
if st.closed then return
if let some (consumer, consumers) := st.consumers.dequeue? then
if st.closed then
return false
else if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
set { st with values := st.values.enqueue v }
return true
/--
Closes an `Channel`.
-/
def Channel.close (ch : Channel α) : BaseIO Unit :=
ch.atomically do
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
if ← Unbounded.trySend ch v then
return .pure <| .ok ()
else
return .pure <| .error .closed
private def close (ch : Unbounded α) : EIO Error Unit := do
ch.state.atomically do
let st ← get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with closed := true, consumers := ∅ }
set { st with consumers := ∅, closed := true }
return ()
private def isClosed (ch : Unbounded α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
let st ← get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return some a
else
return none
private def tryRecv (ch : Unbounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val ← tryRecv' then
return .pure <| some val
else if (← get).closed then
return .pure none
else
let promise ← IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
end Unbounded
/--
Receives a message, without blocking.
The returned task waits for the message.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
The central state structure for a zero buffer channel, maintains the following invariants:
1. `producers = ∅ consumers = ∅`
2. `closed = true → consumers = ∅`
-/
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
ch.atomically do
private structure Zero.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value.
-/
producers : Std.Queue (α × IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
to `none` if the channel closes.
-/
consumers : Std.Queue (IO.Promise (Option α))
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Zero (α : Type) where
state : Mutex (Zero.State α)
namespace Zero
private def new : BaseIO (Zero α) := do
return {
state := ← Mutex.new {
producers := ∅
consumers := ∅
closed := false
}
}
/--
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
let st ← get
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve (some v)
set { st with consumers }
return true
else
return false
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if (← get).closed then
return false
else
trySend' v
private def send (ch : Zero α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if (← get).closed then
return .pure <| .error .closed
else if ← trySend' v then
return .pure <| .ok ()
else
let promise ← IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue (v, promise) }
return promise.result?.map (sync := true)
fun
| none | some false => .error .closed
| some true => .ok ()
private def close (ch : Zero α) : EIO Error Unit := do
ch.state.atomically do
let st ← get
if let some (a, values) := st.values.dequeue? then
set { st with values }
return .pure a
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve none
set { st with consumers := ∅, closed := true }
return ()
private def isClosed (ch : Zero α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
let st ← get
if let some ((val, promise), producers) := st.producers.dequeue? then
set { st with producers }
promise.resolve true
return some val
else
return none
private def tryRecv (ch : Zero α) : BaseIO (Option α) := do
ch.state.atomically do
tryRecv'
private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
let st ← get
if let some val ← tryRecv' then
return .pure <| some val
else if !st.closed then
let promise ← IO.Promise.new
set { st with consumers := st.consumers.enqueue promise }
return promise.result?.map (sync := true) (·.bind id)
else
return .pure none
return .pure <| none
end Zero
/--
`ch.forAsync f` calls `f` for every messages received on `ch`.
The central state structure for a bounded channel, maintains the following invariants:
1. `0 < capacity`
2. `0 < bufCount → consumers = ∅`
3. `bufCount < capacity → producers = ∅`
4. `producers = ∅ consumers = ∅`, implied by 1, 2 and 3.
5. `bufCount` corresponds to the amount of slots in `buf` that are `some`.
6. `sendIdx = (recvIdx + bufCount) % capacity`. However all four of these values still get tracked
as there is potential to make a non-blocking send lock-free in the future with this approach.
7. `closed = true → consumers = ∅`
Note that if this function is called twice, each `forAsync` only gets half the messages.
While it (currently) lacks the partial lock-freeness of go channels, the protocol is based on
[Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
as well as its [implementation](https://go.dev/src/runtime/chan.go).
-/
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
private structure Bounded.State (α : Type) where
/--
Producers that are blocked on a consumer taking their value as there was no buffer space
available when they tried to enqueue.
-/
producers : Std.Queue (IO.Promise Bool)
/--
Consumers that are blocked on a producer providing them a value, as there was no value
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
closes.
-/
consumers : Std.Queue (IO.Promise Bool)
/--
The capacity of the buffer space.
-/
capacity : Nat
/--
The buffer space for the channel, slots with `some v` contain a value that is waiting for
consumption, the slots with `none` are free for enqueueing.
Note that this is a `Vector` of `IO.Ref (Option α)` as the `buf` itself is shared across threads
and would thus keep getting copied if it was a `Vector (Option α)` instead.
-/
buf : Vector (IO.Ref (Option α)) capacity
/--
How many slots in `buf` are currently used, this is used to disambiguate between an empty and a
full buffer without sacrificing a slot for indicating that.
-/
bufCount : Nat
/--
The slot in `buf` that the next send will happen to.
-/
sendIdx : Nat
hsend : sendIdx < capacity
/--
The slot in `buf` that the next receive will happen from.
-/
recvIdx : Nat
hrecv : recvIdx < capacity
/--
Whether the channel is closed already.
-/
closed : Bool
private structure Bounded (α : Type) where
state : Mutex (Bounded.State α)
namespace Bounded
private def new (capacity : Nat) (hcap : 0 < capacity) : BaseIO (Bounded α) := do
return {
state := ← Mutex.new {
producers := ∅
consumers := ∅
capacity := capacity
buf := ← Vector.range capacity |>.mapM (fun _ => IO.mkRef none)
bufCount := 0
sendIdx := 0
hsend := hcap
recvIdx := 0
hrecv := hcap
closed := false
}
}
@[inline]
private def incMod (idx : Nat) (cap : Nat) : Nat :=
if idx + 1 = cap then
0
else
idx + 1
private theorem incMod_lt {idx cap : Nat} (h : idx < cap) : incMod idx cap < cap := by
unfold incMod
split <;> omega
/--
Precondition: The channel must not be closed.
-/
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO Bool := do
let mut st ← get
if st.bufCount = st.capacity then
return false
else
st.buf[st.sendIdx]'st.hsend |>.set (some v)
st := { st with
bufCount := st.bufCount + 1
sendIdx := incMod st.sendIdx st.capacity
hsend := incMod_lt st.hsend
}
if let some (consumer, consumers) := st.consumers.dequeue? then
consumer.resolve true
st := { st with consumers }
set st
return true
private def trySend (ch : Bounded α) (v : α) : BaseIO Bool := do
ch.state.atomically do
if (← get).closed then
return false
else
trySend' v
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
ch.state.atomically do
if (← get).closed then
return .pure <| .error .closed
else if ← trySend' v then
return .pure <| .ok ()
else
let promise ← IO.Promise.new
modify fun st => { st with producers := st.producers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.send ch v
else
return .pure <| .error .closed
private def close (ch : Bounded α) : EIO Error Unit := do
ch.state.atomically do
let st ← get
if st.closed then throw .alreadyClosed
for consumer in st.consumers.toArray do consumer.resolve false
set { st with consumers := ∅, closed := true }
return ()
private def isClosed (ch : Bounded α) : BaseIO Bool :=
ch.state.atomically do
return (← get).closed
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
let st ← get
if st.bufCount == 0 then
return none
else
let val ← st.buf[st.recvIdx]'st.hrecv |>.swap none
let nextRecvIdx := incMod st.recvIdx st.capacity
set { st with
bufCount := st.bufCount - 1
recvIdx := nextRecvIdx,
hrecv := incMod_lt st.hrecv
}
return val
private def tryRecv (ch : Bounded α) : BaseIO (Option α) :=
ch.state.atomically do
tryRecv'
private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
ch.state.atomically do
if let some val ← tryRecv' then
let st ← get
if let some (producer, producers) := (← get).producers.dequeue? then
producer.resolve true
set { st with producers }
return .pure <| some val
else if (← get).closed then
return .pure none
else
let promise ← IO.Promise.new
modify fun st => { st with consumers := st.consumers.enqueue promise }
BaseIO.bindTask promise.result? fun res => do
if res.getD false then
Bounded.recv ch
else
return .pure none
end Bounded
/--
This type represents all flavors of channels that we have available.
-/
private inductive Flavors (α : Type) where
| unbounded (ch : Unbounded α)
| zero (ch : Zero α)
| bounded (ch : Bounded α)
deriving Nonempty
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `CloseableChannel.sync`.
Additionally `Std.CloseableChannel` can be closed if necessary, unlike `Std.Channel`.
This introduces a need for error handling in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel (α : Type) : Type := CloseableChannel.Flavors α
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
Additionally `Std.CloseableChannel.Sync` can be closed if necessary, unlike `Std.Channel.Sync`.
This introduces the need to handle errors in some cases, thus it is usually easier to use
`Std.Channel` if applicable.
-/
def CloseableChannel.Sync (α : Type) : Type := CloseableChannel α
instance : Nonempty (CloseableChannel α) :=
inferInstanceAs (Nonempty (CloseableChannel.Flavors α))
instance : Nonempty (CloseableChannel.Sync α) :=
inferInstanceAs (Nonempty (CloseableChannel α))
namespace CloseableChannel
/--
Create a new channel, if:
- `capacity` is `none` it will be unbounded (the default)
- `capacity` is `some 0` it will always force a rendezvous between sender and receiver
- `capacity` is `some n` with `n > 0` it will use a buffer of size `n` and begin blocking once it
is filled
-/
def new (capacity : Option Nat := none) : BaseIO (CloseableChannel α) := do
match capacity with
| none => return .unbounded (← CloseableChannel.Unbounded.new)
| some 0 => return .zero (← CloseableChannel.Zero.new)
| some (n + 1) => return .bounded (← CloseableChannel.Bounded.new (n + 1) (by omega))
/--
Try to send a value to the channel, if this can be completed right away without blocking return
`true`, otherwise don't send the value and return `false`.
-/
def trySend (ch : CloseableChannel α) (v : α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.trySend ch v
| .zero ch => CloseableChannel.Zero.trySend ch v
| .bounded ch => CloseableChannel.Bounded.trySend ch v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `Except.error` if the channel was closed before it
could be completed.
-/
def send (ch : CloseableChannel α) (v : α) : BaseIO (Task (Except Error Unit)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.send ch v
| .zero ch => CloseableChannel.Zero.send ch v
| .bounded ch => CloseableChannel.Bounded.send ch v
/--
Closes the channel, returns `Except.ok` when called the first time, otherwise `Except.error`.
When a channel is closed:
- no new values can be sent successfully anymore
- all blocked consumers are resolved to `none` (as no new messages can be sent they will never
resolve)
- if there are already values waiting to be received they can still be received by subsequent `recv`
calls
-/
def close (ch : CloseableChannel α) : EIO Error Unit :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.close ch
| .zero ch => CloseableChannel.Zero.close ch
| .bounded ch => CloseableChannel.Bounded.close ch
/--
Return `true` if the channel is closed.
-/
def isClosed (ch : CloseableChannel α) : BaseIO Bool :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.isClosed ch
| .zero ch => CloseableChannel.Zero.isClosed ch
| .bounded ch => CloseableChannel.Bounded.isClosed ch
/--
Try to receive a value from the channel, if this can be completed right away without blocking return
`some value`, otherwise return `none`.
-/
def tryRecv (ch : CloseableChannel α) : BaseIO (Option α) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.tryRecv ch
| .zero ch => CloseableChannel.Zero.tryRecv ch
| .bounded ch => CloseableChannel.Bounded.tryRecv ch
/--
Receive a value from the channel, returning a task that will resolve once the transmission could be
completed. Note that the task may resolve to `none` if the channel was closed before it could be
completed.
-/
def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
match ch with
| .unbounded ch => CloseableChannel.Unbounded.recv ch
| .zero ch => CloseableChannel.Zero.recv ch
| .bounded ch => CloseableChannel.Bounded.recv ch
/--
`ch.forAsync f` calls `f` for every message received on `ch`.
Note that if this function is called twice, each message will only arrive at exactly one invocation.
-/
partial def forAsync (f : α → BaseIO Unit) (ch : CloseableChannel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
BaseIO.bindTask (prio := prio) (← ch.recv) fun
| none => return .pure ()
| some v => do f v; ch.forAsync f prio
/--
Receives all currently queued messages from the channel.
Those messages are dequeued and will not be returned by `recv?`.
This function is a no-op and just a convenient way to expose the synchronous API of the channel.
-/
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
ch.atomically do
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
@[inline]
def sync (ch : CloseableChannel α) : CloseableChannel.Sync α := ch
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
def Channel.Sync := Channel
namespace Sync
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := CloseableChannel.new capacity
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := CloseableChannel.trySend ch v
/--
Accesses synchronous (blocking) version of channel operations.
For example, `ch.sync.recv?` blocks until the next message,
and `for msg in ch.sync do ...` iterates synchronously over the channel.
These functions should only be used in dedicated threads.
Send a value through the channel, blocking until the transmission could be completed. Note that this
function may throw an error when trying to send to an already closed channel.
-/
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
def send (ch : Sync α) (v : α) : EIO Error Unit := do
EIO.ofExcept (← IO.wait (← CloseableChannel.send ch v))
@[inherit_doc CloseableChannel.close, inline]
def close (ch : Sync α) : EIO Error Unit := CloseableChannel.close ch
@[inherit_doc CloseableChannel.isClosed, inline]
def isClosed (ch : Sync α) : BaseIO Bool := CloseableChannel.isClosed ch
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := CloseableChannel.tryRecv ch
/--
Synchronously receives a message from the channel.
Every message is only received once.
Returns `none` if the channel is closed and the queue is empty.
Receive a value from the channel, blocking unitl the transmission could be completed. Note that the
return value may be `none` if the channel was closed before it could be completed.
-/
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
IO.wait (← Channel.recv? ch)
def recv (ch : Sync α) : BaseIO (Option α) := do
IO.wait (← CloseableChannel.recv ch)
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
match ← ch.recv? with
| some a =>
match ← f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
private partial def forIn [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
match ← ch.recv with
| some a =>
match ← f a b with
| .done b => pure b
| .yield b => ch.forIn f b
| none => pure b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
instance [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end CloseableChannel
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and an asynchronous API, to switch into synchronous mode use `Channel.sync`.
If a channel needs to be closed to indicate some sort of completion event use `Std.CloseableChannel`
instead. Note that `Std.CloseableChannel` introduces a need for error handling in some cases, thus
`Std.Channel` is usually easier to use if applicable.
-/
structure Channel (α : Type) where
private mk ::
private inner : CloseableChannel α
deriving Nonempty
/--
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
and is not actually different from the original channel.
If a channel needs to be closed to indicate some sort of completion event use
`Std.CloseableChannel.Sync` instead. Note that `Std.CloseableChannel.Sync` introduces a need for error
handling in some cases, thus `Std.Channel.Sync` is usually easier to use if applicable.
-/
def Channel.Sync (α : Type) : Type := Channel α
instance : Nonempty (Channel.Sync α) :=
inferInstanceAs (Nonempty (Channel α))
namespace Channel
@[inherit_doc CloseableChannel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Channel α) := do
return ⟨← CloseableChannel.new capacity⟩
@[inherit_doc CloseableChannel.trySend, inline]
def trySend (ch : Channel α) (v : α) : BaseIO Bool :=
CloseableChannel.trySend ch.inner v
/--
Send a value through the channel, returning a task that will resolve once the transmission could be
completed.
-/
def send (ch : Channel α) (v : α) : BaseIO (Task Unit) := do
BaseIO.bindTask (sync := true) (← CloseableChannel.send ch.inner v)
fun
| .ok .. => return .pure ()
| .error .. => unreachable!
@[inherit_doc CloseableChannel.tryRecv, inline]
def tryRecv (ch : Channel α) : BaseIO (Option α) :=
CloseableChannel.tryRecv ch.inner
@[inherit_doc CloseableChannel.recv]
def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
BaseIO.bindTask (sync := true) (← CloseableChannel.recv ch.inner)
fun
| some val => return .pure val
| none => unreachable!
@[inherit_doc CloseableChannel.forAsync]
partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α)
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
BaseIO.bindTask (prio := prio) (← ch.recv) fun v => do f v; ch.forAsync f prio
@[inherit_doc CloseableChannel.sync, inline]
def sync (ch : Channel α) : Channel.Sync α := ch
namespace Sync
@[inherit_doc Channel.new, inline]
def new (capacity : Option Nat := none) : BaseIO (Sync α) := Channel.new capacity
@[inherit_doc Channel.trySend, inline]
def trySend (ch : Sync α) (v : α) : BaseIO Bool := Channel.trySend ch v
/--
Send a value through the channel, blocking until the transmission could be completed.
-/
def send (ch : Sync α) (v : α) : BaseIO Unit := do
IO.wait (← Channel.send ch v)
@[inherit_doc Channel.tryRecv, inline]
def tryRecv (ch : Sync α) : BaseIO (Option α) := Channel.tryRecv ch
/--
Receive a value from the channel, blocking unitl the transmission could be completed.
-/
def recv [Inhabited α] (ch : Sync α) : BaseIO α := do
IO.wait (← Channel.recv ch)
private partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
let a ← ch.recv
match ← f a b with
| .done b => pure b
| .yield b => ch.forIn f b
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
forIn ch b f := ch.forIn f b
end Sync
end Channel
end Std

117
tests/bench/channel.lean Normal file
View file

@ -0,0 +1,117 @@
import Std.Sync.Channel
/-
Inspired by:
https://github.com/crossbeam-rs/crossbeam/tree/bd87a61ce3858ca772c42525d5f0c9aa12cc80ac/crossbeam-channel/benchmarks.
We conduct for:
- capacity 0 channels
- capacity 1 channels
- capacity `N` channels
- unbounded channels
the following tests:
- `seq`: A single thread sends `N` messages. Then it receives `N` messages.
- `spsc`: One thread sends `N` messages. Another thread receives `N` messages.
- `mpsc`: `T` threads send `N / T` messages each. One thread receives `N` messages.
- `mpmc`: `T` threads send `N / T` messages each. `T` other threads receive `N / T` messages each.
Note that we will stick exclusively to the sync interface for this as there is no benefit to be
reaped from async in this benchmark so we might as well just block.
-/
def MESSAGES : Nat := 1_000_000
def THREADS : Nat := 4
def seq (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
for i in [:amount] do
ch.send i
for _ in [:amount] do
discard <| ch.recv
def spsc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
let t1 ← IO.asTask (prio := .dedicated) do
for i in [:amount] do
ch.send i
let t2 ← BaseIO.asTask (prio := .dedicated) do
for _ in [:amount] do
discard <| ch.recv
IO.ofExcept (← IO.wait t1)
IO.wait t2
def mpsc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
let mut producers := Array.emptyWithCapacity THREADS
for _ in [:THREADS] do
let t ← IO.asTask (prio := .dedicated) do
for i in [:(amount/THREADS)] do
ch.send i
producers := producers.push t
let consumer ← BaseIO.asTask (prio := .dedicated) do
for _ in [:amount] do
discard <| ch.recv
IO.wait consumer
for producer in producers do
(IO.ofExcept (← IO.wait producer))
def mpmc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
let mut producers := Array.emptyWithCapacity THREADS
for _ in [:THREADS] do
let t ← IO.asTask (prio := .dedicated) do
for i in [:(amount/THREADS)] do
ch.send i
producers := producers.push t
let mut consumers := Array.emptyWithCapacity THREADS
for _ in [:THREADS] do
let t ← IO.asTask (prio := .dedicated) do
while true do
if let some _ ← ch.recv then
continue
else
break
consumers := consumers.push t
for producer in producers do
(IO.ofExcept (← IO.wait producer))
ch.close
for consumer in consumers do
(IO.ofExcept (← IO.wait consumer))
return ()
def run (name : String) (cap : Option Nat) (bench : Std.CloseableChannel.Sync Nat → Nat → IO Unit) :
IO Unit := do
let ch ← Std.CloseableChannel.new cap
let t1 ← IO.monoMsNow
bench ch.sync MESSAGES
let t2 ← IO.monoMsNow
let time : Float := (t2 - t1).toFloat / 1000.0
IO.println s!"{name}: {time}"
def main : IO Unit := do
run "bounded0_spsc" (some 0) spsc
run "bounded0_mpsc" (some 0) mpsc
run "bounded0_mpmc" (some 0) mpmc
run "bounded1_spsc" (some 1) spsc
run "bounded1_mpsc" (some 1) mpsc
run "bounded1_mpmc" (some 1) mpmc
run "boundedn_spsc" (some MESSAGES) spsc
run "boundedn_mpsc" (some MESSAGES) mpsc
run "boundedn_mpmc" (some MESSAGES) mpmc
run "boundedn_seq" (some MESSAGES) seq
run "unbounded_spsc" none spsc
run "unbounded_mpsc" none mpsc
run "unbounded_mpmc" none mpmc
run "unbounded_seq" none seq

View file

@ -462,3 +462,12 @@
run_config:
<<: *time
cmd: lean omega_stress.lean
- attributes:
description: channel.lean
tags: [fast]
run_config:
<<: *time
cmd: ./channel.lean.out
parse_output: true
build_config:
cmd: ./compile.sh channel.lean

View file

@ -15,15 +15,15 @@ open IO
let promise : Promise Nat ← Promise.new
assert! promise.result?.get = none
#eval do
let ch ← Std.Channel.new
#eval show IO _ from do
let ch ← Std.CloseableChannel.new
let out ← IO.mkRef #[]
ch.send 0
ch.sync.send 0
let drainFinished ← ch.forAsync fun x => out.modify (·.push x)
ch.send 1
ch.sync.send 1
ch.close
ch.send 2
assert! (← EIO.toBaseIO (ch.sync.send 2)) matches .error .closed
IO.wait drainFinished
assert! (← out.get) = #[0, 1]

View file

@ -0,0 +1,150 @@
import Std.Sync
open Std
def assertBEq [BEq α] [ToString α] (is should : α) : IO Unit := do
if is != should then
throw <| .userError s!"{is} should be {should}"
def closeClose (ch : CloseableChannel Nat) : IO Unit := do
assertBEq (← ch.isClosed) false
assertBEq ((← EIO.toBaseIO ch.close) matches .ok ()) true
assertBEq (← ch.isClosed) true
assertBEq ((← EIO.toBaseIO ch.close) matches .error .alreadyClosed) true
assertBEq (← ch.isClosed) true
def paired (ch : CloseableChannel Nat) : IO Unit := do
let sendTask ← ch.send 37
let recvTask ← ch.recv
assertBEq ((← IO.wait sendTask) matches .ok ()) true
assertBEq (← IO.wait recvTask) (some 37)
def syncPaired (ch : CloseableChannel.Sync Nat) : IO Unit := do
let sendTask ← IO.asTask (prio := .dedicated) (EIO.toBaseIO (ch.send 37))
let recvTask ← IO.asTask (prio := .dedicated) (ch.recv)
assertBEq ((← IO.ofExcept (← IO.wait sendTask)) matches .ok ()) true
assertBEq (← IO.ofExcept (← IO.wait recvTask)) (some 37)
def trySend (ch : CloseableChannel Nat) (capacity : Option Nat) : IO Unit := do
-- ready a receiver ahead of time
let recvTask ← ch.recv
assertBEq (← ch.trySend 37) true
assertBEq (← IO.wait recvTask) (some 37)
-- the unbounded CloseableChannel cannot go out of space so it is pointless to fill it up
let some capacity := capacity | return ()
for i in [:capacity] do
assertBEq (← ch.trySend i) true
assertBEq (← ch.trySend (capacity + 1)) false
def tryRecv (ch : CloseableChannel Nat) : IO Unit := do
assertBEq (← ch.tryRecv) none
let sendTask ← ch.send 37
assertBEq (← ch.tryRecv) (some 37)
assertBEq ((← IO.wait sendTask) matches .ok ()) true
def sendRecvClose (ch : CloseableChannel Nat) : IO Unit := do
let sendTask ← ch.send 37
assertBEq ((← EIO.toBaseIO ch.close) matches .ok ()) true
let recvTask ← ch.recv
assertBEq ((← IO.wait sendTask) matches .ok ()) true
assertBEq (← IO.wait recvTask) (some 37)
let sendTask ← ch.send 37
assertBEq ((← IO.wait sendTask) matches .error .closed) true
let recvTask ← ch.recv
assertBEq (← IO.wait recvTask) none
assertBEq (← ch.trySend 37) false
assertBEq (← ch.tryRecv) none
def sendIt (ch : CloseableChannel Nat) (messages : List Nat) : BaseIO (Task (Option Unit)) := do
match messages with
| [] => return .pure <| some ()
| msg :: messages =>
BaseIO.bindTask (← ch.send msg) fun
| .error .. =>
return .pure <| none
| .ok .. =>
sendIt ch messages
partial def recvIt (ch : CloseableChannel Nat) (messages : List Nat) : BaseIO (Task (List Nat)) := do
BaseIO.bindTask (← ch.recv) fun
| none => return .pure messages.reverse
| some msg => recvIt ch (msg :: messages)
def sendLots (ch : CloseableChannel Nat) : IO Unit := do
let messages := List.range 1000
let sendTask ← sendIt ch messages
let recvTask ← recvIt ch []
assertBEq (← IO.wait sendTask) (some ())
discard <| ch.close
assertBEq (← IO.wait recvTask) messages
def sendItSync (ch : CloseableChannel.Sync Nat) (messages : List Nat) : IO Unit := do
for msg in messages do
ch.send msg
return ()
def recvItSync (ch : CloseableChannel.Sync Nat) : IO (List Nat) := do
let mut messages := []
for msg in ch do
messages := msg :: messages
return messages.reverse
def sendLotsSync (ch : CloseableChannel.Sync Nat) : IO Unit := do
let messages := List.range 1000
let sendTask ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
let recvTask ← IO.asTask (prio := .dedicated) (recvItSync ch)
IO.ofExcept (← IO.wait sendTask)
discard <| ch.close
assertBEq (← IO.ofExcept (← IO.wait recvTask)) messages
partial def sendLotsMulti (ch : CloseableChannel Nat) : IO Unit := do
let messages := List.range 1000
let sendTask1 ← sendIt ch messages
let sendTask2 ← sendIt ch messages
let recvTask1 ← recvIt ch []
let recvTask2 ← recvIt ch []
assertBEq (← IO.wait sendTask1) (some ())
assertBEq (← IO.wait sendTask2) (some ())
discard <| ch.close
let msg1 ← IO.wait recvTask1
let msg2 ← IO.wait recvTask2
assertBEq (msg1.sum + msg2.sum) (2 * messages.sum)
partial def sendLotsMultiSync (ch : CloseableChannel.Sync Nat) : IO Unit := do
let messages := List.range 1000
let sendTask1 ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
let sendTask2 ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
let recvTask1 ← IO.asTask (prio := .dedicated) (recvItSync ch)
let recvTask2 ← IO.asTask (prio := .dedicated) (recvItSync ch)
IO.ofExcept (← IO.wait sendTask1)
IO.ofExcept (← IO.wait sendTask2)
discard <| ch.close
let msg1 ← IO.ofExcept (← IO.wait recvTask1)
let msg2 ← IO.ofExcept (← IO.wait recvTask2)
assertBEq (msg1.sum + msg2.sum) (2 * messages.sum)
def testIt (capacity : Option Nat) : IO Unit := do
paired (← CloseableChannel.new capacity)
syncPaired (← CloseableChannel.new capacity).sync
closeClose (← CloseableChannel.new capacity)
trySend (← CloseableChannel.new capacity) capacity
tryRecv (← CloseableChannel.new capacity)
sendRecvClose (← CloseableChannel.new capacity)
sendLots (← CloseableChannel.new capacity)
sendLotsSync (← CloseableChannel.new capacity).sync
sendLotsMulti (← CloseableChannel.new capacity)
sendLotsMultiSync (← CloseableChannel.new capacity).sync
def suite : IO Unit := do
testIt none
testIt (some 0)
testIt (some 1)
testIt (some 8)
testIt (some 128)
#eval suite