diff --git a/src/bindings/lua/splay_map.cpp b/src/bindings/lua/splay_map.cpp index 9b75c94abf..864d9b7942 100644 --- a/src/bindings/lua/splay_map.cpp +++ b/src/bindings/lua/splay_map.cpp @@ -168,7 +168,11 @@ static int splay_map_pred(lua_State * L) { } static int splay_map_for_each(lua_State * L) { - lua_splay_map & m = to_splay_map(L, 1); // map + // Remark: we take a copy of the map to make sure + // for_each will not crash if the map is updated while being + // traversed. + // The copy operation is very cheap O(1). + lua_splay_map m(to_splay_map(L, 1)); // map luaL_checktype(L, 2, LUA_TFUNCTION); // user-fun m.for_each([&](lua_ref const & k, lua_ref const & v) { lua_pushvalue(L, 2); // push user-fun diff --git a/tests/lua/map2.lua b/tests/lua/map2.lua new file mode 100644 index 0000000000..4a3f248a6b --- /dev/null +++ b/tests/lua/map2.lua @@ -0,0 +1,33 @@ +local m = splay_map() +for i = 1, 100 do + m:insert(i, i * 3) +end + +local prev_k = nil +-- It is safe to add/erase elements from m while we are traversing. +-- The for_each method will traverse the elements that are in m +-- when the for_each is invoked +m:for_each( + function(k, v) + if prev_k then + assert(prev_k < k) + end + print(k .. " -> " .. v) + prev_k = k + m:insert(-k, v) + end +) + +assert(m:size() == 200) +m2 = m:copy() +m:for_each( + function(k, v) + assert(m2:contains(-k)) + assert(m2:find(-k) == v) + if k > 0 then + m:erase(k) + end + end +) +assert(m:size() == 100) +m:for_each(function(k, v) assert(k < 0) end)