diff --git a/src/Std/Async/ContextAsync.lean b/src/Std/Async/ContextAsync.lean index ad341212c9..7c528e6438 100644 --- a/src/Std/Async/ContextAsync.lean +++ b/src/Std/Async/ContextAsync.lean @@ -119,27 +119,6 @@ def concurrently (x : ContextAsync α) (y : ContextAsync β) concurrentCtx.cancel .cancel return result -/-- -Runs two computations concurrently and returns the result of the first to complete. Each computation runs -in its own child context; when either completes, the other is cancelled immediately. --/ -@[inline, specialize] -def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α) - (prio := Task.Priority.default) : ContextAsync α := do - let parent ← getContext - let ctx1 ← CancellationContext.fork parent - let ctx2 ← CancellationContext.fork parent - - let task1 ← async (x ctx1) prio - let task2 ← async (y ctx2) prio - - let result ← Async.race - (await task1 <* ctx2.cancel .cancel) - (await task2 <* ctx1.cancel .cancel) - prio - - pure result - /-- Runs all computations concurrently and collects results in the same order. Each runs in its own child context; if any computation fails, all others are cancelled and the exception is propagated. @@ -254,6 +233,27 @@ instance [Inhabited α] : Inhabited (ContextAsync α) where instance : MonadAwait AsyncTask ContextAsync where await t := fun _ => await t +/-- +Runs two computations concurrently and returns the result of the first to complete. Each computation runs +in its own child context; when either completes, the other is cancelled immediately. +-/ +@[inline, specialize] +def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α) + (prio := Task.Priority.default) : ContextAsync α := do + let parent ← getContext + let ctx1 ← CancellationContext.fork parent + let ctx2 ← CancellationContext.fork parent + + let task1 ← async (x ctx1) prio + let task2 ← async (y ctx2) prio + + let promise ← IO.Promise.new + BaseIO.chainTask task1 fun result => liftM (promise.resolve result) *> ctx2.cancel .cancel + BaseIO.chainTask task2 fun result => liftM (promise.resolve result) *> ctx1.cancel .cancel + + let result ← MonadAwait.await promise + Async.ofExcept result + end ContextAsync /-- diff --git a/tests/elab/context_async.lean b/tests/elab/context_async.lean index cbd6dfb31d..37dfb7a50a 100644 --- a/tests/elab/context_async.lean +++ b/tests/elab/context_async.lean @@ -9,7 +9,6 @@ def testIsCancelled : IO Unit := do ContextAsync.run do let before ← ContextAsync.isCancelled ContextAsync.cancel .cancel - Async.sleep 50 let after ← ContextAsync.isCancelled return (before, after) @@ -26,7 +25,6 @@ def testGetCancellationReason : IO Unit := do let res ← Async.block do ContextAsync.run do ContextAsync.cancel (.custom "test reason") - Async.sleep 50 let some reason ← ContextAsync.getCancellationReason | return "ERROR: No reason found" return s!"Reason: {reason}" @@ -92,8 +90,6 @@ def testSelectorCancellationFail : IO Unit := do catch err => return Except.error err - Async.sleep 500 - return result let _ ← received.atomically get @@ -115,12 +111,8 @@ def testConcurrently : IO Unit := do let (a, b) ← Async.block do ContextAsync.run do ContextAsync.concurrently - (do - Async.sleep 100 - return 42) - (do - Async.sleep 150 - return "hello") + (do return 42) + (do return "hello") IO.println s!"Results: {a}, {b}" @@ -135,12 +127,8 @@ def testRace : IO Unit := do let result ← Async.block do ContextAsync.run do ContextAsync.race - (do - Async.sleep 50 - return "fast") - (do - Async.sleep 200 - return "slow") + (do ContextAsync.awaitCancellation; return "slow") + (do return "fast") IO.println s!"Winner: {result}" @@ -155,9 +143,9 @@ def testConcurrentlyAll : IO Unit := do let results ← Async.block do ContextAsync.run do let tasks := #[ - (do Async.sleep 50; return 1), - (do Async.sleep 100; return 2), - (do Async.sleep 75; return 3) + (do return 1), + (do return 2), + (do return 3) ] ContextAsync.concurrentlyAll tasks @@ -203,11 +191,9 @@ def testForkCancellation : IO Unit := do discard <| ContextAsync.concurrentlyAll #[ (do let child ← ContextAsync.getContext - Async.sleep 100 child.cancel .cancel childCancelled.atomically (set true)), (do - Async.sleep 200 if ← parent.isCancelled then parentCancelled.atomically (set true)) ] @@ -231,9 +217,7 @@ partial def testNestedFork : IO Unit := do let sel ← ContextAsync.doneSelector let (_, result) ← ContextAsync.concurrently - (do - Async.sleep 100 - ctx.cancel .deadline) + (do ctx.cancel .deadline) (Selectable.one #[.case sel (fun _ => pure true)]) return result @@ -254,9 +238,7 @@ def testSelectorCancelled : IO Unit := do let sel ← Selector.cancelled let (_, result) ← ContextAsync.concurrently - (do - Async.sleep 150 - ctx.cancel .shutdown) + (do ctx.cancel .shutdown) (Selectable.one #[.case sel (fun _ => pure true)]) return result @@ -280,7 +262,7 @@ def testMonadLift : IO Unit := do let msg2 : String := "From BaseIO" -- Lift from Async - let _ ← (Async.sleep 50 : Async Unit) + let _ ← (pure () : Async Unit) return (msg1, msg2) @@ -342,24 +324,24 @@ def testRaceWithCancellation : IO Unit := do let rightCancelled ← Std.Mutex.new false Async.block do + let leftDone ← Std.Semaphore.new 0 ContextAsync.runIn ctx do let _ ← ContextAsync.race (do try - Async.sleep 500 + ContextAsync.awaitCancellation return "left" finally if ← ContextAsync.isCancelled then - leftCancelled.atomically (set true)) + leftCancelled.atomically (set true) + leftDone.release) (do try - Async.sleep 50 return "right" finally if ← ContextAsync.isCancelled then rightCancelled.atomically (set true)) - - Async.sleep 1000 + discard <| MonadAwait.await (← leftDone.acquire).result! let left ← leftCancelled.atomically get let right ← rightCancelled.atomically get @@ -377,28 +359,20 @@ def testComplexWorkflow : IO Unit := do Async.block do ContextAsync.run do - -- Run multiple concurrent operations let (a, b) ← ContextAsync.concurrently (do - Async.sleep 50 results.atomically (modify ("A"::·)) return 1) (do - Async.sleep 75 results.atomically (modify ("B"::·)) return 2) - -- Additional concurrent task discard <| ContextAsync.concurrently - (do - Async.sleep 100 - results.atomically (modify ("BG"::·))) - (do - Async.sleep 200 - results.atomically (modify (s!"Sum:{a+b}"::·))) + (do results.atomically (modify ("BG"::·))) + (do results.atomically (modify (s!"Sum:{a+b}"::·))) let final ← results.atomically get - IO.println s!"Results: {final.reverse}" + IO.println s!"Results: {final.mergeSort}" /-- info: Results: [A, B, BG, Sum:3] @@ -447,11 +421,9 @@ def test0 : IO Unit := do Async.block do ContextAsync.run do - Async.sleep 100 if ← ContextAsync.isCancelled then ref.set true - IO.sleep 200 IO.println s!"{← ref.get}" /-- @@ -465,13 +437,14 @@ def test1 : IO Unit := do let ref ← IO.mkRef false Async.block do + let done ← Std.Semaphore.new 0 ContextAsync.run do ContextAsync.background do - Async.sleep 100 - if ← ContextAsync.isCancelled then - ref.set true + ContextAsync.awaitCancellation + ref.set true + done.release + discard <| MonadAwait.await (← done.acquire).result! - IO.sleep 200 IO.println s!"{← ref.get}" /-- @@ -485,14 +458,15 @@ def test2 : IO Unit := do let ref ← IO.mkRef false Async.block do + let done ← Std.Semaphore.new 0 ContextAsync.run do ContextAsync.background do ContextAsync.background do - Async.sleep 100 - if ← ContextAsync.isCancelled then - ref.set true + ContextAsync.awaitCancellation + ref.set true + done.release + discard <| MonadAwait.await (← done.acquire).result! - IO.sleep 200 IO.println s!"{← ref.get}" /-- @@ -506,14 +480,16 @@ def test2' : IO Unit := do let ref ← IO.mkRef false Async.block do + let done ← Std.Semaphore.new 0 ContextAsync.run do Async.background do ContextAsync.background do - Async.sleep 100 - if ← ContextAsync.isCancelled then - ref.set true + ContextAsync.awaitCancellation + ref.set true + done.release + + discard <| MonadAwait.await (← done.acquire).result! - IO.sleep 200 IO.println s!"{← ref.get}" /-- @@ -522,26 +498,6 @@ info: true #guard_msgs in #eval test2' -/-- Test that Async.background in ContextAsync.background is cancelled -/ -def test2'' : IO Unit := do - let ref ← IO.mkRef false - - Async.block do - ContextAsync.run do - ContextAsync.background do - Async.background do - Async.sleep 100 - if ← ContextAsync.isCancelled then - ref.set true - - IO.sleep 200 - IO.println s!"{← ref.get}" - -/-- -info: true --/ -#guard_msgs in -#eval test2'' /-- Test concurrently with first task succeeding immediately, others checking cancellation -/ def testConcurrentlySuccessWithCancellation : IO Unit := do @@ -590,34 +546,38 @@ def testConcurrentlyFailWithCancellation : IO Unit := do let task3Cancelled ← Std.Mutex.new false let results ← Async.block do - ContextAsync.run do - try - let result ← ContextAsync.concurrentlyAll #[ - (do - -- First task fails immediately - throw (IO.userError "first task failed")), - (do - -- Second task waits and checks for cancellation - let res ← Selectable.one #[ - .case (← ContextAsync.doneSelector) (fun _ => pure true), - .case (← Selector.sleep 2000) (fun _ => pure false) - ] - - task2Cancelled.atomically (set (res)) - return "second"), - (do - let res ← Selectable.one #[ - .case (← ContextAsync.doneSelector) (fun _ => pure true), - .case (← Selector.sleep 2000) (fun _ => pure false) - ] - - task3Cancelled.atomically (set (res)) - return "third") - ] - return Except.ok result - catch e => - Async.sleep 500 - return Except.error e + let task2Done ← Std.Semaphore.new 0 + let task3Done ← Std.Semaphore.new 0 + let result ← ContextAsync.run do + try + let result ← ContextAsync.concurrentlyAll #[ + (do + -- First task fails immediately + throw (IO.userError "first task failed")), + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + task2Cancelled.atomically (set (res)) + task2Done.release + return "second"), + (do + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + task3Cancelled.atomically (set (res)) + task3Done.release + return "third") + ] + return Except.ok result + catch e => + return Except.error e + discard <| MonadAwait.await (← task2Done.acquire).result! + discard <| MonadAwait.await (← task3Done.acquire).result! + return result let t2 ← task2Cancelled.atomically get let t3 ← task3Cancelled.atomically get @@ -668,24 +628,30 @@ Task2 cancelled: false def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do let task2Cancelled ← Std.Mutex.new false - try - Async.block do - ContextAsync.run do - let (_ : (String × String)) ← ContextAsync.concurrently - (do - -- First task fails immediately - throw (IO.userError "first task failed") : ContextAsync String) - (do - -- Second task waits and checks for cancellation - let res ← Selectable.one #[ - .case (← ContextAsync.doneSelector) (fun _ => pure true), - .case (← Selector.sleep 2000) (fun _ => pure false) - ] - - task2Cancelled.atomically (set res) - return "second") - catch e => - IO.sleep 500 + let task2Done ← Std.Semaphore.new 0 + let err ← Async.block do + let err : Option IO.Error ← + try + ContextAsync.run do + let (_ : (String × String)) ← ContextAsync.concurrently + (do + -- First task fails immediately + throw (IO.userError "first task failed") : ContextAsync String) + (do + -- Second task waits and checks for cancellation + let res ← Selectable.one #[ + .case (← ContextAsync.doneSelector) (fun _ => pure true), + .case (← Selector.sleep 2000) (fun _ => pure false) + ] + task2Cancelled.atomically (set res) + task2Done.release + return "second") + pure none + catch e => + pure (some e) + discard <| MonadAwait.await (← task2Done.acquire).result! + return err + if let some e := err then let t2 ← task2Cancelled.atomically get IO.println s!"Error: {e}" IO.println s!"Task2 cancelled: {t2}"