diff --git a/src/Std/Internal/Async/Select.lean b/src/Std/Internal/Async/Select.lean index aa30feb20f..a94662b8b0 100644 --- a/src/Std/Internal/Async/Select.lean +++ b/src/Std/Internal/Async/Select.lean @@ -169,6 +169,83 @@ partial def Selectable.one (selectables : Array (Selectable α)) : Async α := d Async.ofPromise (pure promise) +/-- +Performs fair and data-loss free non-blocking multiplexing on the `Selectable`s in `selectables`. + +This function only tries the non-blocking `tryFn` for each `Selectable` without registering +waiters or blocking. It returns `some result` if any `Selectable` is immediately available, +or `none` if all would block. + +The protocol for this is as follows: +1. The `selectables` are shuffled randomly for fairness. +2. Run `Selector.tryFn` for each element in `selectables`. If any succeed, the corresponding + `Selectable.cont` is executed and its result is returned as `some result`. +3. If none succeed, `none` is returned immediately without blocking. +-/ +def Selectable.tryOne (selectables : Array (Selectable α)) : Async (Option α) := do + if selectables.isEmpty then + return none + + let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8)) + let gen := mkStdGen seed + let selectables := shuffleIt selectables gen + + for selectable in selectables do + if let some val ← selectable.selector.tryFn then + let result ← selectable.cont val + return some result + + return none + +/-- +Creates a `Selector` that performs fair and data-loss free multiplexing on multiple `Selectable`s. +This allows the multiplexing operation to be composed with other selectors. +-/ +def Selectable.combine (selectables : Array (Selectable α)) : IO (Selector α) := do + if selectables.isEmpty then + throw <| .userError "Selectable.one requires at least one Selectable" + + let seed := UInt64.toNat (ByteArray.toUInt64LE! (← IO.getRandomBytes 8)) + let gen := mkStdGen seed + let selectables := shuffleIt selectables gen + + return { + tryFn := do + for selectable in selectables do + if let some val ← selectable.selector.tryFn then + let result ← selectable.cont val + return some result + return none + + registerFn := fun waiter => do + for selectable in selectables do + let waiterPromise ← IO.Promise.new + let derivedWaiter := Waiter.mk waiter.finished waiterPromise + selectable.selector.registerFn derivedWaiter + + discard <| IO.bindTask (t := waiterPromise.result?) fun res? => do + match res? with + | none => return (Task.pure (.ok ())) + | some res => + let async : Async _ := do + let mainPromise := waiter.promise + + for selectable in selectables do + selectable.selector.unregisterFn + + try + let val ← IO.ofExcept res + let result ← selectable.cont val + mainPromise.resolve (.ok result) + catch e => + mainPromise.resolve (.error e) + async.toBaseIO + + unregisterFn := do + for selectable in selectables do + selectable.selector.unregisterFn + } + end Async end IO end Internal diff --git a/src/Std/Sync.lean b/src/Std/Sync.lean index b6899e2a56..3514226d34 100644 --- a/src/Std/Sync.lean +++ b/src/Std/Sync.lean @@ -14,5 +14,6 @@ public import Std.Sync.Barrier public import Std.Sync.SharedMutex public import Std.Sync.Notify public import Std.Sync.Broadcast +public import Std.Sync.StreamMap @[expose] public section diff --git a/src/Std/Sync/StreamMap.lean b/src/Std/Sync/StreamMap.lean new file mode 100644 index 0000000000..47b37ea089 --- /dev/null +++ b/src/Std/Sync/StreamMap.lean @@ -0,0 +1,147 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Std.Data +public import Init.System.Promise +public import Init.Data.Queue +public import Std.Internal.Async.IO +public import Std.Internal.Async.Select +public import Std.Internal.Async.Basic + +public section + +open Std.Internal.Async.IO +open Std.Internal.IO.Async + +/-! +This module provides `StreamMap`, a container that maps keys to async streams. +It allows for dynamic management of multiple named streams with async operations. +-/ + +namespace Std + +/-- +This is an existential wrapper for AsyncStream that is used for the `.ofArray` function +with `CoeDep` so it's easier and we keep StreamMap on `Type 0`. +-/ +inductive AnyAsyncStream (α : Type) where + | mk : {t : Type} → [AsyncStream t α] → t → AnyAsyncStream α + +def AnyAsyncStream.getSelector : AnyAsyncStream α → Selector α × IO Unit + | AnyAsyncStream.mk stream => (AsyncStream.next stream, AsyncStream.stop stream) + +instance [AsyncStream t α] : CoeDep t x (AnyAsyncStream α) where + coe := AnyAsyncStream.mk x + +/-- +A container that maps keys to async streams, enabling dynamic stream management +and unified selection operations across multiple named data sources. +-/ +structure StreamMap (α : Type) (β : Type) where + private mk :: + private streams : Array (α × Selector β × IO Unit) + +namespace StreamMap + +/-- +Create an empty StreamMap +-/ +def empty {α} : StreamMap α β := + { streams := #[] } + +/-- +Register a new async stream with the given name +-/ +def register [BEq α] [AsyncStream t β] (sm : StreamMap α β) (name : α) (reader : t) : StreamMap α β := + let newSelector := AsyncStream.next reader + let filteredStreams := sm.streams.filter (fun (n, _) => n != name) + { sm with streams := filteredStreams.push (name, newSelector, AsyncStream.stop reader) } + +/-- +Create a StreamMap from an array of named streams +-/ +def ofArray [BEq α] (streams : Array (α × AnyAsyncStream β)) : StreamMap α β := + let arrayOfSelectors := streams.map (fun (name, sel) => (name, sel.getSelector)) + { streams := arrayOfSelectors } + +/-- +Get a combined selector that returns the stream name and value +-/ +def selector (stream : StreamMap α β) : Async (Selector (α × β)) := + let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x)) + Selectable.combine selectables + +/-- +Wait for the first value inside of the stream map. +-/ +def recv (stream : StreamMap α β) : Async (α × β) := + let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x)) + Selectable.one selectables + +/-- +Wait for the first value inside of the stream map. +-/ +def tryRecv (stream : StreamMap α β) : Async (Option (α × β)) := + let selectables := stream.streams.map fun (name, selector) => Selectable.case selector.fst (fun x => pure (name, x)) + Selectable.tryOne selectables + +/-- +Remove a stream by name +-/ +def unregister [BEq α] (sm : StreamMap α β) (name : α) : StreamMap α β := + { sm with streams := sm.streams.filter (fun (n, _) => n != name) } + +/-- +Check if a stream with the given name exists +-/ +def contains [BEq α] (sm : StreamMap α β) (name : α) : Bool := + sm.streams.any (fun (n, _) => n == name) + +/-- +Get the number of registered streams +-/ +def size (sm : StreamMap α β) : Nat := + sm.streams.size + +/-- +Check if the StreamMap is empty +-/ +def isEmpty (sm : StreamMap α β) : Bool := + sm.streams.isEmpty + +/-- +Get all registered stream names +-/ +def keys (sm : StreamMap α β) : Array α := + sm.streams.map (·.1) + +/-- +Get a specific stream selector by name +-/ +def get? [BEq α] (sm : StreamMap α β) (name : α) : Option (Selector β) := + sm.streams.find? (fun (n, _) => n == name) |>.map (·.2.1) + +/-- +Filter streams based on their names +-/ +def filterByName (sm : StreamMap α β) (pred : α → Bool) : StreamMap α β := + { streams := sm.streams.filter (fun (name, _) => pred name) } + +/-- +Convert to array of name-selector pairs +-/ +def toArray (sm : StreamMap α β) : Array (α × Selector β) := + sm.streams.map (fun (n, s, _) => (n, s)) + +/-- +Cleanup function +-/ +def close (sm : StreamMap α β) : IO Unit := + sm.streams.forM (fun (_, _, cleanup) => cleanup) + +end StreamMap diff --git a/tests/lean/run/async_streammap.lean b/tests/lean/run/async_streammap.lean new file mode 100644 index 0000000000..46cf205728 --- /dev/null +++ b/tests/lean/run/async_streammap.lean @@ -0,0 +1,327 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +-- Test basic message reception from multiple channels +def testSimpleMessages : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + let channelC ← Std.Channel.new (α := Nat) + + let channel := Std.StreamMap.ofArray #[ + ("a", channelA), + ("b", channelB), + ("c", channelC), + ] + + await (← channelC.send 1) + let (name, message) ← channel.recv + assert! name == "c" && message == 1 + + await (← channelA.send 2) + let (name, message) ← channel.recv + assert! name == "a" && message == 2 + + await (← channelB.send 3) + let (name, message) ← channel.recv + assert! name == "b" && message == 3 + +#eval testSimpleMessages.block + +-- Test empty StreamMap +def testEmpty : Async Unit := do + let stream : Std.StreamMap String Nat := Std.StreamMap.empty + + assert! stream.isEmpty + assert! stream.size == 0 + assert! stream.keys.size == 0 + +#eval testEmpty.block + +-- Test register and unregister operations +def testRegisterUnregister : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.empty.register "a" channelA + + assert! stream.contains "a" + assert! not (stream.contains "b") + assert! stream.size == 1 + + let stream := stream.register "b" channelB + assert! stream.contains "a" && stream.contains "b" + assert! stream.size == 2 + + let stream := stream.unregister "a" + assert! not (stream.contains "a") + assert! stream.contains "b" + assert! stream.size == 1 + + let stream := stream.unregister "b" + assert! stream.isEmpty + +#eval testRegisterUnregister.block + +-- Test replacing existing stream with same name +def testRegisterReplace : Async Unit := do + let channelA1 ← Std.Channel.new (α := Nat) + let channelA2 ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.empty.register "a" channelA1 + assert! stream.size == 1 + + -- Register with same name should replace + let stream := stream.register "a" channelA2 + assert! stream.size == 1 + assert! stream.contains "a" + + -- Send to new channel + await (← channelA2.send 42) + let (name, message) ← stream.recv + assert! name == "a" && message == 42 + +#eval testRegisterReplace.block + +-- Test keys functionality +def testKeys : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + let channelC ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("c", .mk channelC), + ("a", .mk channelA), + ("b", .mk channelB), + ] + + let keys := stream.keys + assert! keys.size == 3 + assert! keys.contains "a" + assert! keys.contains "b" + assert! keys.contains "c" + +#eval testKeys.block + +-- Test get? functionality +def testGet : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ] + + let selectorA := stream.get? "a" + let selectorC := stream.get? "c" + + assert! selectorA.isSome + assert! selectorC.isNone + +#eval testGet.block + +-- Test filterByName functionality +def testFilterByName : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + let channelC ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("prefix_a", .mk channelA), + ("prefix_b", .mk channelB), + ("other_c", .mk channelC), + ] + + let filtered := stream.filterByName (fun name => name.startsWith "prefix_") + + assert! filtered.size == 2 + assert! filtered.contains "prefix_a" + assert! filtered.contains "prefix_b" + assert! not (filtered.contains "other_c") + +#eval testFilterByName.block + +-- Test toArray functionality +def testToArray : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ] + + let array := stream.toArray + assert! array.size == 2 + + let names := array.map (·.1) + assert! names.contains "a" + assert! names.contains "b" + +#eval testToArray.block + +-- Test multiple messages from same channel +def testMultipleFromSame : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ] + + -- Send multiple messages to same channel + await (← channelA.send 1) + await (← channelA.send 2) + await (← channelB.send 10) + + let (name1, msg1) ← stream.recv + let (name2, msg2) ← stream.recv + let (name3, msg3) ← stream.recv + + -- Should receive all messages, order may vary but names should match sources + assert! (name1 == "a" && msg1 == 1) || (name1 == "a" && msg1 == 2) || (name1 == "b" && msg1 == 10) + assert! (name2 == "a" && msg2 == 1) || (name2 == "a" && msg2 == 2) || (name2 == "b" && msg2 == 10) + assert! (name3 == "a" && msg3 == 1) || (name3 == "a" && msg3 == 2) || (name3 == "b" && msg3 == 10) + +#eval testMultipleFromSame.block + +-- Test interleaved messages from different channels +def testInterleavedMessages : Async Unit := do + let channelA ← Std.Channel.new (α := String) + let channelB ← Std.Channel.new (α := String) + let channelC ← Std.Channel.new (α := String) + + let stream := Std.StreamMap.ofArray #[ + ("first", .mk channelA), + ("second", .mk channelB), + ("third", .mk channelC), + ] + + -- Send messages in specific order + await (← channelB.send "msg1") + await (← channelA.send "msg2") + await (← channelC.send "msg3") + await (← channelA.send "msg4") + + let results ← (List.range 4).mapM (fun _ => stream.recv) + + -- Verify we got all messages (order may vary) + let messages := results.map (·.2) + assert! messages.contains "msg1" + assert! messages.contains "msg2" + assert! messages.contains "msg3" + assert! messages.contains "msg4" + +#eval testInterleavedMessages.block + +-- Test with different data typez +def testDifferentTypes : Async Unit := do + let channelStr ← Std.Channel.new (α := String) + let channelBool ← Std.Channel.new (α := String) + + let stream := Std.StreamMap.ofArray #[ + ("strings", .mk channelStr), + ("bools", .mk channelBool), + ] + + await (← channelStr.send "hello") + await (← channelBool.send "world") + + let (name1, msg1) ← stream.recv + let (name2, msg2) ← stream.recv + + assert! ((name1 == "strings" && msg1 == "hello") || (name1 == "bools" && msg1 == "world")) + assert! ((name2 == "strings" && msg2 == "hello") || (name2 == "bools" && msg2 == "world")) + assert! name1 != name2 + +#eval testDifferentTypes.block + +-- Test unregister during operation +def testUnregisterDuringOperation : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + let channelC ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ("c", .mk channelC), + ] + + await (← channelA.send 1) + await (← channelB.send 2) + await (← channelC.send 3) + + assert! (← stream.tryRecv).isSome + assert! (← stream.tryRecv).isSome + assert! (← stream.tryRecv).isSome + + let stream := stream.unregister "b" + assert! not (stream.contains "b") + assert! stream.size == 2 + + let newChannelD ← Std.Channel.new (α := Nat) + let stream := stream.register "d" newChannelD + + await (← newChannelD.send 4) + let (name2, msg2) ← stream.recv + + assert! name2 == "d" && msg2 == 4 + +#eval testUnregisterDuringOperation.block + +-- Test selector functionality +def testSelector : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ] + + let selector ← stream.selector + + await (← channelA.send 42) + + let result ← Selectable.one #[.case selector (fun x => pure x)] + assert! result.1 == "a" && result.2 == 42 + +#eval testSelector.block + +def testClose : Async Unit := do + let channelA ← Std.Channel.new (α := Nat) + let channelB ← Std.Channel.new (α := Nat) + + let stream := Std.StreamMap.ofArray #[ + ("a", .mk channelA), + ("b", .mk channelB), + ] + + stream.close + +#eval testClose.block + +-- Test large number of channels +def testManyChannels : Async Unit := do + let channels : Vector _ 128 ← Vector.ofFnM (fun _ => Std.Channel.new (α := Nat)) + + let streamArray := channels.mapIdx (fun i ch => (s!"channel_{i}", .mk ch)) + let stream := Std.StreamMap.ofArray streamArray.toArray + + assert! stream.size == 128 + + await (← channels[3].send 100) + await (← channels[7].send 200) + + let (name1, msg1) ← stream.recv + let (name2, msg2) ← stream.recv + + assert! ((name1 == "channel_3" && msg1 == 100) || (name1 == "channel_7" && msg1 == 200)) + assert! ((name2 == "channel_3" && msg2 == 100) || (name2 == "channel_7" && msg2 == 200)) + assert! name1 != name2 + +#eval testManyChannels.block