feat: introduce CancellationContext type for cancellation with context propagation (#11499)

This PR adds the `Context` type for cancellation with context
propagation. It works by storing a tree of forks of the main context,
providing a way to control cancellation.
This commit is contained in:
Sofia Rodrigues 2025-12-15 18:20:11 -03:00 committed by GitHub
parent 7b8e51e025
commit 95a7c769d8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 1593 additions and 14 deletions

View file

@ -7,6 +7,7 @@ module
prelude
public import Std.Internal.Async.Basic
public import Std.Internal.Async.ContextAsync
public import Std.Internal.Async.Timer
public import Std.Internal.Async.TCP
public import Std.Internal.Async.UDP

View file

@ -0,0 +1,273 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sofia Rodrigues
-/
module
prelude
public import Std.Time
public import Std.Internal.UV
public import Std.Internal.Async.Basic
public import Std.Internal.Async.Timer
public import Std.Sync.CancellationContext
public section
/-!
This module contains the implementation of `ContextAsync`, a monad for asynchronous computations with
cooperative cancellation support that must be explicitly checked for and cancelled explicitly.
-/
namespace Std
namespace Internal
namespace IO
namespace Async
/--
An asynchronous computation with cooperative cancellation support via a `CancellationContext`. `ContextAsync α`
is equivalent to `ReaderT CancellationContext Async α`, providing a `CancellationContext` value to async
computations.
-/
abbrev ContextAsync (α : Type) := ReaderT CancellationContext Async α
namespace ContextAsync
/--
Runs a `ContextAsync` computation with a given context. See also `ContextAsync.run` for running with a new
context that automatically cancels after execution.
-/
@[inline]
protected def runIn (ctx : CancellationContext) (x : ContextAsync α) : Async α :=
x ctx
/--
Runs a `ContextAsync` computation with a new context that cancels after the execution of the computation.
See also `ContextAsync.runIn` for running with an existing context.
-/
@[inline]
protected def run (x : ContextAsync α) : Async α := do
let ctx ← CancellationContext.new
x ctx <* ctx.cancel .cancel
/--
Returns the current context for inspection or to pass to other functions.
-/
@[inline]
def getContext : ContextAsync CancellationContext :=
fun ctx => pure ctx
/--
Checks if the current context is cancelled. Returns `true` if the context (or any ancestor) has been cancelled.
Long-running operations should periodically check this and exit gracefully when cancelled.
-/
@[inline]
def isCancelled : ContextAsync Bool := do
let ctx ← getContext
ctx.isCancelled
/--
Gets the cancellation reason if the context is cancelled. Returns `some reason` if cancelled, `none` otherwise,
allowing you to distinguish between different cancellation types.
-/
@[inline]
def getCancellationReason : ContextAsync (Option CancellationReason) := do
let ctx ← getContext
ctx.getCancellationReason
/--
Cancels the current context with the given reason, cascading to all child contexts.
Cancellation is cooperative, operations must explicitly check `isCancelled` or use `awaitCancellation` to respond.
-/
@[inline]
def cancel (reason : CancellationReason) : ContextAsync Unit := do
let ctx ← getContext
ctx.cancel reason
/--
Returns a selector that completes when the current context is cancelled.
-/
@[inline]
def doneSelector : ContextAsync (Selector Unit) := do
let ctx ← getContext
return ctx.doneSelector
/--
Waits for the current context to be cancelled.
-/
@[inline]
def awaitCancellation : ContextAsync Unit := do
let ctx ← getContext
let task ← ctx.done
await task
/--
Runs two computations concurrently and returns both results. Each computation runs in its own child context;
if either fails or is cancelled, both are cancelled immediately and the exception is propagated.
-/
@[inline, specialize]
def concurrently (x : ContextAsync α) (y : ContextAsync β)
(prio := Task.Priority.default) : ContextAsync (α × β) := do
let ctx ← getContext
let concurrentCtx ← ctx.fork
let childCtx1 ← concurrentCtx.fork
let childCtx2 ← concurrentCtx.fork
let result ← Async.concurrently
(try x childCtx1 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx1.cancel .cancel)
(try y childCtx2 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx2.cancel .cancel)
prio
concurrentCtx.cancel .cancel
return result
/--
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
in its own child context; when either completes, the other is cancelled immediately.
-/
@[inline, specialize]
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
(prio := Task.Priority.default) : ContextAsync α := do
let parent ← getContext
let ctx1 ← CancellationContext.fork parent
let ctx2 ← CancellationContext.fork parent
let task1 ← async (x ctx1) prio
let task2 ← async (y ctx2) prio
let result ← Async.race
(await task1 <* ctx2.cancel .cancel)
(await task2 <* ctx1.cancel .cancel)
prio
pure result
/--
Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
if any computation fails, all others are cancelled and the exception is propagated.
-/
@[inline, specialize]
def concurrentlyAll (xs : Array (ContextAsync α))
(prio := Task.Priority.default) : ContextAsync (Array α) := do
let ctx ← getContext
let concurrentCtx ← ctx.fork
let tasks : Array (AsyncTask α) ← xs.mapM fun ctxAsync => do
let childCtx ← concurrentCtx.fork
async (prio := prio)
(try
ctxAsync childCtx
catch err => do
concurrentCtx.cancel .cancel
throw err
finally
childCtx.cancel .cancel)
let result ← tasks.mapM await
return result
/--
Launches a `ContextAsync` computation in the background, discarding its result.
The computation runs independently in the background in its own child context. The parent computation does not wait
for background tasks to complete. This means that if the parent finishes its execution it will cause
the cancellation of the background functions. See also `disown` for launching tasks that continue independently
even after parent cancellation.
-/
@[inline, specialize]
def background (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do
let ctx ← getContext
let childCtx ← ctx.fork
Async.background (action childCtx *> childCtx.cancel .cancel) prio
/--
Launches a `ContextAsync` computation in the background, discarding its result. It's similar to `background`,
but the child context is not automatically cancelled when the action completes. This allows the disowned
computation to continue running independently, even if the parent context is cancelled. The child context
will remain alive as long as the computation needs it. See also `background` for launching tasks that are
cancelled when the parent finishes.
-/
@[inline, specialize]
def disown (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do
let childCtx ← CancellationContext.new
Async.background (action childCtx) prio
/--
Runs all computations concurrently and returns the first result. Each computation runs in its own child context;
when the first completes successfully, all others are cancelled immediately.
-/
def raceAll [ForM ContextAsync c (ContextAsync α)] (xs : c)
(prio := Task.Priority.default) : ContextAsync α := do
let parent ← getContext
let promise ← IO.Promise.new
ForM.forM xs fun x => do
let ctx ← CancellationContext.fork parent
let task ← async (x ctx) prio
background do
try
let result ← await task
promise.resolve (.ok result)
catch e =>
discard $ promise.resolve (.error e)
let result ← await promise
parent.cancel .cancel
Async.ofExcept result
/--
Launches a `ContextAsync` computation as an asynchronous task with a forked child context.
The child context is automatically cancelled when the task completes or fails.
-/
@[inline, specialize]
def async (x : ContextAsync α) (prio := Task.Priority.default) : ContextAsync (AsyncTask α) :=
fun ctx => do
let childCtx ← ctx.fork
Async.async (try x childCtx finally childCtx.cancel .cancel) prio
instance : MonadAsync AsyncTask ContextAsync where
async x prio := ContextAsync.async x prio
instance : Functor ContextAsync where
map f x := fun ctx => f <$> x ctx
instance : Monad ContextAsync where
pure a := fun _ => pure a
bind x f := fun ctx => x ctx >>= fun a => f a ctx
instance : MonadLift IO ContextAsync where
monadLift x := fun _ => Async.ofIOTask (Task.pure <$> x)
instance : MonadLift BaseIO ContextAsync where
monadLift x := fun _ => liftM (m := Async) x
instance : MonadExcept IO.Error ContextAsync where
throw e := fun _ => throw e
tryCatch x h := fun ctx => tryCatch (x ctx) (fun e => h e ctx)
instance : MonadFinally ContextAsync where
tryFinally' x f := fun ctx =>
tryFinally' (x ctx) (fun opt => f opt ctx)
instance [Inhabited α] : Inhabited (ContextAsync α) where
default := fun _ => default
instance : MonadAwait AsyncTask ContextAsync where
await t := fun _ => await t
end ContextAsync
/--
Returns a selector that completes when the current context is cancelled.
This is useful for selecting on cancellation alongside other asynchronous operations.
-/
def Selector.cancelled : ContextAsync (Selector Unit) := do
ContextAsync.doneSelector
end Async
end IO
end Internal
end Std

View file

@ -16,5 +16,6 @@ public import Std.Sync.Notify
public import Std.Sync.Broadcast
public import Std.Sync.StreamMap
public import Std.Sync.CancellationToken
public import Std.Sync.CancellationContext
@[expose] public section

View file

@ -0,0 +1,152 @@
/-
Copyright (c) 2025 Lean FRO, LLC. All rights reserved.
Released under Apache 2.0 license as described in the file LICENSE.
Authors: Sofia Rodrigues
-/
module
prelude
public import Std.Data
public import Init.System.Promise
public import Init.Data.Queue
public import Std.Sync.Mutex
public import Std.Sync.CancellationToken
public import Std.Internal.Async.Select
public section
/-!
This module provides a tree-structured cancellation context called `CancellationToken` where cancelling a parent
automatically cancels all child contexts.
-/
namespace Std
open Std.Internal.IO.Async
structure CancellationContext.State where
/--
Map of token IDs to optional tokens and their children.
-/
tokens : TreeMap UInt64 (CancellationToken × Array UInt64) := .empty
/--
Next available ID
-/
id : UInt64 := 1
/--
A cancellation context that allows multiple consumers to wait until cancellation is requested. Forms
a tree structure where cancelling a parent cancels all children.
-/
structure CancellationContext where
state : Std.Mutex CancellationContext.State
token : CancellationToken
id : UInt64
namespace CancellationContext
/--
Creates a new root cancellation context.
-/
def new : BaseIO CancellationContext := do
let token ← Std.CancellationToken.new
return {
state := ← Std.Mutex.new { tokens := .empty |>.insert 0 (token, #[]) },
token,
id := 0
}
/--
Forks a child context from a parent. If the parent is already cancelled, returns the parent context.
Otherwise, creates a new child that will be cancelled when the parent is cancelled.
-/
def fork (root : CancellationContext) : BaseIO CancellationContext := do
root.state.atomically do
if ← root.token.isCancelled then
return root
let token ← Std.CancellationToken.new
let st ← get
let newId := st.id
set { st with
id := newId + 1,
tokens := st.tokens.insert newId (token, #[])
|>.modify root.id (.map (·) (.push · newId))
}
return { state := root.state, token, id := newId }
/--
Recursively cancels a context and all its children with the given reason.
-/
private partial def cancelChildren (state : CancellationContext.State) (id : UInt64) (reason : CancellationReason) : BaseIO CancellationContext.State := do
let mut state := state
let some (token, children) := state.tokens.get? id
| return state
for tokenId in children do
state ← cancelChildren state tokenId reason
token.cancel reason
pure { state with tokens := state.tokens.erase id }
/--
Cancels this context and all child contexts with the given reason.
-/
def cancel (x : CancellationContext) (reason : CancellationReason) : BaseIO Unit := do
if ← x.token.isCancelled then
return
x.state.atomically do
let st ← get
let st ← cancelChildren st x.id reason
set st
/--
Checks if the context is cancelled.
-/
@[inline]
def isCancelled (x : CancellationContext) : BaseIO Bool := do
x.token.isCancelled
/--
Returns the cancellation reason if the context is cancelled.
-/
@[inline]
def getCancellationReason (x : CancellationContext) : BaseIO (Option CancellationReason) := do
x.token.getCancellationReason
/--
Waits for cancellation. Returns a task that completes when the context is cancelled.
-/
@[inline]
def done (x : CancellationContext) : IO (AsyncTask Unit) :=
x.token.wait
/--
Creates a selector that waits for cancellation.
-/
@[inline]
def doneSelector (x : CancellationContext) : Selector Unit :=
x.token.selector
private partial def countAliveTokensRec (state : CancellationContext.State) (id : UInt64) : Nat :=
match state.tokens.get? id with
| none => 0
| some (_, children) => 1 + children.foldl (fun acc childId => acc + countAliveTokensRec state childId) 0
/--
Counts the number of alive (non-cancelled) tokens in the context tree, including
this context and all its descendants.
-/
def countAliveTokens (x : CancellationContext) : BaseIO Nat := do
x.state.atomically do
let st ← get
return countAliveTokensRec st x.id
end CancellationContext
end Std

View file

@ -23,6 +23,38 @@ that a cancellation has occurred.
namespace Std
open Std.Internal.IO.Async
/--
Reasons for cancellation.
-/
inductive CancellationReason where
/--
Cancelled due to a deadline or timeout
-/
| deadline
/--
Cancelled due to shutdown
-/
| shutdown
/--
Explicitly cancelled
-/
| cancel
/--
Custom cancellation reason
-/
| custom (msg : String)
deriving Repr, BEq
instance : ToString CancellationReason where
toString
| .deadline => "deadline"
| .shutdown => "shutdown"
| .cancel => "cancel"
| .custom msg => s!"custom(\"{msg}\")"
inductive CancellationToken.Consumer where
| normal (promise : IO.Promise Unit)
| select (finished : Waiter Unit)
@ -44,9 +76,9 @@ The central state structure for a `CancellationToken`.
-/
structure CancellationToken.State where
/--
Whether this token has been cancelled.
The cancellation reason if cancelled, none otherwise.
-/
cancelled : Bool
reason : Option CancellationReason
/--
Consumers that are blocked waiting for cancellation.
@ -63,24 +95,24 @@ structure CancellationToken where
namespace CancellationToken
/--
Create a new cancellation token.
Creates a new cancellation token.
-/
def new : BaseIO CancellationToken := do
return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } }
return { state := ← Std.Mutex.new { reason := none, consumers := ∅ } }
/--
Cancel the token, notifying all currently waiting consumers with `true`.
Cancels the token with the given reason, notifying all currently waiting consumers.
Once cancelled, the token remains cancelled.
-/
def cancel (x : CancellationToken) : BaseIO Unit := do
def cancel (x : CancellationToken) (reason : CancellationReason := .cancel) : BaseIO Unit := do
x.state.atomically do
let mut st ← get
if st.cancelled then
if st.reason.isSome then
return
let mut remainingConsumers := st.consumers
st := { cancelled := true, consumers := ∅ }
st := { reason := some reason, consumers := ∅ }
while true do
if let some (consumer, rest) := remainingConsumers.dequeue? then
@ -92,21 +124,29 @@ def cancel (x : CancellationToken) : BaseIO Unit := do
set st
/--
Check if the token is cancelled.
Checks if the token is cancelled.
-/
def isCancelled (x : CancellationToken) : BaseIO Bool := do
x.state.atomically do
let st ← get
return st.cancelled
return st.reason.isSome
/--
Wait for cancellation. Returns a task that completes when cancelled,
Gets the cancellation reason if the token is cancelled.
-/
def getCancellationReason (x : CancellationToken) : BaseIO (Option CancellationReason) := do
x.state.atomically do
let st ← get
return st.reason
/--
Waits for cancellation. Returns a task that completes when cancelled.
-/
def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
x.state.atomically do
let st ← get
if st.cancelled then
if st.reason.isSome then
return Task.pure (.ok ())
let promise ← IO.Promise.new
@ -118,7 +158,7 @@ def wait (x : CancellationToken) : IO (AsyncTask Unit) :=
| none => throw (IO.userError "cancellation token dropped")
/--
Creates a selector that waits for cancellation
Creates a selector that waits for cancellation.
-/
def selector (token : CancellationToken) : Selector Unit := {
tryFn := do
@ -131,7 +171,7 @@ def selector (token : CancellationToken) : Selector Unit := {
token.state.atomically do
let st ← get
if st.cancelled then
if st.reason.isSome then
discard <| waiter.race (return false) (fun promise => do
promise.resolve (.ok ())
return true)

View file

@ -0,0 +1,163 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
-- Test basic cancellation with default reason
def testBasicCancellationWithReason : Async Unit := do
let token ← Std.CancellationToken.new
assert! not (← token.isCancelled)
token.cancel
assert! (← token.isCancelled)
let reason ← token.getCancellationReason
assert! reason == some .cancel
#eval testBasicCancellationWithReason.block
-- Test cancellation with deadline reason
def testDeadlineReason : Async Unit := do
let token ← Std.CancellationToken.new
assert! not (← token.isCancelled)
token.cancel .deadline
assert! (← token.isCancelled)
let reason ← token.getCancellationReason
assert! reason == some .deadline
#eval testDeadlineReason.block
-- Test cancellation with shutdown reason
def testShutdownReason : Async Unit := do
let token ← Std.CancellationToken.new
token.cancel .shutdown
let reason ← token.getCancellationReason
assert! reason == some .shutdown
#eval testShutdownReason.block
-- Test cancellation with custom reason
def testCustomReason : Async Unit := do
let token ← Std.CancellationToken.new
token.cancel (.custom "connection timeout")
let reason ← token.getCancellationReason
assert! reason == some (.custom "connection timeout")
#eval testCustomReason.block
-- Test that uncancelled token has no reason
def testUncancelledNoReason : Async Unit := do
let token ← Std.CancellationToken.new
let reason ← token.getCancellationReason
assert! reason == none
#eval testUncancelledNoReason.block
-- Test context cancellation with reason
def testContextCancellation : Async Unit := do
let ctx ← Std.CancellationContext.new
assert! not (← ctx.isCancelled)
ctx.cancel .shutdown
assert! (← ctx.isCancelled)
let reason ← ctx.token.getCancellationReason
assert! reason == some .shutdown
#eval testContextCancellation.block
-- Test context tree with different reasons
def testContextTreeReasons : Async Unit := do
let root ← Std.CancellationContext.new
let child1 ← root.fork
let child2 ← root.fork
let grandchild ← child1.fork
-- Cancel root with shutdown reason
root.cancel .shutdown
-- All should be cancelled
assert! (← root.isCancelled)
assert! (← child1.isCancelled)
assert! (← child2.isCancelled)
assert! (← grandchild.isCancelled)
-- All should have the shutdown reason (propagated from root)
assert! (← root.token.getCancellationReason) == some .shutdown
assert! (← child1.token.getCancellationReason) == some .shutdown
assert! (← child2.token.getCancellationReason) == some .shutdown
assert! (← grandchild.token.getCancellationReason) == some .shutdown
#eval testContextTreeReasons.block
-- Test child cancellation doesn't affect parent
def testChildCancellationIndependent : Async Unit := do
let root ← Std.CancellationContext.new
let child ← root.fork
-- Cancel child with deadline
child.cancel .deadline
-- Child should be cancelled with deadline reason
assert! (← child.isCancelled)
assert! (← child.token.getCancellationReason) == some .deadline
-- Parent should still be active
assert! not (← root.isCancelled)
assert! (← root.token.getCancellationReason) == none
#eval testChildCancellationIndependent.block
-- Test selector with reason
def testSelectorWithReason : Async Unit := do
let token ← Std.CancellationToken.new
let completed ← Std.Mutex.new false
let reasonRef ← Std.Mutex.new none
let task ← async do
Selectable.one #[.case token.selector (fun _ => pure ())]
completed.atomically (set true)
reasonRef.atomically (set (← token.getCancellationReason))
assert! not (← completed.atomically get)
token.cancel .deadline
await task
assert! (← completed.atomically get)
assert! (← reasonRef.atomically get) == some Std.CancellationReason.deadline
#eval testSelectorWithReason.block
-- Test wait with reason
def testWaitWithReason : Async Unit := do
let token ← Std.CancellationToken.new
let task ← async do
let _ ← await (← token.wait)
token.getCancellationReason
Async.sleep 10
token.cancel (.custom "test reason")
let reason ← await task
assert! reason == some (.custom "test reason")
#eval testWaitWithReason.block
-- Test multiple cancellations (first one wins)
def testMultipleCancellations : Async Unit := do
let token ← Std.CancellationToken.new
token.cancel .deadline
token.cancel .shutdown -- This should be ignored
let reason ← token.getCancellationReason
assert! reason == some .deadline -- First reason should persist
#eval testMultipleCancellations.block

View file

@ -0,0 +1,251 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
/-- Test basic tree cancellation -/
partial def testCancelTree : IO Unit := do
let mutex ← Std.Mutex.new 0
let context ← Std.CancellationContext.new
Async.block do
let rec loop (x : Nat) (parent : Std.CancellationContext) : Async Unit := do
match x with
| 0 => do
await (← parent.done)
mutex.atomically (modify (· + 1))
| n + 1 => do
background (loop n (← parent.fork))
background (loop n (← parent.fork))
await (← parent.done)
mutex.atomically (modify (· + 1))
background (loop 3 context)
Async.sleep 500
context.cancel .cancel
Async.sleep 1000
assert! (← context.countAliveTokens) == 0
let size ← mutex.atomically get
IO.println s!"cancelled {size}"
/--
info: cancelled 15
-/
#guard_msgs in
#eval testCancelTree
/-- Test cancellation with different reasons -/
def testCancellationReasons : IO Unit := do
let ctx ← Std.CancellationContext.new
let (reason1, reason2, reason3, reason4) ← Async.block do
-- Test with .cancel reason
let ctx1 ← ctx.fork
ctx1.cancel .cancel
let some reason1 ← ctx1.getCancellationReason | return (none, none, none, none)
-- Test with .deadline reason
let ctx2 ← ctx.fork
ctx2.cancel .deadline
let some reason2 ← ctx2.getCancellationReason | return (none, none, none, none)
-- Test with .shutdown reason
let ctx3 ← ctx.fork
ctx3.cancel .shutdown
let some reason3 ← ctx3.getCancellationReason | return (none, none, none, none)
-- Test with custom reason
let ctx4 ← ctx.fork
ctx4.cancel (.custom "test error")
let some reason4 ← ctx4.getCancellationReason | return (none, none, none, none)
return (some reason1, some reason2, some reason3, some reason4)
if let some r1 := reason1 then IO.println s!"Reason 1: {r1}"
if let some r2 := reason2 then IO.println s!"Reason 2: {r2}"
if let some r3 := reason3 then IO.println s!"Reason 3: {r3}"
if let some r4 := reason4 then IO.println s!"Reason 4: {r4}"
assert! (← ctx.countAliveTokens) == 1
/--
info: Reason 1: cancel
Reason 2: deadline
Reason 3: shutdown
Reason 4: custom("test error")
-/
#guard_msgs in
#eval testCancellationReasons
/-- Test cancellation propagates reason to children -/
def testReasonPropagation : IO Unit := do
let (parentReason, child1Reason, child2Reason, grandchildReason) ← Async.block do
let parent ← Std.CancellationContext.new
let child1 ← parent.fork
let child2 ← parent.fork
let grandchild ← child1.fork
parent.cancel (.custom "parent cancelled")
Async.sleep 100
let some parentReason ← parent.getCancellationReason | return (none, none, none, none)
let some child1Reason ← child1.getCancellationReason | return (none, none, none, none)
let some child2Reason ← child2.getCancellationReason | return (none, none, none, none)
let some grandchildReason ← grandchild.getCancellationReason | return (none, none, none, none)
return (some parentReason, some child1Reason, some child2Reason, some grandchildReason)
if let some r := parentReason then IO.println s!"Parent: {r}"
if let some r := child1Reason then IO.println s!"Child1: {r}"
if let some r := child2Reason then IO.println s!"Child2: {r}"
if let some r := grandchildReason then IO.println s!"Grandchild: {r}"
/--
info: Parent: custom("parent cancelled")
Child1: custom("parent cancelled")
Child2: custom("parent cancelled")
Grandchild: custom("parent cancelled")
-/
#guard_msgs in
#eval testReasonPropagation
/-- Test cancellation in the middle of work -/
def testCancelInMiddle : IO Unit := do
let counter ← Std.Mutex.new 0
let cancelledCounter ← Std.Mutex.new 0
let (finalCount, cancelledCount) ← Async.block do
let context ← Std.CancellationContext.new
-- Worker that does work until cancelled
let worker (ctx : Std.CancellationContext) : Async Unit := do
for _ in [0:100] do
if ← ctx.isCancelled then
cancelledCounter.atomically (modify (· + 1))
break
counter.atomically (modify (· + 1))
Async.sleep 10
-- Start 5 workers
for _ in [0:5] do
background (worker context)
-- Let them run for a bit, then cancel
Async.sleep 200
context.cancel .deadline
-- Wait for them to finish
Async.sleep 500
let finalCount ← counter.atomically get
let cancelledCount ← cancelledCounter.atomically get
return (finalCount, cancelledCount)
IO.println s!"Completed {finalCount} iterations before cancellation"
IO.println s!"{cancelledCount} workers detected cancellation"
/-- Test cancellation before forking -/
def testCancelBeforeFork : IO Unit := do
let (isSame, isChildCancelled) ← Async.block do
let ctx ← Std.CancellationContext.new
ctx.cancel .cancel
-- Fork after cancellation should return same context
let child ← ctx.fork
let isSame := ctx.id == child.id
let isChildCancelled ← child.isCancelled
return (isSame, isChildCancelled)
IO.println s!"Same context: {isSame}, Child cancelled: {isChildCancelled}"
/--
info: Same context: true, Child cancelled: true
-/
#guard_msgs in
#eval testCancelBeforeFork
/-- Test deep tree cancellation with reason -/
partial def testDeepTreeCancellation : IO Unit := do
let depths ← Std.Mutex.new ([] : List (Nat × Std.CancellationReason))
let (count, allSameReason) ← Async.block do
let root ← Std.CancellationContext.new
let rec makeTree (depth : Nat) (ctx : Std.CancellationContext) : Async Unit := do
if depth == 0 then
await (← ctx.done)
if let some reason ← ctx.getCancellationReason then
depths.atomically (modify (·.cons (depth, reason)))
else
let child1 ← ctx.fork
let child2 ← ctx.fork
background (makeTree (depth - 1) child1)
background (makeTree (depth - 1) child2)
await (← ctx.done)
if let some reason ← ctx.getCancellationReason then
depths.atomically (modify (·.cons (depth, reason)))
background (makeTree 4 root)
Async.sleep 200
root.cancel (.custom "deep tree cancel")
Async.sleep 500
let results ← depths.atomically get
let count := results.length
let allSameReason := results.all fun (_, r) => r == .custom "deep tree cancel"
return (count, allSameReason)
IO.println s!"Cancelled {count} nodes, all with same reason: {allSameReason}"
/--
info: Cancelled 31 nodes, all with same reason: true
-/
#guard_msgs in
#eval testDeepTreeCancellation
/-- Test counting alive tokens -/
def testCountAliveTokens : IO Unit := do
let (count0, count1, count2, count3, count4) ← Async.block do
let root ← Std.CancellationContext.new
let count0 ← root.countAliveTokens -- Root only
-- Fork 3 children
let child1 ← root.fork
let child2 ← root.fork
let _child3 ← root.fork
let count1 ← root.countAliveTokens -- Root + 3 children = 4
-- Cancel one child (and its subtree)
child1.cancel .cancel
Async.sleep 100
let count2 ← root.countAliveTokens -- Root + 2 children = 3
-- Fork a grandchild from child2
let _grandchild ← child2.fork
let count3 ← root.countAliveTokens -- Root + 2 children + 1 grandchild = 4
-- Cancel root (should cancel everything)
root.cancel .cancel
Async.sleep 100
let count4 ← root.countAliveTokens -- All cancelled = 0
return (count0, count1, count2, count3, count4)
IO.println s!"Initial (root only): {count0}"
IO.println s!"After forking 3 children: {count1}"
IO.println s!"After cancelling 1 child: {count2}"
IO.println s!"After forking grandchild: {count3}"
IO.println s!"After cancelling root: {count4}"
/--
info: Initial (root only): 1
After forking 3 children: 4
After cancelling 1 child: 3
After forking grandchild: 4
After cancelling root: 0
-/
#guard_msgs in
#eval testCountAliveTokens

View file

@ -0,0 +1,698 @@
import Std.Internal.Async
import Std.Sync
open Std.Internal.IO Async
/-- Test ContextAsync cancellation check -/
def testIsCancelled : IO Unit := do
let (before, after) ← Async.block do
ContextAsync.run do
let before ← ContextAsync.isCancelled
ContextAsync.cancel .cancel
Async.sleep 50
let after ← ContextAsync.isCancelled
return (before, after)
IO.println s!"Before: {before}, After: {after}"
/--
info: Before: false, After: true
-/
#guard_msgs in
#eval testIsCancelled
/-- Test ContextAsync cancellation reason -/
def testGetCancellationReason : IO Unit := do
let res ← Async.block do
ContextAsync.run do
ContextAsync.cancel (.custom "test reason")
Async.sleep 50
let some reason ← ContextAsync.getCancellationReason
| return "ERROR: No reason found"
return s!"Reason: {reason}"
IO.println res
/--
info: Reason: custom("test reason")
-/
#guard_msgs in
#eval testGetCancellationReason
/-- Test awaitCancellation -/
def testAwaitCancellation : IO Unit := do
let received ← Std.Mutex.new false
Async.block do
let started ← Std.Mutex.new false
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
ContextAsync.awaitCancellation
received.atomically (set true))
(do
-- Wait for task to start
while !(← started.atomically get) do
Async.sleep 10
Async.sleep 100
ContextAsync.cancel .shutdown)
Async.sleep 200
let _ ← received.atomically get
IO.println "Cancellation received"
def testSelectorCancellationFail : IO Unit := do
let received ← Std.Mutex.new false
let result ← Async.block do
let ctx ← Std.CancellationContext.new
let started ← Std.Mutex.new false
let result ← do
try
ContextAsync.runIn ctx do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
received.atomically (set res))
(do
throw (.userError "failed")
return ())
return Except.ok ()
catch err =>
return Except.error err
Async.sleep 500
return result
let _ ← received.atomically get
IO.println "Cancellation received"
if let Except.error err := result then
throw err
/--
info: Cancellation received
---
error: failed
-/
#guard_msgs in
#eval testSelectorCancellationFail
/-- Test concurrently with both tasks succeeding -/
def testConcurrently : IO Unit := do
let (a, b) ← Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do
Async.sleep 100
return 42)
(do
Async.sleep 150
return "hello")
IO.println s!"Results: {a}, {b}"
/--
info: Results: 42, hello
-/
#guard_msgs in
#eval testConcurrently
/-- Test race with first task winning -/
def testRace : IO Unit := do
let result ← Async.block do
ContextAsync.run do
ContextAsync.race
(do
Async.sleep 50
return "fast")
(do
Async.sleep 200
return "slow")
IO.println s!"Winner: {result}"
/--
info: Winner: fast
-/
#guard_msgs in
#eval testRace
/-- Test concurrentlyAll -/
def testConcurrentlyAll : IO Unit := do
let results ← Async.block do
ContextAsync.run do
let tasks := #[
(do Async.sleep 50; return 1),
(do Async.sleep 100; return 2),
(do Async.sleep 75; return 3)
]
ContextAsync.concurrentlyAll tasks
IO.println s!"All results: {results}"
/--
info: All results: #[1, 2, 3]
-/
#guard_msgs in
#eval testConcurrentlyAll
/-- Test background task with cancellation -/
def testBackground : IO Unit := do
let counter ← Std.Mutex.new 0
Async.block do
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
for _ in [0:10] do
if ← ContextAsync.isCancelled then
break
counter.atomically (modify (· + 1))
Async.sleep 50)
(do
-- Let it run for a bit
Async.sleep 150
ContextAsync.cancel .cancel)
Async.sleep 200
let final ← counter.atomically get
IO.println s!"Counter reached: {final}"
/-- Test fork cancellation isolation -/
def testForkCancellation : IO Unit := do
let parent ← Std.CancellationContext.new
let childCancelled ← Std.Mutex.new false
let parentCancelled ← Std.Mutex.new false
Async.block do
ContextAsync.runIn parent do
discard <| ContextAsync.concurrentlyAll #[
(do
let child ← ContextAsync.getContext
Async.sleep 100
child.cancel .cancel
childCancelled.atomically (set true)),
(do
Async.sleep 200
if ← parent.isCancelled then
parentCancelled.atomically (set true))
]
let childWasCancelled ← childCancelled.atomically get
let parentWasCancelled ← parentCancelled.atomically get
IO.println s!"Child cancelled: {childWasCancelled}, Parent cancelled: {parentWasCancelled}"
/--
info: Child cancelled: true, Parent cancelled: false
-/
#guard_msgs in
#eval testForkCancellation
/-- Test doneSelector -/
partial def testNestedFork : IO Unit := do
let res ← Async.block do
ContextAsync.run do
let ctx ← ContextAsync.getContext
let sel ← ContextAsync.doneSelector
let (_, result) ← ContextAsync.concurrently
(do
Async.sleep 100
ctx.cancel .deadline)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Done selector triggered: {res}"
/--
info: Done selector triggered: true
-/
#guard_msgs in
#eval testNestedFork
/-- Test Selector.cancelled -/
def testSelectorCancelled : IO Unit := do
let res ← Async.block do
ContextAsync.run do
let ctx ← ContextAsync.getContext
let sel ← Selector.cancelled
let (_, result) ← ContextAsync.concurrently
(do
Async.sleep 150
ctx.cancel .shutdown)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Selector.cancelled triggered: {res}"
/--
info: Selector.cancelled triggered: true
-/
#guard_msgs in
#eval testSelectorCancelled
/-- Test MonadLift instances -/
def testMonadLift : IO Unit := do
let (msg1, msg2) ← Async.block do
ContextAsync.run do
-- Lift from IO
let msg1 : String := "From IO"
-- Lift from BaseIO
let msg2 : String := "From BaseIO"
-- Lift from Async
let _ ← (Async.sleep 50 : Async Unit)
return (msg1, msg2)
IO.println msg1
IO.println msg2
IO.println "All lifts work"
/--
info: From IO
From BaseIO
All lifts work
-/
#guard_msgs in
#eval testMonadLift
/-- Test exception handling in ContextAsync -/
def testExceptionHandling : IO Unit := do
let res ← Async.block do
ContextAsync.run do
try
throw (IO.userError "test error")
return "Should not reach here"
catch e =>
return s!"Caught: {e}"
IO.println res
/--
info: Caught: test error
-/
#guard_msgs in
#eval testExceptionHandling
/-- Test tryFinally in ContextAsync -/
def testTryFinally : IO Unit := do
let cleaned ← Std.Mutex.new false
Async.block do
ContextAsync.run do
try
ContextAsync.cancel .cancel
ContextAsync.awaitCancellation
finally
cleaned.atomically (set true)
let wasCleanedUp ← cleaned.atomically get
IO.println s!"Cleanup ran: {wasCleanedUp}"
/--
info: Cleanup ran: true
-/
#guard_msgs in
#eval testTryFinally
/-- Test race with cancellation -/
def testRaceWithCancellation : IO Unit := do
let ctx ← Std.CancellationContext.new
let leftCancelled ← Std.Mutex.new false
let rightCancelled ← Std.Mutex.new false
Async.block do
ContextAsync.runIn ctx do
let _ ← ContextAsync.race
(do
try
Async.sleep 500
return "left"
finally
if ← ContextAsync.isCancelled then
leftCancelled.atomically (set true))
(do
try
Async.sleep 50
return "right"
finally
if ← ContextAsync.isCancelled then
rightCancelled.atomically (set true))
Async.sleep 1000
let left ← leftCancelled.atomically get
let right ← rightCancelled.atomically get
IO.println s!"Left cancelled: {left}, Right cancelled: {right}"
/--
info: Left cancelled: true, Right cancelled: false
-/
#guard_msgs in
#eval testRaceWithCancellation
/-- Test complex concurrent workflow -/
def testComplexWorkflow : IO Unit := do
let results ← Std.Mutex.new ([] : List String)
Async.block do
ContextAsync.run do
-- Run multiple concurrent operations
let (a, b) ← ContextAsync.concurrently
(do
Async.sleep 50
results.atomically (modify ("A"::·))
return 1)
(do
Async.sleep 75
results.atomically (modify ("B"::·))
return 2)
-- Additional concurrent task
discard <| ContextAsync.concurrently
(do
Async.sleep 100
results.atomically (modify ("BG"::·)))
(do
Async.sleep 200
results.atomically (modify (s!"Sum:{a+b}"::·)))
let final ← results.atomically get
IO.println s!"Results: {final.reverse}"
/--
info: Results: [A, B, BG, Sum:3]
-/
#guard_msgs in
#eval testComplexWorkflow
def testConcurrentlyAllException : IO Unit := do
let ref ← IO.mkRef ""
try
Async.block do
ContextAsync.run do
let tasks := #[
(do
Async.sleep 1000
if ← ContextAsync.isCancelled then
ref.set "cancelled"
return
else
ref.set "not cancelled"
Async.sleep 500
if ← ContextAsync.isCancelled then
ref.modify (· ++ ", cancelled")
else
ref.modify (· ++ ", not cancelled")),
(do
Async.sleep 250
throw (IO.userError "Error: Hello"))
]
discard <| ContextAsync.concurrentlyAll tasks
finally
IO.println (← ref.get)
/--
info: cancelled
---
error: Error: Hello
-/
#guard_msgs in
#eval testConcurrentlyAllException
/-- Test that tasks in ContextAsync.run are not cancelled when run completes -/
def test0 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: false
-/
#guard_msgs in
#eval test0
/-- Test that background tasks are cancelled when ContextAsync.run completes -/
def test1 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test1
/-- Test that nested background tasks (ContextAsync.background in ContextAsync.background) are cancelled -/
def test2 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
ContextAsync.background do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2
/-- Test that ContextAsync.background in Async.background is cancelled -/
def test2' : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
Async.background do
ContextAsync.background do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2'
/-- Test that Async.background in ContextAsync.background is cancelled -/
def test2'' : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
Async.background do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2''
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
def testConcurrentlySuccessWithCancellation : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let task3Cancelled ← Std.Mutex.new false
let results ← Async.block do
ContextAsync.run do
ContextAsync.concurrentlyAll #[
(do
return "first"),
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
return "second"),
(do
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
return "third")
]
let t2 ← task2Cancelled.atomically get
let t3 ← task3Cancelled.atomically get
IO.println s!"Results: {results}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Results: #[first, second, third]
Task2 cancelled: false, Task3 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation
/-- Test concurrently with first task failing, others checking for cancellation -/
def testConcurrentlyFailWithCancellation : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let task3Cancelled ← Std.Mutex.new false
let results ← Async.block do
ContextAsync.run do
try
let result ← ContextAsync.concurrentlyAll #[
(do
-- First task fails immediately
throw (IO.userError "first task failed")),
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
return "second"),
(do
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
return "third")
]
return Except.ok result
catch e =>
Async.sleep 500
return Except.error e
let t2 ← task2Cancelled.atomically get
let t3 ← task3Cancelled.atomically get
match results with
| .ok results => IO.println s!"Results: {results}"
| .error e => IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Error: first task failed
Task2 cancelled: true, Task3 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation
/-- Test concurrently with both tasks succeeding, checking cancellation status -/
def testConcurrentlySuccessWithCancellation2Tasks : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let (r1, r2) ← Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do return "first")
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
return "second")
let t2 ← task2Cancelled.atomically get
IO.println s!"Results: {r1}, {r2}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Results: first, second
Task2 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation2Tasks
/-- Test concurrently with first task failing, second task checking for cancellation -/
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
try
Async.block do
ContextAsync.run do
let (_ : (String × String)) ← ContextAsync.concurrently
(do
-- First task fails immediately
throw (IO.userError "first task failed") : ContextAsync String)
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
return "second")
catch e =>
IO.sleep 500
let t2 ← task2Cancelled.atomically get
IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Error: first task failed
Task2 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation2Tasks