diff --git a/src/Std/Internal/Async.lean b/src/Std/Internal/Async.lean index 1d7954fd16..214f36aafd 100644 --- a/src/Std/Internal/Async.lean +++ b/src/Std/Internal/Async.lean @@ -7,6 +7,7 @@ module prelude public import Std.Internal.Async.Basic +public import Std.Internal.Async.ContextAsync public import Std.Internal.Async.Timer public import Std.Internal.Async.TCP public import Std.Internal.Async.UDP diff --git a/src/Std/Internal/Async/ContextAsync.lean b/src/Std/Internal/Async/ContextAsync.lean new file mode 100644 index 0000000000..bde6733470 --- /dev/null +++ b/src/Std/Internal/Async/ContextAsync.lean @@ -0,0 +1,273 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Std.Time +public import Std.Internal.UV +public import Std.Internal.Async.Basic +public import Std.Internal.Async.Timer +public import Std.Sync.CancellationContext + +public section + +/-! +This module contains the implementation of `ContextAsync`, a monad for asynchronous computations with +cooperative cancellation support that must be explicitly checked for and cancelled explicitly. +-/ + +namespace Std +namespace Internal +namespace IO +namespace Async + +/-- +An asynchronous computation with cooperative cancellation support via a `CancellationContext`. `ContextAsync α` +is equivalent to `ReaderT CancellationContext Async α`, providing a `CancellationContext` value to async +computations. +-/ +abbrev ContextAsync (α : Type) := ReaderT CancellationContext Async α + +namespace ContextAsync + +/-- +Runs a `ContextAsync` computation with a given context. See also `ContextAsync.run` for running with a new +context that automatically cancels after execution. +-/ +@[inline] +protected def runIn (ctx : CancellationContext) (x : ContextAsync α) : Async α := + x ctx + +/-- +Runs a `ContextAsync` computation with a new context that cancels after the execution of the computation. +See also `ContextAsync.runIn` for running with an existing context. +-/ +@[inline] +protected def run (x : ContextAsync α) : Async α := do + let ctx ← CancellationContext.new + x ctx <* ctx.cancel .cancel + +/-- +Returns the current context for inspection or to pass to other functions. +-/ +@[inline] +def getContext : ContextAsync CancellationContext := + fun ctx => pure ctx + +/-- +Checks if the current context is cancelled. Returns `true` if the context (or any ancestor) has been cancelled. +Long-running operations should periodically check this and exit gracefully when cancelled. +-/ +@[inline] +def isCancelled : ContextAsync Bool := do + let ctx ← getContext + ctx.isCancelled + +/-- +Gets the cancellation reason if the context is cancelled. Returns `some reason` if cancelled, `none` otherwise, +allowing you to distinguish between different cancellation types. +-/ +@[inline] +def getCancellationReason : ContextAsync (Option CancellationReason) := do + let ctx ← getContext + ctx.getCancellationReason + +/-- +Cancels the current context with the given reason, cascading to all child contexts. +Cancellation is cooperative, operations must explicitly check `isCancelled` or use `awaitCancellation` to respond. +-/ +@[inline] +def cancel (reason : CancellationReason) : ContextAsync Unit := do + let ctx ← getContext + ctx.cancel reason + +/-- +Returns a selector that completes when the current context is cancelled. +-/ +@[inline] +def doneSelector : ContextAsync (Selector Unit) := do + let ctx ← getContext + return ctx.doneSelector + +/-- +Waits for the current context to be cancelled. +-/ +@[inline] +def awaitCancellation : ContextAsync Unit := do + let ctx ← getContext + let task ← ctx.done + await task + +/-- +Runs two computations concurrently and returns both results. Each computation runs in its own child context; +if either fails or is cancelled, both are cancelled immediately and the exception is propagated. +-/ +@[inline, specialize] +def concurrently (x : ContextAsync α) (y : ContextAsync β) + (prio := Task.Priority.default) : ContextAsync (α × β) := do + + let ctx ← getContext + let concurrentCtx ← ctx.fork + + let childCtx1 ← concurrentCtx.fork + let childCtx2 ← concurrentCtx.fork + + let result ← Async.concurrently + (try x childCtx1 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx1.cancel .cancel) + (try y childCtx2 catch err => do concurrentCtx.cancel .cancel; throw err finally childCtx2.cancel .cancel) + prio + + concurrentCtx.cancel .cancel + return result + +/-- +Runs two computations concurrently and returns the result of the first to complete. Each computation runs +in its own child context; when either completes, the other is cancelled immediately. +-/ +@[inline, specialize] +def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α) + (prio := Task.Priority.default) : ContextAsync α := do + let parent ← getContext + let ctx1 ← CancellationContext.fork parent + let ctx2 ← CancellationContext.fork parent + + let task1 ← async (x ctx1) prio + let task2 ← async (y ctx2) prio + + let result ← Async.race + (await task1 <* ctx2.cancel .cancel) + (await task2 <* ctx1.cancel .cancel) + prio + + pure result + +/-- +Runs all computations concurrently and collects results in the same order. Each runs in its own child context; +if any computation fails, all others are cancelled and the exception is propagated. +-/ +@[inline, specialize] +def concurrentlyAll (xs : Array (ContextAsync α)) + (prio := Task.Priority.default) : ContextAsync (Array α) := do + let ctx ← getContext + let concurrentCtx ← ctx.fork + + let tasks : Array (AsyncTask α) ← xs.mapM fun ctxAsync => do + let childCtx ← concurrentCtx.fork + async (prio := prio) + (try + ctxAsync childCtx + catch err => do + concurrentCtx.cancel .cancel + throw err + finally + childCtx.cancel .cancel) + + let result ← tasks.mapM await + return result + +/-- +Launches a `ContextAsync` computation in the background, discarding its result. + +The computation runs independently in the background in its own child context. The parent computation does not wait +for background tasks to complete. This means that if the parent finishes its execution it will cause +the cancellation of the background functions. See also `disown` for launching tasks that continue independently +even after parent cancellation. +-/ +@[inline, specialize] +def background (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do + let ctx ← getContext + let childCtx ← ctx.fork + Async.background (action childCtx *> childCtx.cancel .cancel) prio + +/-- +Launches a `ContextAsync` computation in the background, discarding its result. It's similar to `background`, +but the child context is not automatically cancelled when the action completes. This allows the disowned +computation to continue running independently, even if the parent context is cancelled. The child context +will remain alive as long as the computation needs it. See also `background` for launching tasks that are +cancelled when the parent finishes. +-/ +@[inline, specialize] +def disown (action : ContextAsync α) (prio := Task.Priority.default) : ContextAsync Unit := do + let childCtx ← CancellationContext.new + Async.background (action childCtx) prio + +/-- +Runs all computations concurrently and returns the first result. Each computation runs in its own child context; +when the first completes successfully, all others are cancelled immediately. +-/ +def raceAll [ForM ContextAsync c (ContextAsync α)] (xs : c) + (prio := Task.Priority.default) : ContextAsync α := do + let parent ← getContext + let promise ← IO.Promise.new + + ForM.forM xs fun x => do + let ctx ← CancellationContext.fork parent + let task ← async (x ctx) prio + + background do + try + let result ← await task + promise.resolve (.ok result) + catch e => + discard $ promise.resolve (.error e) + + let result ← await promise + parent.cancel .cancel + Async.ofExcept result + +/-- +Launches a `ContextAsync` computation as an asynchronous task with a forked child context. +The child context is automatically cancelled when the task completes or fails. +-/ +@[inline, specialize] +def async (x : ContextAsync α) (prio := Task.Priority.default) : ContextAsync (AsyncTask α) := + fun ctx => do + let childCtx ← ctx.fork + Async.async (try x childCtx finally childCtx.cancel .cancel) prio + +instance : MonadAsync AsyncTask ContextAsync where + async x prio := ContextAsync.async x prio + +instance : Functor ContextAsync where + map f x := fun ctx => f <$> x ctx + +instance : Monad ContextAsync where + pure a := fun _ => pure a + bind x f := fun ctx => x ctx >>= fun a => f a ctx + +instance : MonadLift IO ContextAsync where + monadLift x := fun _ => Async.ofIOTask (Task.pure <$> x) + +instance : MonadLift BaseIO ContextAsync where + monadLift x := fun _ => liftM (m := Async) x + +instance : MonadExcept IO.Error ContextAsync where + throw e := fun _ => throw e + tryCatch x h := fun ctx => tryCatch (x ctx) (fun e => h e ctx) + +instance : MonadFinally ContextAsync where + tryFinally' x f := fun ctx => + tryFinally' (x ctx) (fun opt => f opt ctx) + +instance [Inhabited α] : Inhabited (ContextAsync α) where + default := fun _ => default + +instance : MonadAwait AsyncTask ContextAsync where + await t := fun _ => await t + +end ContextAsync + +/-- +Returns a selector that completes when the current context is cancelled. +This is useful for selecting on cancellation alongside other asynchronous operations. +-/ +def Selector.cancelled : ContextAsync (Selector Unit) := do + ContextAsync.doneSelector + +end Async +end IO +end Internal +end Std diff --git a/src/Std/Sync.lean b/src/Std/Sync.lean index 65313a6ed7..c79ce06a5c 100644 --- a/src/Std/Sync.lean +++ b/src/Std/Sync.lean @@ -16,5 +16,6 @@ public import Std.Sync.Notify public import Std.Sync.Broadcast public import Std.Sync.StreamMap public import Std.Sync.CancellationToken +public import Std.Sync.CancellationContext @[expose] public section diff --git a/src/Std/Sync/CancellationContext.lean b/src/Std/Sync/CancellationContext.lean new file mode 100644 index 0000000000..b62e931acb --- /dev/null +++ b/src/Std/Sync/CancellationContext.lean @@ -0,0 +1,152 @@ +/- +Copyright (c) 2025 Lean FRO, LLC. All rights reserved. +Released under Apache 2.0 license as described in the file LICENSE. +Authors: Sofia Rodrigues +-/ +module + +prelude +public import Std.Data +public import Init.System.Promise +public import Init.Data.Queue +public import Std.Sync.Mutex +public import Std.Sync.CancellationToken +public import Std.Internal.Async.Select + +public section + +/-! +This module provides a tree-structured cancellation context called `CancellationToken` where cancelling a parent +automatically cancels all child contexts. +-/ + +namespace Std +open Std.Internal.IO.Async + +structure CancellationContext.State where + /-- + Map of token IDs to optional tokens and their children. + -/ + tokens : TreeMap UInt64 (CancellationToken × Array UInt64) := .empty + + /-- + Next available ID + -/ + id : UInt64 := 1 + +/-- +A cancellation context that allows multiple consumers to wait until cancellation is requested. Forms +a tree structure where cancelling a parent cancels all children. +-/ +structure CancellationContext where + state : Std.Mutex CancellationContext.State + token : CancellationToken + id : UInt64 + +namespace CancellationContext + +/-- +Creates a new root cancellation context. +-/ +def new : BaseIO CancellationContext := do + let token ← Std.CancellationToken.new + return { + state := ← Std.Mutex.new { tokens := .empty |>.insert 0 (token, #[]) }, + token, + id := 0 + } + +/-- +Forks a child context from a parent. If the parent is already cancelled, returns the parent context. +Otherwise, creates a new child that will be cancelled when the parent is cancelled. +-/ +def fork (root : CancellationContext) : BaseIO CancellationContext := do + root.state.atomically do + if ← root.token.isCancelled then + return root + + let token ← Std.CancellationToken.new + + let st ← get + let newId := st.id + + set { st with + id := newId + 1, + tokens := st.tokens.insert newId (token, #[]) + |>.modify root.id (.map (·) (.push · newId)) + } + + return { state := root.state, token, id := newId } + +/-- +Recursively cancels a context and all its children with the given reason. +-/ +private partial def cancelChildren (state : CancellationContext.State) (id : UInt64) (reason : CancellationReason) : BaseIO CancellationContext.State := do + let mut state := state + + let some (token, children) := state.tokens.get? id + | return state + + for tokenId in children do + state ← cancelChildren state tokenId reason + + token.cancel reason + + pure { state with tokens := state.tokens.erase id } + +/-- +Cancels this context and all child contexts with the given reason. +-/ +def cancel (x : CancellationContext) (reason : CancellationReason) : BaseIO Unit := do + if ← x.token.isCancelled then + return + + x.state.atomically do + let st ← get + let st ← cancelChildren st x.id reason + set st + +/-- +Checks if the context is cancelled. +-/ +@[inline] +def isCancelled (x : CancellationContext) : BaseIO Bool := do + x.token.isCancelled + +/-- +Returns the cancellation reason if the context is cancelled. +-/ +@[inline] +def getCancellationReason (x : CancellationContext) : BaseIO (Option CancellationReason) := do + x.token.getCancellationReason + +/-- +Waits for cancellation. Returns a task that completes when the context is cancelled. +-/ +@[inline] +def done (x : CancellationContext) : IO (AsyncTask Unit) := + x.token.wait + +/-- +Creates a selector that waits for cancellation. +-/ +@[inline] +def doneSelector (x : CancellationContext) : Selector Unit := + x.token.selector + +private partial def countAliveTokensRec (state : CancellationContext.State) (id : UInt64) : Nat := + match state.tokens.get? id with + | none => 0 + | some (_, children) => 1 + children.foldl (fun acc childId => acc + countAliveTokensRec state childId) 0 + +/-- +Counts the number of alive (non-cancelled) tokens in the context tree, including +this context and all its descendants. +-/ +def countAliveTokens (x : CancellationContext) : BaseIO Nat := do + x.state.atomically do + let st ← get + return countAliveTokensRec st x.id + +end CancellationContext +end Std diff --git a/src/Std/Sync/CancellationToken.lean b/src/Std/Sync/CancellationToken.lean index a674e4b23d..4f5c0debbf 100644 --- a/src/Std/Sync/CancellationToken.lean +++ b/src/Std/Sync/CancellationToken.lean @@ -23,6 +23,38 @@ that a cancellation has occurred. namespace Std open Std.Internal.IO.Async +/-- +Reasons for cancellation. +-/ +inductive CancellationReason where + /-- + Cancelled due to a deadline or timeout + -/ + | deadline + + /-- + Cancelled due to shutdown + -/ + | shutdown + + /-- + Explicitly cancelled + -/ + | cancel + + /-- + Custom cancellation reason + -/ + | custom (msg : String) + deriving Repr, BEq + +instance : ToString CancellationReason where + toString + | .deadline => "deadline" + | .shutdown => "shutdown" + | .cancel => "cancel" + | .custom msg => s!"custom(\"{msg}\")" + inductive CancellationToken.Consumer where | normal (promise : IO.Promise Unit) | select (finished : Waiter Unit) @@ -44,9 +76,9 @@ The central state structure for a `CancellationToken`. -/ structure CancellationToken.State where /-- - Whether this token has been cancelled. + The cancellation reason if cancelled, none otherwise. -/ - cancelled : Bool + reason : Option CancellationReason /-- Consumers that are blocked waiting for cancellation. @@ -63,24 +95,24 @@ structure CancellationToken where namespace CancellationToken /-- -Create a new cancellation token. +Creates a new cancellation token. -/ def new : BaseIO CancellationToken := do - return { state := ← Std.Mutex.new { cancelled := false, consumers := ∅ } } + return { state := ← Std.Mutex.new { reason := none, consumers := ∅ } } /-- -Cancel the token, notifying all currently waiting consumers with `true`. +Cancels the token with the given reason, notifying all currently waiting consumers. Once cancelled, the token remains cancelled. -/ -def cancel (x : CancellationToken) : BaseIO Unit := do +def cancel (x : CancellationToken) (reason : CancellationReason := .cancel) : BaseIO Unit := do x.state.atomically do let mut st ← get - if st.cancelled then + if st.reason.isSome then return let mut remainingConsumers := st.consumers - st := { cancelled := true, consumers := ∅ } + st := { reason := some reason, consumers := ∅ } while true do if let some (consumer, rest) := remainingConsumers.dequeue? then @@ -92,21 +124,29 @@ def cancel (x : CancellationToken) : BaseIO Unit := do set st /-- -Check if the token is cancelled. +Checks if the token is cancelled. -/ def isCancelled (x : CancellationToken) : BaseIO Bool := do x.state.atomically do let st ← get - return st.cancelled + return st.reason.isSome /-- -Wait for cancellation. Returns a task that completes when cancelled, +Gets the cancellation reason if the token is cancelled. +-/ +def getCancellationReason (x : CancellationToken) : BaseIO (Option CancellationReason) := do + x.state.atomically do + let st ← get + return st.reason + +/-- +Waits for cancellation. Returns a task that completes when cancelled. -/ def wait (x : CancellationToken) : IO (AsyncTask Unit) := x.state.atomically do let st ← get - if st.cancelled then + if st.reason.isSome then return Task.pure (.ok ()) let promise ← IO.Promise.new @@ -118,7 +158,7 @@ def wait (x : CancellationToken) : IO (AsyncTask Unit) := | none => throw (IO.userError "cancellation token dropped") /-- -Creates a selector that waits for cancellation +Creates a selector that waits for cancellation. -/ def selector (token : CancellationToken) : Selector Unit := { tryFn := do @@ -131,7 +171,7 @@ def selector (token : CancellationToken) : Selector Unit := { token.state.atomically do let st ← get - if st.cancelled then + if st.reason.isSome then discard <| waiter.race (return false) (fun promise => do promise.resolve (.ok ()) return true) diff --git a/tests/lean/run/async_cancellation_reasons.lean b/tests/lean/run/async_cancellation_reasons.lean new file mode 100644 index 0000000000..99ccad59fb --- /dev/null +++ b/tests/lean/run/async_cancellation_reasons.lean @@ -0,0 +1,163 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +-- Test basic cancellation with default reason +def testBasicCancellationWithReason : Async Unit := do + let token ← Std.CancellationToken.new + assert! not (← token.isCancelled) + + token.cancel + assert! (← token.isCancelled) + + let reason ← token.getCancellationReason + assert! reason == some .cancel + +#eval testBasicCancellationWithReason.block + +-- Test cancellation with deadline reason +def testDeadlineReason : Async Unit := do + let token ← Std.CancellationToken.new + assert! not (← token.isCancelled) + + token.cancel .deadline + assert! (← token.isCancelled) + + let reason ← token.getCancellationReason + assert! reason == some .deadline + +#eval testDeadlineReason.block + +-- Test cancellation with shutdown reason +def testShutdownReason : Async Unit := do + let token ← Std.CancellationToken.new + token.cancel .shutdown + + let reason ← token.getCancellationReason + assert! reason == some .shutdown + +#eval testShutdownReason.block + +-- Test cancellation with custom reason +def testCustomReason : Async Unit := do + let token ← Std.CancellationToken.new + token.cancel (.custom "connection timeout") + + let reason ← token.getCancellationReason + assert! reason == some (.custom "connection timeout") + +#eval testCustomReason.block + +-- Test that uncancelled token has no reason +def testUncancelledNoReason : Async Unit := do + let token ← Std.CancellationToken.new + + let reason ← token.getCancellationReason + assert! reason == none + +#eval testUncancelledNoReason.block + +-- Test context cancellation with reason +def testContextCancellation : Async Unit := do + let ctx ← Std.CancellationContext.new + assert! not (← ctx.isCancelled) + + ctx.cancel .shutdown + assert! (← ctx.isCancelled) + + let reason ← ctx.token.getCancellationReason + assert! reason == some .shutdown + +#eval testContextCancellation.block + +-- Test context tree with different reasons +def testContextTreeReasons : Async Unit := do + let root ← Std.CancellationContext.new + let child1 ← root.fork + let child2 ← root.fork + let grandchild ← child1.fork + + -- Cancel root with shutdown reason + root.cancel .shutdown + + -- All should be cancelled + assert! (← root.isCancelled) + assert! (← child1.isCancelled) + assert! (← child2.isCancelled) + assert! (← grandchild.isCancelled) + + -- All should have the shutdown reason (propagated from root) + assert! (← root.token.getCancellationReason) == some .shutdown + assert! (← child1.token.getCancellationReason) == some .shutdown + assert! (← child2.token.getCancellationReason) == some .shutdown + assert! (← grandchild.token.getCancellationReason) == some .shutdown + +#eval testContextTreeReasons.block + +-- Test child cancellation doesn't affect parent +def testChildCancellationIndependent : Async Unit := do + let root ← Std.CancellationContext.new + let child ← root.fork + + -- Cancel child with deadline + child.cancel .deadline + + -- Child should be cancelled with deadline reason + assert! (← child.isCancelled) + assert! (← child.token.getCancellationReason) == some .deadline + + -- Parent should still be active + assert! not (← root.isCancelled) + assert! (← root.token.getCancellationReason) == none + +#eval testChildCancellationIndependent.block + +-- Test selector with reason +def testSelectorWithReason : Async Unit := do + let token ← Std.CancellationToken.new + let completed ← Std.Mutex.new false + let reasonRef ← Std.Mutex.new none + + let task ← async do + Selectable.one #[.case token.selector (fun _ => pure ())] + completed.atomically (set true) + reasonRef.atomically (set (← token.getCancellationReason)) + + assert! not (← completed.atomically get) + + token.cancel .deadline + await task + + assert! (← completed.atomically get) + assert! (← reasonRef.atomically get) == some Std.CancellationReason.deadline + +#eval testSelectorWithReason.block + +-- Test wait with reason +def testWaitWithReason : Async Unit := do + let token ← Std.CancellationToken.new + + let task ← async do + let _ ← await (← token.wait) + token.getCancellationReason + + Async.sleep 10 + token.cancel (.custom "test reason") + + let reason ← await task + assert! reason == some (.custom "test reason") + +#eval testWaitWithReason.block + +-- Test multiple cancellations (first one wins) +def testMultipleCancellations : Async Unit := do + let token ← Std.CancellationToken.new + + token.cancel .deadline + token.cancel .shutdown -- This should be ignored + + let reason ← token.getCancellationReason + assert! reason == some .deadline -- First reason should persist + +#eval testMultipleCancellations.block diff --git a/tests/lean/run/cancellation_context.lean b/tests/lean/run/cancellation_context.lean new file mode 100644 index 0000000000..73562d42d8 --- /dev/null +++ b/tests/lean/run/cancellation_context.lean @@ -0,0 +1,251 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +/-- Test basic tree cancellation -/ +partial def testCancelTree : IO Unit := do + let mutex ← Std.Mutex.new 0 + let context ← Std.CancellationContext.new + + Async.block do + + let rec loop (x : Nat) (parent : Std.CancellationContext) : Async Unit := do + match x with + | 0 => do + await (← parent.done) + mutex.atomically (modify (· + 1)) + + | n + 1 => do + background (loop n (← parent.fork)) + background (loop n (← parent.fork)) + await (← parent.done) + mutex.atomically (modify (· + 1)) + + background (loop 3 context) + Async.sleep 500 + context.cancel .cancel + Async.sleep 1000 + + assert! (← context.countAliveTokens) == 0 + let size ← mutex.atomically get + IO.println s!"cancelled {size}" + +/-- +info: cancelled 15 +-/ +#guard_msgs in +#eval testCancelTree + +/-- Test cancellation with different reasons -/ +def testCancellationReasons : IO Unit := do + let ctx ← Std.CancellationContext.new + let (reason1, reason2, reason3, reason4) ← Async.block do + -- Test with .cancel reason + let ctx1 ← ctx.fork + ctx1.cancel .cancel + let some reason1 ← ctx1.getCancellationReason | return (none, none, none, none) + + -- Test with .deadline reason + let ctx2 ← ctx.fork + ctx2.cancel .deadline + let some reason2 ← ctx2.getCancellationReason | return (none, none, none, none) + + -- Test with .shutdown reason + let ctx3 ← ctx.fork + ctx3.cancel .shutdown + let some reason3 ← ctx3.getCancellationReason | return (none, none, none, none) + + -- Test with custom reason + let ctx4 ← ctx.fork + ctx4.cancel (.custom "test error") + let some reason4 ← ctx4.getCancellationReason | return (none, none, none, none) + + return (some reason1, some reason2, some reason3, some reason4) + + if let some r1 := reason1 then IO.println s!"Reason 1: {r1}" + if let some r2 := reason2 then IO.println s!"Reason 2: {r2}" + if let some r3 := reason3 then IO.println s!"Reason 3: {r3}" + if let some r4 := reason4 then IO.println s!"Reason 4: {r4}" + + assert! (← ctx.countAliveTokens) == 1 + +/-- +info: Reason 1: cancel +Reason 2: deadline +Reason 3: shutdown +Reason 4: custom("test error") +-/ +#guard_msgs in +#eval testCancellationReasons + +/-- Test cancellation propagates reason to children -/ +def testReasonPropagation : IO Unit := do + let (parentReason, child1Reason, child2Reason, grandchildReason) ← Async.block do + let parent ← Std.CancellationContext.new + let child1 ← parent.fork + let child2 ← parent.fork + let grandchild ← child1.fork + + parent.cancel (.custom "parent cancelled") + Async.sleep 100 + + let some parentReason ← parent.getCancellationReason | return (none, none, none, none) + let some child1Reason ← child1.getCancellationReason | return (none, none, none, none) + let some child2Reason ← child2.getCancellationReason | return (none, none, none, none) + let some grandchildReason ← grandchild.getCancellationReason | return (none, none, none, none) + + return (some parentReason, some child1Reason, some child2Reason, some grandchildReason) + + if let some r := parentReason then IO.println s!"Parent: {r}" + if let some r := child1Reason then IO.println s!"Child1: {r}" + if let some r := child2Reason then IO.println s!"Child2: {r}" + if let some r := grandchildReason then IO.println s!"Grandchild: {r}" + +/-- +info: Parent: custom("parent cancelled") +Child1: custom("parent cancelled") +Child2: custom("parent cancelled") +Grandchild: custom("parent cancelled") +-/ +#guard_msgs in +#eval testReasonPropagation + +/-- Test cancellation in the middle of work -/ +def testCancelInMiddle : IO Unit := do + let counter ← Std.Mutex.new 0 + let cancelledCounter ← Std.Mutex.new 0 + + let (finalCount, cancelledCount) ← Async.block do + let context ← Std.CancellationContext.new + + -- Worker that does work until cancelled + let worker (ctx : Std.CancellationContext) : Async Unit := do + for _ in [0:100] do + if ← ctx.isCancelled then + cancelledCounter.atomically (modify (· + 1)) + break + counter.atomically (modify (· + 1)) + Async.sleep 10 + + -- Start 5 workers + for _ in [0:5] do + background (worker context) + + -- Let them run for a bit, then cancel + Async.sleep 200 + context.cancel .deadline + + -- Wait for them to finish + Async.sleep 500 + + let finalCount ← counter.atomically get + let cancelledCount ← cancelledCounter.atomically get + return (finalCount, cancelledCount) + + IO.println s!"Completed {finalCount} iterations before cancellation" + IO.println s!"{cancelledCount} workers detected cancellation" + +/-- Test cancellation before forking -/ +def testCancelBeforeFork : IO Unit := do + let (isSame, isChildCancelled) ← Async.block do + let ctx ← Std.CancellationContext.new + ctx.cancel .cancel + + -- Fork after cancellation should return same context + let child ← ctx.fork + let isSame := ctx.id == child.id + let isChildCancelled ← child.isCancelled + + return (isSame, isChildCancelled) + + IO.println s!"Same context: {isSame}, Child cancelled: {isChildCancelled}" + +/-- +info: Same context: true, Child cancelled: true +-/ +#guard_msgs in +#eval testCancelBeforeFork + +/-- Test deep tree cancellation with reason -/ +partial def testDeepTreeCancellation : IO Unit := do + let depths ← Std.Mutex.new ([] : List (Nat × Std.CancellationReason)) + + let (count, allSameReason) ← Async.block do + let root ← Std.CancellationContext.new + + let rec makeTree (depth : Nat) (ctx : Std.CancellationContext) : Async Unit := do + if depth == 0 then + await (← ctx.done) + if let some reason ← ctx.getCancellationReason then + depths.atomically (modify (·.cons (depth, reason))) + else + let child1 ← ctx.fork + let child2 ← ctx.fork + background (makeTree (depth - 1) child1) + background (makeTree (depth - 1) child2) + await (← ctx.done) + if let some reason ← ctx.getCancellationReason then + depths.atomically (modify (·.cons (depth, reason))) + + background (makeTree 4 root) + Async.sleep 200 + root.cancel (.custom "deep tree cancel") + Async.sleep 500 + + let results ← depths.atomically get + let count := results.length + let allSameReason := results.all fun (_, r) => r == .custom "deep tree cancel" + return (count, allSameReason) + + IO.println s!"Cancelled {count} nodes, all with same reason: {allSameReason}" + +/-- +info: Cancelled 31 nodes, all with same reason: true +-/ +#guard_msgs in +#eval testDeepTreeCancellation + +/-- Test counting alive tokens -/ +def testCountAliveTokens : IO Unit := do + let (count0, count1, count2, count3, count4) ← Async.block do + let root ← Std.CancellationContext.new + let count0 ← root.countAliveTokens -- Root only + + -- Fork 3 children + let child1 ← root.fork + let child2 ← root.fork + let _child3 ← root.fork + let count1 ← root.countAliveTokens -- Root + 3 children = 4 + + -- Cancel one child (and its subtree) + child1.cancel .cancel + Async.sleep 100 + let count2 ← root.countAliveTokens -- Root + 2 children = 3 + + -- Fork a grandchild from child2 + let _grandchild ← child2.fork + let count3 ← root.countAliveTokens -- Root + 2 children + 1 grandchild = 4 + + -- Cancel root (should cancel everything) + root.cancel .cancel + Async.sleep 100 + let count4 ← root.countAliveTokens -- All cancelled = 0 + + return (count0, count1, count2, count3, count4) + + IO.println s!"Initial (root only): {count0}" + IO.println s!"After forking 3 children: {count1}" + IO.println s!"After cancelling 1 child: {count2}" + IO.println s!"After forking grandchild: {count3}" + IO.println s!"After cancelling root: {count4}" + +/-- +info: Initial (root only): 1 +After forking 3 children: 4 +After cancelling 1 child: 3 +After forking grandchild: 4 +After cancelling root: 0 +-/ +#guard_msgs in +#eval testCountAliveTokens diff --git a/tests/lean/run/context_async.lean b/tests/lean/run/context_async.lean new file mode 100644 index 0000000000..ffabcae20f --- /dev/null +++ b/tests/lean/run/context_async.lean @@ -0,0 +1,698 @@ +import Std.Internal.Async +import Std.Sync + +open Std.Internal.IO Async + +/-- Test ContextAsync cancellation check -/ +def testIsCancelled : IO Unit := do + let (before, after) ← Async.block do + ContextAsync.run do + let before ← ContextAsync.isCancelled + ContextAsync.cancel .cancel + Async.sleep 50 + let after ← ContextAsync.isCancelled + return (before, after) + + IO.println s!"Before: {before}, After: {after}" + +/-- +info: Before: false, After: true +-/ +#guard_msgs in +#eval testIsCancelled + +/-- Test ContextAsync cancellation reason -/ +def testGetCancellationReason : IO Unit := do + let res ← Async.block do + ContextAsync.run do + ContextAsync.cancel (.custom "test reason") + Async.sleep 50 + let some reason ← ContextAsync.getCancellationReason + | return "ERROR: No reason found" + return s!"Reason: {reason}" + + IO.println res + +/-- +info: Reason: custom("test reason") +-/ +#guard_msgs in +#eval testGetCancellationReason + +/-- Test awaitCancellation -/ +def testAwaitCancellation : IO Unit := do + let received ← Std.Mutex.new false + + Async.block do + let started ← Std.Mutex.new false + + ContextAsync.run do + discard <| ContextAsync.concurrently + (do + started.atomically (set true) + ContextAsync.awaitCancellation + received.atomically (set true)) + (do + -- Wait for task to start + while !(← started.atomically get) do + Async.sleep 10 + + Async.sleep 100 + ContextAsync.cancel .shutdown) + + Async.sleep 200 + + let _ ← received.atomically get + IO.println "Cancellation received" + + +def testSelectorCancellationFail : IO Unit := do + let received ← Std.Mutex.new false + + let result ← Async.block do + let ctx ← Std.CancellationContext.new + let started ← Std.Mutex.new false + + + let result ← do + try + ContextAsync.runIn ctx do + discard <| ContextAsync.concurrently + (do + started.atomically (set true) + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + received.atomically (set res)) + (do + throw (.userError "failed") + return ()) + return Except.ok () + catch err => + return Except.error err + + Async.sleep 500 + + return result + + let _ ← received.atomically get + IO.println "Cancellation received" + + if let Except.error err := result then + throw err + +/-- +info: Cancellation received +--- +error: failed +-/ +#guard_msgs in +#eval testSelectorCancellationFail + +/-- Test concurrently with both tasks succeeding -/ +def testConcurrently : IO Unit := do + let (a, b) ← Async.block do + ContextAsync.run do + ContextAsync.concurrently + (do + Async.sleep 100 + return 42) + (do + Async.sleep 150 + return "hello") + + IO.println s!"Results: {a}, {b}" + +/-- +info: Results: 42, hello +-/ +#guard_msgs in +#eval testConcurrently + +/-- Test race with first task winning -/ +def testRace : IO Unit := do + let result ← Async.block do + ContextAsync.run do + ContextAsync.race + (do + Async.sleep 50 + return "fast") + (do + Async.sleep 200 + return "slow") + + IO.println s!"Winner: {result}" + +/-- +info: Winner: fast +-/ +#guard_msgs in +#eval testRace + +/-- Test concurrentlyAll -/ +def testConcurrentlyAll : IO Unit := do + let results ← Async.block do + ContextAsync.run do + let tasks := #[ + (do Async.sleep 50; return 1), + (do Async.sleep 100; return 2), + (do Async.sleep 75; return 3) + ] + ContextAsync.concurrentlyAll tasks + + IO.println s!"All results: {results}" + +/-- +info: All results: #[1, 2, 3] +-/ +#guard_msgs in +#eval testConcurrentlyAll + +/-- Test background task with cancellation -/ +def testBackground : IO Unit := do + let counter ← Std.Mutex.new 0 + + Async.block do + ContextAsync.run do + discard <| ContextAsync.concurrently + (do + for _ in [0:10] do + if ← ContextAsync.isCancelled then + break + counter.atomically (modify (· + 1)) + Async.sleep 50) + (do + -- Let it run for a bit + Async.sleep 150 + ContextAsync.cancel .cancel) + + Async.sleep 200 + + let final ← counter.atomically get + IO.println s!"Counter reached: {final}" + +/-- Test fork cancellation isolation -/ +def testForkCancellation : IO Unit := do + let parent ← Std.CancellationContext.new + let childCancelled ← Std.Mutex.new false + let parentCancelled ← Std.Mutex.new false + + Async.block do + ContextAsync.runIn parent do + discard <| ContextAsync.concurrentlyAll #[ + (do + let child ← ContextAsync.getContext + Async.sleep 100 + child.cancel .cancel + childCancelled.atomically (set true)), + (do + Async.sleep 200 + if ← parent.isCancelled then + parentCancelled.atomically (set true)) + ] + + let childWasCancelled ← childCancelled.atomically get + let parentWasCancelled ← parentCancelled.atomically get + + IO.println s!"Child cancelled: {childWasCancelled}, Parent cancelled: {parentWasCancelled}" + +/-- +info: Child cancelled: true, Parent cancelled: false +-/ +#guard_msgs in +#eval testForkCancellation + +/-- Test doneSelector -/ +partial def testNestedFork : IO Unit := do + let res ← Async.block do + ContextAsync.run do + let ctx ← ContextAsync.getContext + let sel ← ContextAsync.doneSelector + + let (_, result) ← ContextAsync.concurrently + (do + Async.sleep 100 + ctx.cancel .deadline) + (Selectable.one #[.case sel (fun _ => pure true)]) + + return result + + IO.println s!"Done selector triggered: {res}" + +/-- +info: Done selector triggered: true +-/ +#guard_msgs in +#eval testNestedFork + +/-- Test Selector.cancelled -/ +def testSelectorCancelled : IO Unit := do + let res ← Async.block do + ContextAsync.run do + let ctx ← ContextAsync.getContext + let sel ← Selector.cancelled + + let (_, result) ← ContextAsync.concurrently + (do + Async.sleep 150 + ctx.cancel .shutdown) + (Selectable.one #[.case sel (fun _ => pure true)]) + + return result + + IO.println s!"Selector.cancelled triggered: {res}" + +/-- +info: Selector.cancelled triggered: true +-/ +#guard_msgs in +#eval testSelectorCancelled + +/-- Test MonadLift instances -/ +def testMonadLift : IO Unit := do + let (msg1, msg2) ← Async.block do + ContextAsync.run do + -- Lift from IO + let msg1 : String := "From IO" + + -- Lift from BaseIO + let msg2 : String := "From BaseIO" + + -- Lift from Async + let _ ← (Async.sleep 50 : Async Unit) + + return (msg1, msg2) + + IO.println msg1 + IO.println msg2 + IO.println "All lifts work" + +/-- +info: From IO +From BaseIO +All lifts work +-/ +#guard_msgs in +#eval testMonadLift + +/-- Test exception handling in ContextAsync -/ +def testExceptionHandling : IO Unit := do + let res ← Async.block do + ContextAsync.run do + try + throw (IO.userError "test error") + return "Should not reach here" + catch e => + return s!"Caught: {e}" + + IO.println res + +/-- +info: Caught: test error +-/ +#guard_msgs in +#eval testExceptionHandling + +/-- Test tryFinally in ContextAsync -/ +def testTryFinally : IO Unit := do + let cleaned ← Std.Mutex.new false + + Async.block do + ContextAsync.run do + try + ContextAsync.cancel .cancel + ContextAsync.awaitCancellation + finally + cleaned.atomically (set true) + + let wasCleanedUp ← cleaned.atomically get + IO.println s!"Cleanup ran: {wasCleanedUp}" + +/-- +info: Cleanup ran: true +-/ +#guard_msgs in +#eval testTryFinally + +/-- Test race with cancellation -/ +def testRaceWithCancellation : IO Unit := do + let ctx ← Std.CancellationContext.new + let leftCancelled ← Std.Mutex.new false + let rightCancelled ← Std.Mutex.new false + + Async.block do + ContextAsync.runIn ctx do + let _ ← ContextAsync.race + (do + try + Async.sleep 500 + return "left" + finally + if ← ContextAsync.isCancelled then + leftCancelled.atomically (set true)) + (do + try + Async.sleep 50 + return "right" + finally + if ← ContextAsync.isCancelled then + rightCancelled.atomically (set true)) + + Async.sleep 1000 + + let left ← leftCancelled.atomically get + let right ← rightCancelled.atomically get + IO.println s!"Left cancelled: {left}, Right cancelled: {right}" + +/-- +info: Left cancelled: true, Right cancelled: false +-/ +#guard_msgs in +#eval testRaceWithCancellation + +/-- Test complex concurrent workflow -/ +def testComplexWorkflow : IO Unit := do + let results ← Std.Mutex.new ([] : List String) + + Async.block do + ContextAsync.run do + -- Run multiple concurrent operations + let (a, b) ← ContextAsync.concurrently + (do + Async.sleep 50 + results.atomically (modify ("A"::·)) + return 1) + (do + Async.sleep 75 + results.atomically (modify ("B"::·)) + return 2) + + -- Additional concurrent task + discard <| ContextAsync.concurrently + (do + Async.sleep 100 + results.atomically (modify ("BG"::·))) + (do + Async.sleep 200 + results.atomically (modify (s!"Sum:{a+b}"::·))) + + let final ← results.atomically get + IO.println s!"Results: {final.reverse}" + +/-- +info: Results: [A, B, BG, Sum:3] +-/ +#guard_msgs in +#eval testComplexWorkflow + +def testConcurrentlyAllException : IO Unit := do + let ref ← IO.mkRef "" + + try + Async.block do + ContextAsync.run do + let tasks := #[ + (do + Async.sleep 1000 + if ← ContextAsync.isCancelled then + ref.set "cancelled" + return + else + ref.set "not cancelled" + Async.sleep 500 + if ← ContextAsync.isCancelled then + ref.modify (· ++ ", cancelled") + else + ref.modify (· ++ ", not cancelled")), + (do + Async.sleep 250 + throw (IO.userError "Error: Hello")) + ] + discard <| ContextAsync.concurrentlyAll tasks + finally + IO.println (← ref.get) + +/-- +info: cancelled +--- +error: Error: Hello +-/ +#guard_msgs in +#eval testConcurrentlyAllException + +/-- Test that tasks in ContextAsync.run are not cancelled when run completes -/ +def test0 : IO Unit := do + let ref ← IO.mkRef false + + Async.block do + ContextAsync.run do + Async.sleep 100 + if ← ContextAsync.isCancelled then + ref.set true + + IO.sleep 200 + IO.println s!"{← ref.get}" + +/-- +info: false +-/ +#guard_msgs in +#eval test0 + +/-- Test that background tasks are cancelled when ContextAsync.run completes -/ +def test1 : IO Unit := do + let ref ← IO.mkRef false + + Async.block do + ContextAsync.run do + ContextAsync.background do + Async.sleep 100 + if ← ContextAsync.isCancelled then + ref.set true + + IO.sleep 200 + IO.println s!"{← ref.get}" + +/-- +info: true +-/ +#guard_msgs in +#eval test1 + +/-- Test that nested background tasks (ContextAsync.background in ContextAsync.background) are cancelled -/ +def test2 : IO Unit := do + let ref ← IO.mkRef false + + Async.block do + ContextAsync.run do + ContextAsync.background do + ContextAsync.background do + Async.sleep 100 + if ← ContextAsync.isCancelled then + ref.set true + + IO.sleep 200 + IO.println s!"{← ref.get}" + +/-- +info: true +-/ +#guard_msgs in +#eval test2 + +/-- Test that ContextAsync.background in Async.background is cancelled -/ +def test2' : IO Unit := do + let ref ← IO.mkRef false + + Async.block do + ContextAsync.run do + Async.background do + ContextAsync.background do + Async.sleep 100 + if ← ContextAsync.isCancelled then + ref.set true + + IO.sleep 200 + IO.println s!"{← ref.get}" + +/-- +info: true +-/ +#guard_msgs in +#eval test2' + +/-- Test that Async.background in ContextAsync.background is cancelled -/ +def test2'' : IO Unit := do + let ref ← IO.mkRef false + + Async.block do + ContextAsync.run do + ContextAsync.background do + Async.background do + Async.sleep 100 + if ← ContextAsync.isCancelled then + ref.set true + + IO.sleep 200 + IO.println s!"{← ref.get}" + +/-- +info: true +-/ +#guard_msgs in +#eval test2'' + +/-- Test concurrently with first task succeeding immediately, others checking cancellation -/ +def testConcurrentlySuccessWithCancellation : IO Unit := do + let task2Cancelled ← Std.Mutex.new false + let task3Cancelled ← Std.Mutex.new false + + let results ← Async.block do + ContextAsync.run do + ContextAsync.concurrentlyAll #[ + (do + return "first"), + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 500) (fun _ => pure false) + ] + + task2Cancelled.atomically (set (res)) + return "second"), + (do + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 500) (fun _ => pure false) + ] + + task3Cancelled.atomically (set (res)) + return "third") + ] + + let t2 ← task2Cancelled.atomically get + let t3 ← task3Cancelled.atomically get + IO.println s!"Results: {results}" + IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}" + +/-- +info: Results: #[first, second, third] +Task2 cancelled: false, Task3 cancelled: false +-/ +#guard_msgs in +#eval testConcurrentlySuccessWithCancellation + +/-- Test concurrently with first task failing, others checking for cancellation -/ +def testConcurrentlyFailWithCancellation : IO Unit := do + let task2Cancelled ← Std.Mutex.new false + let task3Cancelled ← Std.Mutex.new false + + let results ← Async.block do + ContextAsync.run do + try + let result ← ContextAsync.concurrentlyAll #[ + (do + -- First task fails immediately + throw (IO.userError "first task failed")), + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + + task2Cancelled.atomically (set (res)) + return "second"), + (do + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + + task3Cancelled.atomically (set (res)) + return "third") + ] + return Except.ok result + catch e => + Async.sleep 500 + return Except.error e + + let t2 ← task2Cancelled.atomically get + let t3 ← task3Cancelled.atomically get + + match results with + | .ok results => IO.println s!"Results: {results}" + | .error e => IO.println s!"Error: {e}" + + IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}" + +/-- +info: Error: first task failed +Task2 cancelled: true, Task3 cancelled: true +-/ +#guard_msgs in +#eval testConcurrentlyFailWithCancellation + +/-- Test concurrently with both tasks succeeding, checking cancellation status -/ +def testConcurrentlySuccessWithCancellation2Tasks : IO Unit := do + let task2Cancelled ← Std.Mutex.new false + + let (r1, r2) ← Async.block do + ContextAsync.run do + ContextAsync.concurrently + (do return "first") + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 500) (fun _ => pure false) + ] + + task2Cancelled.atomically (set res) + return "second") + + let t2 ← task2Cancelled.atomically get + IO.println s!"Results: {r1}, {r2}" + IO.println s!"Task2 cancelled: {t2}" + +/-- +info: Results: first, second +Task2 cancelled: false +-/ +#guard_msgs in +#eval testConcurrentlySuccessWithCancellation2Tasks + +/-- Test concurrently with first task failing, second task checking for cancellation -/ +def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do + let task2Cancelled ← Std.Mutex.new false + + try + Async.block do + ContextAsync.run do + let (_ : (String × String)) ← ContextAsync.concurrently + (do + -- First task fails immediately + throw (IO.userError "first task failed") : ContextAsync String) + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + + task2Cancelled.atomically (set res) + return "second") + catch e => + IO.sleep 500 + let t2 ← task2Cancelled.atomically get + IO.println s!"Error: {e}" + IO.println s!"Task2 cancelled: {t2}" + +/-- +info: Error: first task failed +Task2 cancelled: true +-/ +#guard_msgs in +#eval testConcurrentlyFailWithCancellation2Tasks