feat: add StreamMap (#10400)

This PR adds the StreamMap type that enables multiplexing in
asynchronous streams.

This PR depends on: #10366, #10367 and #10370.

---------

Co-authored-by: Markus Himmel <markus@lean-fro.org>
This commit is contained in:
Sofia Rodrigues 2025-10-06 20:39:44 -03:00 committed by GitHub
parent 1f7374a5d6
commit ad701b577b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 552 additions and 0 deletions

View file

@ -169,6 +169,83 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d
Async.ofPromise (pure promise)
/--
Performs fair and data-loss free non-blocking multiplexing on the `Selectable`s in `selectables`.
This function only tries the non-blocking `tryFn` for each `Selectable` without registering
waiters or blocking. It returns `some result` if any `Selectable` is immediately available,
or `none` if all would block.
The protocol for this is as follows:
1. The `selectables` are shuffled randomly for fairness.
2. Run `Selector.tryFn` for each element in `selectables`. If any succeed, the corresponding
`Selectable.cont` is executed and its result is returned as `some result`.
3. If none succeed, `none` is returned immediately without blocking.
-/
def Selectable.tryOne (selectables : Array (Selectable α)) : Async (Option α) := do
if selectables.isEmpty then
return none
let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8))
let gen := mkStdGen seed
let selectables := shuffleIt selectables gen
for selectable in selectables do
if let some val ← selectable.selector.tryFn then
let result ← selectable.cont val
return some result
return none
/--
Creates a `Selector` that performs fair and data-loss free multiplexing on multiple `Selectable`s.
This allows the multiplexing operation to be composed with other selectors.
-/
def Selectable.combine (selectables : Array (Selectable α)) : IO (Selector α) := do
if selectables.isEmpty then
throw <| .userError "Selectable.one requires at least one Selectable"
let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8))
let gen := mkStdGen seed
let selectables := shuffleIt selectables gen
return {
tryFn := do
for selectable in selectables do
if let some val ← selectable.selector.tryFn then
let result ← selectable.cont val
return some result
return none
registerFn := fun waiter => do
for selectable in selectables do
let waiterPromise ← IO.Promise.new
let derivedWaiter := Waiter.mk waiter.finished waiterPromise
selectable.selector.registerFn derivedWaiter
discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do
match res? with
| none => return (Task.pure (.ok ()))
| some res =>
let async : Async _ := do
let mainPromise := waiter.promise
for selectable in selectables do
selectable.selector.unregisterFn
try
let val ← IO.ofExcept res
let result ← selectable.cont val
mainPromise.resolve (.ok result)
catch e =>
mainPromise.resolve (.error e)
async.toBaseIO
unregisterFn := do
for selectable in selectables do
selectable.selector.unregisterFn
}
end Async
end IO
end Internal

View file

@ -14,5 +14,6 @@ public import Std.Sync.Barrier
public import Std.Sync.SharedMutex
public import Std.Sync.Notify
public import Std.Sync.Broadcast
public import Std.Sync.StreamMap
@[expose] public section

147
src/Std/Sync/StreamMap.lean Normal file
View file

