feat: add Std.Notify type (#10368)
This PR adds `Notify` that is a structure that is similar to `CondVar`
but it's used for concurrency. The main difference between
`Std.Sync.Notify` and `Std.Condvar` is that depends on a `Std.Mutex` and
blocks the entire thread that the `Task` is using while waiting. If I
try to use it with async and a lot of `Task`s like this:
```lean
def condvar : Async Unit := do
let condvar ← Std.Condvar.new
let mutex ← Std.Mutex.new false
for i in [0:threads] do
background do
IO.println s!"start {i + 1}"
await =<< (show IO (ETask _ _) from IO.asTask (mutex.atomically (condvar.wait mutex)))
IO.println s!"end {i + 1}"
IO.sleep 2000
condvar.notifyAll
```
It causes some weird behavior because some tasks start running and get
notified, while others don’t, because `condvar.wait` blocks the `Task`
entire task and right now afaik it blocks an entire thread and cannot be
paused while doing blocking operations like that.
`Notify` uses `Promise`s so it’s better suited for concurrency. The
`Task` is not blocked while waiting for a notification which makes it
simpler for use cases that just involve notifying:
```lean
def notify : Async Unit := do
let notify ← Std.Notify.new
for i in [0:threads] do
background do
IO.println s!"start {i}"
notify.wait
IO.println s!"end {i}"
IO.sleep 2000
notify.notify
```
This PR depends on: #10366, #10367 and #10370.
This commit is contained in:
parent
781e3c6add
commit
161a1c06a2
3 changed files with 249 additions and 0 deletions
|
|
@ -12,5 +12,6 @@ public import Std.Sync.Mutex
|
|||
public import Std.Sync.RecursiveMutex
|
||||
public import Std.Sync.Barrier
|
||||
public import Std.Sync.SharedMutex
|
||||
public import Std.Sync.Notify
|
||||
|
||||
@[expose] public section
|
||||
|
|
|
|||
140
src/Std/Sync/Notify.lean
Normal file
140
src/Std/Sync/Notify.lean
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Init.System.Promise
|
||||
public import Init.Data.Queue
|
||||
public import Std.Sync.Mutex
|
||||
public import Std.Internal.Async.Select
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
This module contains the implementation of `Std.Notify`. `Std.Notify` provides a lightweight
|
||||
notification primitive for signaling between tasks or threads. It supports both synchronous
|
||||
and asynchronous waiting, and is useful for cases where you want to notify one or more waiters
|
||||
that an event has occurred.
|
||||
|
||||
Unlike a channel, `Std.Notify` does not buffer messages or carry data. It's simply a trigger.
|
||||
If no one is waiting, notifications are lost. If one or more waiters are present, exactly one
|
||||
will be woken up per notification.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
open Std.Internal.IO.Async
|
||||
|
||||
inductive Notify.Consumer (α : Type) where
|
||||
| normal (promise : IO.Promise α)
|
||||
| select (finished : Waiter α)
|
||||
|
||||
def Notify.Consumer.resolve (c : Consumer α) (x : α) : BaseIO Bool := do
|
||||
match c with
|
||||
| .normal promise =>
|
||||
promise.resolve x
|
||||
return true
|
||||
| .select waiter =>
|
||||
let lose := return false
|
||||
let win promise := do
|
||||
promise.resolve (.ok x)
|
||||
return true
|
||||
waiter.race lose win
|
||||
|
||||
/--
|
||||
The central state structure for an a `Notify`.
|
||||
-/
|
||||
structure Notify.State where
|
||||
|
||||
/--
|
||||
Consumers that are blocked waiting for a notification.
|
||||
--/
|
||||
consumers : Std.Queue (Notify.Consumer Unit)
|
||||
|
||||
/--
|
||||
A notify is a synchronization primitive that allows multiple consumers to wait
|
||||
until notify is called.
|
||||
-/
|
||||
structure Notify where
|
||||
state : Std.Mutex Notify.State
|
||||
|
||||
namespace Notify
|
||||
|
||||
/--
|
||||
Create a new notify.
|
||||
-/
|
||||
def new : BaseIO Notify := do
|
||||
return { state := ← Std.Mutex.new { consumers := ∅ } }
|
||||
|
||||
/--
|
||||
Notify all currently waiting consumers.
|
||||
-/
|
||||
def notify (x : Notify) : BaseIO Unit := do
|
||||
x.state.atomically do
|
||||
let mut st ← get
|
||||
|
||||
let mut remainingConsumers := st.consumers
|
||||
st := { st with consumers := ∅ }
|
||||
|
||||
while true do
|
||||
if let some (consumer, rest) := remainingConsumers.dequeue? then
|
||||
remainingConsumers := rest
|
||||
discard <| consumer.resolve ()
|
||||
else
|
||||
break
|
||||
|
||||
set st
|
||||
|
||||
/--
|
||||
Notify exactly one waiting consumer (if any). Returns true if a consumer
|
||||
was notified, false if no consumers were waiting.
|
||||
-/
|
||||
def notifyOne (x : Notify) : BaseIO Bool := do
|
||||
x.state.atomically do
|
||||
let mut st ← get
|
||||
|
||||
if let some (consumer, rest) := st.consumers.dequeue? then
|
||||
st := { st with consumers := rest }
|
||||
set st
|
||||
consumer.resolve ()
|
||||
else
|
||||
return false
|
||||
|
||||
/--
|
||||
Wait to be notified. Returns a task that completes when notify is called.
|
||||
Note: if notify was called before wait, this will wait for the next notify call.
|
||||
-/
|
||||
def wait (x : Notify) : IO (AsyncTask Unit) :=
|
||||
x.state.atomically do
|
||||
let promise ← IO.Promise.new
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) }
|
||||
IO.bindTask promise.result? fun
|
||||
| some res => pure <| Task.pure (.ok res)
|
||||
| none => throw (IO.userError "notify dropped")
|
||||
|
||||
/--
|
||||
Creates a selector that waits for notifications
|
||||
-/
|
||||
def selector (notify : Notify) : Selector Unit := {
|
||||
tryFn := do
|
||||
return none
|
||||
|
||||
registerFn := fun waiter => do
|
||||
notify.state.atomically do
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
|
||||
|
||||
unregisterFn := do
|
||||
notify.state.atomically do
|
||||
let st ← get
|
||||
|
||||
let consumers ← st.consumers.filterM fun
|
||||
| .normal _ => return true
|
||||
| .select waiter => return !(← waiter.checkFinished)
|
||||
|
||||
set { st with consumers }
|
||||
}
|
||||
|
||||
end Notify
|
||||
end Std
|
||||
108
tests/lean/run/sync_notify.lean
Normal file
108
tests/lean/run/sync_notify.lean
Normal file
|
|
@ -0,0 +1,108 @@
|
|||
import Std.Internal.Async
|
||||
import Std.Sync
|
||||
|
||||
open Std.Internal.IO Async
|
||||
|
||||
-- Test basic wait and notifyOne functionality
|
||||
def testBasicWaitNotifyOne : Async Unit := do
|
||||
let notify ← Std.Notify.new
|
||||
let waitTask ← notify.wait
|
||||
|
||||
assert! (← waitTask.getState) = .waiting
|
||||
discard <| notify.notifyOne
|
||||
await waitTask
|
||||
assert! (← waitTask.getState) = .finished
|
||||
|
||||
#eval testBasicWaitNotifyOne.block
|
||||
|
||||
-- Test multiple waiters with notifyOne (only one should be notified)
|
||||
def testMultipleWaitersNotifyOne : Async Unit := do
|
||||
let notify ← Std.Notify.new
|
||||
let task1 ← notify.wait
|
||||
let task2 ← notify.wait
|
||||
let task3 ← notify.wait
|
||||
|
||||
assert! (← task1.getState) = .waiting
|
||||
assert! (← task2.getState) = .waiting
|
||||
assert! (← task3.getState) = .waiting
|
||||
|
||||
discard <| notify.notifyOne
|
||||
|
||||
IO.sleep 100
|
||||
|
||||
let states ← [task1, task2, task3].mapM (fun t => t.getState)
|
||||
let finishedCount := states.filter (· == .finished) |>.length
|
||||
let waitingCount := states.filter (· == .waiting) |>.length
|
||||
|
||||
assert! finishedCount == 1
|
||||
assert! waitingCount == 2
|
||||
|
||||
discard <| notify.notifyOne
|
||||
|
||||
#eval testMultipleWaitersNotifyOne.block
|
||||
|
||||
-- Test multiple waiters with notify (all should be notified)
|
||||
def testMultipleWaitersNotifyAll : Async Unit := do
|
||||
let notify ← Std.Notify.new
|
||||
let task1 ← notify.wait
|
||||
let task2 ← notify.wait
|
||||
let task3 ← notify.wait
|
||||
|
||||
assert! (← task1.getState) = .waiting
|
||||
assert! (← task2.getState) = .waiting
|
||||
assert! (← task3.getState) = .waiting
|
||||
|
||||
discard <| notify.notify
|
||||
|
||||
await task1
|
||||
await task2
|
||||
await task3
|
||||
|
||||
assert! (← task1.getState) = .finished
|
||||
assert! (← task2.getState) = .finished
|
||||
assert! (← task3.getState) = .finished
|
||||
|
||||
#eval testMultipleWaitersNotifyAll.block
|
||||
|
||||
-- Test sequential notification
|
||||
def testSequentialNotification : Async Unit := do
|
||||
let notify ← Std.Notify.new
|
||||
let task1 ← notify.wait
|
||||
let task2 ← notify.wait
|
||||
let task3 ← notify.wait
|
||||
|
||||
discard <| notify.notifyOne
|
||||
await task1
|
||||
assert! (← task1.getState) = .finished
|
||||
assert! (← task2.getState) = .waiting
|
||||
assert! (← task3.getState) = .waiting
|
||||
|
||||
discard <| notify.notifyOne
|
||||
await task2
|
||||
assert! (← task2.getState) = .finished
|
||||
assert! (← task3.getState) = .waiting
|
||||
|
||||
discard <| notify.notifyOne
|
||||
await task3
|
||||
assert! (← task3.getState) = .finished
|
||||
|
||||
#eval testSequentialNotification.block
|
||||
|
||||
def testReuseAfterCompletion : Async Unit := do
|
||||
let notify ← Std.Notify.new
|
||||
|
||||
let task1 ← notify.wait
|
||||
discard <| notify.notifyOne
|
||||
await task1
|
||||
assert! (← task1.getState) = .finished
|
||||
|
||||
let task2 ← notify.wait
|
||||
let task3 ← notify.wait
|
||||
discard <| notify.notify
|
||||
await task2
|
||||
await task3
|
||||
|
||||
assert! (← task2.getState) = .finished
|
||||
assert! (← task3.getState) = .finished
|
||||
|
||||
#eval testReuseAfterCompletion.block
|
||||
Loading…
Add table
Reference in a new issue