fix: timing issues and race conditions in ContextAsync.race (#13718)
This PR fixes tests in context_async.lean by removing all the issues with Async.sleep and IO.sleep and improving how ContextAsync.race works.
This commit is contained in:
parent
cc103d8ed6
commit
ed0d50fcf0
2 changed files with 111 additions and 145 deletions
|
|
@ -119,27 +119,6 @@ def concurrently (x : ContextAsync α) (y : ContextAsync β)
|
|||
concurrentCtx.cancel .cancel
|
||||
return result
|
||||
|
||||
/--
|
||||
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
|
||||
in its own child context; when either completes, the other is cancelled immediately.
|
||||
-/
|
||||
@[inline, specialize]
|
||||
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
|
||||
(prio := Task.Priority.default) : ContextAsync α := do
|
||||
let parent ← getContext
|
||||
let ctx1 ← CancellationContext.fork parent
|
||||
let ctx2 ← CancellationContext.fork parent
|
||||
|
||||
let task1 ← async (x ctx1) prio
|
||||
let task2 ← async (y ctx2) prio
|
||||
|
||||
let result ← Async.race
|
||||
(await task1 <* ctx2.cancel .cancel)
|
||||
(await task2 <* ctx1.cancel .cancel)
|
||||
prio
|
||||
|
||||
pure result
|
||||
|
||||
/--
|
||||
Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
|
||||
if any computation fails, all others are cancelled and the exception is propagated.
|
||||
|
|
@ -254,6 +233,27 @@ instance [Inhabited α] : Inhabited (ContextAsync α) where
|
|||
instance : MonadAwait AsyncTask ContextAsync where
|
||||
await t := fun _ => await t
|
||||
|
||||
/--
|
||||
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
|
||||
in its own child context; when either completes, the other is cancelled immediately.
|
||||
-/
|
||||
@[inline, specialize]
|
||||
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
|
||||
(prio := Task.Priority.default) : ContextAsync α := do
|
||||
let parent ← getContext
|
||||
let ctx1 ← CancellationContext.fork parent
|
||||
let ctx2 ← CancellationContext.fork parent
|
||||
|
||||
let task1 ← async (x ctx1) prio
|
||||
let task2 ← async (y ctx2) prio
|
||||
|
||||
let promise ← IO.Promise.new
|
||||
BaseIO.chainTask task1 fun result => liftM (promise.resolve result) *> ctx2.cancel .cancel
|
||||
BaseIO.chainTask task2 fun result => liftM (promise.resolve result) *> ctx1.cancel .cancel
|
||||
|
||||
let result ← MonadAwait.await promise
|
||||
Async.ofExcept result
|
||||
|
||||
end ContextAsync
|
||||
|
||||
/--
|
||||
|
|
|
|||
|
|
@ -9,7 +9,6 @@ def testIsCancelled : IO Unit := do
|
|||
ContextAsync.run do
|
||||
let before ← ContextAsync.isCancelled
|
||||
ContextAsync.cancel .cancel
|
||||
Async.sleep 50
|
||||
let after ← ContextAsync.isCancelled
|
||||
return (before, after)
|
||||
|
||||
|
|
@ -26,7 +25,6 @@ def testGetCancellationReason : IO Unit := do
|
|||
let res ← Async.block do
|
||||
ContextAsync.run do
|
||||
ContextAsync.cancel (.custom "test reason")
|
||||
Async.sleep 50
|
||||
let some reason ← ContextAsync.getCancellationReason
|
||||
| return "ERROR: No reason found"
|
||||
return s!"Reason: {reason}"
|
||||
|
|
@ -92,8 +90,6 @@ def testSelectorCancellationFail : IO Unit := do
|
|||
catch err =>
|
||||
return Except.error err
|
||||
|
||||
Async.sleep 500
|
||||
|
||||
return result
|
||||
|
||||
let _ ← received.atomically get
|
||||
|
|
@ -115,12 +111,8 @@ def testConcurrently : IO Unit := do
|
|||
let (a, b) ← Async.block do
|
||||
ContextAsync.run do
|
||||
ContextAsync.concurrently
|
||||
(do
|
||||
Async.sleep 100
|
||||
return 42)
|
||||
(do
|
||||
Async.sleep 150
|
||||
return "hello")
|
||||
(do return 42)
|
||||
(do return "hello")
|
||||
|
||||
IO.println s!"Results: {a}, {b}"
|
||||
|
||||
|
|
@ -135,12 +127,8 @@ def testRace : IO Unit := do
|
|||
let result ← Async.block do
|
||||
ContextAsync.run do
|
||||
ContextAsync.race
|
||||
(do
|
||||
Async.sleep 50
|
||||
return "fast")
|
||||
(do
|
||||
Async.sleep 200
|
||||
return "slow")
|
||||
(do ContextAsync.awaitCancellation; return "slow")
|
||||
(do return "fast")
|
||||
|
||||
IO.println s!"Winner: {result}"
|
||||
|
||||
|
|
@ -155,9 +143,9 @@ def testConcurrentlyAll : IO Unit := do
|
|||
let results ← Async.block do
|
||||
ContextAsync.run do
|
||||
let tasks := #[
|
||||
(do Async.sleep 50; return 1),
|
||||
(do Async.sleep 100; return 2),
|
||||
(do Async.sleep 75; return 3)
|
||||
(do return 1),
|
||||
(do return 2),
|
||||
(do return 3)
|
||||
]
|
||||
ContextAsync.concurrentlyAll tasks
|
||||
|
||||
|
|
@ -203,11 +191,9 @@ def testForkCancellation : IO Unit := do
|
|||
discard <| ContextAsync.concurrentlyAll #[
|
||||
(do
|
||||
let child ← ContextAsync.getContext
|
||||
Async.sleep 100
|
||||
child.cancel .cancel
|
||||
childCancelled.atomically (set true)),
|
||||
(do
|
||||
Async.sleep 200
|
||||
if ← parent.isCancelled then
|
||||
parentCancelled.atomically (set true))
|
||||
]
|
||||
|
|
@ -231,9 +217,7 @@ partial def testNestedFork : IO Unit := do
|
|||
let sel ← ContextAsync.doneSelector
|
||||
|
||||
let (_, result) ← ContextAsync.concurrently
|
||||
(do
|
||||
Async.sleep 100
|
||||
ctx.cancel .deadline)
|
||||
(do ctx.cancel .deadline)
|
||||
(Selectable.one #[.case sel (fun _ => pure true)])
|
||||
|
||||
return result
|
||||
|
|
@ -254,9 +238,7 @@ def testSelectorCancelled : IO Unit := do
|
|||
let sel ← Selector.cancelled
|
||||
|
||||
let (_, result) ← ContextAsync.concurrently
|
||||
(do
|
||||
Async.sleep 150
|
||||
ctx.cancel .shutdown)
|
||||
(do ctx.cancel .shutdown)
|
||||
(Selectable.one #[.case sel (fun _ => pure true)])
|
||||
|
||||
return result
|
||||
|
|
@ -280,7 +262,7 @@ def testMonadLift : IO Unit := do
|
|||
let msg2 : String := "From BaseIO"
|
||||
|
||||
-- Lift from Async
|
||||
let _ ← (Async.sleep 50 : Async Unit)
|
||||
let _ ← (pure () : Async Unit)
|
||||
|
||||
return (msg1, msg2)
|
||||
|
||||
|
|
@ -342,24 +324,24 @@ def testRaceWithCancellation : IO Unit := do
|
|||
let rightCancelled ← Std.Mutex.new false
|
||||
|
||||
Async.block do
|
||||
let leftDone ← Std.Semaphore.new 0
|
||||
ContextAsync.runIn ctx do
|
||||
let _ ← ContextAsync.race
|
||||
(do
|
||||
try
|
||||
Async.sleep 500
|
||||
ContextAsync.awaitCancellation
|
||||
return "left"
|
||||
finally
|
||||
if ← ContextAsync.isCancelled then
|
||||
leftCancelled.atomically (set true))
|
||||
leftCancelled.atomically (set true)
|
||||
leftDone.release)
|
||||
(do
|
||||
try
|
||||
Async.sleep 50
|
||||
return "right"
|
||||
finally
|
||||
if ← ContextAsync.isCancelled then
|
||||
rightCancelled.atomically (set true))
|
||||
|
||||
Async.sleep 1000
|
||||
discard <| MonadAwait.await (← leftDone.acquire).result!
|
||||
|
||||
let left ← leftCancelled.atomically get
|
||||
let right ← rightCancelled.atomically get
|
||||
|
|
@ -377,28 +359,20 @@ def testComplexWorkflow : IO Unit := do
|
|||
|
||||
Async.block do
|
||||
ContextAsync.run do
|
||||
-- Run multiple concurrent operations
|
||||
let (a, b) ← ContextAsync.concurrently
|
||||
(do
|
||||
Async.sleep 50
|
||||
results.atomically (modify ("A"::·))
|
||||
return 1)
|
||||
(do
|
||||
Async.sleep 75
|
||||
results.atomically (modify ("B"::·))
|
||||
return 2)
|
||||
|
||||
-- Additional concurrent task
|
||||
discard <| ContextAsync.concurrently
|
||||
(do
|
||||
Async.sleep 100
|
||||
results.atomically (modify ("BG"::·)))
|
||||
(do
|
||||
Async.sleep 200
|
||||
results.atomically (modify (s!"Sum:{a+b}"::·)))
|
||||
(do results.atomically (modify ("BG"::·)))
|
||||
(do results.atomically (modify (s!"Sum:{a+b}"::·)))
|
||||
|
||||
let final ← results.atomically get
|
||||
IO.println s!"Results: {final.reverse}"
|
||||
IO.println s!"Results: {final.mergeSort}"
|
||||
|
||||
/--
|
||||
info: Results: [A, B, BG, Sum:3]
|
||||
|
|
@ -447,11 +421,9 @@ def test0 : IO Unit := do
|
|||
|
||||
Async.block do
|
||||
ContextAsync.run do
|
||||
Async.sleep 100
|
||||
if ← ContextAsync.isCancelled then
|
||||
ref.set true
|
||||
|
||||
IO.sleep 200
|
||||
IO.println s!"{← ref.get}"
|
||||
|
||||
/--
|
||||
|
|
@ -465,13 +437,14 @@ def test1 : IO Unit := do
|
|||
let ref ← IO.mkRef false
|
||||
|
||||
Async.block do
|
||||
let done ← Std.Semaphore.new 0
|
||||
ContextAsync.run do
|
||||
ContextAsync.background do
|
||||
Async.sleep 100
|
||||
if ← ContextAsync.isCancelled then
|
||||
ref.set true
|
||||
ContextAsync.awaitCancellation
|
||||
ref.set true
|
||||
done.release
|
||||
discard <| MonadAwait.await (← done.acquire).result!
|
||||
|
||||
IO.sleep 200
|
||||
IO.println s!"{← ref.get}"
|
||||
|
||||
/--
|
||||
|
|
@ -485,14 +458,15 @@ def test2 : IO Unit := do
|
|||
let ref ← IO.mkRef false
|
||||
|
||||
Async.block do
|
||||
let done ← Std.Semaphore.new 0
|
||||
ContextAsync.run do
|
||||
ContextAsync.background do
|
||||
ContextAsync.background do
|
||||
Async.sleep 100
|
||||
if ← ContextAsync.isCancelled then
|
||||
ref.set true
|
||||
ContextAsync.awaitCancellation
|
||||
ref.set true
|
||||
done.release
|
||||
discard <| MonadAwait.await (← done.acquire).result!
|
||||
|
||||
IO.sleep 200
|
||||
IO.println s!"{← ref.get}"
|
||||
|
||||
/--
|
||||
|
|
@ -506,14 +480,16 @@ def test2' : IO Unit := do
|
|||
let ref ← IO.mkRef false
|
||||
|
||||
Async.block do
|
||||
let done ← Std.Semaphore.new 0
|
||||
ContextAsync.run do
|
||||
Async.background do
|
||||
ContextAsync.background do
|
||||
Async.sleep 100
|
||||
if ← ContextAsync.isCancelled then
|
||||
ref.set true
|
||||
ContextAsync.awaitCancellation
|
||||
ref.set true
|
||||
done.release
|
||||
|
||||
discard <| MonadAwait.await (← done.acquire).result!
|
||||
|
||||
IO.sleep 200
|
||||
IO.println s!"{← ref.get}"
|
||||
|
||||
/--
|
||||
|
|
@ -522,26 +498,6 @@ info: true
|
|||
#guard_msgs in
|
||||
#eval test2'
|
||||
|
||||
/-- Test that Async.background in ContextAsync.background is cancelled -/
|
||||
def test2'' : IO Unit := do
|
||||
let ref ← IO.mkRef false
|
||||
|
||||
Async.block do
|
||||
ContextAsync.run do
|
||||
ContextAsync.background do
|
||||
Async.background do
|
||||
Async.sleep 100
|
||||
if ← ContextAsync.isCancelled then
|
||||
ref.set true
|
||||
|
||||
IO.sleep 200
|
||||
IO.println s!"{← ref.get}"
|
||||
|
||||
/--
|
||||
info: true
|
||||
-/
|
||||
#guard_msgs in
|
||||
#eval test2''
|
||||
|
||||
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
|
||||
def testConcurrentlySuccessWithCancellation : IO Unit := do
|
||||
|
|
@ -590,34 +546,38 @@ def testConcurrentlyFailWithCancellation : IO Unit := do
|
|||
let task3Cancelled ← Std.Mutex.new false
|
||||
|
||||
let results ← Async.block do
|
||||
ContextAsync.run do
|
||||
try
|
||||
let result ← ContextAsync.concurrentlyAll #[
|
||||
(do
|
||||
-- First task fails immediately
|
||||
throw (IO.userError "first task failed")),
|
||||
(do
|
||||
-- Second task waits and checks for cancellation
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
|
||||
task2Cancelled.atomically (set (res))
|
||||
return "second"),
|
||||
(do
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
|
||||
task3Cancelled.atomically (set (res))
|
||||
return "third")
|
||||
]
|
||||
return Except.ok result
|
||||
catch e =>
|
||||
Async.sleep 500
|
||||
return Except.error e
|
||||
let task2Done ← Std.Semaphore.new 0
|
||||
let task3Done ← Std.Semaphore.new 0
|
||||
let result ← ContextAsync.run do
|
||||
try
|
||||
let result ← ContextAsync.concurrentlyAll #[
|
||||
(do
|
||||
-- First task fails immediately
|
||||
throw (IO.userError "first task failed")),
|
||||
(do
|
||||
-- Second task waits and checks for cancellation
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
task2Cancelled.atomically (set (res))
|
||||
task2Done.release
|
||||
return "second"),
|
||||
(do
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
task3Cancelled.atomically (set (res))
|
||||
task3Done.release
|
||||
return "third")
|
||||
]
|
||||
return Except.ok result
|
||||
catch e =>
|
||||
return Except.error e
|
||||
discard <| MonadAwait.await (← task2Done.acquire).result!
|
||||
discard <| MonadAwait.await (← task3Done.acquire).result!
|
||||
return result
|
||||
|
||||
let t2 ← task2Cancelled.atomically get
|
||||
let t3 ← task3Cancelled.atomically get
|
||||
|
|
@ -668,24 +628,30 @@ Task2 cancelled: false
|
|||
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
|
||||
let task2Cancelled ← Std.Mutex.new false
|
||||
|
||||
try
|
||||
Async.block do
|
||||
ContextAsync.run do
|
||||
let (_ : (String × String)) ← ContextAsync.concurrently
|
||||
(do
|
||||
-- First task fails immediately
|
||||
throw (IO.userError "first task failed") : ContextAsync String)
|
||||
(do
|
||||
-- Second task waits and checks for cancellation
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
|
||||
task2Cancelled.atomically (set res)
|
||||
return "second")
|
||||
catch e =>
|
||||
IO.sleep 500
|
||||
let task2Done ← Std.Semaphore.new 0
|
||||
let err ← Async.block do
|
||||
let err : Option IO.Error ←
|
||||
try
|
||||
ContextAsync.run do
|
||||
let (_ : (String × String)) ← ContextAsync.concurrently
|
||||
(do
|
||||
-- First task fails immediately
|
||||
throw (IO.userError "first task failed") : ContextAsync String)
|
||||
(do
|
||||
-- Second task waits and checks for cancellation
|
||||
let res ← Selectable.one #[
|
||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||
]
|
||||
task2Cancelled.atomically (set res)
|
||||
task2Done.release
|
||||
return "second")
|
||||
pure none
|
||||
catch e =>
|
||||
pure (some e)
|
||||
discard <| MonadAwait.await (← task2Done.acquire).result!
|
||||
return err
|
||||
if let some e := err then
|
||||
let t2 ← task2Cancelled.atomically get
|
||||
IO.println s!"Error: {e}"
|
||||
IO.println s!"Task2 cancelled: {t2}"
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue