From 6964a15b5d7a4dd36fee299e650a9c249854645f Mon Sep 17 00:00:00 2001 From: Sofia Rodrigues Date: Tue, 7 Oct 2025 00:21:45 -0300 Subject: [PATCH] feat: add `Std.CancellationToken` type (#10510) This PR adds a `Std.CancellationToken` type --- src/Std/Sync.lean | 1 + src/Std/Sync/CancellationToken.lean | 154 +++++++++++++++ tests/lean/run/async_cancellation.lean | 256 +++++++++++++++++++++++++ tests/lean/run/broadcast.lean | 4 +- 4 files changed, 413 insertions(+), 2 deletions(-) create mode 100644 src/Std/Sync/CancellationToken.lean create mode 100644 tests/lean/run/async_cancellation.lean diff --git a/src/Std/Sync.lean b/src/Std/Sync.lean index 3514226d34..65313a6ed7 100644 --- a/src/Std/Sync.lean +++ b/src/Std/Sync.lean @@ -15,5 +15,6 @@ public import Std.Sync.SharedMutex public import Std.Sync.Notify public import Std.Sync.Broadcast public import Std.Sync.StreamMap +public import Std.Sync.CancellationToken @[expose] public section diff --git a/src/Std/Sync/CancellationToken.lean b/src/Std/Sync/CancellationToken.lean new file mode 100644 index 0000000000..8ad95e6bb6 --- /dev/null +++ b/src/Std/Sync/CancellationToken.lean @@ -0,0 +1,154 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Std.Data +public import Init.System.Promise +public import Init.Data.Queue +public import Std.Sync.Mutex +public import Std.Internal.Async.Select + +public section + +/-! +This module contains the implementation of `Std.CancellationToken`. `Std.CancellationToken` provides a +cancellation primitive for signaling cancellation between tasks or threads. It supports both synchronous +and asynchronous waiting, and is useful for cases where you want to notify one or more waiters +that a cancellation has occurred. +-/ + +namespace Std +open Std.Internal.IO.Async + +inductive CancellationToken.Consumer where + | normal (promise : IO.Promise Unit) + | select (finished : Waiter Unit) + +def CancellationToken.Consumer.resolve (c : Consumer) : BaseIO Bool := do + match c with + | .normal promise => + promise.resolve () + return true + | .select waiter => + let lose := return false + let win promise := do + promise.resolve (.ok ()) + return true + waiter.race lose win + +/-- +The central state structure for a `CancellationToken`. +-/ +structure CancellationToken.State where + /-- + Whether this token has been cancelled. + -/ + cancelled : Bool + + /-- + Consumers that are blocked waiting for cancellation. + --/ + consumers : Std.Queue (CancellationToken.Consumer) + +/-- +A cancellation token is a synchronization primitive that allows multiple consumers to wait +until cancellation is requested. +-/ +structure CancellationToken where + state : Std.Mutex CancellationToken.State + +namespace CancellationToken + +/-- +Create a new cancellation token. +-/ +def new : BaseIO CancellationToken := do + return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } } + +/-- +Cancel the token, notifying all currently waiting consumers with `true`. +Once cancelled, the token remains cancelled. +-/ +def cancel (x : CancellationToken) : BaseIO Unit := do + x.state.atomically do + let mut st ← get + + if st.cancelled then + return + + let mut remainingConsumers := st.consumers + st := { cancelled := true, consumers := ∅ } + + while true do + if let some (consumer, rest) := remainingConsumers.dequeue? then + remainingConsumers := rest + discard <| consumer.resolve + else + break + + set st + +/-- +Check if the token is cancelled. +-/ +def isCancelled (x : CancellationToken) : BaseIO Bool := do + x.state.atomically do + let st ← get + return st.cancelled + +/-- +Wait for cancellation. Returns a task that completes when cancelled, +-/ +def wait (x : CancellationToken) : IO (AsyncTask Unit) := + x.state.atomically do + let st ← get + + if st.cancelled then + return Task.pure (.ok ()) + + let promise ← IO.Promise.new + + modify fun st => { st with consumers := st.consumers.enqueue (.normal promise) } + + IO.bindTask promise.result? fun + | some _ => pure <| Task.pure (.ok ()) + | none => throw (IO.userError "cancellation token dropped") + +/-- +Creates a selector that waits for cancellation +-/ +def selector (token : CancellationToken) : Selector Unit := { + tryFn := do + if ← token.isCancelled then + return some () + else + return none + + registerFn := fun waiter => do + token.state.atomically do + let st ← get + + if st.cancelled then + discard <| waiter.race (return false) (fun promise => do + promise.resolve (.ok ()) + return true) + else + modify fun st => { st with consumers := st.consumers.enqueue (.select waiter) } + + unregisterFn := do + token.state.atomically do + let st ← get + + let consumers ← st.consumers.filterM fun + | .normal _ => return true + | .select waiter => return !(← waiter.checkFinished) + + set { st with consumers } +} + +end CancellationToken +end Std diff --git a/tests/lean/run/async_cancellation.lean b/tests/lean/run/async_cancellation.lean new file mode 100644 index 0000000000..90a162b60c --- /dev/null +++ b/tests/lean/run/async_cancellation.lean @@ -0,0 +1,256 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +def cancellableSelector [Monad m] [MonadLift IO m] [MonadAsync AsyncTask m] (fn : Std.CancellationToken → m α) : m (Selector (Except IO.Error α)) := do + let signal ← Std.CancellationToken.new + let promise ← IO.Promise.new + let result : AsyncTask α ← async (fn signal) + + IO.chainTask result (promise.resolve ·) + + return { + tryFn := do + if ← promise.isResolved + then return promise.result!.get + else return none + + registerFn := fun waiter => do + discard <| IO.mapTask (t := promise.result?) fun + | none => pure () + | some res => do + if ¬ (← signal.isCancelled) then + waiter.race (pure ()) (·.resolve (.ok res)) + + unregisterFn := do + signal.cancel + } + +-- Test basic cancellation token creation and cancellation +def testBasicCancellation : Async Unit := do + let token ← Std.CancellationToken.new + assert! not (← token.isCancelled) + token.cancel + assert! (← token.isCancelled) + +#eval testBasicCancellation.block + +-- Test selector functionality +def testSelector : Async Unit := do + let token ← Std.CancellationToken.new + let completed ← Std.Mutex.new false + + let task ← async do + Selectable.one #[.case token.selector (fun _ => pure ())] + completed.atomically (set true) + + assert! not (← completed.atomically get) + + token.cancel + await task + + assert! (← completed.atomically get) + +#eval testSelector.block + +-- Test selector with already cancelled token +def testSelectorAlreadyCancelled : Async Unit := do + let token ← Std.CancellationToken.new + token.cancel + + let completed ← Std.Mutex.new false + + let task ← async do + Selectable.one #[.case token.selector pure] + completed.atomically (set true) + + await task + assert! (← completed.atomically get) + +#eval testSelectorAlreadyCancelled.block + +-- Test multiple selectors on same token +def testMultipleSelectors : Async Unit := do + let token ← Std.CancellationToken.new + let completed1 ← Std.Mutex.new false + let completed2 ← Std.Mutex.new false + let completed3 ← Std.Mutex.new false + + let task1 ← async do + Selectable.one #[.case token.selector pure] + completed1.atomically (set true) + + let task2 ← async do + Selectable.one #[.case token.selector pure] + completed2.atomically (set true) + + let task3 ← async do + Selectable.one #[.case token.selector pure] + completed3.atomically (set true) + + -- Verify none completed initially + assert! not (← completed1.atomically get) + assert! not (← completed2.atomically get) + assert! not (← completed3.atomically get) + + -- Cancel token + token.cancel + + -- Wait for all tasks to complete + await task1 + await task2 + await task3 + + -- Verify all completed + assert! (← completed1.atomically get) + assert! (← completed2.atomically get) + assert! (← completed3.atomically get) + +#eval testMultipleSelectors.block + +-- Test cancellation during async operations +def testCancellationDuringOperation : Async Unit := do + let token ← Std.CancellationToken.new + let operationStarted ← Std.Mutex.new false + let operationCompleted ← Std.Mutex.new false + let operationCancelled ← Std.Mutex.new false + + let task ← async do + operationStarted.atomically (set true) + try + for _ in List.range 100 do + if (← token.isCancelled) then + operationCancelled.atomically (set true) + return + Async.sleep 5 + operationCompleted.atomically (set true) + catch _ => + operationCancelled.atomically (set true) + + -- Wait for operation to start + while not (← operationStarted.atomically get) do + Async.sleep 1 + + -- Cancel after operation started + Async.sleep 20 + token.cancel + + await task + + -- Verify operation was cancelled, not completed + assert! (← operationStarted.atomically get) + assert! (← operationCancelled.atomically get) + assert! not (← operationCompleted.atomically get) + +#eval testCancellationDuringOperation.block + +-- Test token reuse (create new tokens) +def testTokenReuse : Async Unit := do + let token1 ← Std.CancellationToken.new + + -- First use + token1.cancel + assert! (← token1.isCancelled) + + -- Create new token for second use + let token2 ← Std.CancellationToken.new + assert! not (← token2.isCancelled) + + token2.cancel + assert! (← token2.isCancelled) + +#eval testTokenReuse.block + +-- Test performance with many tokens +def testManyTokens : Async Unit := do + let tokens : Array Std.CancellationToken ← (Array.range 100).mapM (fun _ => Std.CancellationToken.new) + + -- All should start unresolved + for token in tokens do + assert! not (← token.isCancelled) + + -- Cancel all tokens + for token in tokens do + token.cancel + + -- Verify all are cancelled + for token in tokens do + assert! (← token.isCancelled) + +#eval testManyTokens.block + +-- Test cooperative cancellation pattern +def cooperativeWork (token : Std.CancellationToken) (workDone : Std.Mutex Nat) : Async Unit := do + for _ in List.range 50 do + -- Check for cancellation before each work unit + if ← token.isCancelled then + return + + -- Do some work + workDone.atomically (modify (· + 1)) + Async.sleep 10 + +def testCooperativeCancellation : Async Unit := do + let token ← Std.CancellationToken.new + let workDone ← Std.Mutex.new 0 + + -- Start cooperative work + let workTask ← async (cooperativeWork token workDone) + + -- Let some work happen + Async.sleep 150 + + -- Cancel the work + token.cancel + + await workTask + + -- Verify some but not all work was done + let finalCount ← workDone.atomically get + assert! finalCount > 0 + assert! finalCount < 50 + +#eval testCooperativeCancellation.block + +-- Test selector with other operations +def testSelectorMixed : Async Unit := do + let token ← Std.CancellationToken.new + let result ← Std.Mutex.new "" + + let task ← async do + let selected ← Selectable.one #[ + .case token.selector (fun _ => pure "cancelled") + ] + result.atomically (set selected) + + -- Race between promise resolution and cancellation + Async.sleep 50 + token.cancel + + await task + + let finalResult ← result.atomically get + assert! finalResult == "cancelled" + +#eval testSelectorMixed.block + +-- Test immediate cancellation +def testImmediateCancellation : Async Unit := do + let token ← Std.CancellationToken.new + + -- Cancel immediately + token.cancel + + -- Should be resolved right away + assert! (← token.isCancelled) + + -- Selector should work with already cancelled token + let task ← async do + Selectable.one #[.case token.selector pure] + return "done" + + let result ← await task + assert! result == "done" + +#eval testImmediateCancellation.block diff --git a/tests/lean/run/broadcast.lean b/tests/lean/run/broadcast.lean index 161c083035..d39d94b863 100644 --- a/tests/lean/run/broadcast.lean +++ b/tests/lean/run/broadcast.lean @@ -231,10 +231,10 @@ def testRecvOnEmpty : Async Unit := do assert! (← IO.getTaskState recv) == IO.TaskState.waiting let result ← await (← channel.send 3) + let result ← await recv assert! (← IO.getTaskState recv) == IO.TaskState.finished - - assert! recv.get == some 3 + assert! result == some 3 #eval testRecvOnEmpty.block