diff --git a/src/runtime/io.cpp b/src/runtime/io.cpp index e1e16df4f0..18c897a1c5 100644 --- a/src/runtime/io.cpp +++ b/src/runtime/io.cpp @@ -1495,10 +1495,8 @@ extern "C" LEAN_EXPORT obj_res lean_st_ref_swap(b_obj_arg ref, obj_arg a) { } } -extern "C" LEAN_EXPORT obj_res lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) { - // TODO(Leo): ref_maybe_mt - bool r = lean_to_ref(ref1)->m_value == lean_to_ref(ref2)->m_value; - return box(r); +extern "C" LEAN_EXPORT uint8_t lean_st_ref_ptr_eq(b_obj_arg ref1, b_obj_arg ref2) { + return lean_to_ref(ref1) == lean_to_ref(ref2); } /* {α : Type} (act : BaseIO α) : α */ diff --git a/tests/lean/run/st_test.lean b/tests/lean/run/st_test.lean new file mode 100644 index 0000000000..1d83d40420 --- /dev/null +++ b/tests/lean/run/st_test.lean @@ -0,0 +1,85 @@ +/-! +Some basic tests for the ST monad. +-/ + +namespace STTest + +def ptrEq : IO Unit := do + let ref1 ← IO.mkRef 0 + let ref2 ← IO.mkRef 0 + IO.println (← ref1.ptrEq ref1) + IO.println (← ref1.ptrEq ref2) + +/-- +info: true +false +-/ +#guard_msgs in +#eval ptrEq + +def readWriteRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.get) + ref1.set 1 + IO.println (← ref1.get) + +/-- +info: 0 +1 +-/ +#guard_msgs in +#eval readWriteRegister + +def swapRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.swap 5) + IO.println (← ref1.get) + +/-- +info: 0 +5 +-/ +#guard_msgs in +#eval swapRegister + +unsafe def takeRegister : IO Unit := do + let ref1 ← IO.mkRef 0 + IO.println (← ref1.take) + ref1.set 5 + IO.println (← ref1.get) + +/-- +info: 0 +5 +-/ +#guard_msgs in +#eval takeRegister + +def modifyRegister : IO Unit := do + let ref1 ← IO.mkRef 1 + IO.println (← ref1.get) + ref1.modify (fun x => 2 * x) + IO.println (← ref1.get) + +/-- +info: 1 +2 +-/ +#guard_msgs in +#eval modifyRegister + +def modifyGetRegister : IO Unit := do + let ref1 ← IO.mkRef 1 + IO.println (← ref1.get) + IO.println (← ref1.modifyGet (fun x => (x, 2 * x))) + IO.println (← ref1.get) + +/-- +info: 1 +1 +2 +-/ +#guard_msgs in +#eval modifyGetRegister + +end STTest