fix: revert the waitAny refactoring (#11000)

This PR fixes a memleak caused by the Lean based `IO.waitAny`
implementation by reverting it.

This the faulty Lean implementation:
```lean
def IO.waitAny (tasks : @& List (Task α)) (h : tasks.length > 0 := by exact Nat.zero_lt_succ _) :
    BaseIO α := do
  have : Nonempty α := ⟨tasks[0].get⟩
  let promise : IO.Promise α ← IO.Promise.new
  tasks.forM <| fun t => BaseIO.chainTask (sync := true) t promise.resolve
  return promise.result!.get
```
In a situation where we call this function repeatedly in a loop with a
pair of tasks `[t1, t2]`
where `t2` is a long lived task that we pass every time and `t1` is
fresh a short lived task, `t2` will
accumlate more and more children from `BaseIO.chainTask` that fill
memory over time. The old C++
implementation did not have this issue so we are reverting.
This commit is contained in:
Henrik Böving 2025-10-29 09:27:16 +01:00 committed by GitHub
parent 2497cf0b40
commit 2cfd980528
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 43 additions and 11 deletions

View file

@ -557,6 +557,13 @@ Waits for the task to finish, then returns its result.
@[extern "lean_io_wait"] opaque wait (t : Task α) : BaseIO α :=
return t.get
/--
Waits until any of the tasks in the list has finished, then return its result.
-/
@[extern "lean_io_wait_any"] opaque waitAny (tasks : @& List (Task α))
(h : tasks.length > 0 := by exact Nat.zero_lt_succ _) : BaseIO α :=
return tasks[0].get
/--
Returns the number of _heartbeats_ that have occurred during the current thread's execution. The
heartbeat count is the number of “small” memory allocations performed in a thread.

View file

@ -12,17 +12,6 @@ public import Init.System.Promise
public section
/--
Waits until any of the tasks in the list has finished, then return its result.
-/
@[noinline]
def IO.waitAny (tasks : @& List (Task α)) (h : tasks.length > 0 := by exact Nat.zero_lt_succ _) :
BaseIO α := do
have : Nonempty α := ⟨tasks[0].get⟩
let promise : IO.Promise α ← IO.Promise.new
tasks.forM <| fun t => BaseIO.chainTask (sync := true) t promise.resolve
return promise.result!.get
namespace Task
/--

View file

@ -1225,6 +1225,8 @@ LEAN_EXPORT bool lean_io_check_canceled_core(void);
LEAN_EXPORT void lean_io_cancel_core(b_lean_obj_arg t);
/* primitive for implementing `IO.getTaskState : Task a -> IO TaskState` */
LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_lean_obj_arg t);
/* primitive for implementing `IO.waitAny : List (Task a) -> IO (Task a)` */
LEAN_EXPORT b_lean_obj_res lean_io_wait_any_core(b_lean_obj_arg task_list);
/* External objects */

View file

@ -1554,6 +1554,13 @@ extern "C" LEAN_EXPORT obj_res lean_io_wait(obj_arg t) {
return lean_task_get_own(t);
}
extern "C" LEAN_EXPORT obj_res lean_io_wait_any(b_obj_arg task_list) {
object * t = lean_io_wait_any_core(task_list);
object * v = lean_task_get(t);
lean_inc(v);
return v;
}
extern "C" LEAN_EXPORT obj_res lean_io_exit(uint8_t code) {
exit(code);
}

View file

@ -847,6 +847,17 @@ class task_manager {
}
}
object * wait_any_check(object * task_list) {
object * it = task_list;
while (!is_scalar(it)) {
object * head = lean_ctor_get(it, 0);
if (lean_to_task(head)->m_value)
return head;
it = cnstr_get(it, 1);
}
return nullptr;
}
public:
task_manager(unsigned max_std_workers):
m_max_std_workers(max_std_workers) {
@ -929,6 +940,17 @@ public:
}
}
object * wait_any(object * task_list) {
if (object * t = wait_any_check(task_list))
return t;
unique_lock<mutex> lock(m_mutex);
while (true) {
if (object * t = wait_any_check(task_list))
return t;
m_task_finished_cv.wait(lock);
}
}
void deactivate_task(lean_task_object * t) {
unique_lock<mutex> lock(m_mutex);
if (object * v = t->m_value) {
@ -1166,6 +1188,10 @@ extern "C" LEAN_EXPORT uint8_t lean_io_get_task_state_core(b_obj_arg t) {
return g_task_manager->get_task_state(o);
}
extern "C" LEAN_EXPORT b_obj_res lean_io_wait_any_core(b_obj_arg task_list) {
return g_task_manager->wait_any(task_list);
}
obj_res lean_promise_new() {
lean_always_assert(g_task_manager);

View file

@ -287,6 +287,7 @@ inline b_obj_res task_get(b_obj_arg t) { return lean_task_get(t); }
inline bool io_check_canceled_core() { return lean_io_check_canceled_core(); }
inline void io_cancel_core(b_obj_arg t) { return lean_io_cancel_core(t); }
inline bool io_get_task_state_core(b_obj_arg t) { return lean_io_get_task_state_core(t); }
inline b_obj_res io_wait_any_core(b_obj_arg task_list) { return lean_io_wait_any_core(task_list); }
// =======================================
// External