feat: add StreamMap (#10400)
This PR adds the StreamMap type that enables multiplexing in asynchronous streams. This PR depends on: #10366, #10367 and #10370. --------- Co-authored-by: Markus Himmel <markus@lean-fro.org>
This commit is contained in:
parent
1f7374a5d6
commit
ad701b577b
4 changed files with 552 additions and 0 deletions
|
|
@ -169,6 +169,83 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d
|
|||
|
||||
Async.ofPromise (pure promise)
|
||||
|
||||
/--
|
||||
Performs fair and data-loss free non-blocking multiplexing on the `Selectable`s in `selectables`.
|
||||
|
||||
This function only tries the non-blocking `tryFn` for each `Selectable` without registering
|
||||
waiters or blocking. It returns `some result` if any `Selectable` is immediately available,
|
||||
or `none` if all would block.
|
||||
|
||||
The protocol for this is as follows:
|
||||
1. The `selectables` are shuffled randomly for fairness.
|
||||
2. Run `Selector.tryFn` for each element in `selectables`. If any succeed, the corresponding
|
||||
`Selectable.cont` is executed and its result is returned as `some result`.
|
||||
3. If none succeed, `none` is returned immediately without blocking.
|
||||
-/
|
||||
def Selectable.tryOne (selectables : Array (Selectable α)) : Async (Option α) := do
|
||||
if selectables.isEmpty then
|
||||
return none
|
||||
|
||||
let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8))
|
||||
let gen := mkStdGen seed
|
||||
let selectables := shuffleIt selectables gen
|
||||
|
||||
for selectable in selectables do
|
||||
if let some val ← selectable.selector.tryFn then
|
||||
let result ← selectable.cont val
|
||||
return some result
|
||||
|
||||
return none
|
||||
|
||||
/--
|
||||
Creates a `Selector` that performs fair and data-loss free multiplexing on multiple `Selectable`s.
|
||||
This allows the multiplexing operation to be composed with other selectors.
|
||||
-/
|
||||
def Selectable.combine (selectables : Array (Selectable α)) : IO (Selector α) := do
|
||||
if selectables.isEmpty then
|
||||
throw <| .userError "Selectable.one requires at least one Selectable"
|
||||
|
||||
let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8))
|
||||
let gen := mkStdGen seed
|
||||
let selectables := shuffleIt selectables gen
|
||||
|
||||
return {
|
||||
tryFn := do
|
||||
for selectable in selectables do
|
||||
if let some val ← selectable.selector.tryFn then
|
||||
let result ← selectable.cont val
|
||||
return some result
|
||||
return none
|
||||
|
||||
registerFn := fun waiter => do
|
||||
for selectable in selectables do
|
||||
let waiterPromise ← IO.Promise.new
|
||||
let derivedWaiter := Waiter.mk waiter.finished waiterPromise
|
||||
selectable.selector.registerFn derivedWaiter
|
||||
|
||||
discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do
|
||||
match res? with
|
||||
| none => return (Task.pure (.ok ()))
|
||||
| some res =>
|
||||
let async : Async _ := do
|
||||
let mainPromise := waiter.promise
|
||||
|
||||
for selectable in selectables do
|
||||
selectable.selector.unregisterFn
|
||||
|
||||
try
|
||||
let val ← IO.ofExcept res
|
||||
let result ← selectable.cont val
|
||||
mainPromise.resolve (.ok result)
|
||||
catch e =>
|
||||
mainPromise.resolve (.error e)
|
||||
async.toBaseIO
|
||||
|
||||
unregisterFn := do
|
||||
for selectable in selectables do
|
||||
selectable.selector.unregisterFn
|
||||
}
|
||||
|
||||
end Async
|
||||
end IO
|
||||
end Internal
|
||||
|
|
|
|||
|
|
@ -14,5 +14,6 @@ public import Std.Sync.Barrier
|
|||
public import Std.Sync.SharedMutex
|
||||
public import Std.Sync.Notify
|
||||
public import Std.Sync.Broadcast
|
||||
public import Std.Sync.StreamMap
|
||||
|
||||
@[expose] public section
|
||||
|
|
|
|||
147
src/Std/Sync/StreamMap.lean
Normal file
147
src/Std/Sync/StreamMap.lean
Normal file
|
|
@ -0,0 +1,147 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Data
|
||||
public import Init.System.Promise
|
||||
public import Init.Data.Queue
|
||||
public import Std.Internal.Async.IO
|
||||
public import Std.Internal.Async.Select
|
||||
public import Std.Internal.Async.Basic
|
||||
|
||||
public section
|
||||
|
||||
open Std.Internal.Async.IO
|
||||
open Std.Internal.IO.Async
|
||||
|
||||
/-!
|
||||
This module provides `StreamMap`, a container that maps keys to async streams.
|
||||
It allows for dynamic management of multiple named streams with async operations.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
|
||||
/--
|
||||
This is an existential wrapper for AsyncStream that is used for the `.ofArray` function
|
||||
with `CoeDep` so it's easier and we keep StreamMap on `Type 0`.
|
||||
-/
|
||||
inductive AnyAsyncStream (α : Type) where
|
||||
| mk : {t : Type} → [AsyncStream t α] → t → AnyAsyncStream α
|
||||
|
||||
def AnyAsyncStream.getSelector : AnyAsyncStream α → Selector α × IO Unit
|
||||
| AnyAsyncStream.mk stream => (AsyncStream.next stream, AsyncStream.stop stream)
|
||||
|
||||
instance [AsyncStream t α] : CoeDep t x (AnyAsyncStream α) where
|
||||
coe := AnyAsyncStream.mk x
|
||||
|
||||
/--
|
||||
A container that maps keys to async streams, enabling dynamic stream management
|
||||
and unified selection operations across multiple named data sources.
|
||||
-/
|
||||
structure StreamMap (α : Type) (β : Type) where
|
||||
private mk ::
|
||||
private streams : Array (α × Selector β × IO Unit)
|
||||
|
||||
namespace StreamMap
|
||||
|
||||
/--
|
||||
Create an empty StreamMap
|
||||
-/
|
||||
def empty {α} : StreamMap α β :=
|
||||
{ streams := #[] }
|
||||
|
||||
/--
|
||||
Register a new async stream with the given name
|
||||
-/
|
||||
def register [BEq α] [AsyncStream t β] (sm : StreamMap α β) (name : α) (reader : t) : StreamMap α β :=
|
||||
let newSelector := AsyncStream.next reader
|
||||
let filteredStreams := sm.streams.filter (fun (n, _) => n != name)
|
||||
{ sm with streams := filteredStreams.push (name, newSelector, AsyncStream.stop reader) }
|
||||
|
||||
/--
|
||||
Create a StreamMap from an array of named streams
|
||||
-/
|
||||
def ofArray [BEq α] (streams : Array (α × AnyAsyncStream β)) : StreamMap α β :=
|
||||
let arrayOfSelectors := streams.map (fun (name, sel) => (name, sel.getSelector))
|
||||
{ streams := arrayOfSelectors }
|
||||
|
||||
/--
|
||||
Get a combined selector that returns the stream name and value
|
||||
-/
|
||||
def selector (stream : StreamMap α β) : Async (Selector (α × β)) :=
|
||||
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
|
||||
Selectable.combine selectables
|
||||
|
||||
/--
|
||||
Wait for the first value inside of the stream map.
|
||||
-/
|
||||
def recv (stream : StreamMap α β) : Async (α × β) :=
|
||||
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
|
||||
Selectable.one selectables
|
||||
|
||||
/--
|
||||
Wait for the first value inside of the stream map.
|
||||
-/
|
||||
def tryRecv (stream : StreamMap α β) : Async (Option (α × β)) :=
|
||||
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
|
||||
Selectable.tryOne selectables
|
||||
|
||||
/--
|
||||
Remove a stream by name
|
||||
-/
|
||||
def unregister [BEq α] (sm : StreamMap α β) (name : α) : StreamMap α β :=
|
||||
{ sm with streams := sm.streams.filter (fun (n, _) => n != name) }
|
||||
|
||||
/--
|
||||
Check if a stream with the given name exists
|
||||
-/
|
||||
def contains [BEq α] (sm : StreamMap α β) (name : α) : Bool :=
|
||||
sm.streams.any (fun (n, _) => n == name)
|
||||
|
||||
/--
|
||||
Get the number of registered streams
|
||||
-/
|
||||
def size (sm : StreamMap α β) : Nat :=
|
||||
sm.streams.size
|
||||
|
||||
/--
|
||||
Check if the StreamMap is empty
|
||||
-/
|
||||
def isEmpty (sm : StreamMap α β) : Bool :=
|
||||
sm.streams.isEmpty
|
||||
|
||||
/--
|
||||
Get all registered stream names
|
||||
-/
|
||||
def keys (sm : StreamMap α β) : Array α :=
|
||||
sm.streams.map (·.1)
|
||||
|
||||
/--
|
||||
Get a specific stream selector by name
|
||||
-/
|
||||
def get? [BEq α] (sm : StreamMap α β) (name : α) : Option (Selector β) :=
|
||||
sm.streams.find? (fun (n, _) => n == name) |>.map (·.2.1)
|
||||
|
||||
/--
|
||||
Filter streams based on their names
|
||||
-/
|
||||
def filterByName (sm : StreamMap α β) (pred : α → Bool) : StreamMap α β :=
|
||||
{ streams := sm.streams.filter (fun (name, _) => pred name) }
|
||||
|
||||
/--
|
||||
Convert to array of name-selector pairs
|
||||
-/
|
||||
def toArray (sm : StreamMap α β) : Array (α × Selector β) :=
|
||||
sm.streams.map (fun (n, s, _) => (n, s))
|
||||
|
||||
/--
|
||||
Cleanup function
|
||||
-/
|
||||
def close (sm : StreamMap α β) : IO Unit :=
|
||||
sm.streams.forM (fun (_, _, cleanup) => cleanup)
|
||||
|
||||
end StreamMap
|
||||
327
tests/lean/run/async_streammap.lean
Normal file
327
tests/lean/run/async_streammap.lean
Normal file
|
|
@ -0,0 +1,327 @@
|
|||
import Std.Internal.Async
|
||||
import Std.Sync
|
||||
|
||||
open Std.Internal.IO Async
|
||||
|
||||
-- Test basic message reception from multiple channels
|
||||
def testSimpleMessages : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
let channelC ← Std.Channel.new (α := Nat)
|
||||
|
||||
let channel := Std.StreamMap.ofArray #[
|
||||
("a", channelA),
|
||||
("b", channelB),
|
||||
("c", channelC),
|
||||
]
|
||||
|
||||
await (← channelC.send 1)
|
||||
let (name, message) ← channel.recv
|
||||
assert! name == "c" && message == 1
|
||||
|
||||
await (← channelA.send 2)
|
||||
let (name, message) ← channel.recv
|
||||
assert! name == "a" && message == 2
|
||||
|
||||
await (← channelB.send 3)
|
||||
let (name, message) ← channel.recv
|
||||
assert! name == "b" && message == 3
|
||||
|
||||
#eval testSimpleMessages.block
|
||||
|
||||
-- Test empty StreamMap
|
||||
def testEmpty : Async Unit := do
|
||||
let stream : Std.StreamMap String Nat := Std.StreamMap.empty
|
||||
|
||||
assert! stream.isEmpty
|
||||
assert! stream.size == 0
|
||||
assert! stream.keys.size == 0
|
||||
|
||||
#eval testEmpty.block
|
||||
|
||||
-- Test register and unregister operations
|
||||
def testRegisterUnregister : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.empty.register "a" channelA
|
||||
|
||||
assert! stream.contains "a"
|
||||
assert! not (stream.contains "b")
|
||||
assert! stream.size == 1
|
||||
|
||||
let stream := stream.register "b" channelB
|
||||
assert! stream.contains "a" && stream.contains "b"
|
||||
assert! stream.size == 2
|
||||
|
||||
let stream := stream.unregister "a"
|
||||
assert! not (stream.contains "a")
|
||||
assert! stream.contains "b"
|
||||
assert! stream.size == 1
|
||||
|
||||
let stream := stream.unregister "b"
|
||||
assert! stream.isEmpty
|
||||
|
||||
#eval testRegisterUnregister.block
|
||||
|
||||
-- Test replacing existing stream with same name
|
||||
def testRegisterReplace : Async Unit := do
|
||||
let channelA1 ← Std.Channel.new (α := Nat)
|
||||
let channelA2 ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.empty.register "a" channelA1
|
||||
assert! stream.size == 1
|
||||
|
||||
-- Register with same name should replace
|
||||
let stream := stream.register "a" channelA2
|
||||
assert! stream.size == 1
|
||||
assert! stream.contains "a"
|
||||
|
||||
-- Send to new channel
|
||||
await (← channelA2.send 42)
|
||||
let (name, message) ← stream.recv
|
||||
assert! name == "a" && message == 42
|
||||
|
||||
#eval testRegisterReplace.block
|
||||
|
||||
-- Test keys functionality
|
||||
def testKeys : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
let channelC ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("c", .mk channelC),
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
let keys := stream.keys
|
||||
assert! keys.size == 3
|
||||
assert! keys.contains "a"
|
||||
assert! keys.contains "b"
|
||||
assert! keys.contains "c"
|
||||
|
||||
#eval testKeys.block
|
||||
|
||||
-- Test get? functionality
|
||||
def testGet : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
let selectorA := stream.get? "a"
|
||||
let selectorC := stream.get? "c"
|
||||
|
||||
assert! selectorA.isSome
|
||||
assert! selectorC.isNone
|
||||
|
||||
#eval testGet.block
|
||||
|
||||
-- Test filterByName functionality
|
||||
def testFilterByName : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
let channelC ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("prefix_a", .mk channelA),
|
||||
("prefix_b", .mk channelB),
|
||||
("other_c", .mk channelC),
|
||||
]
|
||||
|
||||
let filtered := stream.filterByName (fun name => name.startsWith "prefix_")
|
||||
|
||||
assert! filtered.size == 2
|
||||
assert! filtered.contains "prefix_a"
|
||||
assert! filtered.contains "prefix_b"
|
||||
assert! not (filtered.contains "other_c")
|
||||
|
||||
#eval testFilterByName.block
|
||||
|
||||
-- Test toArray functionality
|
||||
def testToArray : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
let array := stream.toArray
|
||||
assert! array.size == 2
|
||||
|
||||
let names := array.map (·.1)
|
||||
assert! names.contains "a"
|
||||
assert! names.contains "b"
|
||||
|
||||
#eval testToArray.block
|
||||
|
||||
-- Test multiple messages from same channel
|
||||
def testMultipleFromSame : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
-- Send multiple messages to same channel
|
||||
await (← channelA.send 1)
|
||||
await (← channelA.send 2)
|
||||
await (← channelB.send 10)
|
||||
|
||||
let (name1, msg1) ← stream.recv
|
||||
let (name2, msg2) ← stream.recv
|
||||
let (name3, msg3) ← stream.recv
|
||||
|
||||
-- Should receive all messages, order may vary but names should match sources
|
||||
assert! (name1 == "a" && msg1 == 1) || (name1 == "a" && msg1 == 2) || (name1 == "b" && msg1 == 10)
|
||||
assert! (name2 == "a" && msg2 == 1) || (name2 == "a" && msg2 == 2) || (name2 == "b" && msg2 == 10)
|
||||
assert! (name3 == "a" && msg3 == 1) || (name3 == "a" && msg3 == 2) || (name3 == "b" && msg3 == 10)
|
||||
|
||||
#eval testMultipleFromSame.block
|
||||
|
||||
-- Test interleaved messages from different channels
|
||||
def testInterleavedMessages : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := String)
|
||||
let channelB ← Std.Channel.new (α := String)
|
||||
let channelC ← Std.Channel.new (α := String)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("first", .mk channelA),
|
||||
("second", .mk channelB),
|
||||
("third", .mk channelC),
|
||||
]
|
||||
|
||||
-- Send messages in specific order
|
||||
await (← channelB.send "msg1")
|
||||
await (← channelA.send "msg2")
|
||||
await (← channelC.send "msg3")
|
||||
await (← channelA.send "msg4")
|
||||
|
||||
let results ← (List.range 4).mapM (fun _ => stream.recv)
|
||||
|
||||
-- Verify we got all messages (order may vary)
|
||||
let messages := results.map (·.2)
|
||||
assert! messages.contains "msg1"
|
||||
assert! messages.contains "msg2"
|
||||
assert! messages.contains "msg3"
|
||||
assert! messages.contains "msg4"
|
||||
|
||||
#eval testInterleavedMessages.block
|
||||
|
||||
-- Test with different data typez
|
||||
def testDifferentTypes : Async Unit := do
|
||||
let channelStr ← Std.Channel.new (α := String)
|
||||
let channelBool ← Std.Channel.new (α := String)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("strings", .mk channelStr),
|
||||
("bools", .mk channelBool),
|
||||
]
|
||||
|
||||
await (← channelStr.send "hello")
|
||||
await (← channelBool.send "world")
|
||||
|
||||
let (name1, msg1) ← stream.recv
|
||||
let (name2, msg2) ← stream.recv
|
||||
|
||||
assert! ((name1 == "strings" && msg1 == "hello") || (name1 == "bools" && msg1 == "world"))
|
||||
assert! ((name2 == "strings" && msg2 == "hello") || (name2 == "bools" && msg2 == "world"))
|
||||
assert! name1 != name2
|
||||
|
||||
#eval testDifferentTypes.block
|
||||
|
||||
-- Test unregister during operation
|
||||
def testUnregisterDuringOperation : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
let channelC ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
("c", .mk channelC),
|
||||
]
|
||||
|
||||
await (← channelA.send 1)
|
||||
await (← channelB.send 2)
|
||||
await (← channelC.send 3)
|
||||
|
||||
assert! (← stream.tryRecv).isSome
|
||||
assert! (← stream.tryRecv).isSome
|
||||
assert! (← stream.tryRecv).isSome
|
||||
|
||||
let stream := stream.unregister "b"
|
||||
assert! not (stream.contains "b")
|
||||
assert! stream.size == 2
|
||||
|
||||
let newChannelD ← Std.Channel.new (α := Nat)
|
||||
let stream := stream.register "d" newChannelD
|
||||
|
||||
await (← newChannelD.send 4)
|
||||
let (name2, msg2) ← stream.recv
|
||||
|
||||
assert! name2 == "d" && msg2 == 4
|
||||
|
||||
#eval testUnregisterDuringOperation.block
|
||||
|
||||
-- Test selector functionality
|
||||
def testSelector : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
let selector ← stream.selector
|
||||
|
||||
await (← channelA.send 42)
|
||||
|
||||
let result ← Selectable.one #[.case selector (fun x => pure x)]
|
||||
assert! result.1 == "a" && result.2 == 42
|
||||
|
||||
#eval testSelector.block
|
||||
|
||||
def testClose : Async Unit := do
|
||||
let channelA ← Std.Channel.new (α := Nat)
|
||||
let channelB ← Std.Channel.new (α := Nat)
|
||||
|
||||
let stream := Std.StreamMap.ofArray #[
|
||||
("a", .mk channelA),
|
||||
("b", .mk channelB),
|
||||
]
|
||||
|
||||
stream.close
|
||||
|
||||
#eval testClose.block
|
||||
|
||||
-- Test large number of channels
|
||||
def testManyChannels : Async Unit := do
|
||||
let channels : Vector _ 128 ← Vector.ofFnM (fun _ => Std.Channel.new (α := Nat))
|
||||
|
||||
let streamArray := channels.mapIdx (fun i ch => (s!"channel_{i}", .mk ch))
|
||||
let stream := Std.StreamMap.ofArray streamArray.toArray
|
||||
|
||||
assert! stream.size == 128
|
||||
|
||||
await (← channels[3].send 100)
|
||||
await (← channels[7].send 200)
|
||||
|
||||
let (name1, msg1) ← stream.recv
|
||||
let (name2, msg2) ← stream.recv
|
||||
|
||||
assert! ((name1 == "channel_3" && msg1 == 100) || (name1 == "channel_7" && msg1 == 200))
|
||||
assert! ((name2 == "channel_3" && msg2 == 100) || (name2 == "channel_7" && msg2 == 200))
|
||||
assert! name1 != name2
|
||||
|
||||
#eval testManyChannels.block
|
||||
Loading…
Add table
Reference in a new issue