From 2a18ba753bca64e79f6d9857a9f64e638b32c371 Mon Sep 17 00:00:00 2001 From: Evan Brown Date: Mon, 16 Oct 2023 12:41:17 -0700 Subject: [PATCH] Add sanitizer mode checks that element constructors/destructors don't make reentrant calls to raw_hash_set member functions. PiperOrigin-RevId: 573897598 Change-Id: If40c23ac3cd9fff315ee18774e27c480cbca3a81 --- absl/container/BUILD.bazel | 2 + absl/container/CMakeLists.txt | 2 + absl/container/internal/container_memory.h | 8 ++ absl/container/internal/raw_hash_set.h | 39 ++++++--- absl/container/internal/raw_hash_set_test.cc | 84 +++++++++++++++++--- 5 files changed, 113 insertions(+), 22 deletions(-) diff --git a/absl/container/BUILD.bazel b/absl/container/BUILD.bazel index 5b69ae6b972..4aa67d35ed5 100644 --- a/absl/container/BUILD.bazel +++ b/absl/container/BUILD.bazel @@ -686,8 +686,10 @@ cc_test( "//absl/base:config", "//absl/base:core_headers", "//absl/base:prefetch", + "//absl/functional:function_ref", "//absl/log", "//absl/strings", + "//absl/types:optional", "@com_google_googletest//:gtest", "@com_google_googletest//:gtest_main", ], diff --git a/absl/container/CMakeLists.txt b/absl/container/CMakeLists.txt index 96cdf59cbbe..8362c440501 100644 --- a/absl/container/CMakeLists.txt +++ b/absl/container/CMakeLists.txt @@ -743,10 +743,12 @@ absl_cc_test( absl::core_headers absl::flat_hash_map absl::flat_hash_set + absl::function_ref absl::hash_function_defaults absl::hash_policy_testing absl::hashtable_debug absl::log + absl::optional absl::prefetch absl::raw_hash_set absl::strings diff --git a/absl/container/internal/container_memory.h b/absl/container/internal/container_memory.h index f59ca4ee22d..a735ca3429a 100644 --- a/absl/container/internal/container_memory.h +++ b/absl/container/internal/container_memory.h @@ -249,6 +249,14 @@ inline void SanitizerUnpoisonObject(const T* object) { SanitizerUnpoisonMemoryRegion(object, sizeof(T)); } +template +void RunWithReentrancyGuard(Container& c, Alloc& a, F f) { + SanitizerPoisonObject(&c); + if (!std::is_empty()) SanitizerUnpoisonObject(&a); + f(); + SanitizerUnpoisonObject(&c); +} + namespace memory_internal { // If Pair is a standard-layout type, OffsetOf::kFirst and diff --git a/absl/container/internal/raw_hash_set.h b/absl/container/internal/raw_hash_set.h index a8fd80caa54..67964f545b8 100644 --- a/absl/container/internal/raw_hash_set.h +++ b/absl/container/internal/raw_hash_set.h @@ -2143,7 +2143,7 @@ class raw_hash_set { alignas(slot_type) unsigned char raw[sizeof(slot_type)]; slot_type* slot = reinterpret_cast(&raw); - PolicyTraits::construct(&alloc_ref(), slot, std::forward(args)...); + construct(slot, std::forward(args)...); const auto& elem = PolicyTraits::element(slot); return PolicyTraits::apply(InsertSlot{*this, std::move(*slot)}, elem); } @@ -2248,7 +2248,7 @@ class raw_hash_set { // a better match if non-const iterator is passed as an argument. void erase(iterator it) { AssertIsFull(it.ctrl_, it.generation(), it.generation_ptr(), "erase()"); - PolicyTraits::destroy(&alloc_ref(), it.slot_); + destroy(it.slot_); erase_meta_only(it); } @@ -2541,10 +2541,9 @@ class raw_hash_set { std::pair operator()(const K& key, Args&&...) && { auto res = s.find_or_prepare_insert(key); if (res.second) { - PolicyTraits::transfer(&s.alloc_ref(), s.slot_array() + res.first, - &slot); + s.transfer(s.slot_array() + res.first, &slot); } else if (do_destroy) { - PolicyTraits::destroy(&s.alloc_ref(), &slot); + s.destroy(&slot); } return {s.iterator_at(res.first), res.second}; } @@ -2553,13 +2552,31 @@ class raw_hash_set { slot_type&& slot; }; + // Helpers to enable sanitizer mode validation to protect against reentrant + // calls during element constructor/destructor. + template + inline void construct(slot_type* slot, Args&&... args) { + RunWithReentrancyGuard(*this, alloc_ref(), [&] { + PolicyTraits::construct(&alloc_ref(), slot, std::forward(args)...); + }); + } + inline void destroy(slot_type* slot) { + RunWithReentrancyGuard(*this, alloc_ref(), + [&] { PolicyTraits::destroy(&alloc_ref(), slot); }); + } + inline void transfer(slot_type* to, slot_type* from) { + RunWithReentrancyGuard(*this, alloc_ref(), [&] { + PolicyTraits::transfer(&alloc_ref(), to, from); + }); + } + inline void destroy_slots() { const size_t cap = capacity(); const ctrl_t* ctrl = control(); slot_type* slot = slot_array(); for (size_t i = 0; i != cap; ++i) { if (IsFull(ctrl[i])) { - PolicyTraits::destroy(&alloc_ref(), slot + i); + destroy(slot + i); } } } @@ -2622,7 +2639,7 @@ class raw_hash_set { size_t new_i = target.offset; total_probe_length += target.probe_length; SetCtrl(common(), new_i, H2(hash), sizeof(slot_type)); - PolicyTraits::transfer(&alloc_ref(), new_slots + new_i, old_slots + i); + transfer(new_slots + new_i, old_slots + i); } } if (old_capacity) { @@ -2725,7 +2742,7 @@ class raw_hash_set { reserve(size); for (iterator it = that.begin(); it != that.end(); ++it) { insert(std::move(PolicyTraits::element(it.slot_))); - PolicyTraits::destroy(&that.alloc_ref(), it.slot_); + that.destroy(it.slot_); } that.dealloc(); that.common() = CommonFields{}; @@ -2816,8 +2833,7 @@ class raw_hash_set { // POSTCONDITION: *m.iterator_at(i) == value_type(forward(args)...). template void emplace_at(size_t i, Args&&... args) { - PolicyTraits::construct(&alloc_ref(), slot_array() + i, - std::forward(args)...); + construct(slot_array() + i, std::forward(args)...); assert(PolicyTraits::apply(FindElement{*this}, *iterator_at(i)) == iterator_at(i) && @@ -2883,8 +2899,7 @@ class raw_hash_set { } static void transfer_slot_fn(void* set, void* dst, void* src) { auto* h = static_cast(set); - PolicyTraits::transfer(&h->alloc_ref(), static_cast(dst), - static_cast(src)); + h->transfer(static_cast(dst), static_cast(src)); } // Note: dealloc_fn will only be used if we have a non-standard allocator. static void dealloc_fn(CommonFields& common, const PolicyFunctions&) { diff --git a/absl/container/internal/raw_hash_set_test.cc b/absl/container/internal/raw_hash_set_test.cc index d194ca1b5d3..4e67b79ef7a 100644 --- a/absl/container/internal/raw_hash_set_test.cc +++ b/absl/container/internal/raw_hash_set_test.cc @@ -49,8 +49,10 @@ #include "absl/container/internal/hash_policy_testing.h" #include "absl/container/internal/hashtable_debug.h" #include "absl/container/internal/test_allocator.h" +#include "absl/functional/function_ref.h" #include "absl/log/log.h" #include "absl/strings/string_view.h" +#include "absl/types/optional.h" namespace absl { ABSL_NAMESPACE_BEGIN @@ -409,19 +411,15 @@ struct StringTable using Base::Base; }; -struct IntTable - : raw_hash_set, - std::equal_to, std::allocator> { - using Base = typename IntTable::raw_hash_set; +template +struct ValueTable : raw_hash_set, hash_default_hash, + std::equal_to, std::allocator> { + using Base = typename ValueTable::raw_hash_set; using Base::Base; }; -struct Uint8Table - : raw_hash_set, - std::equal_to, std::allocator> { - using Base = typename Uint8Table::raw_hash_set; - using Base::Base; -}; +using IntTable = ValueTable; +using Uint8Table = ValueTable; template struct CustomAlloc : std::allocator { @@ -2489,6 +2487,72 @@ using RawHashSetAlloc = raw_hash_set, TEST(Table, AllocatorPropagation) { TestAllocPropagation(); } +struct ConstructCaller { + explicit ConstructCaller(int v) : val(v) {} + ConstructCaller(int v, absl::FunctionRef func) : val(v) { func(); } + template + friend H AbslHashValue(H h, const ConstructCaller& d) { + return H::combine(std::move(h), d.val); + } + bool operator==(const ConstructCaller& c) const { return val == c.val; } + + int val; +}; + +struct DestroyCaller { + explicit DestroyCaller(int v) : val(v) {} + DestroyCaller(int v, absl::FunctionRef func) + : val(v), destroy_func(func) {} + DestroyCaller(DestroyCaller&& that) + : val(that.val), destroy_func(std::move(that.destroy_func)) { + that.Deactivate(); + } + ~DestroyCaller() { + if (destroy_func) (*destroy_func)(); + } + void Deactivate() { destroy_func = absl::nullopt; } + + template + friend H AbslHashValue(H h, const DestroyCaller& d) { + return H::combine(std::move(h), d.val); + } + bool operator==(const DestroyCaller& d) const { return val == d.val; } + + int val; + absl::optional> destroy_func; +}; + +#if defined(ABSL_HAVE_ADDRESS_SANITIZER) || defined(ABSL_HAVE_MEMORY_SANITIZER) +TEST(Table, ReentrantCallsFail) { + constexpr const char* kDeathMessage = + "use-after-poison|use-of-uninitialized-value"; + { + ValueTable t; + t.insert(ConstructCaller{0}); + auto erase_begin = [&] { t.erase(t.begin()); }; + EXPECT_DEATH_IF_SUPPORTED(t.emplace(1, erase_begin), kDeathMessage); + } + { + ValueTable t; + t.insert(DestroyCaller{0}); + auto find_0 = [&] { t.find(DestroyCaller{0}); }; + t.insert(DestroyCaller{1, find_0}); + for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i}); + EXPECT_DEATH_IF_SUPPORTED(t.clear(), kDeathMessage); + for (auto& elem : t) elem.Deactivate(); + } + { + ValueTable t; + t.insert(DestroyCaller{0}); + auto insert_1 = [&] { t.insert(DestroyCaller{1}); }; + t.insert(DestroyCaller{1, insert_1}); + for (int i = 10; i < 20; ++i) t.insert(DestroyCaller{i}); + EXPECT_DEATH_IF_SUPPORTED(t.clear(), kDeathMessage); + for (auto& elem : t) elem.Deactivate(); + } +} +#endif + } // namespace } // namespace container_internal ABSL_NAMESPACE_END