From 823173a7610543d2669a654a0617bbcb55ece708 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20B=C3=B6ving?= Date: Sun, 2 Nov 2025 11:42:33 +0100 Subject: [PATCH] fix: make ST.Ref.ptrEq behave as stated in the docs (#11056) This PR fixes `ST.Ref.ptrEq` to act as described in the docs. This fixes two bugs: 1. The recent `IO.RealWorld` elimination PR overlooked this function (afaik this is the only one), causing its return value to be generally wrong. 2. The implementation of `ptrEq` would previously always consider two different cells with pointer equivalent value to be pointer equal. However, the function is supposed to check whether two `Ref` are the same cell, not whether the contained elements are. --- src/runtime/io.cpp | 6 +-- tests/lean/run/st_test.lean | 85 +++++++++++++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 4 deletions(-) create mode 100644 tests/lean/run/st_test.lean diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index e1e16df4f0..18c897a1c5 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -1495,10 +1495,8 @@ extern "C" LEAN_EXPORT obj_res lean_st_ref_swap(b_obj_arg ref, obj_arg a) { } } -extern "C" LEAN_EXPORT obj_res lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) { - // TODO(Leo): ref_maybe_mt - bool r = lean_to_ref(ref1)->m_value == lean_to_ref(ref2)->m_value; - return box(r); +extern "C" LEAN_EXPORT uint8_t lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) { + return lean_to_ref(ref1) == lean_to_ref(ref2); } /* {α : Type} (act : BaseIO α) : α */ diff --git a/tests/lean/run/st_test.lean b/tests/lean/run/st_test.lean new file mode 100644 index 0000000000..1d83d40420 --- /dev/null +++ b/tests/lean/run/st_test.lean @@ -0,0 +1,85 @@ +/-! +Some basic tests for the ST monad. +-/ + +namespace STTest + +def ptrEq : IO Unit := do + let ref1 ← IO.mkRef 0 + let ref2 ← IO.mkRef 0 + IO.println (← ref1.ptrEq ref1) + IO.println (← ref1.ptrEq ref2) + +/-- +info: true +false +-/ +#guard_msgs in +#eval ptrEq + +def readWriteRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.get) + ref1.set 1 + IO.println (← ref1.get) + +/-- +info: 0 +1 +-/ +#guard_msgs in +#eval readWriteRegister + +def swapRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.swap 5) + IO.println (← ref1.get) + +/-- +info: 0 +5 +-/ +#guard_msgs in +#eval swapRegister + +unsafe def takeRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.take) + ref1.set 5 + IO.println (← ref1.get) + +/-- +info: 0 +5 +-/ +#guard_msgs in +#eval takeRegister + +def modifyRegister : IO Unit := do + let ref1 ← IO.mkRef 1 + IO.println (← ref1.get) + ref1.modify (fun x => 2 * x) + IO.println (← ref1.get) + +/-- +info: 1 +2 +-/ +#guard_msgs in +#eval modifyRegister + +def modifyGetRegister : IO Unit := do + let ref1 ← IO.mkRef 1 + IO.println (← ref1.get) + IO.println (← ref1.modifyGet (fun x => (x, 2 * x))) + IO.println (← ref1.get) + +/-- +info: 1 +1 +2 +-/ +#guard_msgs in +#eval modifyGetRegister + +end STTest