From 5f684b4777b482ea35e25ff787c078754373e9fa Mon Sep 17 00:00:00 2001 From: Leonardo de Moura Date: Sun, 6 Apr 2025 12:52:50 -0700 Subject: [PATCH] feat: support `mpz` in the `shareCommon` APIs (#7838) This PR adds support for mpz objects (i.e., big nums) to the `shareCommon` functions. --- src/runtime/sharecommon.cpp | 39 ++++++++++++++------- tests/lean/run/sharecommon_mpz.lean | 54 +++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+), 12 deletions(-) create mode 100644 tests/lean/run/sharecommon_mpz.lean diff --git a/src/runtime/sharecommon.cpp b/src/runtime/sharecommon.cpp index ae4cf30cb7..887054e971 100644 --- a/src/runtime/sharecommon.cpp +++ b/src/runtime/sharecommon.cpp @@ -17,22 +17,32 @@ extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2) { size_t sz2 = lean_object_data_byte_size(o2); if (sz1 != sz2) return false; // compare relevant parts of the header - if (lean_ptr_tag(o1) != lean_ptr_tag(o2)) return false; + uint8_t tag = lean_ptr_tag(o1); + if (tag != lean_ptr_tag(o2)) return false; if (lean_ptr_other(o1) != lean_ptr_other(o2)) return false; - size_t header_sz = sizeof(lean_object); - lean_assert(sz1 >= header_sz); - // compare objects' bodies - return memcmp(reinterpret_cast(o1) + header_sz, reinterpret_cast(o2) + header_sz, sz1 - header_sz) == 0; + if (tag == LeanMPZ) { + return mpz_value(o1) == mpz_value(o2); + } else { + size_t header_sz = sizeof(lean_object); + lean_assert(sz1 >= header_sz); + // compare objects' bodies + return memcmp(reinterpret_cast(o1) + header_sz, reinterpret_cast(o2) + header_sz, sz1 - header_sz) == 0; + } } extern "C" LEAN_EXPORT uint64_t lean_sharecommon_hash(b_obj_arg o) { lean_assert(!lean_is_scalar(o)); size_t sz = lean_object_data_byte_size(o); size_t header_sz = sizeof(lean_object); - // hash relevant parts of the header - unsigned init = hash(lean_ptr_tag(o), lean_ptr_other(o)); - // hash body - return hash_str(sz - header_sz, reinterpret_cast(o) + header_sz, init); + uint8_t tag = lean_ptr_tag(o); + if (tag == LeanMPZ) { + return hash(tag, mpz_value(o).hash()); + } else { + // hash relevant parts of the header + unsigned init = hash(tag, lean_ptr_other(o)); + // hash body + return hash_str(sz - header_sz, reinterpret_cast(o) + header_sz, init); + } } static obj_res mk_pair(obj_arg a, obj_arg b) { @@ -114,7 +124,7 @@ class sharecommon_fn { case LeanReserved: lean_unreachable(); // We do not maximize sharing for the following kinds of objects - case LeanMPZ: case LeanThunk: + case LeanThunk: case LeanTask: case LeanRef: case LeanExternal: case LeanClosure: case LeanPromise: @@ -201,6 +211,11 @@ class sharecommon_fn { save(a, (lean_object*)new_a); } + void visit_mpz(b_obj_arg a) { + object * new_a = alloc_mpz(mpz_value(a)); + save(a, new_a); + } + void visit_ctor(b_obj_arg a) { clear_children(); unsigned num_objs = lean_ctor_num_objs(a); @@ -247,7 +262,7 @@ public: case LeanArray: visit_array(curr); break; case LeanScalarArray: visit_sarray(curr); break; case LeanString: visit_string(curr); break; - case LeanMPZ: lean_unreachable(); + case LeanMPZ: visit_mpz(curr); break; case LeanThunk: lean_unreachable(); case LeanTask: lean_unreachable(); case LeanPromise: lean_unreachable(); @@ -409,7 +424,6 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) { Similarly to `sharecommon_fn`, we only maximally share arrays, scalar arrays, strings, and constructor objects. */ - case LeanMPZ: lean_inc_ref(a); return a; case LeanClosure: lean_inc_ref(a); return a; case LeanThunk: lean_inc_ref(a); return a; case LeanTask: lean_inc_ref(a); return a; @@ -417,6 +431,7 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) { case LeanRef: lean_inc_ref(a); return a; case LeanExternal: lean_inc_ref(a); return a; case LeanReserved: lean_inc_ref(a); return a; + case LeanMPZ: return visit_terminal(a); case LeanScalarArray: return visit_terminal(a); case LeanString: return visit_terminal(a); case LeanArray: return visit_array(a); diff --git a/tests/lean/run/sharecommon_mpz.lean b/tests/lean/run/sharecommon_mpz.lean new file mode 100644 index 0000000000..d6b26a76e5 --- /dev/null +++ b/tests/lean/run/sharecommon_mpz.lean @@ -0,0 +1,54 @@ +import Lean + +open Lean Meta Tactic Grind + +def runGrind (x : GrindM α) : MetaM α := do + GrindM.run x `dummy (← mkParams {}) (pure ()) + +@[noinline] def mkA (x : Nat) := x + 1 + +def tst (a b : Nat) : GrindM Unit := do + IO.println a + IO.println b + let a ← shareCommon (mkNatLit a) + let b ← shareCommon (mkNatLit b) + IO.println (isSameExpr a b) + +/-- +info: 1000000000000000000000000001 +1000000000000000000000000001 +true +-/ +#guard_msgs (info) in +run_meta do + let a := mkA 1000000000000000000000000000 + let b := 1000000000000000000000000001 + runGrind (tst a b) + +/-- +info: 1001 +1001 +true +-/ +#guard_msgs (info) in +run_meta do + let a := mkA 1000 + let b := 1001 + runGrind (tst a b) + +def tst2 (a b : Nat) : IO Unit := do + IO.println a + IO.println b + let (a, b) := ShareCommon.shareCommon' (mkNatLit a, mkNatLit b) + IO.println (isSameExpr a b) + +/-- +info: 1000000000000000000000000001 +1000000000000000000000000001 +true +-/ +#guard_msgs (info) in +run_meta do + let a := mkA 1000000000000000000000000000 + let b := 1000000000000000000000000001 + tst2 a b