From 161a1c06a2e8ea736e4df1a1ef20b715a5b5bd9f Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Wed, 24 Sep 2025 00:35:08 -0300 Subject: [PATCH] feat: add `Std.Notify` type (#10368) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR adds `Notify` that is a structure that is similar to `CondVar` but it's used for concurrency. The main difference between `Std.Sync.Notify` and `Std.Condvar` is that depends on a `Std.Mutex` and blocks the entire thread that the `Task` is using while waiting. If I try to use it with async and a lot of `Task`s like this: ```lean def condvar : Async Unit := do let condvar ← Std.Condvar.new let mutex ← Std.Mutex.new false for i in [0:threads] do background do IO.println s!"start {i + 1}" await =<< (show IO (ETask _ _) from IO.asTask (mutex.atomically (condvar.wait mutex))) IO.println s!"end {i + 1}" IO.sleep 2000 condvar.notifyAll ``` It causes some weird behavior because some tasks start running and get notified, while others don’t, because `condvar.wait` blocks the `Task` entire task and right now afaik it blocks an entire thread and cannot be paused while doing blocking operations like that. `Notify` uses `Promise`s so it’s better suited for concurrency. The `Task` is not blocked while waiting for a notification which makes it simpler for use cases that just involve notifying: ```lean def notify : Async Unit := do let notify ← Std.Notify.new for i in [0:threads] do background do IO.println s!"start {i}" notify.wait IO.println s!"end {i}" IO.sleep 2000 notify.notify ``` This PR depends on: #10366, #10367 and #10370. --- src/Std/Sync.lean | 1 + src/Std/Sync/Notify.lean | 140 ++++++++++++++++++++++++++++++++ tests/lean/run/sync_notify.lean | 108 ++++++++++++++++++++++++ 3 files changed, 249 insertions(+) create mode 100644 src/Std/Sync/Notify.lean create mode 100644 tests/lean/run/sync_notify.lean diff --git a/src/Std/Sync.lean b/src/Std/Sync.lean index a366781cb0..410b8d1fe0 100644 --- a/src/Std/Sync.lean +++ b/src/Std/Sync.lean @@ -12,5 +12,6 @@ public import Std.Sync.Mutex public import Std.Sync.RecursiveMutex public import Std.Sync.Barrier public import Std.Sync.SharedMutex +public import Std.Sync.Notify @[expose] public section diff --git a/src/Std/Sync/Notify.lean b/src/Std/Sync/Notify.lean new file mode 100644 index 0000000000..9945c81bff --- /dev/null +++ b/src/Std/Sync/Notify.lean @@ -0,0 +1,140 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Init.System.Promise +public import Init.Data.Queue +public import Std.Sync.Mutex +public import Std.Internal.Async.Select + +public section + +/-! +This module contains the implementation of `Std.Notify`. `Std.Notify` provides a lightweight +notification primitive for signaling between tasks or threads. It supports both synchronous +and asynchronous waiting, and is useful for cases where you want to notify one or more waiters +that an event has occurred. + +Unlike a channel, `Std.Notify` does not buffer messages or carry data. It's simply a trigger. +If no one is waiting, notifications are lost. If one or more waiters are present, exactly one +will be woken up per notification. +-/ + +namespace Std +open Std.Internal.IO.Async + +inductive Notify.Consumer (α : Type) where + | normal (promise : IO.Promise α) + | select (finished : Waiter α) + +def Notify.Consumer.resolve (c : Consumer α) (x : α) : BaseIO Bool := do + match c with + | .normal promise => + promise.resolve x + return true + | .select waiter => + let lose := return false + let win promise := do + promise.resolve (.ok x) + return true + waiter.race lose win + +/-- +The central state structure for an a `Notify`. +-/ +structure Notify.State where + + /-- + Consumers that are blocked waiting for a notification. + --/ + consumers : Std.Queue (Notify.Consumer Unit) + +/-- +A notify is a synchronization primitive that allows multiple consumers to wait +until notify is called. +-/ +structure Notify where + state : Std.Mutex Notify.State + +namespace Notify + +/-- +Create a new notify. +-/ +def new : BaseIO Notify := do + return { state := ← Std.Mutex.new { consumers := ∅ } } + +/-- +Notify all currently waiting consumers. +-/ +def notify (x : Notify) : BaseIO Unit := do + x.state.atomically do + let mut st ← get + + let mut remainingConsumers := st.consumers + st := { st with consumers := ∅ } + + while true do + if let some (consumer, rest) := remainingConsumers.dequeue? then + remainingConsumers := rest + discard <| consumer.resolve () + else + break + + set st + +/-- +Notify exactly one waiting consumer (if any). Returns true if a consumer +was notified, false if no consumers were waiting. +-/ +def notifyOne (x : Notify) : BaseIO Bool := do + x.state.atomically do + let mut st ← get + + if let some (consumer, rest) := st.consumers.dequeue? then + st := { st with consumers := rest } + set st + consumer.resolve () + else + return false + +/-- +Wait to be notified. Returns a task that completes when notify is called. +Note: if notify was called before wait, this will wait for the next notify call. +-/ +def wait (x : Notify) : IO (AsyncTask Unit) := + x.state.atomically do + let promise ← IO.Promise.new + modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) } + IO.bindTask promise.result? fun + | some res => pure <| Task.pure (.ok res) + | none => throw (IO.userError "notify dropped") + +/-- +Creates a selector that waits for notifications +-/ +def selector (notify : Notify) : Selector Unit := { + tryFn := do + return none + + registerFn := fun waiter => do + notify.state.atomically do + modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) } + + unregisterFn := do + notify.state.atomically do + let st ← get + + let consumers ← st.consumers.filterM fun + | .normal _ => return true + | .select waiter => return !(← waiter.checkFinished) + + set { st with consumers } +} + +end Notify +end Std diff --git a/tests/lean/run/sync_notify.lean b/tests/lean/run/sync_notify.lean new file mode 100644 index 0000000000..0f091c3e52 --- /dev/null +++ b/tests/lean/run/sync_notify.lean @@ -0,0 +1,108 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +-- Test basic wait and notifyOne functionality +def testBasicWaitNotifyOne : Async Unit := do + let notify ← Std.Notify.new + let waitTask ← notify.wait + + assert! (← waitTask.getState) = .waiting + discard <| notify.notifyOne + await waitTask + assert! (← waitTask.getState) = .finished + +#eval testBasicWaitNotifyOne.block + +-- Test multiple waiters with notifyOne (only one should be notified) +def testMultipleWaitersNotifyOne : Async Unit := do + let notify ← Std.Notify.new + let task1 ← notify.wait + let task2 ← notify.wait + let task3 ← notify.wait + + assert! (← task1.getState) = .waiting + assert! (← task2.getState) = .waiting + assert! (← task3.getState) = .waiting + + discard <| notify.notifyOne + + IO.sleep 100 + + let states ← [task1, task2, task3].mapM (fun t => t.getState) + let finishedCount := states.filter (· == .finished) |>.length + let waitingCount := states.filter (· == .waiting) |>.length + + assert! finishedCount == 1 + assert! waitingCount == 2 + + discard <| notify.notifyOne + +#eval testMultipleWaitersNotifyOne.block + +-- Test multiple waiters with notify (all should be notified) +def testMultipleWaitersNotifyAll : Async Unit := do + let notify ← Std.Notify.new + let task1 ← notify.wait + let task2 ← notify.wait + let task3 ← notify.wait + + assert! (← task1.getState) = .waiting + assert! (← task2.getState) = .waiting + assert! (← task3.getState) = .waiting + + discard <| notify.notify + + await task1 + await task2 + await task3 + + assert! (← task1.getState) = .finished + assert! (← task2.getState) = .finished + assert! (← task3.getState) = .finished + +#eval testMultipleWaitersNotifyAll.block + +-- Test sequential notification +def testSequentialNotification : Async Unit := do + let notify ← Std.Notify.new + let task1 ← notify.wait + let task2 ← notify.wait + let task3 ← notify.wait + + discard <| notify.notifyOne + await task1 + assert! (← task1.getState) = .finished + assert! (← task2.getState) = .waiting + assert! (← task3.getState) = .waiting + + discard <| notify.notifyOne + await task2 + assert! (← task2.getState) = .finished + assert! (← task3.getState) = .waiting + + discard <| notify.notifyOne + await task3 + assert! (← task3.getState) = .finished + +#eval testSequentialNotification.block + +def testReuseAfterCompletion : Async Unit := do + let notify ← Std.Notify.new + + let task1 ← notify.wait + discard <| notify.notifyOne + await task1 + assert! (← task1.getState) = .finished + + let task2 ← notify.wait + let task3 ← notify.wait + discard <| notify.notify + await task2 + await task3 + + assert! (← task2.getState) = .finished + assert! (← task3.getState) = .finished + +#eval testReuseAfterCompletion.block