fix: timing issues and race conditions in ContextAsync.race (#13718)

This PR fixes tests in context_async.lean by removing all the issues
with Async.sleep and IO.sleep and improving how ContextAsync.race works.
This commit is contained in:
Sofia Rodrigues 2026-05-12 22:25:01 -03:00 committed by GitHub
parent cc103d8ed6
commit ed0d50fcf0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 111 additions and 145 deletions

View file

@ -119,27 +119,6 @@ def concurrently (x : ContextAsync α) (y : ContextAsync β)
concurrentCtx.cancel .cancel concurrentCtx.cancel .cancel
return result return result
/--
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
in its own child context; when either completes, the other is cancelled immediately.
-/
@[inline, specialize]
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
(prio := Task.Priority.default) : ContextAsync α := do
let parent ← getContext
let ctx1 ← CancellationContext.fork parent
let ctx2 ← CancellationContext.fork parent
let task1 ← async (x ctx1) prio
let task2 ← async (y ctx2) prio
let result ← Async.race
(await task1 <* ctx2.cancel .cancel)
(await task2 <* ctx1.cancel .cancel)
prio
pure result
/-- /--
Runs all computations concurrently and collects results in the same order. Each runs in its own child context; Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
if any computation fails, all others are cancelled and the exception is propagated. if any computation fails, all others are cancelled and the exception is propagated.
@ -254,6 +233,27 @@ instance [Inhabited α] : Inhabited (ContextAsync α) where
instance : MonadAwait AsyncTask ContextAsync where instance : MonadAwait AsyncTask ContextAsync where
await t := fun _ => await t await t := fun _ => await t
/--
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
in its own child context; when either completes, the other is cancelled immediately.
-/
@[inline, specialize]
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
(prio := Task.Priority.default) : ContextAsync α := do
let parent ← getContext
let ctx1 ← CancellationContext.fork parent
let ctx2 ← CancellationContext.fork parent
let task1 ← async (x ctx1) prio
let task2 ← async (y ctx2) prio
let promise ← IO.Promise.new
BaseIO.chainTask task1 fun result => liftM (promise.resolve result) *> ctx2.cancel .cancel
BaseIO.chainTask task2 fun result => liftM (promise.resolve result) *> ctx1.cancel .cancel
let result ← MonadAwait.await promise
Async.ofExcept result
end ContextAsync end ContextAsync
/-- /--

View file

@ -9,7 +9,6 @@ def testIsCancelled : IO Unit := do
ContextAsync.run do ContextAsync.run do
let before ← ContextAsync.isCancelled let before ← ContextAsync.isCancelled
ContextAsync.cancel .cancel ContextAsync.cancel .cancel
Async.sleep 50
let after ← ContextAsync.isCancelled let after ← ContextAsync.isCancelled
return (before, after) return (before, after)
@ -26,7 +25,6 @@ def testGetCancellationReason : IO Unit := do
let res ← Async.block do let res ← Async.block do
ContextAsync.run do ContextAsync.run do
ContextAsync.cancel (.custom "test reason") ContextAsync.cancel (.custom "test reason")
Async.sleep 50
let some reason ← ContextAsync.getCancellationReason let some reason ← ContextAsync.getCancellationReason
| return "ERROR: No reason found" | return "ERROR: No reason found"
return s!"Reason: {reason}" return s!"Reason: {reason}"
@ -92,8 +90,6 @@ def testSelectorCancellationFail : IO Unit := do
catch err => catch err =>
return Except.error err return Except.error err
Async.sleep 500
return result return result
let _ ← received.atomically get let _ ← received.atomically get
@ -115,12 +111,8 @@ def testConcurrently : IO Unit := do
let (a, b) ← Async.block do let (a, b) ← Async.block do
ContextAsync.run do ContextAsync.run do
ContextAsync.concurrently ContextAsync.concurrently
(do (do return 42)
Async.sleep 100 (do return "hello")
return 42)
(do
Async.sleep 150
return "hello")
IO.println s!"Results: {a}, {b}" IO.println s!"Results: {a}, {b}"
@ -135,12 +127,8 @@ def testRace : IO Unit := do
let result ← Async.block do let result ← Async.block do
ContextAsync.run do ContextAsync.run do
ContextAsync.race ContextAsync.race
(do (do ContextAsync.awaitCancellation; return "slow")
Async.sleep 50 (do return "fast")
return "fast")
(do
Async.sleep 200
return "slow")
IO.println s!"Winner: {result}" IO.println s!"Winner: {result}"
@ -155,9 +143,9 @@ def testConcurrentlyAll : IO Unit := do
let results ← Async.block do let results ← Async.block do
ContextAsync.run do ContextAsync.run do
let tasks := #[ let tasks := #[
(do Async.sleep 50; return 1), (do return 1),
(do Async.sleep 100; return 2), (do return 2),
(do Async.sleep 75; return 3) (do return 3)
] ]
ContextAsync.concurrentlyAll tasks ContextAsync.concurrentlyAll tasks
@ -203,11 +191,9 @@ def testForkCancellation : IO Unit := do
discard <| ContextAsync.concurrentlyAll #[ discard <| ContextAsync.concurrentlyAll #[
(do (do
let child ← ContextAsync.getContext let child ← ContextAsync.getContext
Async.sleep 100
child.cancel .cancel child.cancel .cancel
childCancelled.atomically (set true)), childCancelled.atomically (set true)),
(do (do
Async.sleep 200
if ← parent.isCancelled then if ← parent.isCancelled then
parentCancelled.atomically (set true)) parentCancelled.atomically (set true))
] ]
@ -231,9 +217,7 @@ partial def testNestedFork : IO Unit := do
let sel ← ContextAsync.doneSelector let sel ← ContextAsync.doneSelector
let (_, result) ← ContextAsync.concurrently let (_, result) ← ContextAsync.concurrently
(do (do ctx.cancel .deadline)
Async.sleep 100
ctx.cancel .deadline)
(Selectable.one #[.case sel (fun _ => pure true)]) (Selectable.one #[.case sel (fun _ => pure true)])
return result return result
@ -254,9 +238,7 @@ def testSelectorCancelled : IO Unit := do
let sel ← Selector.cancelled let sel ← Selector.cancelled
let (_, result) ← ContextAsync.concurrently let (_, result) ← ContextAsync.concurrently
(do (do ctx.cancel .shutdown)
Async.sleep 150
ctx.cancel .shutdown)
(Selectable.one #[.case sel (fun _ => pure true)]) (Selectable.one #[.case sel (fun _ => pure true)])
return result return result
@ -280,7 +262,7 @@ def testMonadLift : IO Unit := do
let msg2 : String := "From BaseIO" let msg2 : String := "From BaseIO"
-- Lift from Async -- Lift from Async
let _ ← (Async.sleep 50 : Async Unit) let _ ← (pure () : Async Unit)
return (msg1, msg2) return (msg1, msg2)
@ -342,24 +324,24 @@ def testRaceWithCancellation : IO Unit := do
let rightCancelled ← Std.Mutex.new false let rightCancelled ← Std.Mutex.new false
Async.block do Async.block do
let leftDone ← Std.Semaphore.new 0
ContextAsync.runIn ctx do ContextAsync.runIn ctx do
let _ ← ContextAsync.race let _ ← ContextAsync.race
(do (do
try try
Async.sleep 500 ContextAsync.awaitCancellation
return "left" return "left"
finally finally
if ← ContextAsync.isCancelled then if ← ContextAsync.isCancelled then
leftCancelled.atomically (set true)) leftCancelled.atomically (set true)
leftDone.release)
(do (do
try try
Async.sleep 50
return "right" return "right"
finally finally
if ← ContextAsync.isCancelled then if ← ContextAsync.isCancelled then
rightCancelled.atomically (set true)) rightCancelled.atomically (set true))
discard <| MonadAwait.await (← leftDone.acquire).result!
Async.sleep 1000
let left ← leftCancelled.atomically get let left ← leftCancelled.atomically get
let right ← rightCancelled.atomically get let right ← rightCancelled.atomically get
@ -377,28 +359,20 @@ def testComplexWorkflow : IO Unit := do
Async.block do Async.block do
ContextAsync.run do ContextAsync.run do
-- Run multiple concurrent operations
let (a, b) ← ContextAsync.concurrently let (a, b) ← ContextAsync.concurrently
(do (do
Async.sleep 50
results.atomically (modify ("A"::·)) results.atomically (modify ("A"::·))
return 1) return 1)
(do (do
Async.sleep 75
results.atomically (modify ("B"::·)) results.atomically (modify ("B"::·))
return 2) return 2)
-- Additional concurrent task
discard <| ContextAsync.concurrently discard <| ContextAsync.concurrently
(do (do results.atomically (modify ("BG"::·)))
Async.sleep 100 (do results.atomically (modify (s!"Sum:{a+b}"::·)))
results.atomically (modify ("BG"::·)))
(do
Async.sleep 200
results.atomically (modify (s!"Sum:{a+b}"::·)))
let final ← results.atomically get let final ← results.atomically get
IO.println s!"Results: {final.reverse}" IO.println s!"Results: {final.mergeSort}"
/-- /--
info: Results: [A, B, BG, Sum:3] info: Results: [A, B, BG, Sum:3]
@ -447,11 +421,9 @@ def test0 : IO Unit := do
Async.block do Async.block do
ContextAsync.run do ContextAsync.run do
Async.sleep 100
if ← ContextAsync.isCancelled then if ← ContextAsync.isCancelled then
ref.set true ref.set true
IO.sleep 200
IO.println s!"{← ref.get}" IO.println s!"{← ref.get}"
/-- /--
@ -465,13 +437,14 @@ def test1 : IO Unit := do
let ref ← IO.mkRef false let ref ← IO.mkRef false
Async.block do Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do ContextAsync.run do
ContextAsync.background do ContextAsync.background do
Async.sleep 100 ContextAsync.awaitCancellation
if ← ContextAsync.isCancelled then ref.set true
ref.set true done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.sleep 200
IO.println s!"{← ref.get}" IO.println s!"{← ref.get}"
/-- /--
@ -485,14 +458,15 @@ def test2 : IO Unit := do
let ref ← IO.mkRef false let ref ← IO.mkRef false
Async.block do Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do ContextAsync.run do
ContextAsync.background do ContextAsync.background do
ContextAsync.background do ContextAsync.background do
Async.sleep 100 ContextAsync.awaitCancellation
if ← ContextAsync.isCancelled then ref.set true
ref.set true done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.sleep 200
IO.println s!"{← ref.get}" IO.println s!"{← ref.get}"
/-- /--
@ -506,14 +480,16 @@ def test2' : IO Unit := do
let ref ← IO.mkRef false let ref ← IO.mkRef false
Async.block do Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do ContextAsync.run do
Async.background do Async.background do
ContextAsync.background do ContextAsync.background do
Async.sleep 100 ContextAsync.awaitCancellation
if ← ContextAsync.isCancelled then ref.set true
ref.set true done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.sleep 200
IO.println s!"{← ref.get}" IO.println s!"{← ref.get}"
/-- /--
@ -522,26 +498,6 @@ info: true
#guard_msgs in #guard_msgs in
#eval test2' #eval test2'
/-- Test that Async.background in ContextAsync.background is cancelled -/
def test2'' : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
ContextAsync.background do
Async.background do
Async.sleep 100
if ← ContextAsync.isCancelled then
ref.set true
IO.sleep 200
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2''
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/ /-- Test concurrently with first task succeeding immediately, others checking cancellation -/
def testConcurrentlySuccessWithCancellation : IO Unit := do def testConcurrentlySuccessWithCancellation : IO Unit := do
@ -590,34 +546,38 @@ def testConcurrentlyFailWithCancellation : IO Unit := do
let task3Cancelled ← Std.Mutex.new false let task3Cancelled ← Std.Mutex.new false
let results ← Async.block do let results ← Async.block do
ContextAsync.run do let task2Done ← Std.Semaphore.new 0
try let task3Done ← Std.Semaphore.new 0
let result ← ContextAsync.concurrentlyAll #[ let result ← ContextAsync.run do
(do try
-- First task fails immediately let result ← ContextAsync.concurrentlyAll #[
throw (IO.userError "first task failed")), (do
(do -- First task fails immediately
-- Second task waits and checks for cancellation throw (IO.userError "first task failed")),
let res ← Selectable.one #[ (do
.case (← ContextAsync.doneSelector) (fun _ => pure true), -- Second task waits and checks for cancellation
.case (← Selector.sleep 2000) (fun _ => pure false) let res ← Selectable.one #[
] .case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
task2Cancelled.atomically (set (res)) ]
return "second"), task2Cancelled.atomically (set (res))
(do task2Done.release
let res ← Selectable.one #[ return "second"),
.case (← ContextAsync.doneSelector) (fun _ => pure true), (do
.case (← Selector.sleep 2000) (fun _ => pure false) let res ← Selectable.one #[
] .case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
task3Cancelled.atomically (set (res)) ]
return "third") task3Cancelled.atomically (set (res))
] task3Done.release
return Except.ok result return "third")
catch e => ]
Async.sleep 500 return Except.ok result
return Except.error e catch e =>
return Except.error e
discard <| MonadAwait.await (← task2Done.acquire).result!
discard <| MonadAwait.await (← task3Done.acquire).result!
return result
let t2 ← task2Cancelled.atomically get let t2 ← task2Cancelled.atomically get
let t3 ← task3Cancelled.atomically get let t3 ← task3Cancelled.atomically get
@ -668,24 +628,30 @@ Task2 cancelled: false
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
let task2Cancelled ← Std.Mutex.new false let task2Cancelled ← Std.Mutex.new false
try let task2Done ← Std.Semaphore.new 0
Async.block do let err ← Async.block do
ContextAsync.run do let err : Option IO.Error ←
let (_ : (String × String)) ← ContextAsync.concurrently try
(do ContextAsync.run do
-- First task fails immediately let (_ : (String × String)) ← ContextAsync.concurrently
throw (IO.userError "first task failed") : ContextAsync String) (do
(do -- First task fails immediately
-- Second task waits and checks for cancellation throw (IO.userError "first task failed") : ContextAsync String)
let res ← Selectable.one #[ (do
.case (← ContextAsync.doneSelector) (fun _ => pure true), -- Second task waits and checks for cancellation
.case (← Selector.sleep 2000) (fun _ => pure false) let res ← Selectable.one #[
] .case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
task2Cancelled.atomically (set res) ]
return "second") task2Cancelled.atomically (set res)
catch e => task2Done.release
IO.sleep 500 return "second")
pure none
catch e =>
pure (some e)
discard <| MonadAwait.await (← task2Done.acquire).result!
return err
if let some e := err then
let t2 ← task2Cancelled.atomically get let t2 ← task2Cancelled.atomically get
IO.println s!"Error: {e}" IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}" IO.println s!"Task2 cancelled: {t2}"