fix: make ST.Ref.ptrEq behave as stated in the docs (#11056)

This PR fixes `ST.Ref.ptrEq` to act as described in the docs. This fixes
two bugs:
1. The recent `IO.RealWorld` elimination PR overlooked this function
(afaik this is the only one),
   causing its return value to be generally wrong.
2. The implementation of `ptrEq` would previously always consider two
different cells with pointer
equivalent value to be pointer equal. However, the function is supposed
to check whether two
   `Ref` are the same cell, not whether the contained elements are.
This commit is contained in:
Henrik Böving 2025-11-02 11:42:33 +01:00 committed by GitHub
parent 3e86729ee0
commit 823173a761
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 87 additions and 4 deletions

View file

@ -1495,10 +1495,8 @@ extern "C" LEAN_EXPORT obj_res lean_st_ref_swap(b_obj_arg ref, obj_arg a) {
}
}
extern "C" LEAN_EXPORT obj_res lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) {
// TODO(Leo): ref_maybe_mt
bool r = lean_to_ref(ref1)->m_value == lean_to_ref(ref2)->m_value;
return box(r);
extern "C" LEAN_EXPORT uint8_t lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) {
return lean_to_ref(ref1) == lean_to_ref(ref2);
}
/* {α : Type} (act : BaseIO α) : α */

View file

@ -0,0 +1,85 @@
/-!
Some basic tests for the ST monad.
-/
namespace STTest
def ptrEq : IO Unit := do
let ref1 ← IO.mkRef 0
let ref2 ← IO.mkRef 0
IO.println (← ref1.ptrEq ref1)
IO.println (← ref1.ptrEq ref2)
/--
info: true
false
-/
#guard_msgs in
#eval ptrEq
def readWriteRegister : IO Unit := do
let ref1 ← IO.mkRef 0
IO.println (← ref1.get)
ref1.set 1
IO.println (← ref1.get)
/--
info: 0
1
-/
#guard_msgs in
#eval readWriteRegister
def swapRegister : IO Unit := do
let ref1 ← IO.mkRef 0
IO.println (← ref1.swap 5)
IO.println (← ref1.get)
/--
info: 0
5
-/
#guard_msgs in
#eval swapRegister
unsafe def takeRegister : IO Unit := do
let ref1 ← IO.mkRef 0
IO.println (← ref1.take)
ref1.set 5
IO.println (← ref1.get)
/--
info: 0
5
-/
#guard_msgs in
#eval takeRegister
def modifyRegister : IO Unit := do
let ref1 ← IO.mkRef 1
IO.println (← ref1.get)
ref1.modify (fun x => 2 * x)
IO.println (← ref1.get)
/--
info: 1
2
-/
#guard_msgs in
#eval modifyRegister
def modifyGetRegister : IO Unit := do
let ref1 ← IO.mkRef 1
IO.println (← ref1.get)
IO.println (← ref1.modifyGet (fun x => (x, 2 * x)))
IO.println (← ref1.get)
/--
info: 1
1
2
-/
#guard_msgs in
#eval modifyGetRegister
end STTest