refactor: more complete channel implementation for Std.Channel (#7819)
This PR extends `Std.Channel` to provide a full sync and async API, as well as unbounded, zero sized and bounded channels. A few notes on the implementation: - the bounded channel is inspired by [Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub) though currently doesn't do any of the lock-free optimizations - @mhuisi convinced me that having a non-closable channel may be a good idea as this alleviates the need for error handling which is very annoying when working with `Task`. This does complicate the API a little bit and I'm not quite sure whether this is a choice we want users to give. An alternative to this would be to just write `send!` that panics on sending to a closed channel (receiving from a closed channel is not an error), this is for example the behavior that golang goes with.
This commit is contained in:
parent
85a0232e87
commit
dd7ca772d8
9 changed files with 956 additions and 246 deletions
|
|
@ -34,7 +34,6 @@ import Init.Data.Stream
|
|||
import Init.Data.Prod
|
||||
import Init.Data.AC
|
||||
import Init.Data.Queue
|
||||
import Init.Data.Channel
|
||||
import Init.Data.Sum
|
||||
import Init.Data.BEq
|
||||
import Init.Data.Subtype
|
||||
|
|
|
|||
|
|
@ -1,149 +0,0 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Gabriel Ebner
|
||||
-/
|
||||
prelude
|
||||
import Init.Data.Queue
|
||||
import Init.System.Promise
|
||||
import Init.System.Mutex
|
||||
|
||||
set_option linter.deprecated false
|
||||
|
||||
namespace IO
|
||||
|
||||
/--
|
||||
Internal state of an `Channel`.
|
||||
|
||||
We maintain the invariant that at all times either `consumers` or `values` is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.State from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
structure Channel.State (α : Type) where
|
||||
values : Std.Queue α := ∅
|
||||
consumers : Std.Queue (Promise (Option α)) := ∅
|
||||
closed := false
|
||||
deriving Inhabited
|
||||
|
||||
/--
|
||||
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
|
||||
|
||||
A channel can be closed. Once it is closed, all `send`s are ignored, and
|
||||
`recv?` returns `none` once the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel (α : Type) : Type := Mutex (Channel.State α)
|
||||
|
||||
instance : Nonempty (Channel α) :=
|
||||
inferInstanceAs (Nonempty (Mutex _))
|
||||
|
||||
/-- Creates a new `Channel`. -/
|
||||
@[deprecated "Use Std.Channel.new from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.new : BaseIO (Channel α) :=
|
||||
Mutex.new {}
|
||||
|
||||
/--
|
||||
Sends a message on an `Channel`.
|
||||
|
||||
This function does not block.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.send from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
if st.closed then return
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
|
||||
/--
|
||||
Closes an `Channel`.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.close from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.close (ch : Channel α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with closed := true, consumers := ∅ }
|
||||
|
||||
/--
|
||||
Receives a message, without blocking.
|
||||
The returned task waits for the message.
|
||||
Every message is only received once.
|
||||
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
|
||||
ch.atomically do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return .pure a
|
||||
else if !st.closed then
|
||||
let promise ← Promise.new
|
||||
set { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result
|
||||
else
|
||||
return .pure none
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every messages received on `ch`.
|
||||
|
||||
Note that if this function is called twice, each `forAsync` only gets half the messages.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.forAsync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
|
||||
| none => return .pure ()
|
||||
| some v => do f v; ch.forAsync f prio
|
||||
|
||||
/--
|
||||
Receives all currently queued messages from the channel.
|
||||
|
||||
Those messages are dequeued and will not be returned by `recv?`.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.recvAllCurrent from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
|
||||
ch.atomically do
|
||||
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
|
||||
|
||||
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
|
||||
@[deprecated "Use Std.Channel.Sync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.Sync := Channel
|
||||
|
||||
/--
|
||||
Accesses synchronous (blocking) version of channel operations.
|
||||
|
||||
For example, `ch.sync.recv?` blocks until the next message,
|
||||
and `for msg in ch.sync do ...` iterates synchronously over the channel.
|
||||
These functions should only be used in dedicated threads.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.sync from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
|
||||
|
||||
/--
|
||||
Synchronously receives a message from the channel.
|
||||
|
||||
Every message is only received once.
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
-/
|
||||
@[deprecated "Use Std.Channel.Sync.recv? from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← Channel.recv? ch)
|
||||
|
||||
@[deprecated "Use Std.Channel.Sync.forIn from Std.Sync.Channel instead" (since := "2024-12-02")]
|
||||
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv? with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
|
|
@ -111,6 +111,7 @@ inductive Message where
|
|||
| response (id : RequestID) (result : Json)
|
||||
/-- A non-successful response. -/
|
||||
| responseError (id : RequestID) (code : ErrorCode) (message : String) (data? : Option Json)
|
||||
deriving Inhabited
|
||||
|
||||
def Batch := Array Message
|
||||
|
||||
|
|
|
|||
|
|
@ -207,7 +207,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
|||
stickyInteractiveDiagnostics ++ docInteractiveDiagnostics
|
||||
|>.map (·.toDiagnostic)
|
||||
let notification := mkPublishDiagnosticsNotification doc.meta diagnostics
|
||||
ctx.chanOut.send notification
|
||||
ctx.chanOut.sync.send notification
|
||||
|
||||
open Language in
|
||||
/--
|
||||
|
|
@ -239,7 +239,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
|||
publishDiagnostics ctx doc
|
||||
-- This will overwrite existing ilean info for the file, in case something
|
||||
-- went wrong during the incremental updates.
|
||||
ctx.chanOut.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
|
||||
ctx.chanOut.sync.send (← mkIleanInfoFinalNotification doc.meta st.allInfoTrees)
|
||||
return ()
|
||||
where
|
||||
/--
|
||||
|
|
@ -312,7 +312,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
|||
if let some itree := node.element.infoTree? then
|
||||
let mut newInfoTrees := (← get).newInfoTrees.push itree
|
||||
if (← get).hasBlocked then
|
||||
ctx.chanOut.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
|
||||
ctx.chanOut.sync.send (← mkIleanInfoUpdateNotification doc.meta newInfoTrees)
|
||||
newInfoTrees := #[]
|
||||
modify fun st => { st with newInfoTrees, allInfoTrees := st.allInfoTrees.push itree }
|
||||
|
||||
|
|
@ -329,7 +329,7 @@ This option can only be set on the command line, not in the lakefile or via `set
|
|||
| none => rs.push r
|
||||
let ranges := ranges.map (·.toLspRange doc.meta.text)
|
||||
let notifs := ranges.map ({ range := ·, kind := .processing })
|
||||
ctx.chanOut.send <| mkFileProgressNotification doc.meta notifs
|
||||
ctx.chanOut.sync.send <| mkFileProgressNotification doc.meta notifs
|
||||
|
||||
end Elab
|
||||
|
||||
|
|
@ -389,9 +389,9 @@ def setupImports
|
|||
severity? := DiagnosticSeverity.information
|
||||
message := stderrLine
|
||||
}
|
||||
chanOut.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
|
||||
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[progressDiagnostic]
|
||||
-- clear progress notifications in the end
|
||||
chanOut.send <| mkPublishDiagnosticsNotification meta #[]
|
||||
chanOut.sync.send <| mkPublishDiagnosticsNotification meta #[]
|
||||
match fileSetupResult.kind with
|
||||
| .importsOutOfDate =>
|
||||
return .error {
|
||||
|
|
@ -525,7 +525,7 @@ section ServerRequests
|
|||
(freshRequestId, freshRequestId + 1)
|
||||
let responseTask ← ctx.initPendingServerRequest responseType freshRequestId
|
||||
let r : JsonRpc.Request paramType := ⟨freshRequestId, method, param⟩
|
||||
ctx.chanOut.send r
|
||||
ctx.chanOut.sync.send r
|
||||
return responseTask
|
||||
|
||||
def sendUntypedServerRequest
|
||||
|
|
@ -679,7 +679,7 @@ section MessageHandling
|
|||
let availableImports ← ImportCompletion.collectAvailableImports
|
||||
let lastRequestTimestampMs ← IO.monoMsNow
|
||||
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
|
||||
ctx.chanOut.send <| .response id (toJson completions)
|
||||
ctx.chanOut.sync.send <| .response id (toJson completions)
|
||||
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
|
||||
|
||||
| some task => ServerTask.IO.mapTaskCostly (t := task) fun (result : Except Error AvailableImportsCache) => do
|
||||
|
|
@ -689,7 +689,7 @@ section MessageHandling
|
|||
availableImports ← ImportCompletion.collectAvailableImports
|
||||
lastRequestTimestampMs := timestampNowMs
|
||||
let completions := ImportCompletion.find text st.doc.initSnap.stx params availableImports
|
||||
ctx.chanOut.send <| .response id (toJson completions)
|
||||
ctx.chanOut.sync.send <| .response id (toJson completions)
|
||||
pure { availableImports, lastRequestTimestampMs : AvailableImportsCache }
|
||||
|
||||
def handleStatefulPreRequestSpecialCases (id : RequestID) (method : String) (params : Json) : WorkerM Bool := do
|
||||
|
|
@ -701,7 +701,7 @@ section MessageHandling
|
|||
| "$/lean/rpc/connect" =>
|
||||
let ps ← parseParams RpcConnectParams params
|
||||
let resp ← handleRpcConnect ps
|
||||
ctx.chanOut.send <| .response id (toJson resp)
|
||||
ctx.chanOut.sync.send <| .response id (toJson resp)
|
||||
return true
|
||||
| "textDocument/completion" =>
|
||||
let params ← parseParams CompletionParams params
|
||||
|
|
@ -714,7 +714,7 @@ section MessageHandling
|
|||
| _ =>
|
||||
return false
|
||||
catch e =>
|
||||
ctx.chanOut.send <| .responseError id .internalError (toString e) none
|
||||
ctx.chanOut.sync.send <| .responseError id .internalError (toString e) none
|
||||
return true
|
||||
|
||||
open Widget RequestM Language in
|
||||
|
|
@ -836,7 +836,7 @@ section MessageHandling
|
|||
emitResponse ctx (isComplete := false) <| e.toLspResponseError id
|
||||
where
|
||||
emitResponse (ctx : WorkerContext) (m : JsonRpc.Message) (isComplete : Bool) : IO Unit := do
|
||||
ctx.chanOut.send m
|
||||
ctx.chanOut.sync.send m
|
||||
let timestamp ← IO.monoMsNow
|
||||
ctx.modifyPartialHandler method fun h => { h with
|
||||
requestsInFlight := h.requestsInFlight - 1
|
||||
|
|
|
|||
|
|
@ -1,137 +1,720 @@
|
|||
/-
|
||||
Copyright (c) 2022 Microsoft Corporation. All rights reserved.
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Gabriel Ebner
|
||||
Authors: Henrik Böving
|
||||
-/
|
||||
prelude
|
||||
import Init.System.Promise
|
||||
import Init.Data.Queue
|
||||
import Std.Sync.Mutex
|
||||
|
||||
/-!
|
||||
This module contains the implementation of `Std.Channel`. `Std.Channel` is a multi-producer
|
||||
multi-consumer FIFO channel that offers both bounded and unbounded buffering as well as synchronous
|
||||
and asynchronous APIs.
|
||||
|
||||
Additionally `Std.CloseableChannel` is provided in case closing the channel is of interest.
|
||||
The two are distinct as the non closable `Std.Channel` can never throw errors which makes
|
||||
for cleaner code.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
|
||||
/--
|
||||
Internal state of an `Channel`.
|
||||
|
||||
We maintain the invariant that at all times either `consumers` or `values` is empty.
|
||||
-/
|
||||
structure Channel.State (α : Type) where
|
||||
values : Std.Queue α := ∅
|
||||
consumers : Std.Queue (IO.Promise (Option α)) := ∅
|
||||
closed := false
|
||||
deriving Inhabited
|
||||
namespace CloseableChannel
|
||||
|
||||
/--
|
||||
FIFO channel with unbounded buffer, where `recv?` returns a `Task`.
|
||||
|
||||
A channel can be closed. Once it is closed, all `send`s are ignored, and
|
||||
`recv?` returns `none` once the queue is empty.
|
||||
Errors that may be thrown while interacting with the channel API.
|
||||
-/
|
||||
def Channel (α : Type) : Type := Mutex (Channel.State α)
|
||||
inductive Error where
|
||||
/--
|
||||
Tried to send to a closed channel.
|
||||
-/
|
||||
| closed
|
||||
/--
|
||||
Tried to close an already closed channel.
|
||||
-/
|
||||
| alreadyClosed
|
||||
deriving Repr, DecidableEq, Hashable
|
||||
|
||||
instance : Nonempty (Channel α) :=
|
||||
inferInstanceAs (Nonempty (Mutex _))
|
||||
instance : ToString Error where
|
||||
toString
|
||||
| .closed => "trying to send on an already closed channel"
|
||||
| .alreadyClosed => "trying to close an already closed channel"
|
||||
|
||||
/-- Creates a new `Channel`. -/
|
||||
def Channel.new : BaseIO (Channel α) :=
|
||||
Mutex.new {}
|
||||
instance : MonadLift (EIO Error) IO where
|
||||
monadLift x := EIO.toIO (.userError <| toString ·) x
|
||||
|
||||
/--
|
||||
Sends a message on an `Channel`.
|
||||
|
||||
This function does not block.
|
||||
The central state structure for an unbounded channel, maintains the following invariants:
|
||||
1. `values = ∅ ∨ consumers = ∅`
|
||||
2. `closed = true → consumers = ∅`
|
||||
-/
|
||||
def Channel.send (ch : Channel α) (v : α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
private structure Unbounded.State (α : Type) where
|
||||
/--
|
||||
Values pushed into the channel that are waiting to be consumed.
|
||||
-/
|
||||
values : Std.Queue α
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be
|
||||
resolved to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
deriving Nonempty
|
||||
|
||||
private structure Unbounded (α : Type) where
|
||||
state : Mutex (Unbounded.State α)
|
||||
deriving Nonempty
|
||||
|
||||
namespace Unbounded
|
||||
|
||||
private def new : BaseIO (Unbounded α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
values := ∅
|
||||
consumers := ∅
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
private def trySend (ch : Unbounded α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then return
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
if st.closed then
|
||||
return false
|
||||
else if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
set { st with values := st.values.enqueue v }
|
||||
return true
|
||||
|
||||
/--
|
||||
Closes an `Channel`.
|
||||
-/
|
||||
def Channel.close (ch : Channel α) : BaseIO Unit :=
|
||||
ch.atomically do
|
||||
private def send (ch : Unbounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
if ← Unbounded.trySend ch v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
return .pure <| .error .closed
|
||||
|
||||
private def close (ch : Unbounded α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with closed := true, consumers := ∅ }
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Unbounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Unbounded.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return some a
|
||||
else
|
||||
return none
|
||||
|
||||
private def tryRecv (ch : Unbounded α) : BaseIO (Option α) :=
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private def recv (ch : Unbounded α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
if let some val ← tryRecv' then
|
||||
return .pure <| some val
|
||||
else if (← get).closed then
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
|
||||
end Unbounded
|
||||
|
||||
/--
|
||||
Receives a message, without blocking.
|
||||
The returned task waits for the message.
|
||||
Every message is only received once.
|
||||
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
The central state structure for a zero buffer channel, maintains the following invariants:
|
||||
1. `producers = ∅ ∨ consumers = ∅`
|
||||
2. `closed = true → consumers = ∅`
|
||||
-/
|
||||
def Channel.recv? (ch : Channel α) : BaseIO (Task (Option α)) :=
|
||||
ch.atomically do
|
||||
private structure Zero.State (α : Type) where
|
||||
/--
|
||||
Producers that are blocked on a consumer taking their value.
|
||||
-/
|
||||
producers : Std.Queue (α × IO.Promise Bool)
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value. The `IO.Promise` will be resolved
|
||||
to `none` if the channel closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise (Option α))
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
|
||||
private structure Zero (α : Type) where
|
||||
state : Mutex (Zero.State α)
|
||||
|
||||
namespace Zero
|
||||
|
||||
private def new : BaseIO (Zero α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
producers := ∅
|
||||
consumers := ∅
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
/--
|
||||
Precondition: The channel must not be closed.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Zero.State α) BaseIO Bool := do
|
||||
let st ← get
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve (some v)
|
||||
set { st with consumers }
|
||||
return true
|
||||
else
|
||||
return false
|
||||
|
||||
private def trySend (ch : Zero α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return false
|
||||
else
|
||||
trySend' v
|
||||
|
||||
private def send (ch : Zero α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return .pure <| .error .closed
|
||||
else if ← trySend' v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with producers := st.producers.enqueue (v, promise) }
|
||||
return promise.result?.map (sync := true)
|
||||
fun
|
||||
| none | some false => .error .closed
|
||||
| some true => .ok ()
|
||||
|
||||
private def close (ch : Zero α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if let some (a, values) := st.values.dequeue? then
|
||||
set { st with values }
|
||||
return .pure a
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve none
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Zero α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Zero.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if let some ((val, promise), producers) := st.producers.dequeue? then
|
||||
set { st with producers }
|
||||
promise.resolve true
|
||||
return some val
|
||||
else
|
||||
return none
|
||||
|
||||
private def tryRecv (ch : Zero α) : BaseIO (Option α) := do
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private def recv (ch : Zero α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if let some val ← tryRecv' then
|
||||
return .pure <| some val
|
||||
else if !st.closed then
|
||||
let promise ← IO.Promise.new
|
||||
set { st with consumers := st.consumers.enqueue promise }
|
||||
return promise.result?.map (sync := true) (·.bind id)
|
||||
else
|
||||
return .pure none
|
||||
return .pure <| none
|
||||
|
||||
end Zero
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every messages received on `ch`.
|
||||
The central state structure for a bounded channel, maintains the following invariants:
|
||||
1. `0 < capacity`
|
||||
2. `0 < bufCount → consumers = ∅`
|
||||
3. `bufCount < capacity → producers = ∅`
|
||||
4. `producers = ∅ ∨ consumers = ∅`, implied by 1, 2 and 3.
|
||||
5. `bufCount` corresponds to the amount of slots in `buf` that are `some`.
|
||||
6. `sendIdx = (recvIdx + bufCount) % capacity`. However all four of these values still get tracked
|
||||
as there is potential to make a non-blocking send lock-free in the future with this approach.
|
||||
7. `closed = true → consumers = ∅`
|
||||
|
||||
Note that if this function is called twice, each `forAsync` only gets half the messages.
|
||||
While it (currently) lacks the partial lock-freeness of go channels, the protocol is based on
|
||||
[Go channels on steroids](https://docs.google.com/document/d/1yIAYmbvL3JxOKOjuCyon7JhW4cSv1wy5hC0ApeGMV9s/pub)
|
||||
as well as its [implementation](https://go.dev/src/runtime/chan.go).
|
||||
-/
|
||||
partial def Channel.forAsync (f : α → BaseIO Unit) (ch : Channel α)
|
||||
private structure Bounded.State (α : Type) where
|
||||
/--
|
||||
Producers that are blocked on a consumer taking their value as there was no buffer space
|
||||
available when they tried to enqueue.
|
||||
-/
|
||||
producers : Std.Queue (IO.Promise Bool)
|
||||
/--
|
||||
Consumers that are blocked on a producer providing them a value, as there was no value
|
||||
enqueued when they tried to dequeue. The `IO.Promise` will be resolved to `false` if the channel
|
||||
closes.
|
||||
-/
|
||||
consumers : Std.Queue (IO.Promise Bool)
|
||||
/--
|
||||
The capacity of the buffer space.
|
||||
-/
|
||||
capacity : Nat
|
||||
/--
|
||||
The buffer space for the channel, slots with `some v` contain a value that is waiting for
|
||||
consumption, the slots with `none` are free for enqueueing.
|
||||
|
||||
Note that this is a `Vector` of `IO.Ref (Option α)` as the `buf` itself is shared across threads
|
||||
and would thus keep getting copied if it was a `Vector (Option α)` instead.
|
||||
-/
|
||||
buf : Vector (IO.Ref (Option α)) capacity
|
||||
/--
|
||||
How many slots in `buf` are currently used, this is used to disambiguate between an empty and a
|
||||
full buffer without sacrificing a slot for indicating that.
|
||||
-/
|
||||
bufCount : Nat
|
||||
/--
|
||||
The slot in `buf` that the next send will happen to.
|
||||
-/
|
||||
sendIdx : Nat
|
||||
hsend : sendIdx < capacity
|
||||
/--
|
||||
The slot in `buf` that the next receive will happen from.
|
||||
-/
|
||||
recvIdx : Nat
|
||||
hrecv : recvIdx < capacity
|
||||
/--
|
||||
Whether the channel is closed already.
|
||||
-/
|
||||
closed : Bool
|
||||
|
||||
private structure Bounded (α : Type) where
|
||||
state : Mutex (Bounded.State α)
|
||||
|
||||
namespace Bounded
|
||||
|
||||
private def new (capacity : Nat) (hcap : 0 < capacity) : BaseIO (Bounded α) := do
|
||||
return {
|
||||
state := ← Mutex.new {
|
||||
producers := ∅
|
||||
consumers := ∅
|
||||
capacity := capacity
|
||||
buf := ← Vector.range capacity |>.mapM (fun _ => IO.mkRef none)
|
||||
bufCount := 0
|
||||
sendIdx := 0
|
||||
hsend := hcap
|
||||
recvIdx := 0
|
||||
hrecv := hcap
|
||||
closed := false
|
||||
}
|
||||
}
|
||||
|
||||
@[inline]
|
||||
private def incMod (idx : Nat) (cap : Nat) : Nat :=
|
||||
if idx + 1 = cap then
|
||||
0
|
||||
else
|
||||
idx + 1
|
||||
|
||||
private theorem incMod_lt {idx cap : Nat} (h : idx < cap) : incMod idx cap < cap := by
|
||||
unfold incMod
|
||||
split <;> omega
|
||||
|
||||
/--
|
||||
Precondition: The channel must not be closed.
|
||||
-/
|
||||
private def trySend' (v : α) : AtomicT (Bounded.State α) BaseIO Bool := do
|
||||
let mut st ← get
|
||||
if st.bufCount = st.capacity then
|
||||
return false
|
||||
else
|
||||
st.buf[st.sendIdx]'st.hsend |>.set (some v)
|
||||
st := { st with
|
||||
bufCount := st.bufCount + 1
|
||||
sendIdx := incMod st.sendIdx st.capacity
|
||||
hsend := incMod_lt st.hsend
|
||||
}
|
||||
|
||||
if let some (consumer, consumers) := st.consumers.dequeue? then
|
||||
consumer.resolve true
|
||||
st := { st with consumers }
|
||||
|
||||
set st
|
||||
|
||||
return true
|
||||
|
||||
private def trySend (ch : Bounded α) (v : α) : BaseIO Bool := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return false
|
||||
else
|
||||
trySend' v
|
||||
|
||||
private partial def send (ch : Bounded α) (v : α) : BaseIO (Task (Except Error Unit)) := do
|
||||
ch.state.atomically do
|
||||
if (← get).closed then
|
||||
return .pure <| .error .closed
|
||||
else if ← trySend' v then
|
||||
return .pure <| .ok ()
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with producers := st.producers.enqueue promise }
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.send ch v
|
||||
else
|
||||
return .pure <| .error .closed
|
||||
|
||||
private def close (ch : Bounded α) : EIO Error Unit := do
|
||||
ch.state.atomically do
|
||||
let st ← get
|
||||
if st.closed then throw .alreadyClosed
|
||||
for consumer in st.consumers.toArray do consumer.resolve false
|
||||
set { st with consumers := ∅, closed := true }
|
||||
return ()
|
||||
|
||||
private def isClosed (ch : Bounded α) : BaseIO Bool :=
|
||||
ch.state.atomically do
|
||||
return (← get).closed
|
||||
|
||||
private def tryRecv' : AtomicT (Bounded.State α) BaseIO (Option α) := do
|
||||
let st ← get
|
||||
if st.bufCount == 0 then
|
||||
return none
|
||||
else
|
||||
let val ← st.buf[st.recvIdx]'st.hrecv |>.swap none
|
||||
let nextRecvIdx := incMod st.recvIdx st.capacity
|
||||
set { st with
|
||||
bufCount := st.bufCount - 1
|
||||
recvIdx := nextRecvIdx,
|
||||
hrecv := incMod_lt st.hrecv
|
||||
}
|
||||
return val
|
||||
|
||||
private def tryRecv (ch : Bounded α) : BaseIO (Option α) :=
|
||||
ch.state.atomically do
|
||||
tryRecv'
|
||||
|
||||
private partial def recv (ch : Bounded α) : BaseIO (Task (Option α)) := do
|
||||
ch.state.atomically do
|
||||
if let some val ← tryRecv' then
|
||||
let st ← get
|
||||
if let some (producer, producers) := (← get).producers.dequeue? then
|
||||
producer.resolve true
|
||||
set { st with producers }
|
||||
return .pure <| some val
|
||||
else if (← get).closed then
|
||||
return .pure none
|
||||
else
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue promise }
|
||||
BaseIO.bindTask promise.result? fun res => do
|
||||
if res.getD false then
|
||||
Bounded.recv ch
|
||||
else
|
||||
return .pure none
|
||||
|
||||
end Bounded
|
||||
|
||||
/--
|
||||
This type represents all flavors of channels that we have available.
|
||||
-/
|
||||
private inductive Flavors (α : Type) where
|
||||
| unbounded (ch : Unbounded α)
|
||||
| zero (ch : Zero α)
|
||||
| bounded (ch : Bounded α)
|
||||
deriving Nonempty
|
||||
|
||||
end CloseableChannel
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and an asynchronous API, to switch into synchronous mode use `CloseableChannel.sync`.
|
||||
|
||||
Additionally `Std.CloseableChannel` can be closed if necessary, unlike `Std.Channel`.
|
||||
This introduces a need for error handling in some cases, thus it is usually easier to use
|
||||
`Std.Channel` if applicable.
|
||||
-/
|
||||
def CloseableChannel (α : Type) : Type := CloseableChannel.Flavors α
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
|
||||
and is not actually different from the original channel.
|
||||
|
||||
Additionally `Std.CloseableChannel.Sync` can be closed if necessary, unlike `Std.Channel.Sync`.
|
||||
This introduces the need to handle errors in some cases, thus it is usually easier to use
|
||||
`Std.Channel` if applicable.
|
||||
-/
|
||||
def CloseableChannel.Sync (α : Type) : Type := CloseableChannel α
|
||||
|
||||
instance : Nonempty (CloseableChannel α) :=
|
||||
inferInstanceAs (Nonempty (CloseableChannel.Flavors α))
|
||||
|
||||
instance : Nonempty (CloseableChannel.Sync α) :=
|
||||
inferInstanceAs (Nonempty (CloseableChannel α))
|
||||
|
||||
namespace CloseableChannel
|
||||
|
||||
/--
|
||||
Create a new channel, if:
|
||||
- `capacity` is `none` it will be unbounded (the default)
|
||||
- `capacity` is `some 0` it will always force a rendezvous between sender and receiver
|
||||
- `capacity` is `some n` with `n > 0` it will use a buffer of size `n` and begin blocking once it
|
||||
is filled
|
||||
-/
|
||||
def new (capacity : Option Nat := none) : BaseIO (CloseableChannel α) := do
|
||||
match capacity with
|
||||
| none => return .unbounded (← CloseableChannel.Unbounded.new)
|
||||
| some 0 => return .zero (← CloseableChannel.Zero.new)
|
||||
| some (n + 1) => return .bounded (← CloseableChannel.Bounded.new (n + 1) (by omega))
|
||||
|
||||
/--
|
||||
Try to send a value to the channel, if this can be completed right away without blocking return
|
||||
`true`, otherwise don't send the value and return `false`.
|
||||
-/
|
||||
def trySend (ch : CloseableChannel α) (v : α) : BaseIO Bool :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.trySend ch v
|
||||
| .zero ch => CloseableChannel.Zero.trySend ch v
|
||||
| .bounded ch => CloseableChannel.Bounded.trySend ch v
|
||||
|
||||
/--
|
||||
Send a value through the channel, returning a task that will resolve once the transmission could be
|
||||
completed. Note that the task may resolve to `Except.error` if the channel was closed before it
|
||||
could be completed.
|
||||
-/
|
||||
def send (ch : CloseableChannel α) (v : α) : BaseIO (Task (Except Error Unit)) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.send ch v
|
||||
| .zero ch => CloseableChannel.Zero.send ch v
|
||||
| .bounded ch => CloseableChannel.Bounded.send ch v
|
||||
|
||||
/--
|
||||
Closes the channel, returns `Except.ok` when called the first time, otherwise `Except.error`.
|
||||
When a channel is closed:
|
||||
- no new values can be sent successfully anymore
|
||||
- all blocked consumers are resolved to `none` (as no new messages can be sent they will never
|
||||
resolve)
|
||||
- if there are already values waiting to be received they can still be received by subsequent `recv`
|
||||
calls
|
||||
-/
|
||||
def close (ch : CloseableChannel α) : EIO Error Unit :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.close ch
|
||||
| .zero ch => CloseableChannel.Zero.close ch
|
||||
| .bounded ch => CloseableChannel.Bounded.close ch
|
||||
|
||||
/--
|
||||
Return `true` if the channel is closed.
|
||||
-/
|
||||
def isClosed (ch : CloseableChannel α) : BaseIO Bool :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.isClosed ch
|
||||
| .zero ch => CloseableChannel.Zero.isClosed ch
|
||||
| .bounded ch => CloseableChannel.Bounded.isClosed ch
|
||||
|
||||
/--
|
||||
Try to receive a value from the channel, if this can be completed right away without blocking return
|
||||
`some value`, otherwise return `none`.
|
||||
-/
|
||||
def tryRecv (ch : CloseableChannel α) : BaseIO (Option α) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.tryRecv ch
|
||||
| .zero ch => CloseableChannel.Zero.tryRecv ch
|
||||
| .bounded ch => CloseableChannel.Bounded.tryRecv ch
|
||||
|
||||
/--
|
||||
Receive a value from the channel, returning a task that will resolve once the transmission could be
|
||||
completed. Note that the task may resolve to `none` if the channel was closed before it could be
|
||||
completed.
|
||||
-/
|
||||
def recv (ch : CloseableChannel α) : BaseIO (Task (Option α)) :=
|
||||
match ch with
|
||||
| .unbounded ch => CloseableChannel.Unbounded.recv ch
|
||||
| .zero ch => CloseableChannel.Zero.recv ch
|
||||
| .bounded ch => CloseableChannel.Bounded.recv ch
|
||||
|
||||
/--
|
||||
`ch.forAsync f` calls `f` for every message received on `ch`.
|
||||
|
||||
Note that if this function is called twice, each message will only arrive at exactly one invocation.
|
||||
-/
|
||||
partial def forAsync (f : α → BaseIO Unit) (ch : CloseableChannel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv?) fun
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv) fun
|
||||
| none => return .pure ()
|
||||
| some v => do f v; ch.forAsync f prio
|
||||
|
||||
/--
|
||||
Receives all currently queued messages from the channel.
|
||||
|
||||
Those messages are dequeued and will not be returned by `recv?`.
|
||||
This function is a no-op and just a convenient way to expose the synchronous API of the channel.
|
||||
-/
|
||||
def Channel.recvAllCurrent (ch : Channel α) : BaseIO (Array α) :=
|
||||
ch.atomically do
|
||||
modifyGet fun st => (st.values.toArray, { st with values := ∅ })
|
||||
@[inline]
|
||||
def sync (ch : CloseableChannel α) : CloseableChannel.Sync α := ch
|
||||
|
||||
/-- Type tag for synchronous (blocking) operations on a `Channel`. -/
|
||||
def Channel.Sync := Channel
|
||||
namespace Sync
|
||||
|
||||
@[inherit_doc CloseableChannel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Sync α) := CloseableChannel.new capacity
|
||||
|
||||
@[inherit_doc CloseableChannel.trySend, inline]
|
||||
def trySend (ch : Sync α) (v : α) : BaseIO Bool := CloseableChannel.trySend ch v
|
||||
|
||||
/--
|
||||
Accesses synchronous (blocking) version of channel operations.
|
||||
|
||||
For example, `ch.sync.recv?` blocks until the next message,
|
||||
and `for msg in ch.sync do ...` iterates synchronously over the channel.
|
||||
These functions should only be used in dedicated threads.
|
||||
Send a value through the channel, blocking until the transmission could be completed. Note that this
|
||||
function may throw an error when trying to send to an already closed channel.
|
||||
-/
|
||||
def Channel.sync (ch : Channel α) : Channel.Sync α := ch
|
||||
def send (ch : Sync α) (v : α) : EIO Error Unit := do
|
||||
EIO.ofExcept (← IO.wait (← CloseableChannel.send ch v))
|
||||
|
||||
@[inherit_doc CloseableChannel.close, inline]
|
||||
def close (ch : Sync α) : EIO Error Unit := CloseableChannel.close ch
|
||||
|
||||
@[inherit_doc CloseableChannel.isClosed, inline]
|
||||
def isClosed (ch : Sync α) : BaseIO Bool := CloseableChannel.isClosed ch
|
||||
|
||||
@[inherit_doc CloseableChannel.tryRecv, inline]
|
||||
def tryRecv (ch : Sync α) : BaseIO (Option α) := CloseableChannel.tryRecv ch
|
||||
|
||||
/--
|
||||
Synchronously receives a message from the channel.
|
||||
|
||||
Every message is only received once.
|
||||
Returns `none` if the channel is closed and the queue is empty.
|
||||
Receive a value from the channel, blocking unitl the transmission could be completed. Note that the
|
||||
return value may be `none` if the channel was closed before it could be completed.
|
||||
-/
|
||||
def Channel.Sync.recv? (ch : Channel.Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← Channel.recv? ch)
|
||||
def recv (ch : Sync α) : BaseIO (Option α) := do
|
||||
IO.wait (← CloseableChannel.recv ch)
|
||||
|
||||
private partial def Channel.Sync.forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Channel.Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv? with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
private partial def forIn [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
match ← ch.recv with
|
||||
| some a =>
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
| none => pure b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Channel.Sync α) α where
|
||||
instance [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
|
||||
end Sync
|
||||
|
||||
end CloseableChannel
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and an asynchronous API, to switch into synchronous mode use `Channel.sync`.
|
||||
|
||||
If a channel needs to be closed to indicate some sort of completion event use `Std.CloseableChannel`
|
||||
instead. Note that `Std.CloseableChannel` introduces a need for error handling in some cases, thus
|
||||
`Std.Channel` is usually easier to use if applicable.
|
||||
-/
|
||||
structure Channel (α : Type) where
|
||||
private mk ::
|
||||
private inner : CloseableChannel α
|
||||
deriving Nonempty
|
||||
|
||||
/--
|
||||
A multi-producer multi-consumer FIFO channel that offers both bounded and unbounded buffering
|
||||
and a synchronous API. This type acts as a convenient layer to use a channel in a blocking fashion
|
||||
and is not actually different from the original channel.
|
||||
|
||||
If a channel needs to be closed to indicate some sort of completion event use
|
||||
`Std.CloseableChannel.Sync` instead. Note that `Std.CloseableChannel.Sync` introduces a need for error
|
||||
handling in some cases, thus `Std.Channel.Sync` is usually easier to use if applicable.
|
||||
-/
|
||||
def Channel.Sync (α : Type) : Type := Channel α
|
||||
|
||||
instance : Nonempty (Channel.Sync α) :=
|
||||
inferInstanceAs (Nonempty (Channel α))
|
||||
|
||||
namespace Channel
|
||||
|
||||
@[inherit_doc CloseableChannel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Channel α) := do
|
||||
return ⟨← CloseableChannel.new capacity⟩
|
||||
|
||||
@[inherit_doc CloseableChannel.trySend, inline]
|
||||
def trySend (ch : Channel α) (v : α) : BaseIO Bool :=
|
||||
CloseableChannel.trySend ch.inner v
|
||||
|
||||
/--
|
||||
Send a value through the channel, returning a task that will resolve once the transmission could be
|
||||
completed.
|
||||
-/
|
||||
def send (ch : Channel α) (v : α) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (sync := true) (← CloseableChannel.send ch.inner v)
|
||||
fun
|
||||
| .ok .. => return .pure ()
|
||||
| .error .. => unreachable!
|
||||
|
||||
@[inherit_doc CloseableChannel.tryRecv, inline]
|
||||
def tryRecv (ch : Channel α) : BaseIO (Option α) :=
|
||||
CloseableChannel.tryRecv ch.inner
|
||||
|
||||
@[inherit_doc CloseableChannel.recv]
|
||||
def recv [Inhabited α] (ch : Channel α) : BaseIO (Task α) := do
|
||||
BaseIO.bindTask (sync := true) (← CloseableChannel.recv ch.inner)
|
||||
fun
|
||||
| some val => return .pure val
|
||||
| none => unreachable!
|
||||
|
||||
@[inherit_doc CloseableChannel.forAsync]
|
||||
partial def forAsync [Inhabited α] (f : α → BaseIO Unit) (ch : Channel α)
|
||||
(prio : Task.Priority := .default) : BaseIO (Task Unit) := do
|
||||
BaseIO.bindTask (prio := prio) (← ch.recv) fun v => do f v; ch.forAsync f prio
|
||||
|
||||
@[inherit_doc CloseableChannel.sync, inline]
|
||||
def sync (ch : Channel α) : Channel.Sync α := ch
|
||||
|
||||
namespace Sync
|
||||
|
||||
@[inherit_doc Channel.new, inline]
|
||||
def new (capacity : Option Nat := none) : BaseIO (Sync α) := Channel.new capacity
|
||||
|
||||
@[inherit_doc Channel.trySend, inline]
|
||||
def trySend (ch : Sync α) (v : α) : BaseIO Bool := Channel.trySend ch v
|
||||
|
||||
/--
|
||||
Send a value through the channel, blocking until the transmission could be completed.
|
||||
-/
|
||||
def send (ch : Sync α) (v : α) : BaseIO Unit := do
|
||||
IO.wait (← Channel.send ch v)
|
||||
|
||||
@[inherit_doc Channel.tryRecv, inline]
|
||||
def tryRecv (ch : Sync α) : BaseIO (Option α) := Channel.tryRecv ch
|
||||
|
||||
/--
|
||||
Receive a value from the channel, blocking unitl the transmission could be completed.
|
||||
-/
|
||||
def recv [Inhabited α] (ch : Sync α) : BaseIO α := do
|
||||
IO.wait (← Channel.recv ch)
|
||||
|
||||
private partial def forIn [Inhabited α] [Monad m] [MonadLiftT BaseIO m]
|
||||
(ch : Sync α) (f : α → β → m (ForInStep β)) : β → m β := fun b => do
|
||||
let a ← ch.recv
|
||||
match ← f a b with
|
||||
| .done b => pure b
|
||||
| .yield b => ch.forIn f b
|
||||
|
||||
/-- `for msg in ch.sync do ...` receives all messages in the channel until it is closed. -/
|
||||
instance [Inhabited α] [MonadLiftT BaseIO m] : ForIn m (Sync α) α where
|
||||
forIn ch b f := ch.forIn f b
|
||||
|
||||
end Sync
|
||||
|
||||
end Channel
|
||||
|
||||
end Std
|
||||
|
|
|
|||
117
tests/bench/channel.lean
Normal file
117
tests/bench/channel.lean
Normal file
|
|
@ -0,0 +1,117 @@
|
|||
import Std.Sync.Channel
|
||||
|
||||
|
||||
/-
|
||||
Inspired by:
|
||||
https://github.com/crossbeam-rs/crossbeam/tree/bd87a61ce3858ca772c42525d5f0c9aa12cc80ac/crossbeam-channel/benchmarks.
|
||||
|
||||
We conduct for:
|
||||
- capacity 0 channels
|
||||
- capacity 1 channels
|
||||
- capacity `N` channels
|
||||
- unbounded channels
|
||||
|
||||
the following tests:
|
||||
- `seq`: A single thread sends `N` messages. Then it receives `N` messages.
|
||||
- `spsc`: One thread sends `N` messages. Another thread receives `N` messages.
|
||||
- `mpsc`: `T` threads send `N / T` messages each. One thread receives `N` messages.
|
||||
- `mpmc`: `T` threads send `N / T` messages each. `T` other threads receive `N / T` messages each.
|
||||
|
||||
Note that we will stick exclusively to the sync interface for this as there is no benefit to be
|
||||
reaped from async in this benchmark so we might as well just block.
|
||||
-/
|
||||
|
||||
def MESSAGES : Nat := 1_000_000
|
||||
def THREADS : Nat := 4
|
||||
|
||||
def seq (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
|
||||
for i in [:amount] do
|
||||
ch.send i
|
||||
|
||||
for _ in [:amount] do
|
||||
discard <| ch.recv
|
||||
|
||||
def spsc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
|
||||
let t1 ← IO.asTask (prio := .dedicated) do
|
||||
for i in [:amount] do
|
||||
ch.send i
|
||||
|
||||
let t2 ← BaseIO.asTask (prio := .dedicated) do
|
||||
for _ in [:amount] do
|
||||
discard <| ch.recv
|
||||
|
||||
IO.ofExcept (← IO.wait t1)
|
||||
IO.wait t2
|
||||
|
||||
def mpsc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
|
||||
let mut producers := Array.emptyWithCapacity THREADS
|
||||
for _ in [:THREADS] do
|
||||
let t ← IO.asTask (prio := .dedicated) do
|
||||
for i in [:(amount/THREADS)] do
|
||||
ch.send i
|
||||
producers := producers.push t
|
||||
|
||||
let consumer ← BaseIO.asTask (prio := .dedicated) do
|
||||
for _ in [:amount] do
|
||||
discard <| ch.recv
|
||||
|
||||
IO.wait consumer
|
||||
for producer in producers do
|
||||
(IO.ofExcept (← IO.wait producer))
|
||||
|
||||
def mpmc (ch : Std.CloseableChannel.Sync Nat) (amount : Nat) : IO Unit := do
|
||||
let mut producers := Array.emptyWithCapacity THREADS
|
||||
for _ in [:THREADS] do
|
||||
let t ← IO.asTask (prio := .dedicated) do
|
||||
for i in [:(amount/THREADS)] do
|
||||
ch.send i
|
||||
producers := producers.push t
|
||||
|
||||
let mut consumers := Array.emptyWithCapacity THREADS
|
||||
for _ in [:THREADS] do
|
||||
let t ← IO.asTask (prio := .dedicated) do
|
||||
while true do
|
||||
if let some _ ← ch.recv then
|
||||
continue
|
||||
else
|
||||
break
|
||||
consumers := consumers.push t
|
||||
|
||||
for producer in producers do
|
||||
(IO.ofExcept (← IO.wait producer))
|
||||
|
||||
ch.close
|
||||
|
||||
for consumer in consumers do
|
||||
(IO.ofExcept (← IO.wait consumer))
|
||||
|
||||
return ()
|
||||
|
||||
def run (name : String) (cap : Option Nat) (bench : Std.CloseableChannel.Sync Nat → Nat → IO Unit) :
|
||||
IO Unit := do
|
||||
let ch ← Std.CloseableChannel.new cap
|
||||
let t1 ← IO.monoMsNow
|
||||
bench ch.sync MESSAGES
|
||||
let t2 ← IO.monoMsNow
|
||||
let time : Float := (t2 - t1).toFloat / 1000.0
|
||||
IO.println s!"{name}: {time}"
|
||||
|
||||
|
||||
def main : IO Unit := do
|
||||
run "bounded0_spsc" (some 0) spsc
|
||||
run "bounded0_mpsc" (some 0) mpsc
|
||||
run "bounded0_mpmc" (some 0) mpmc
|
||||
|
||||
run "bounded1_spsc" (some 1) spsc
|
||||
run "bounded1_mpsc" (some 1) mpsc
|
||||
run "bounded1_mpmc" (some 1) mpmc
|
||||
|
||||
run "boundedn_spsc" (some MESSAGES) spsc
|
||||
run "boundedn_mpsc" (some MESSAGES) mpsc
|
||||
run "boundedn_mpmc" (some MESSAGES) mpmc
|
||||
run "boundedn_seq" (some MESSAGES) seq
|
||||
|
||||
run "unbounded_spsc" none spsc
|
||||
run "unbounded_mpsc" none mpsc
|
||||
run "unbounded_mpmc" none mpmc
|
||||
run "unbounded_seq" none seq
|
||||
|
|
@ -462,3 +462,12 @@
|
|||
run_config:
|
||||
<<: *time
|
||||
cmd: lean omega_stress.lean
|
||||
- attributes:
|
||||
description: channel.lean
|
||||
tags: [fast]
|
||||
run_config:
|
||||
<<: *time
|
||||
cmd: ./channel.lean.out
|
||||
parse_output: true
|
||||
build_config:
|
||||
cmd: ./compile.sh channel.lean
|
||||
|
|
|
|||
|
|
@ -15,15 +15,15 @@ open IO
|
|||
let promise : Promise Nat ← Promise.new
|
||||
assert! promise.result?.get = none
|
||||
|
||||
#eval do
|
||||
let ch ← Std.Channel.new
|
||||
#eval show IO _ from do
|
||||
let ch ← Std.CloseableChannel.new
|
||||
|
||||
let out ← IO.mkRef #[]
|
||||
ch.send 0
|
||||
ch.sync.send 0
|
||||
let drainFinished ← ch.forAsync fun x => out.modify (·.push x)
|
||||
ch.send 1
|
||||
ch.sync.send 1
|
||||
ch.close
|
||||
ch.send 2
|
||||
assert! (← EIO.toBaseIO (ch.sync.send 2)) matches .error .closed
|
||||
|
||||
IO.wait drainFinished
|
||||
assert! (← out.get) = #[0, 1]
|
||||
|
|
|
|||
150
tests/lean/run/sync_channel.lean
Normal file
150
tests/lean/run/sync_channel.lean
Normal file
|
|
@ -0,0 +1,150 @@
|
|||
import Std.Sync
|
||||
|
||||
open Std
|
||||
|
||||
def assertBEq [BEq α] [ToString α] (is should : α) : IO Unit := do
|
||||
if is != should then
|
||||
throw <| .userError s!"{is} should be {should}"
|
||||
|
||||
def closeClose (ch : CloseableChannel Nat) : IO Unit := do
|
||||
assertBEq (← ch.isClosed) false
|
||||
assertBEq ((← EIO.toBaseIO ch.close) matches .ok ()) true
|
||||
assertBEq (← ch.isClosed) true
|
||||
assertBEq ((← EIO.toBaseIO ch.close) matches .error .alreadyClosed) true
|
||||
assertBEq (← ch.isClosed) true
|
||||
|
||||
def paired (ch : CloseableChannel Nat) : IO Unit := do
|
||||
let sendTask ← ch.send 37
|
||||
let recvTask ← ch.recv
|
||||
assertBEq ((← IO.wait sendTask) matches .ok ()) true
|
||||
assertBEq (← IO.wait recvTask) (some 37)
|
||||
|
||||
def syncPaired (ch : CloseableChannel.Sync Nat) : IO Unit := do
|
||||
let sendTask ← IO.asTask (prio := .dedicated) (EIO.toBaseIO (ch.send 37))
|
||||
let recvTask ← IO.asTask (prio := .dedicated) (ch.recv)
|
||||
assertBEq ((← IO.ofExcept (← IO.wait sendTask)) matches .ok ()) true
|
||||
assertBEq (← IO.ofExcept (← IO.wait recvTask)) (some 37)
|
||||
|
||||
def trySend (ch : CloseableChannel Nat) (capacity : Option Nat) : IO Unit := do
|
||||
-- ready a receiver ahead of time
|
||||
let recvTask ← ch.recv
|
||||
assertBEq (← ch.trySend 37) true
|
||||
assertBEq (← IO.wait recvTask) (some 37)
|
||||
|
||||
-- the unbounded CloseableChannel cannot go out of space so it is pointless to fill it up
|
||||
let some capacity := capacity | return ()
|
||||
for i in [:capacity] do
|
||||
assertBEq (← ch.trySend i) true
|
||||
|
||||
assertBEq (← ch.trySend (capacity + 1)) false
|
||||
|
||||
def tryRecv (ch : CloseableChannel Nat) : IO Unit := do
|
||||
assertBEq (← ch.tryRecv) none
|
||||
let sendTask ← ch.send 37
|
||||
assertBEq (← ch.tryRecv) (some 37)
|
||||
assertBEq ((← IO.wait sendTask) matches .ok ()) true
|
||||
|
||||
def sendRecvClose (ch : CloseableChannel Nat) : IO Unit := do
|
||||
let sendTask ← ch.send 37
|
||||
assertBEq ((← EIO.toBaseIO ch.close) matches .ok ()) true
|
||||
let recvTask ← ch.recv
|
||||
assertBEq ((← IO.wait sendTask) matches .ok ()) true
|
||||
assertBEq (← IO.wait recvTask) (some 37)
|
||||
|
||||
let sendTask ← ch.send 37
|
||||
assertBEq ((← IO.wait sendTask) matches .error .closed) true
|
||||
let recvTask ← ch.recv
|
||||
assertBEq (← IO.wait recvTask) none
|
||||
assertBEq (← ch.trySend 37) false
|
||||
assertBEq (← ch.tryRecv) none
|
||||
|
||||
def sendIt (ch : CloseableChannel Nat) (messages : List Nat) : BaseIO (Task (Option Unit)) := do
|
||||
match messages with
|
||||
| [] => return .pure <| some ()
|
||||
| msg :: messages =>
|
||||
BaseIO.bindTask (← ch.send msg) fun
|
||||
| .error .. =>
|
||||
return .pure <| none
|
||||
| .ok .. =>
|
||||
sendIt ch messages
|
||||
|
||||
partial def recvIt (ch : CloseableChannel Nat) (messages : List Nat) : BaseIO (Task (List Nat)) := do
|
||||
BaseIO.bindTask (← ch.recv) fun
|
||||
| none => return .pure messages.reverse
|
||||
| some msg => recvIt ch (msg :: messages)
|
||||
|
||||
def sendLots (ch : CloseableChannel Nat) : IO Unit := do
|
||||
let messages := List.range 1000
|
||||
let sendTask ← sendIt ch messages
|
||||
let recvTask ← recvIt ch []
|
||||
assertBEq (← IO.wait sendTask) (some ())
|
||||
discard <| ch.close
|
||||
assertBEq (← IO.wait recvTask) messages
|
||||
|
||||
def sendItSync (ch : CloseableChannel.Sync Nat) (messages : List Nat) : IO Unit := do
|
||||
for msg in messages do
|
||||
ch.send msg
|
||||
return ()
|
||||
|
||||
def recvItSync (ch : CloseableChannel.Sync Nat) : IO (List Nat) := do
|
||||
let mut messages := []
|
||||
for msg in ch do
|
||||
messages := msg :: messages
|
||||
return messages.reverse
|
||||
|
||||
def sendLotsSync (ch : CloseableChannel.Sync Nat) : IO Unit := do
|
||||
let messages := List.range 1000
|
||||
let sendTask ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
|
||||
let recvTask ← IO.asTask (prio := .dedicated) (recvItSync ch)
|
||||
IO.ofExcept (← IO.wait sendTask)
|
||||
discard <| ch.close
|
||||
assertBEq (← IO.ofExcept (← IO.wait recvTask)) messages
|
||||
|
||||
partial def sendLotsMulti (ch : CloseableChannel Nat) : IO Unit := do
|
||||
let messages := List.range 1000
|
||||
let sendTask1 ← sendIt ch messages
|
||||
let sendTask2 ← sendIt ch messages
|
||||
let recvTask1 ← recvIt ch []
|
||||
let recvTask2 ← recvIt ch []
|
||||
assertBEq (← IO.wait sendTask1) (some ())
|
||||
assertBEq (← IO.wait sendTask2) (some ())
|
||||
discard <| ch.close
|
||||
let msg1 ← IO.wait recvTask1
|
||||
let msg2 ← IO.wait recvTask2
|
||||
assertBEq (msg1.sum + msg2.sum) (2 * messages.sum)
|
||||
|
||||
partial def sendLotsMultiSync (ch : CloseableChannel.Sync Nat) : IO Unit := do
|
||||
let messages := List.range 1000
|
||||
let sendTask1 ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
|
||||
let sendTask2 ← IO.asTask (prio := .dedicated) (sendItSync ch messages)
|
||||
let recvTask1 ← IO.asTask (prio := .dedicated) (recvItSync ch)
|
||||
let recvTask2 ← IO.asTask (prio := .dedicated) (recvItSync ch)
|
||||
IO.ofExcept (← IO.wait sendTask1)
|
||||
IO.ofExcept (← IO.wait sendTask2)
|
||||
discard <| ch.close
|
||||
let msg1 ← IO.ofExcept (← IO.wait recvTask1)
|
||||
let msg2 ← IO.ofExcept (← IO.wait recvTask2)
|
||||
assertBEq (msg1.sum + msg2.sum) (2 * messages.sum)
|
||||
|
||||
def testIt (capacity : Option Nat) : IO Unit := do
|
||||
paired (← CloseableChannel.new capacity)
|
||||
syncPaired (← CloseableChannel.new capacity).sync
|
||||
|
||||
closeClose (← CloseableChannel.new capacity)
|
||||
trySend (← CloseableChannel.new capacity) capacity
|
||||
tryRecv (← CloseableChannel.new capacity)
|
||||
sendRecvClose (← CloseableChannel.new capacity)
|
||||
|
||||
sendLots (← CloseableChannel.new capacity)
|
||||
sendLotsSync (← CloseableChannel.new capacity).sync
|
||||
sendLotsMulti (← CloseableChannel.new capacity)
|
||||
sendLotsMultiSync (← CloseableChannel.new capacity).sync
|
||||
|
||||
def suite : IO Unit := do
|
||||
testIt none
|
||||
testIt (some 0)
|
||||
testIt (some 1)
|
||||
testIt (some 8)
|
||||
testIt (some 128)
|
||||
|
||||
#eval suite
|
||||
Loading…
Add table
Reference in a new issue