feat: support mpz in the shareCommon APIs (#7838)
This PR adds support for mpz objects (i.e., big nums) to the `shareCommon` functions.
This commit is contained in:
parent
c3ff4334cd
commit
5f684b4777
2 changed files with 81 additions and 12 deletions
|
|
@ -17,22 +17,32 @@ extern "C" LEAN_EXPORT uint8 lean_sharecommon_eq(b_obj_arg o1, b_obj_arg o2) {
|
|||
size_t sz2 = lean_object_data_byte_size(o2);
|
||||
if (sz1 != sz2) return false;
|
||||
// compare relevant parts of the header
|
||||
if (lean_ptr_tag(o1) != lean_ptr_tag(o2)) return false;
|
||||
uint8_t tag = lean_ptr_tag(o1);
|
||||
if (tag != lean_ptr_tag(o2)) return false;
|
||||
if (lean_ptr_other(o1) != lean_ptr_other(o2)) return false;
|
||||
size_t header_sz = sizeof(lean_object);
|
||||
lean_assert(sz1 >= header_sz);
|
||||
// compare objects' bodies
|
||||
return memcmp(reinterpret_cast<char*>(o1) + header_sz, reinterpret_cast<char*>(o2) + header_sz, sz1 - header_sz) == 0;
|
||||
if (tag == LeanMPZ) {
|
||||
return mpz_value(o1) == mpz_value(o2);
|
||||
} else {
|
||||
size_t header_sz = sizeof(lean_object);
|
||||
lean_assert(sz1 >= header_sz);
|
||||
// compare objects' bodies
|
||||
return memcmp(reinterpret_cast<char*>(o1) + header_sz, reinterpret_cast<char*>(o2) + header_sz, sz1 - header_sz) == 0;
|
||||
}
|
||||
}
|
||||
|
||||
extern "C" LEAN_EXPORT uint64_t lean_sharecommon_hash(b_obj_arg o) {
|
||||
lean_assert(!lean_is_scalar(o));
|
||||
size_t sz = lean_object_data_byte_size(o);
|
||||
size_t header_sz = sizeof(lean_object);
|
||||
// hash relevant parts of the header
|
||||
unsigned init = hash(lean_ptr_tag(o), lean_ptr_other(o));
|
||||
// hash body
|
||||
return hash_str(sz - header_sz, reinterpret_cast<unsigned char const *>(o) + header_sz, init);
|
||||
uint8_t tag = lean_ptr_tag(o);
|
||||
if (tag == LeanMPZ) {
|
||||
return hash(tag, mpz_value(o).hash());
|
||||
} else {
|
||||
// hash relevant parts of the header
|
||||
unsigned init = hash(tag, lean_ptr_other(o));
|
||||
// hash body
|
||||
return hash_str(sz - header_sz, reinterpret_cast<unsigned char const *>(o) + header_sz, init);
|
||||
}
|
||||
}
|
||||
|
||||
static obj_res mk_pair(obj_arg a, obj_arg b) {
|
||||
|
|
@ -114,7 +124,7 @@ class sharecommon_fn {
|
|||
case LeanReserved:
|
||||
lean_unreachable();
|
||||
// We do not maximize sharing for the following kinds of objects
|
||||
case LeanMPZ: case LeanThunk:
|
||||
case LeanThunk:
|
||||
case LeanTask: case LeanRef:
|
||||
case LeanExternal: case LeanClosure:
|
||||
case LeanPromise:
|
||||
|
|
@ -201,6 +211,11 @@ class sharecommon_fn {
|
|||
save(a, (lean_object*)new_a);
|
||||
}
|
||||
|
||||
void visit_mpz(b_obj_arg a) {
|
||||
object * new_a = alloc_mpz(mpz_value(a));
|
||||
save(a, new_a);
|
||||
}
|
||||
|
||||
void visit_ctor(b_obj_arg a) {
|
||||
clear_children();
|
||||
unsigned num_objs = lean_ctor_num_objs(a);
|
||||
|
|
@ -247,7 +262,7 @@ public:
|
|||
case LeanArray: visit_array(curr); break;
|
||||
case LeanScalarArray: visit_sarray(curr); break;
|
||||
case LeanString: visit_string(curr); break;
|
||||
case LeanMPZ: lean_unreachable();
|
||||
case LeanMPZ: visit_mpz(curr); break;
|
||||
case LeanThunk: lean_unreachable();
|
||||
case LeanTask: lean_unreachable();
|
||||
case LeanPromise: lean_unreachable();
|
||||
|
|
@ -409,7 +424,6 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) {
|
|||
Similarly to `sharecommon_fn`, we only maximally share arrays, scalar arrays, strings, and
|
||||
constructor objects.
|
||||
*/
|
||||
case LeanMPZ: lean_inc_ref(a); return a;
|
||||
case LeanClosure: lean_inc_ref(a); return a;
|
||||
case LeanThunk: lean_inc_ref(a); return a;
|
||||
case LeanTask: lean_inc_ref(a); return a;
|
||||
|
|
@ -417,6 +431,7 @@ lean_object * sharecommon_quick_fn::visit(lean_object * a) {
|
|||
case LeanRef: lean_inc_ref(a); return a;
|
||||
case LeanExternal: lean_inc_ref(a); return a;
|
||||
case LeanReserved: lean_inc_ref(a); return a;
|
||||
case LeanMPZ: return visit_terminal(a);
|
||||
case LeanScalarArray: return visit_terminal(a);
|
||||
case LeanString: return visit_terminal(a);
|
||||
case LeanArray: return visit_array(a);
|
||||
|
|
|
|||
54
tests/lean/run/sharecommon_mpz.lean
Normal file
54
tests/lean/run/sharecommon_mpz.lean
Normal file
|
|
@ -0,0 +1,54 @@
|
|||
import Lean
|
||||
|
||||
open Lean Meta Tactic Grind
|
||||
|
||||
def runGrind (x : GrindM α) : MetaM α := do
|
||||
GrindM.run x `dummy (← mkParams {}) (pure ())
|
||||
|
||||
@[noinline] def mkA (x : Nat) := x + 1
|
||||
|
||||
def tst (a b : Nat) : GrindM Unit := do
|
||||
IO.println a
|
||||
IO.println b
|
||||
let a ← shareCommon (mkNatLit a)
|
||||
let b ← shareCommon (mkNatLit b)
|
||||
IO.println (isSameExpr a b)
|
||||
|
||||
/--
|
||||
info: 1000000000000000000000000001
|
||||
1000000000000000000000000001
|
||||
true
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
run_meta do
|
||||
let a := mkA 1000000000000000000000000000
|
||||
let b := 1000000000000000000000000001
|
||||
runGrind (tst a b)
|
||||
|
||||
/--
|
||||
info: 1001
|
||||
1001
|
||||
true
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
run_meta do
|
||||
let a := mkA 1000
|
||||
let b := 1001
|
||||
runGrind (tst a b)
|
||||
|
||||
def tst2 (a b : Nat) : IO Unit := do
|
||||
IO.println a
|
||||
IO.println b
|
||||
let (a, b) := ShareCommon.shareCommon' (mkNatLit a, mkNatLit b)
|
||||
IO.println (isSameExpr a b)
|
||||
|
||||
/--
|
||||
info: 1000000000000000000000000001
|
||||
1000000000000000000000000001
|
||||
true
|
||||
-/
|
||||
#guard_msgs (info) in
|
||||
run_meta do
|
||||
let a := mkA 1000000000000000000000000000
|
||||
let b := 1000000000000000000000000001
|
||||
tst2 a b
|
||||
Loading…
Add table
Reference in a new issue