diff --git a/library/init/util.lean b/library/init/util.lean index c023f5b3f0..ac38c9113e 100644 --- a/library/init/util.lean +++ b/library/init/util.lean @@ -12,6 +12,11 @@ universes u def dbgTrace {α : Type u} (s : String) (f : Unit → α) : α := f () +/- Display the given message if `a` is shared, that is, RC(a) > 1 -/ +@[extern cpp inline "lean::dbg_trace_if_shared(#2, #3)"] +def dbgTraceIfShared {α : Type u} (s : String) (a : α) : α := +a + @[extern cpp inline "lean::dbg_sleep(#2, #3)"] def dbgSleep {α : Type u} (ms : UInt32) (f : Unit → α) : α := f () diff --git a/src/runtime/object.cpp b/src/runtime/object.cpp index 18b27cb10c..69fa9b35e0 100644 --- a/src/runtime/object.cpp +++ b/src/runtime/object.cpp @@ -1944,6 +1944,15 @@ object * dbg_sleep(uint32 ms, obj_arg fn) { return apply_1(fn, box(0)); } +object * dbg_trace_if_shared(obj_arg s, obj_arg a) { + if (is_heap_obj(a) && is_shared(a)) { + unique_lock lock(g_dbg_mutex); + std::cout << "shared RC: " << get_rc(a) << ", " << string_cstr(s) << std::endl; + } + dec(s); + return a; +} + // ======================================= // Statistics #ifdef LEAN_RUNTIME_STATS diff --git a/src/runtime/object.h b/src/runtime/object.h index b6a6ea7e79..f1c9d4130f 100644 --- a/src/runtime/object.h +++ b/src/runtime/object.h @@ -1436,6 +1436,7 @@ object * mk_array(obj_arg n, obj_arg v); // debugging helper functions object * dbg_trace(obj_arg s, obj_arg fn); object * dbg_sleep(uint32 ms, obj_arg fn); +object * dbg_trace_if_shared(obj_arg s, obj_arg a); // ======================================= // IO helper functions diff --git a/tests/playground/mapVShmap.lean b/tests/playground/mapVShmap.lean index 47d5df9e4c..2a2b9ebd48 100644 --- a/tests/playground/mapVShmap.lean +++ b/tests/playground/mapVShmap.lean @@ -8,27 +8,32 @@ set_option pp.binder_types false -- set_option trace.compiler.boxed true def f1 (ps : Array (Nat × Nat)) : Array (Nat × Nat) := -ps.hmap (λ p, match p with (n, m) := (n+1, m)) +ps.hmap (λ p, + -- let p := dbgTraceIfShared "bad1" p in + match p with (n, m) := (n+1, m)) def f2 (ps : Array (Nat × Nat)) : Array (Nat × Nat) := -ps.map (λ p, match p with (n, m) := (n+1, m)) +ps.map (λ p, + -- let p := dbgTraceIfShared "bad2" p in + match p with (n, m) := (n+1, m)) def prof {α : Type} (msg : String) (p : IO α) : IO α := let msg₁ := "Time for '" ++ msg ++ "':" in timeit msg₁ p -def test1 (n : Nat) (m : Array (Nat × Nat)) : IO Unit := +def test1 (n : Nat) : IO Unit := +let m := mkBigPairs n ∅ in let m := n.repeat f1 m in let s := m.foldl (λ p n, n + p.1) 0 in IO.println s -def test2 (n : Nat) (m : Array (Nat × Nat)) : IO Unit := +def test2 (n : Nat) : IO Unit := +let m := mkBigPairs n ∅ in let m := n.repeat f2 m in let s := m.foldl (λ p n, n + p.1) 0 in IO.println s def main (xs : List String) : IO Unit := -let n := xs.head.toNat in -let m := mkBigPairs n ∅ in -prof "hmap" (test1 n m) *> -prof "map" (test2 n m) +let n := xs.head.toNat in +prof "hmap" (test1 n) *> +prof "map" (test2 n)