lean4-htt/tests/elab/context_async.lean
Sofia Rodrigues ed0d50fcf0
fix: timing issues and race conditions in ContextAsync.race (#13718)
This PR fixes tests in context_async.lean by removing all the issues
with Async.sleep and IO.sleep and improving how ContextAsync.race works.
2026-05-13 01:25:01 +00:00

664 lines
17 KiB
Text
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import Std.Async
import Std.Sync
open Std Async
/-- Test ContextAsync cancellation check -/
def testIsCancelled : IO Unit := do
let (before, after) ← Async.block do
ContextAsync.run do
let before ← ContextAsync.isCancelled
ContextAsync.cancel .cancel
let after ← ContextAsync.isCancelled
return (before, after)
IO.println s!"Before: {before}, After: {after}"
/--
info: Before: false, After: true
-/
#guard_msgs in
#eval testIsCancelled
/-- Test ContextAsync cancellation reason -/
def testGetCancellationReason : IO Unit := do
let res ← Async.block do
ContextAsync.run do
ContextAsync.cancel (.custom "test reason")
let some reason ← ContextAsync.getCancellationReason
| return "ERROR: No reason found"
return s!"Reason: {reason}"
IO.println res
/--
info: Reason: custom("test reason")
-/
#guard_msgs in
#eval testGetCancellationReason
/-- Test awaitCancellation -/
def testAwaitCancellation : IO Unit := do
let received ← Std.Mutex.new false
Async.block do
let started ← Std.Mutex.new false
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
ContextAsync.awaitCancellation
received.atomically (set true))
(do
-- Wait for task to start
while !(← started.atomically get) do
Async.sleep 10
Async.sleep 100
ContextAsync.cancel .shutdown)
Async.sleep 200
let _ ← received.atomically get
IO.println "Cancellation received"
def testSelectorCancellationFail : IO Unit := do
let received ← Std.Mutex.new false
let result ← Async.block do
let ctx ← Std.CancellationContext.new
let started ← Std.Mutex.new false
let result ← do
try
ContextAsync.runIn ctx do
discard <| ContextAsync.concurrently
(do
started.atomically (set true)
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
received.atomically (set res))
(do
throw (.userError "failed")
return ())
return Except.ok ()
catch err =>
return Except.error err
return result
let _ ← received.atomically get
IO.println "Cancellation received"
if let Except.error err := result then
throw err
/--
info: Cancellation received
---
error: failed
-/
#guard_msgs in
#eval testSelectorCancellationFail
/-- Test concurrently with both tasks succeeding -/
def testConcurrently : IO Unit := do
let (a, b) ← Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do return 42)
(do return "hello")
IO.println s!"Results: {a}, {b}"
/--
info: Results: 42, hello
-/
#guard_msgs in
#eval testConcurrently
/-- Test race with first task winning -/
def testRace : IO Unit := do
let result ← Async.block do
ContextAsync.run do
ContextAsync.race
(do ContextAsync.awaitCancellation; return "slow")
(do return "fast")
IO.println s!"Winner: {result}"
/--
info: Winner: fast
-/
#guard_msgs in
#eval testRace
/-- Test concurrentlyAll -/
def testConcurrentlyAll : IO Unit := do
let results ← Async.block do
ContextAsync.run do
let tasks := #[
(do return 1),
(do return 2),
(do return 3)
]
ContextAsync.concurrentlyAll tasks
IO.println s!"All results: {results}"
/--
info: All results: #[1, 2, 3]
-/
#guard_msgs in
#eval testConcurrentlyAll
/-- Test background task with cancellation -/
def testBackground : IO Unit := do
let counter ← Std.Mutex.new 0
Async.block do
ContextAsync.run do
discard <| ContextAsync.concurrently
(do
for _ in [0:10] do
if ← ContextAsync.isCancelled then
break
counter.atomically (modify (· + 1))
Async.sleep 50)
(do
-- Let it run for a bit
Async.sleep 150
ContextAsync.cancel .cancel)
Async.sleep 200
let final ← counter.atomically get
IO.println s!"Counter reached: {final}"
/-- Test fork cancellation isolation -/
def testForkCancellation : IO Unit := do
let parent ← Std.CancellationContext.new
let childCancelled ← Std.Mutex.new false
let parentCancelled ← Std.Mutex.new false
Async.block do
ContextAsync.runIn parent do
discard <| ContextAsync.concurrentlyAll #[
(do
let child ← ContextAsync.getContext
child.cancel .cancel
childCancelled.atomically (set true)),
(do
if ← parent.isCancelled then
parentCancelled.atomically (set true))
]
let childWasCancelled ← childCancelled.atomically get
let parentWasCancelled ← parentCancelled.atomically get
IO.println s!"Child cancelled: {childWasCancelled}, Parent cancelled: {parentWasCancelled}"
/--
info: Child cancelled: true, Parent cancelled: false
-/
#guard_msgs in
#eval testForkCancellation
/-- Test doneSelector -/
partial def testNestedFork : IO Unit := do
let res ← Async.block do
ContextAsync.run do
let ctx ← ContextAsync.getContext
let sel ← ContextAsync.doneSelector
let (_, result) ← ContextAsync.concurrently
(do ctx.cancel .deadline)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Done selector triggered: {res}"
/--
info: Done selector triggered: true
-/
#guard_msgs in
#eval testNestedFork
/-- Test Selector.cancelled -/
def testSelectorCancelled : IO Unit := do
let res ← Async.block do
ContextAsync.run do
let ctx ← ContextAsync.getContext
let sel ← Selector.cancelled
let (_, result) ← ContextAsync.concurrently
(do ctx.cancel .shutdown)
(Selectable.one #[.case sel (fun _ => pure true)])
return result
IO.println s!"Selector.cancelled triggered: {res}"
/--
info: Selector.cancelled triggered: true
-/
#guard_msgs in
#eval testSelectorCancelled
/-- Test MonadLift instances -/
def testMonadLift : IO Unit := do
let (msg1, msg2) ← Async.block do
ContextAsync.run do
-- Lift from IO
let msg1 : String := "From IO"
-- Lift from BaseIO
let msg2 : String := "From BaseIO"
-- Lift from Async
let _ ← (pure () : Async Unit)
return (msg1, msg2)
IO.println msg1
IO.println msg2
IO.println "All lifts work"
/--
info: From IO
From BaseIO
All lifts work
-/
#guard_msgs in
#eval testMonadLift
/-- Test exception handling in ContextAsync -/
def testExceptionHandling : IO Unit := do
let res ← Async.block do
ContextAsync.run do
try
throw (IO.userError "test error")
return "Should not reach here"
catch e =>
return s!"Caught: {e}"
IO.println res
/--
info: Caught: test error
-/
#guard_msgs in
#eval testExceptionHandling
/-- Test tryFinally in ContextAsync -/
def testTryFinally : IO Unit := do
let cleaned ← Std.Mutex.new false
Async.block do
ContextAsync.run do
try
ContextAsync.cancel .cancel
ContextAsync.awaitCancellation
finally
cleaned.atomically (set true)
let wasCleanedUp ← cleaned.atomically get
IO.println s!"Cleanup ran: {wasCleanedUp}"
/--
info: Cleanup ran: true
-/
#guard_msgs in
#eval testTryFinally
/-- Test race with cancellation -/
def testRaceWithCancellation : IO Unit := do
let ctx ← Std.CancellationContext.new
let leftCancelled ← Std.Mutex.new false
let rightCancelled ← Std.Mutex.new false
Async.block do
let leftDone ← Std.Semaphore.new 0
ContextAsync.runIn ctx do
let _ ← ContextAsync.race
(do
try
ContextAsync.awaitCancellation
return "left"
finally
if ← ContextAsync.isCancelled then
leftCancelled.atomically (set true)
leftDone.release)
(do
try
return "right"
finally
if ← ContextAsync.isCancelled then
rightCancelled.atomically (set true))
discard <| MonadAwait.await (← leftDone.acquire).result!
let left ← leftCancelled.atomically get
let right ← rightCancelled.atomically get
IO.println s!"Left cancelled: {left}, Right cancelled: {right}"
/--
info: Left cancelled: true, Right cancelled: false
-/
#guard_msgs in
#eval testRaceWithCancellation
/-- Test complex concurrent workflow -/
def testComplexWorkflow : IO Unit := do
let results ← Std.Mutex.new ([] : List String)
Async.block do
ContextAsync.run do
let (a, b) ← ContextAsync.concurrently
(do
results.atomically (modify ("A"::·))
return 1)
(do
results.atomically (modify ("B"::·))
return 2)
discard <| ContextAsync.concurrently
(do results.atomically (modify ("BG"::·)))
(do results.atomically (modify (s!"Sum:{a+b}"::·)))
let final ← results.atomically get
IO.println s!"Results: {final.mergeSort}"
/--
info: Results: [A, B, BG, Sum:3]
-/
#guard_msgs in
#eval testComplexWorkflow
def testConcurrentlyAllException : IO Unit := do
let ref ← IO.mkRef ""
try
Async.block do
ContextAsync.run do
let tasks := #[
(do
Async.sleep 1000
if ← ContextAsync.isCancelled then
ref.set "cancelled"
return
else
ref.set "not cancelled"
Async.sleep 500
if ← ContextAsync.isCancelled then
ref.modify (· ++ ", cancelled")
else
ref.modify (· ++ ", not cancelled")),
(do
Async.sleep 250
throw (IO.userError "Error: Hello"))
]
discard <| ContextAsync.concurrentlyAll tasks
finally
IO.println (← ref.get)
/--
info: cancelled
---
error: Error: Hello
-/
#guard_msgs in
#eval testConcurrentlyAllException
/-- Test that tasks in ContextAsync.run are not cancelled when run completes -/
def test0 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
ContextAsync.run do
if ← ContextAsync.isCancelled then
ref.set true
IO.println s!"{← ref.get}"
/--
info: false
-/
#guard_msgs in
#eval test0
/-- Test that background tasks are cancelled when ContextAsync.run completes -/
def test1 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do
ContextAsync.background do
ContextAsync.awaitCancellation
ref.set true
done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test1
/-- Test that nested background tasks (ContextAsync.background in ContextAsync.background) are cancelled -/
def test2 : IO Unit := do
let ref ← IO.mkRef false
Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do
ContextAsync.background do
ContextAsync.background do
ContextAsync.awaitCancellation
ref.set true
done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2
/-- Test that ContextAsync.background in Async.background is cancelled -/
def test2' : IO Unit := do
let ref ← IO.mkRef false
Async.block do
let done ← Std.Semaphore.new 0
ContextAsync.run do
Async.background do
ContextAsync.background do
ContextAsync.awaitCancellation
ref.set true
done.release
discard <| MonadAwait.await (← done.acquire).result!
IO.println s!"{← ref.get}"
/--
info: true
-/
#guard_msgs in
#eval test2'
/-- Test concurrently with first task succeeding immediately, others checking cancellation -/
def testConcurrentlySuccessWithCancellation : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let task3Cancelled ← Std.Mutex.new false
let results ← Async.block do
ContextAsync.run do
ContextAsync.concurrentlyAll #[
(do
return "first"),
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
return "second"),
(do
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
return "third")
]
let t2 ← task2Cancelled.atomically get
let t3 ← task3Cancelled.atomically get
IO.println s!"Results: {results}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Results: #[first, second, third]
Task2 cancelled: false, Task3 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation
/-- Test concurrently with first task failing, others checking for cancellation -/
def testConcurrentlyFailWithCancellation : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let task3Cancelled ← Std.Mutex.new false
let results ← Async.block do
let task2Done ← Std.Semaphore.new 0
let task3Done ← Std.Semaphore.new 0
let result ← ContextAsync.run do
try
let result ← ContextAsync.concurrentlyAll #[
(do
-- First task fails immediately
throw (IO.userError "first task failed")),
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set (res))
task2Done.release
return "second"),
(do
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task3Cancelled.atomically (set (res))
task3Done.release
return "third")
]
return Except.ok result
catch e =>
return Except.error e
discard <| MonadAwait.await (← task2Done.acquire).result!
discard <| MonadAwait.await (← task3Done.acquire).result!
return result
let t2 ← task2Cancelled.atomically get
let t3 ← task3Cancelled.atomically get
match results with
| .ok results => IO.println s!"Results: {results}"
| .error e => IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}, Task3 cancelled: {t3}"
/--
info: Error: first task failed
Task2 cancelled: true, Task3 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation
/-- Test concurrently with both tasks succeeding, checking cancellation status -/
def testConcurrentlySuccessWithCancellation2Tasks : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let (r1, r2) ← Async.block do
ContextAsync.run do
ContextAsync.concurrently
(do return "first")
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 500) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
return "second")
let t2 ← task2Cancelled.atomically get
IO.println s!"Results: {r1}, {r2}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Results: first, second
Task2 cancelled: false
-/
#guard_msgs in
#eval testConcurrentlySuccessWithCancellation2Tasks
/-- Test concurrently with first task failing, second task checking for cancellation -/
def testConcurrentlyFailWithCancellation2Tasks : IO Unit := do
let task2Cancelled ← Std.Mutex.new false
let task2Done ← Std.Semaphore.new 0
let err ← Async.block do
let err : Option IO.Error ←
try
ContextAsync.run do
let (_ : (String × String)) ← ContextAsync.concurrently
(do
-- First task fails immediately
throw (IO.userError "first task failed") : ContextAsync String)
(do
-- Second task waits and checks for cancellation
let res ← Selectable.one #[
.case (← ContextAsync.doneSelector) (fun _ => pure true),
.case (← Selector.sleep 2000) (fun _ => pure false)
]
task2Cancelled.atomically (set res)
task2Done.release
return "second")
pure none
catch e =>
pure (some e)
discard <| MonadAwait.await (← task2Done.acquire).result!
return err
if let some e := err then
let t2 ← task2Cancelled.atomically get
IO.println s!"Error: {e}"
IO.println s!"Task2 cancelled: {t2}"
/--
info: Error: first task failed
Task2 cancelled: true
-/
#guard_msgs in
#eval testConcurrentlyFailWithCancellation2Tasks