@ -0,0 +1,147 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sofia Rodrigues
-/
module
prelude
public import Std.Data
public import Init.System.Promise
public import Init.Data.Queue
public import Std.Internal.Async.IO
public import Std.Internal.Async.Select
public import Std.Internal.Async.Basic
public section
open Std.Internal.Async.IO
open Std.Internal.IO.Async
/-!
This module provides `StreamMap`, a container that maps keys to async streams.
It allows for dynamic management of multiple named streams with async operations.
-/
namespace Std
/--
This is an existential wrapper for AsyncStream that is used for the `.ofArray` function
with `CoeDep` so it's easier and we keep StreamMap on `Type 0`.
-/
inductive AnyAsyncStream (α : Type) where
| mk : {t : Type} → [AsyncStream t α] → t → AnyAsyncStream α
def AnyAsyncStream.getSelector : AnyAsyncStream α → Selector α × IO Unit
| AnyAsyncStream.mk stream => (AsyncStream.next stream, AsyncStream.stop stream)
instance [AsyncStream t α] : CoeDep t x (AnyAsyncStream α) where
coe := AnyAsyncStream.mk x
/--
A container that maps keys to async streams, enabling dynamic stream management
and unified selection operations across multiple named data sources.
-/
structure StreamMap (α : Type) (β : Type) where
private mk ::
private streams : Array (α × Selector β × IO Unit)
namespace StreamMap
/--
Create an empty StreamMap
-/
def empty {α} : StreamMap α β :=
{ streams := #[] }
/--
Register a new async stream with the given name
-/
def register [BEq α] [AsyncStream t β] (sm : StreamMap α β) (name : α) (reader : t) : StreamMap α β :=
let newSelector := AsyncStream.next reader
let filteredStreams := sm.streams.filter (fun (n, _) => n != name)
{ sm with streams := filteredStreams.push (name, newSelector, AsyncStream.stop reader) }
/--
Create a StreamMap from an array of named streams
-/
def ofArray [BEq α] (streams : Array (α × AnyAsyncStream β)) : StreamMap α β :=
let arrayOfSelectors := streams.map (fun (name, sel) => (name, sel.getSelector))
{ streams := arrayOfSelectors }
/--
Get a combined selector that returns the stream name and value
-/
def selector (stream : StreamMap α β) : Async (Selector (α × β)) :=
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
Selectable.combine selectables
/--
Wait for the first value inside of the stream map.
-/
def recv (stream : StreamMap α β) : Async (α × β) :=
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
Selectable.one selectables
/--
Wait for the first value inside of the stream map.
-/
def tryRecv (stream : StreamMap α β) : Async (Option (α × β)) :=
let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x))
Selectable.tryOne selectables
/--
Remove a stream by name
-/
def unregister [BEq α] (sm : StreamMap α β) (name : α) : StreamMap α β :=
{ sm with streams := sm.streams.filter (fun (n, _) => n != name) }
/--
Check if a stream with the given name exists
-/
def contains [BEq α] (sm : StreamMap α β) (name : α) : Bool :=
sm.streams.any (fun (n, _) => n == name)
/--
Get the number of registered streams
-/
def size (sm : StreamMap α β) : Nat :=
sm.streams.size
/--
Check if the StreamMap is empty
-/
def isEmpty (sm : StreamMap α β) : Bool :=
sm.streams.isEmpty
/--
Get all registered stream names
-/
def keys (sm : StreamMap α β) : Array α :=
sm.streams.map (·.1)
/--
Get a specific stream selector by name
-/
def get? [BEq α] (sm : StreamMap α β) (name : α) : Option (Selector β) :=
sm.streams.find? (fun (n, _) => n == name) |>.map (·.2.1)
/--
Filter streams based on their names
-/
def filterByName (sm : StreamMap α β) (pred : α → Bool) : StreamMap α β :=
{ streams := sm.streams.filter (fun (name, _) => pred name) }
/--
Convert to array of name-selector pairs
-/
def toArray (sm : StreamMap α β) : Array (α × Selector β) :=
sm.streams.map (fun (n, s, _) => (n, s))
/--
Cleanup function
-/
def close (sm : StreamMap α β) : IO Unit :=
sm.streams.forM (fun (_, _, cleanup) => cleanup)
end StreamMap

View file

@ -0,0 +1,327 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
-- Test basic message reception from multiple channels
def testSimpleMessages : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let channelC ← Std.Channel.new (α := Nat)
let channel := Std.StreamMap.ofArray #[
("a", channelA),
("b", channelB),
("c", channelC),
]
await (← channelC.send 1)
let (name, message) ← channel.recv
assert! name == "c" && message == 1
await (← channelA.send 2)
let (name, message) ← channel.recv
assert! name == "a" && message == 2
await (← channelB.send 3)
let (name, message) ← channel.recv
assert! name == "b" && message == 3
#eval testSimpleMessages.block
-- Test empty StreamMap
def testEmpty : Async Unit := do
let stream : Std.StreamMap String Nat := Std.StreamMap.empty
assert! stream.isEmpty
assert! stream.size == 0
assert! stream.keys.size == 0
#eval testEmpty.block
-- Test register and unregister operations
def testRegisterUnregister : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.empty.register "a" channelA
assert! stream.contains "a"
assert! not (stream.contains "b")
assert! stream.size == 1
let stream := stream.register "b" channelB
assert! stream.contains "a" && stream.contains "b"
assert! stream.size == 2
let stream := stream.unregister "a"
assert! not (stream.contains "a")
assert! stream.contains "b"
assert! stream.size == 1
let stream := stream.unregister "b"
assert! stream.isEmpty
#eval testRegisterUnregister.block
-- Test replacing existing stream with same name
def testRegisterReplace : Async Unit := do
let channelA1 ← Std.Channel.new (α := Nat)
let channelA2 ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.empty.register "a" channelA1
assert! stream.size == 1
-- Register with same name should replace
let stream := stream.register "a" channelA2
assert! stream.size == 1
assert! stream.contains "a"
-- Send to new channel
await (← channelA2.send 42)
let (name, message) ← stream.recv
assert! name == "a" && message == 42
#eval testRegisterReplace.block
-- Test keys functionality
def testKeys : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let channelC ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("c", .mk channelC),
("a", .mk channelA),
("b", .mk channelB),
]
let keys := stream.keys
assert! keys.size == 3
assert! keys.contains "a"
assert! keys.contains "b"
assert! keys.contains "c"
#eval testKeys.block
-- Test get? functionality
def testGet : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
]
let selectorA := stream.get? "a"
let selectorC := stream.get? "c"
assert! selectorA.isSome
assert! selectorC.isNone
#eval testGet.block
-- Test filterByName functionality
def testFilterByName : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let channelC ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("prefix_a", .mk channelA),
("prefix_b", .mk channelB),
("other_c", .mk channelC),
]
let filtered := stream.filterByName (fun name => name.startsWith "prefix_")
assert! filtered.size == 2
assert! filtered.contains "prefix_a"
assert! filtered.contains "prefix_b"
assert! not (filtered.contains "other_c")
#eval testFilterByName.block
-- Test toArray functionality
def testToArray : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
]
let array := stream.toArray
assert! array.size == 2
let names := array.map (·.1)
assert! names.contains "a"
assert! names.contains "b"
#eval testToArray.block
-- Test multiple messages from same channel
def testMultipleFromSame : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
]
-- Send multiple messages to same channel
await (← channelA.send 1)
await (← channelA.send 2)
await (← channelB.send 10)
let (name1, msg1) ← stream.recv
let (name2, msg2) ← stream.recv
let (name3, msg3) ← stream.recv
-- Should receive all messages, order may vary but names should match sources
assert! (name1 == "a" && msg1 == 1) || (name1 == "a" && msg1 == 2) || (name1 == "b" && msg1 == 10)
assert! (name2 == "a" && msg2 == 1) || (name2 == "a" && msg2 == 2) || (name2 == "b" && msg2 == 10)
assert! (name3 == "a" && msg3 == 1) || (name3 == "a" && msg3 == 2) || (name3 == "b" && msg3 == 10)
#eval testMultipleFromSame.block
-- Test interleaved messages from different channels
def testInterleavedMessages : Async Unit := do
let channelA ← Std.Channel.new (α := String)
let channelB ← Std.Channel.new (α := String)
let channelC ← Std.Channel.new (α := String)
let stream := Std.StreamMap.ofArray #[
("first", .mk channelA),
("second", .mk channelB),
("third", .mk channelC),
]
-- Send messages in specific order
await (← channelB.send "msg1")
await (← channelA.send "msg2")
await (← channelC.send "msg3")
await (← channelA.send "msg4")
let results ← (List.range 4).mapM (fun _ => stream.recv)
-- Verify we got all messages (order may vary)
let messages := results.map (·.2)
assert! messages.contains "msg1"
assert! messages.contains "msg2"
assert! messages.contains "msg3"
assert! messages.contains "msg4"
#eval testInterleavedMessages.block
-- Test with different data typez
def testDifferentTypes : Async Unit := do
let channelStr ← Std.Channel.new (α := String)
let channelBool ← Std.Channel.new (α := String)
let stream := Std.StreamMap.ofArray #[
("strings", .mk channelStr),
("bools", .mk channelBool),
]
await (← channelStr.send "hello")
await (← channelBool.send "world")
let (name1, msg1) ← stream.recv
let (name2, msg2) ← stream.recv
assert! ((name1 == "strings" && msg1 == "hello") || (name1 == "bools" && msg1 == "world"))
assert! ((name2 == "strings" && msg2 == "hello") || (name2 == "bools" && msg2 == "world"))
assert! name1 != name2
#eval testDifferentTypes.block
-- Test unregister during operation
def testUnregisterDuringOperation : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let channelC ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
("c", .mk channelC),
]
await (← channelA.send 1)
await (← channelB.send 2)
await (← channelC.send 3)
assert! (← stream.tryRecv).isSome
assert! (← stream.tryRecv).isSome
assert! (← stream.tryRecv).isSome
let stream := stream.unregister "b"
assert! not (stream.contains "b")
assert! stream.size == 2
let newChannelD ← Std.Channel.new (α := Nat)
let stream := stream.register "d" newChannelD
await (← newChannelD.send 4)
let (name2, msg2) ← stream.recv
assert! name2 == "d" && msg2 == 4
#eval testUnregisterDuringOperation.block
-- Test selector functionality
def testSelector : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
]
let selector ← stream.selector
await (← channelA.send 42)
let result ← Selectable.one #[.case selector (fun x => pure x)]
assert! result.1 == "a" && result.2 == 42
#eval testSelector.block
def testClose : Async Unit := do
let channelA ← Std.Channel.new (α := Nat)
let channelB ← Std.Channel.new (α := Nat)
let stream := Std.StreamMap.ofArray #[
("a", .mk channelA),
("b", .mk channelB),
]
stream.close
#eval testClose.block
-- Test large number of channels
def testManyChannels : Async Unit := do
let channels : Vector _ 128 ← Vector.ofFnM (fun _ => Std.Channel.new (α := Nat))
let streamArray := channels.mapIdx (fun i ch => (s!"channel_{i}", .mk ch))
let stream := Std.StreamMap.ofArray streamArray.toArray
assert! stream.size == 128
await (← channels[3].send 100)
await (← channels[7].send 200)
let (name1, msg1) ← stream.recv
let (name2, msg2) ← stream.recv
assert! ((name1 == "channel_3" && msg1 == 100) || (name1 == "channel_7" && msg1 == 200))
assert! ((name2 == "channel_3" && msg2 == 100) || (name2 == "channel_7" && msg2 == 200))
assert! name1 != name2
#eval testManyChannels.block