fix: timing issues and race conditions in ContextAsync.race (#13718)
This PR fixes tests in context_async.lean by removing all the issues with Async.sleep and IO.sleep and improving how ContextAsync.race works.
This commit is contained in:
parent
cc103d8ed6
commit
ed0d50fcf0
2 changed files with 111 additions and 145 deletions
|
|
@ -119,27 +119,6 @@ def concurrently (x : ContextAsync α) (y : ContextAsync β)
|
||||||
concurrentCtx.cancel .cancel
|
concurrentCtx.cancel .cancel
|
||||||
return result
|
return result
|
||||||
|
|
||||||
/--
|
|
||||||
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
|
|
||||||
in its own child context; when either completes, the other is cancelled immediately.
|
|
||||||
-/
|
|
||||||
@[inline, specialize]
|
|
||||||
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
|
|
||||||
(prio := Task.Priority.default) : ContextAsync α := do
|
|
||||||
let parent ← getContext
|
|
||||||
let ctx1 ← CancellationContext.fork parent
|
|
||||||
let ctx2 ← CancellationContext.fork parent
|
|
||||||
|
|
||||||
let task1 ← async (x ctx1) prio
|
|
||||||
let task2 ← async (y ctx2) prio
|
|
||||||
|
|
||||||
let result ← Async.race
|
|
||||||
(await task1 <* ctx2.cancel .cancel)
|
|
||||||
(await task2 <* ctx1.cancel .cancel)
|
|
||||||
prio
|
|
||||||
|
|
||||||
pure result
|
|
||||||
|
|
||||||
/--
|
/--
|
||||||
Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
|
Runs all computations concurrently and collects results in the same order. Each runs in its own child context;
|
||||||
if any computation fails, all others are cancelled and the exception is propagated.
|
if any computation fails, all others are cancelled and the exception is propagated.
|
||||||
|
|
@ -254,6 +233,27 @@ instance [Inhabited α] : Inhabited (ContextAsync α) where
|
||||||
instance : MonadAwait AsyncTask ContextAsync where
|
instance : MonadAwait AsyncTask ContextAsync where
|
||||||
await t := fun _ => await t
|
await t := fun _ => await t
|
||||||
|
|
||||||
|
/--
|
||||||
|
Runs two computations concurrently and returns the result of the first to complete. Each computation runs
|
||||||
|
in its own child context; when either completes, the other is cancelled immediately.
|
||||||
|
-/
|
||||||
|
@[inline, specialize]
|
||||||
|
def race [Inhabited α] (x : ContextAsync α) (y : ContextAsync α)
|
||||||
|
(prio := Task.Priority.default) : ContextAsync α := do
|
||||||
|
let parent ← getContext
|
||||||
|
let ctx1 ← CancellationContext.fork parent
|
||||||
|
let ctx2 ← CancellationContext.fork parent
|
||||||
|
|
||||||
|
let task1 ← async (x ctx1) prio
|
||||||
|
let task2 ← async (y ctx2) prio
|
||||||
|
|
||||||
|
let promise ← IO.Promise.new
|
||||||
|
BaseIO.chainTask task1 fun result => liftM (promise.resolve result) *> ctx2.cancel .cancel
|
||||||
|
BaseIO.chainTask task2 fun result => liftM (promise.resolve result) *> ctx1.cancel .cancel
|
||||||
|
|
||||||
|
let result ← MonadAwait.await promise
|
||||||
|
Async.ofExcept result
|
||||||
|
|
||||||
end ContextAsync
|
end ContextAsync
|
||||||
|
|
||||||
/--
|
/--
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,6 @@ def testIsCancelled : IO Unit := do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
let before ← ContextAsync.isCancelled
|
let before ← ContextAsync.isCancelled
|
||||||
ContextAsync.cancel .cancel
|
ContextAsync.cancel .cancel
|
||||||
Async.sleep 50
|
|
||||||
let after ← ContextAsync.isCancelled
|
let after ← ContextAsync.isCancelled
|
||||||
return (before, after)
|
return (before, after)
|
||||||
|
|
||||||
|
|
@ -26,7 +25,6 @@ def testGetCancellationReason : IO Unit := do
|
||||||
let res ← Async.block do
|
let res ← Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
ContextAsync.cancel (.custom "test reason")
|
ContextAsync.cancel (.custom "test reason")
|
||||||
Async.sleep 50
|
|
||||||
let some reason ← ContextAsync.getCancellationReason
|
let some reason ← ContextAsync.getCancellationReason
|
||||||
| return "ERROR: No reason found"
|
| return "ERROR: No reason found"
|
||||||
return s!"Reason: {reason}"
|
return s!"Reason: {reason}"
|
||||||
|
|
@ -92,8 +90,6 @@ def testSelectorCancellationFail : IO Unit := do
|
||||||
catch err =>
|
catch err =>
|
||||||
return Except.error err
|
return Except.error err
|
||||||
|
|
||||||
Async.sleep 500
|
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
let _ ← received.atomically get
|
let _ ← received.atomically get
|
||||||
|
|
@ -115,12 +111,8 @@ def testConcurrently : IO Unit := do
|
||||||
let (a, b) ← Async.block do
|
let (a, b) ← Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
ContextAsync.concurrently
|
ContextAsync.concurrently
|
||||||
(do
|
(do return 42)
|
||||||
Async.sleep 100
|
(do return "hello")
|
||||||
return 42)
|
|
||||||
(do
|
|
||||||
Async.sleep 150
|
|
||||||
return "hello")
|
|
||||||
|
|
||||||
IO.println s!"Results: {a}, {b}"
|
IO.println s!"Results: {a}, {b}"
|
||||||
|
|
||||||
|
|
@ -135,12 +127,8 @@ def testRace : IO Unit := do
|
||||||
let result ← Async.block do
|
let result ← Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
ContextAsync.race
|
ContextAsync.race
|
||||||
(do
|
(do ContextAsync.awaitCancellation; return "slow")
|
||||||
Async.sleep 50
|
(do return "fast")
|
||||||
return "fast")
|
|
||||||
(do
|
|
||||||
Async.sleep 200
|
|
||||||
return "slow")
|
|
||||||
|
|
||||||
IO.println s!"Winner: {result}"
|
IO.println s!"Winner: {result}"
|
||||||
|
|
||||||
|
|
@ -155,9 +143,9 @@ def testConcurrentlyAll : IO Unit := do
|
||||||
let results ← Async.block do
|
let results ← Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
let tasks := #[
|
let tasks := #[
|
||||||
(do Async.sleep 50; return 1),
|
(do return 1),
|
||||||
(do Async.sleep 100; return 2),
|
(do return 2),
|
||||||
(do Async.sleep 75; return 3)
|
(do return 3)
|
||||||
]
|
]
|
||||||
ContextAsync.concurrentlyAll tasks
|
ContextAsync.concurrentlyAll tasks
|
||||||
|
|
||||||
|
|
@ -203,11 +191,9 @@ def testForkCancellation : IO Unit := do
|
||||||
discard <| ContextAsync.concurrentlyAll #[
|
discard <| ContextAsync.concurrentlyAll #[
|
||||||
(do
|
(do
|
||||||
let child ← ContextAsync.getContext
|
let child ← ContextAsync.getContext
|
||||||
Async.sleep 100
|
|
||||||
child.cancel .cancel
|
child.cancel .cancel
|
||||||
childCancelled.atomically (set true)),
|
childCancelled.atomically (set true)),
|
||||||
(do
|
(do
|
||||||
Async.sleep 200
|
|
||||||
if ← parent.isCancelled then
|
if ← parent.isCancelled then
|
||||||
parentCancelled.atomically (set true))
|
parentCancelled.atomically (set true))
|
||||||
]
|
]
|
||||||
|
|
@ -231,9 +217,7 @@ partial def testNestedFork : IO Unit := do
|
||||||
let sel ← ContextAsync.doneSelector
|
let sel ← ContextAsync.doneSelector
|
||||||
|
|
||||||
let (_, result) ← ContextAsync.concurrently
|
let (_, result) ← ContextAsync.concurrently
|
||||||
(do
|
(do ctx.cancel .deadline)
|
||||||
Async.sleep 100
|
|
||||||
ctx.cancel .deadline)
|
|
||||||
(Selectable.one #[.case sel (fun _ => pure true)])
|
(Selectable.one #[.case sel (fun _ => pure true)])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -254,9 +238,7 @@ def testSelectorCancelled : IO Unit := do
|
||||||
let sel ← Selector.cancelled
|
let sel ← Selector.cancelled
|
||||||
|
|
||||||
let (_, result) ← ContextAsync.concurrently
|
let (_, result) ← ContextAsync.concurrently
|
||||||
(do
|
(do ctx.cancel .shutdown)
|
||||||
Async.sleep 150
|
|
||||||
ctx.cancel .shutdown)
|
|
||||||
(Selectable.one #[.case sel (fun _ => pure true)])
|
(Selectable.one #[.case sel (fun _ => pure true)])
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -280,7 +262,7 @@ def testMonadLift : IO Unit := do
|
||||||
let msg2 : String := "From BaseIO"
|
let msg2 : String := "From BaseIO"
|
||||||
|
|
||||||
-- Lift from Async
|
-- Lift from Async
|
||||||
let _ ← (Async.sleep 50 : Async Unit)
|
let _ ← (pure () : Async Unit)
|
||||||
|
|
||||||
return (msg1, msg2)
|
return (msg1, msg2)
|
||||||
|
|
||||||
|
|
@ -342,24 +324,24 @@ def testRaceWithCancellation : IO Unit := do
|
||||||
let rightCancelled ← Std.Mutex.new false
|
let rightCancelled ← Std.Mutex.new false
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
|
let leftDone ← Std.Semaphore.new 0
|
||||||
ContextAsync.runIn ctx do
|
ContextAsync.runIn ctx do
|
||||||
let _ ← ContextAsync.race
|
let _ ← ContextAsync.race
|
||||||
(do
|
(do
|
||||||
try
|
try
|
||||||
Async.sleep 500
|
ContextAsync.awaitCancellation
|
||||||
return "left"
|
return "left"
|
||||||
finally
|
finally
|
||||||
if ← ContextAsync.isCancelled then
|
if ← ContextAsync.isCancelled then
|
||||||
leftCancelled.atomically (set true))
|
leftCancelled.atomically (set true)
|
||||||
|
leftDone.release)
|
||||||
(do
|
(do
|
||||||
try
|
try
|
||||||
Async.sleep 50
|
|
||||||
return "right"
|
return "right"
|
||||||
finally
|
finally
|
||||||
if ← ContextAsync.isCancelled then
|
if ← ContextAsync.isCancelled then
|
||||||
rightCancelled.atomically (set true))
|
rightCancelled.atomically (set true))
|
||||||
|
discard <| MonadAwait.await (← leftDone.acquire).result!
|
||||||
Async.sleep 1000
|
|
||||||
|
|
||||||
let left ← leftCancelled.atomically get
|
let left ← leftCancelled.atomically get
|
||||||
let right ← rightCancelled.atomically get
|
let right ← rightCancelled.atomically get
|
||||||
|
|
@ -377,28 +359,20 @@ def testComplexWorkflow : IO Unit := do
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
-- Run multiple concurrent operations
|
|
||||||
let (a, b) ← ContextAsync.concurrently
|
let (a, b) ← ContextAsync.concurrently
|
||||||
(do
|
(do
|
||||||
Async.sleep 50
|
|
||||||
results.atomically (modify ("A"::·))
|
results.atomically (modify ("A"::·))
|
||||||
return 1)
|
return 1)
|
||||||
(do
|
(do
|
||||||
Async.sleep 75
|
|
||||||
results.atomically (modify ("B"::·))
|
results.atomically (modify ("B"::·))
|
||||||
return 2)
|
return 2)
|
||||||
|
|
||||||
-- Additional concurrent task
|
|
||||||
discard <| ContextAsync.concurrently
|
discard <| ContextAsync.concurrently
|
||||||
(do
|
(do results.atomically (modify ("BG"::·)))
|
||||||
Async.sleep 100
|
(do results.atomically (modify (s!"Sum:{a+b}"::·)))
|
||||||
results.atomically (modify ("BG"::·)))
|
|
||||||
(do
|
|
||||||
Async.sleep 200
|
|
||||||
results.atomically (modify (s!"Sum:{a+b}"::·)))
|
|
||||||
|
|
||||||
let final ← results.atomically get
|
let final ← results.atomically get
|
||||||
IO.println s!"Results: {final.reverse}"
|
IO.println s!"Results: {final.mergeSort}"
|
||||||
|
|
||||||
/--
|
/--
|
||||||
info: Results: [A, B, BG, Sum:3]
|
info: Results: [A, B, BG, Sum:3]
|
||||||
|
|
@ -447,11 +421,9 @@ def test0 : IO Unit := do
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
Async.sleep 100
|
|
||||||
if ← ContextAsync.isCancelled then
|
if ← ContextAsync.isCancelled then
|
||||||
ref.set true
|
ref.set true
|
||||||
|
|
||||||
IO.sleep 200
|
|
||||||
IO.println s!"{← ref.get}"
|
IO.println s!"{← ref.get}"
|
||||||
|
|
||||||
/--
|
/--
|
||||||
|
|
@ -465,13 +437,14 @@ def test1 : IO Unit := do
|
||||||
let ref ← IO.mkRef false
|
let ref ← IO.mkRef false
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
|
let done ← Std.Semaphore.new 0
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
ContextAsync.background do
|
ContextAsync.background do
|
||||||
Async.sleep 100
|
ContextAsync.awaitCancellation
|
||||||
if ← ContextAsync.isCancelled then
|
ref.set true
|
||||||
ref.set true
|
done.release
|
||||||
|
discard <| MonadAwait.await (← done.acquire).result!
|
||||||
|
|
||||||
IO.sleep 200
|
|
||||||
IO.println s!"{← ref.get}"
|
IO.println s!"{← ref.get}"
|
||||||
|
|
||||||
/--
|
/--
|
||||||
|
|
@ -485,14 +458,15 @@ def test2 : IO Unit := do
|
||||||
let ref ← IO.mkRef false
|
let ref ← IO.mkRef false
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
|
let done ← Std.Semaphore.new 0
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
ContextAsync.background do
|
ContextAsync.background do
|
||||||
ContextAsync.background do
|
ContextAsync.background do
|
||||||
Async.sleep 100
|
ContextAsync.awaitCancellation
|
||||||
if ← ContextAsync.isCancelled then
|
ref.set true
|
||||||
ref.set true
|
done.release
|
||||||
|
discard <| MonadAwait.await (← done.acquire).result!
|
||||||
|
|
||||||
IO.sleep 200
|
|
||||||
IO.println s!"{← ref.get}"
|
IO.println s!"{← ref.get}"
|
||||||
|
|
||||||
/--
|
/--
|
||||||
|
|
@ -506,14 +480,16 @@ def test2' : IO Unit := do
|
||||||
let ref ← IO.mkRef false
|
let ref ← IO.mkRef false
|
||||||
|
|
||||||
Async.block do
|
Async.block do
|
||||||
|
let done ← Std.Semaphore.new 0
|
||||||
ContextAsync.run do
|
ContextAsync.run do
|
||||||
Async.background do
|
Async.background do
|
||||||
ContextAsync.background do
|
ContextAsync.background do
|
||||||
Async.sleep 100
|
ContextAsync.awaitCancellation
|
||||||
if ← ContextAsync.isCancelled then
|
ref.set true
|
||||||
ref.set true
|
done.release
|
||||||
|
|
||||||
|
discard <| MonadAwait.await (← done.acquire).result!
|
||||||
|
|
||||||
IO.sleep 200
|
|
||||||
IO.println s!"{← ref.get}"
|
IO.println s!"{← ref.get}"
|
||||||
|
|
||||||
/--
|
/--
|
||||||
|
|
@ -522,26 +498,6 @@ info: true
|
||||||
#guard_msgs in
|
#guard_msgs in
|
||||||
#eval test2'
|
#eval test2'
|
||||||
|
|
||||||
/-- Test that Async.background in ContextAsync.background is cancelled -/
|
|
||||||
def test2'' : IO Unit := do
|
|
||||||
let ref ← IO.mkRef false
|
|
||||||
|
|
||||||
Async.block do
|
|
||||||
ContextAsync.run do
|
|
||||||
ContextAsync.background do
|
|
||||||
Async.background do
|
|
||||||
Async.sleep 100
|
|
||||||
if ← ContextAsync.isCancelled then
|
|
||||||
ref.set true
|
|
||||||
|
|
||||||
IO.sleep 200
|
|
||||||
IO.println s!"{← ref.get}"
|
|
||||||
|
|
||||||
/--
|
|
||||||
info: true
|
|
||||||
-/
|
|
||||||
#guard_msgs in
|
|
||||||
#eval test2''
|
|
||||||
|
|
||||||
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
|
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
|
||||||
def testConcurrentlySuccessWithCancellation : IO Unit := do
|
def testConcurrentlySuccessWithCancellation : IO Unit := do
|
||||||
|
|
@ -590,34 +546,38 @@ def testConcurrentlyFailWithCancellation : IO Unit := do
|
||||||
let task3Cancelled ← Std.Mutex.new false
|
let task3Cancelled ← Std.Mutex.new false
|
||||||
|
|
||||||
let results ← Async.block do
|
let results ← Async.block do
|
||||||
ContextAsync.run do
|
let task2Done ← Std.Semaphore.new 0
|
||||||
try
|
let task3Done ← Std.Semaphore.new 0
|
||||||
let result ← ContextAsync.concurrentlyAll #[
|
let result ← ContextAsync.run do
|
||||||
(do
|
try
|
||||||
-- First task fails immediately
|
let result ← ContextAsync.concurrentlyAll #[
|
||||||
throw (IO.userError "first task failed")),
|
(do
|
||||||
(do
|
-- First task fails immediately
|
||||||
-- Second task waits and checks for cancellation
|
throw (IO.userError "first task failed")),
|
||||||
let res ← Selectable.one #[
|
(do
|
||||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
-- Second task waits and checks for cancellation
|
||||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
let res ← Selectable.one #[
|
||||||
]
|
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||||
|
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||||
task2Cancelled.atomically (set (res))
|
]
|
||||||
return "second"),
|
task2Cancelled.atomically (set (res))
|
||||||
(do
|
task2Done.release
|
||||||
let res ← Selectable.one #[
|
return "second"),
|
||||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
(do
|
||||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
let res ← Selectable.one #[
|
||||||
]
|
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||||
|
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||||
task3Cancelled.atomically (set (res))
|
]
|
||||||
return "third")
|
task3Cancelled.atomically (set (res))
|
||||||
]
|
task3Done.release
|
||||||
return Except.ok result
|
return "third")
|
||||||
catch e =>
|
]
|
||||||
Async.sleep 500
|
return Except.ok result
|
||||||
return Except.error e
|
catch e =>
|
||||||
|
return Except.error e
|
||||||
|
discard <| MonadAwait.await (← task2Done.acquire).result!
|
||||||
|
discard <| MonadAwait.await (← task3Done.acquire).result!
|
||||||
|
return result
|
||||||
|
|
||||||
let t2 ← task2Cancelled.atomically get
|
let t2 ← task2Cancelled.atomically get
|
||||||
let t3 ← task3Cancelled.atomically get
|
let t3 ← task3Cancelled.atomically get
|
||||||
|
|
@ -668,24 +628,30 @@ Task2 cancelled: false
|
||||||
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
|
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
|
||||||
let task2Cancelled ← Std.Mutex.new false
|
let task2Cancelled ← Std.Mutex.new false
|
||||||
|
|
||||||
try
|
let task2Done ← Std.Semaphore.new 0
|
||||||
Async.block do
|
let err ← Async.block do
|
||||||
ContextAsync.run do
|
let err : Option IO.Error ←
|
||||||
let (_ : (String × String)) ← ContextAsync.concurrently
|
try
|
||||||
(do
|
ContextAsync.run do
|
||||||
-- First task fails immediately
|
let (_ : (String × String)) ← ContextAsync.concurrently
|
||||||
throw (IO.userError "first task failed") : ContextAsync String)
|
(do
|
||||||
(do
|
-- First task fails immediately
|
||||||
-- Second task waits and checks for cancellation
|
throw (IO.userError "first task failed") : ContextAsync String)
|
||||||
let res ← Selectable.one #[
|
(do
|
||||||
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
-- Second task waits and checks for cancellation
|
||||||
.case (← Selector.sleep 2000) (fun _ => pure false)
|
let res ← Selectable.one #[
|
||||||
]
|
.case (← ContextAsync.doneSelector) (fun _ => pure true),
|
||||||
|
.case (← Selector.sleep 2000) (fun _ => pure false)
|
||||||
task2Cancelled.atomically (set res)
|
]
|
||||||
return "second")
|
task2Cancelled.atomically (set res)
|
||||||
catch e =>
|
task2Done.release
|
||||||
IO.sleep 500
|
return "second")
|
||||||
|
pure none
|
||||||
|
catch e =>
|
||||||
|
pure (some e)
|
||||||
|
discard <| MonadAwait.await (← task2Done.acquire).result!
|
||||||
|
return err
|
||||||
|
if let some e := err then
|
||||||
let t2 ← task2Cancelled.atomically get
|
let t2 ← task2Cancelled.atomically get
|
||||||
IO.println s!"Error: {e}"
|
IO.println s!"Error: {e}"
|
||||||
IO.println s!"Task2 cancelled: {t2}"
|
IO.println s!"Task2 cancelled: {t2}"
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue