feat: add Std.CancellationToken type (#10510)
This PR adds a `Std.CancellationToken` type
This commit is contained in:
parent
ad701b577b
commit
6964a15b5d
4 changed files with 413 additions and 2 deletions
|
|
@ -15,5 +15,6 @@ public import Std.Sync.SharedMutex
|
|||
public import Std.Sync.Notify
|
||||
public import Std.Sync.Broadcast
|
||||
public import Std.Sync.StreamMap
|
||||
public import Std.Sync.CancellationToken
|
||||
|
||||
@[expose] public section
|
||||
|
|
|
|||
154
src/Std/Sync/CancellationToken.lean
Normal file
154
src/Std/Sync/CancellationToken.lean
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
/-
|
||||
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
|
||||
Released under Apache 2.0 license as described in the file LICENSE.
|
||||
Authors: Sofia Rodrigues
|
||||
-/
|
||||
module
|
||||
|
||||
prelude
|
||||
public import Std.Data
|
||||
public import Init.System.Promise
|
||||
public import Init.Data.Queue
|
||||
public import Std.Sync.Mutex
|
||||
public import Std.Internal.Async.Select
|
||||
|
||||
public section
|
||||
|
||||
/-!
|
||||
This module contains the implementation of `Std.CancellationToken`. `Std.CancellationToken` provides a
|
||||
cancellation primitive for signaling cancellation between tasks or threads. It supports both synchronous
|
||||
and asynchronous waiting, and is useful for cases where you want to notify one or more waiters
|
||||
that a cancellation has occurred.
|
||||
-/
|
||||
|
||||
namespace Std
|
||||
open Std.Internal.IO.Async
|
||||
|
||||
inductive CancellationToken.Consumer where
|
||||
| normal (promise : IO.Promise Unit)
|
||||
| select (finished : Waiter Unit)
|
||||
|
||||
def CancellationToken.Consumer.resolve (c : Consumer) : BaseIO Bool := do
|
||||
match c with
|
||||
| .normal promise =>
|
||||
promise.resolve ()
|
||||
return true
|
||||
| .select waiter =>
|
||||
let lose := return false
|
||||
let win promise := do
|
||||
promise.resolve (.ok ())
|
||||
return true
|
||||
waiter.race lose win
|
||||
|
||||
/--
|
||||
The central state structure for a `CancellationToken`.
|
||||
-/
|
||||
structure CancellationToken.State where
|
||||
/--
|
||||
Whether this token has been cancelled.
|
||||
-/
|
||||
cancelled : Bool
|
||||
|
||||
/--
|
||||
Consumers that are blocked waiting for cancellation.
|
||||
--/
|
||||
consumers : Std.Queue (CancellationToken.Consumer)
|
||||
|
||||
/--
|
||||
A cancellation token is a synchronization primitive that allows multiple consumers to wait
|
||||
until cancellation is requested.
|
||||
-/
|
||||
structure CancellationToken where
|
||||
state : Std.Mutex CancellationToken.State
|
||||
|
||||
namespace CancellationToken
|
||||
|
||||
/--
|
||||
Create a new cancellation token.
|
||||
-/
|
||||
def new : BaseIO CancellationToken := do
|
||||
return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } }
|
||||
|
||||
/--
|
||||
Cancel the token, notifying all currently waiting consumers with `true`.
|
||||
Once cancelled, the token remains cancelled.
|
||||
-/
|
||||
def cancel (x : CancellationToken) : BaseIO Unit := do
|
||||
x.state.atomically do
|
||||
let mut st ← get
|
||||
|
||||
if st.cancelled then
|
||||
return
|
||||
|
||||
let mut remainingConsumers := st.consumers
|
||||
st := { cancelled := true, consumers := ∅ }
|
||||
|
||||
while true do
|
||||
if let some (consumer, rest) := remainingConsumers.dequeue? then
|
||||
remainingConsumers := rest
|
||||
discard <| consumer.resolve
|
||||
else
|
||||
break
|
||||
|
||||
set st
|
||||
|
||||
/--
|
||||
Check if the token is cancelled.
|
||||
-/
|
||||
def isCancelled (x : CancellationToken) : BaseIO Bool := do
|
||||
x.state.atomically do
|
||||
let st ← get
|
||||
return st.cancelled
|
||||
|
||||
/--
|
||||
Wait for cancellation. Returns a task that completes when cancelled,
|
||||
-/
|
||||
def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
|
||||
x.state.atomically do
|
||||
let st ← get
|
||||
|
||||
if st.cancelled then
|
||||
return Task.pure (.ok ())
|
||||
|
||||
let promise ← IO.Promise.new
|
||||
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) }
|
||||
|
||||
IO.bindTask promise.result? fun
|
||||
| some _ => pure <| Task.pure (.ok ())
|
||||
| none => throw (IO.userError "cancellation token dropped")
|
||||
|
||||
/--
|
||||
Creates a selector that waits for cancellation
|
||||
-/
|
||||
def selector (token : CancellationToken) : Selector Unit := {
|
||||
tryFn := do
|
||||
if ← token.isCancelled then
|
||||
return some ()
|
||||
else
|
||||
return none
|
||||
|
||||
registerFn := fun waiter => do
|
||||
token.state.atomically do
|
||||
let st ← get
|
||||
|
||||
if st.cancelled then
|
||||
discard <| waiter.race (return false) (fun promise => do
|
||||
promise.resolve (.ok ())
|
||||
return true)
|
||||
else
|
||||
modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) }
|
||||
|
||||
unregisterFn := do
|
||||
token.state.atomically do
|
||||
let st ← get
|
||||
|
||||
let consumers ← st.consumers.filterM fun
|
||||
| .normal _ => return true
|
||||
| .select waiter => return !(← waiter.checkFinished)
|
||||
|
||||
set { st with consumers }
|
||||
}
|
||||
|
||||
end CancellationToken
|
||||
end Std
|
||||
256
tests/lean/run/async_cancellation.lean
Normal file
256
tests/lean/run/async_cancellation.lean
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
import Std.Internal.Async
|
||||
import Std.Sync
|
||||
|
||||
open Std.Internal.IO Async
|
||||
|
||||
def cancellableSelector [Monad m] [MonadLift IO m] [MonadAsync AsyncTask m] (fn : Std.CancellationToken → m α) : m (Selector (Except IO.Error α)) := do
|
||||
let signal ← Std.CancellationToken.new
|
||||
let promise ← IO.Promise.new
|
||||
let result : AsyncTask α ← async (fn signal)
|
||||
|
||||
IO.chainTask result (promise.resolve ·)
|
||||
|
||||
return {
|
||||
tryFn := do
|
||||
if ← promise.isResolved
|
||||
then return promise.result!.get
|
||||
else return none
|
||||
|
||||
registerFn := fun waiter => do
|
||||
discard <| IO.mapTask (t := promise.result?) fun
|
||||
| none => pure ()
|
||||
| some res => do
|
||||
if ¬ (← signal.isCancelled) then
|
||||
waiter.race (pure ()) (·.resolve (.ok res))
|
||||
|
||||
unregisterFn := do
|
||||
signal.cancel
|
||||
}
|
||||
|
||||
-- Test basic cancellation token creation and cancellation
|
||||
def testBasicCancellation : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
assert! not (← token.isCancelled)
|
||||
token.cancel
|
||||
assert! (← token.isCancelled)
|
||||
|
||||
#eval testBasicCancellation.block
|
||||
|
||||
-- Test selector functionality
|
||||
def testSelector : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
let completed ← Std.Mutex.new false
|
||||
|
||||
let task ← async do
|
||||
Selectable.one #[.case token.selector (fun _ => pure ())]
|
||||
completed.atomically (set true)
|
||||
|
||||
assert! not (← completed.atomically get)
|
||||
|
||||
token.cancel
|
||||
await task
|
||||
|
||||
assert! (← completed.atomically get)
|
||||
|
||||
#eval testSelector.block
|
||||
|
||||
-- Test selector with already cancelled token
|
||||
def testSelectorAlreadyCancelled : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
token.cancel
|
||||
|
||||
let completed ← Std.Mutex.new false
|
||||
|
||||
let task ← async do
|
||||
Selectable.one #[.case token.selector pure]
|
||||
completed.atomically (set true)
|
||||
|
||||
await task
|
||||
assert! (← completed.atomically get)
|
||||
|
||||
#eval testSelectorAlreadyCancelled.block
|
||||
|
||||
-- Test multiple selectors on same token
|
||||
def testMultipleSelectors : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
let completed1 ← Std.Mutex.new false
|
||||
let completed2 ← Std.Mutex.new false
|
||||
let completed3 ← Std.Mutex.new false
|
||||
|
||||
let task1 ← async do
|
||||
Selectable.one #[.case token.selector pure]
|
||||
completed1.atomically (set true)
|
||||
|
||||
let task2 ← async do
|
||||
Selectable.one #[.case token.selector pure]
|
||||
completed2.atomically (set true)
|
||||
|
||||
let task3 ← async do
|
||||
Selectable.one #[.case token.selector pure]
|
||||
completed3.atomically (set true)
|
||||
|
||||
-- Verify none completed initially
|
||||
assert! not (← completed1.atomically get)
|
||||
assert! not (← completed2.atomically get)
|
||||
assert! not (← completed3.atomically get)
|
||||
|
||||
-- Cancel token
|
||||
token.cancel
|
||||
|
||||
-- Wait for all tasks to complete
|
||||
await task1
|
||||
await task2
|
||||
await task3
|
||||
|
||||
-- Verify all completed
|
||||
assert! (← completed1.atomically get)
|
||||
assert! (← completed2.atomically get)
|
||||
assert! (← completed3.atomically get)
|
||||
|
||||
#eval testMultipleSelectors.block
|
||||
|
||||
-- Test cancellation during async operations
|
||||
def testCancellationDuringOperation : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
let operationStarted ← Std.Mutex.new false
|
||||
let operationCompleted ← Std.Mutex.new false
|
||||
let operationCancelled ← Std.Mutex.new false
|
||||
|
||||
let task ← async do
|
||||
operationStarted.atomically (set true)
|
||||
try
|
||||
for _ in List.range 100 do
|
||||
if (← token.isCancelled) then
|
||||
operationCancelled.atomically (set true)
|
||||
return
|
||||
Async.sleep 5
|
||||
operationCompleted.atomically (set true)
|
||||
catch _ =>
|
||||
operationCancelled.atomically (set true)
|
||||
|
||||
-- Wait for operation to start
|
||||
while not (← operationStarted.atomically get) do
|
||||
Async.sleep 1
|
||||
|
||||
-- Cancel after operation started
|
||||
Async.sleep 20
|
||||
token.cancel
|
||||
|
||||
await task
|
||||
|
||||
-- Verify operation was cancelled, not completed
|
||||
assert! (← operationStarted.atomically get)
|
||||
assert! (← operationCancelled.atomically get)
|
||||
assert! not (← operationCompleted.atomically get)
|
||||
|
||||
#eval testCancellationDuringOperation.block
|
||||
|
||||
-- Test token reuse (create new tokens)
|
||||
def testTokenReuse : Async Unit := do
|
||||
let token1 ← Std.CancellationToken.new
|
||||
|
||||
-- First use
|
||||
token1.cancel
|
||||
assert! (← token1.isCancelled)
|
||||
|
||||
-- Create new token for second use
|
||||
let token2 ← Std.CancellationToken.new
|
||||
assert! not (← token2.isCancelled)
|
||||
|
||||
token2.cancel
|
||||
assert! (← token2.isCancelled)
|
||||
|
||||
#eval testTokenReuse.block
|
||||
|
||||
-- Test performance with many tokens
|
||||
def testManyTokens : Async Unit := do
|
||||
let tokens : Array Std.CancellationToken ← (Array.range 100).mapM (fun _ => Std.CancellationToken.new)
|
||||
|
||||
-- All should start unresolved
|
||||
for token in tokens do
|
||||
assert! not (← token.isCancelled)
|
||||
|
||||
-- Cancel all tokens
|
||||
for token in tokens do
|
||||
token.cancel
|
||||
|
||||
-- Verify all are cancelled
|
||||
for token in tokens do
|
||||
assert! (← token.isCancelled)
|
||||
|
||||
#eval testManyTokens.block
|
||||
|
||||
-- Test cooperative cancellation pattern
|
||||
def cooperativeWork (token : Std.CancellationToken) (workDone : Std.Mutex Nat) : Async Unit := do
|
||||
for _ in List.range 50 do
|
||||
-- Check for cancellation before each work unit
|
||||
if ← token.isCancelled then
|
||||
return
|
||||
|
||||
-- Do some work
|
||||
workDone.atomically (modify (· + 1))
|
||||
Async.sleep 10
|
||||
|
||||
def testCooperativeCancellation : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
let workDone ← Std.Mutex.new 0
|
||||
|
||||
-- Start cooperative work
|
||||
let workTask ← async (cooperativeWork token workDone)
|
||||
|
||||
-- Let some work happen
|
||||
Async.sleep 150
|
||||
|
||||
-- Cancel the work
|
||||
token.cancel
|
||||
|
||||
await workTask
|
||||
|
||||
-- Verify some but not all work was done
|
||||
let finalCount ← workDone.atomically get
|
||||
assert! finalCount > 0
|
||||
assert! finalCount < 50
|
||||
|
||||
#eval testCooperativeCancellation.block
|
||||
|
||||
-- Test selector with other operations
|
||||
def testSelectorMixed : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
let result ← Std.Mutex.new ""
|
||||
|
||||
let task ← async do
|
||||
let selected ← Selectable.one #[
|
||||
.case token.selector (fun _ => pure "cancelled")
|
||||
]
|
||||
result.atomically (set selected)
|
||||
|
||||
-- Race between promise resolution and cancellation
|
||||
Async.sleep 50
|
||||
token.cancel
|
||||
|
||||
await task
|
||||
|
||||
let finalResult ← result.atomically get
|
||||
assert! finalResult == "cancelled"
|
||||
|
||||
#eval testSelectorMixed.block
|
||||
|
||||
-- Test immediate cancellation
|
||||
def testImmediateCancellation : Async Unit := do
|
||||
let token ← Std.CancellationToken.new
|
||||
|
||||
-- Cancel immediately
|
||||
token.cancel
|
||||
|
||||
-- Should be resolved right away
|
||||
assert! (← token.isCancelled)
|
||||
|
||||
-- Selector should work with already cancelled token
|
||||
let task ← async do
|
||||
Selectable.one #[.case token.selector pure]
|
||||
return "done"
|
||||
|
||||
let result ← await task
|
||||
assert! result == "done"
|
||||
|
||||
#eval testImmediateCancellation.block
|
||||
|
|
@ -231,10 +231,10 @@ def testRecvOnEmpty : Async Unit := do
|
|||
assert! (← IO.getTaskState recv) == IO.TaskState.waiting
|
||||
|
||||
let result ← await (← channel.send 3)
|
||||
let result ← await recv
|
||||
|
||||
assert! (← IO.getTaskState recv) == IO.TaskState.finished
|
||||
|
||||
assert! recv.get == some 3
|
||||
assert! result == some 3
|
||||
|
||||
#eval testRecvOnEmpty.block
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue