Skip to content

Commit

Permalink
Add a container-based version of std::sample()
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 592864147
Change-Id: I83179b0225aa446ae0b57b46b604af14f1fa14df
  • Loading branch information
ericastor authored and copybara-github committed Dec 21, 2023
1 parent 794352a commit 258e5a1
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 1 deletion.
30 changes: 30 additions & 0 deletions absl/algorithm/container.h
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,36 @@ void c_shuffle(RandomAccessContainer& c, UniformRandomBitGenerator&& gen) {
std::forward<UniformRandomBitGenerator>(gen));
}

// c_sample()
//
// Container-based version of the <algorithm> `std::sample()` function to
// randomly sample elements from the container without replacement using a
// `gen()` uniform random number generator and write them to an iterator range.
template <typename C, typename OutputIterator, typename Distance,
typename UniformRandomBitGenerator>
OutputIterator c_sample(const C& c, OutputIterator result, Distance n,
UniformRandomBitGenerator&& gen) {
#if defined(__cpp_lib_sample) && __cpp_lib_sample >= 201603L
return std::sample(container_algorithm_internal::c_begin(c),
container_algorithm_internal::c_end(c), result, n,
std::forward<UniformRandomBitGenerator>(gen));
#else
// Fall back to a stable selection-sampling implementation.
auto first = container_algorithm_internal::c_begin(c);
Distance unsampled_elements = c_distance(c);
n = (std::min)(n, unsampled_elements);
for (; n != 0; ++first) {
Distance r =
std::uniform_int_distribution<Distance>(0, --unsampled_elements)(gen);
if (r < n) {
*result++ = *first;
--n;
}
}
return result;
#endif
}

//------------------------------------------------------------------------------
// <algorithm> Partition functions
//------------------------------------------------------------------------------
Expand Down
22 changes: 21 additions & 1 deletion absl/algorithm/container_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

#include "absl/algorithm/container.h"

#include <algorithm>
#include <functional>
#include <initializer_list>
#include <iterator>
Expand All @@ -40,8 +41,10 @@ using ::testing::Each;
using ::testing::ElementsAre;
using ::testing::Gt;
using ::testing::IsNull;
using ::testing::IsSubsetOf;
using ::testing::Lt;
using ::testing::Pointee;
using ::testing::SizeIs;
using ::testing::Truly;
using ::testing::UnorderedElementsAre;

Expand Down Expand Up @@ -963,12 +966,29 @@ TEST(MutatingTest, RotateCopy) {
EXPECT_THAT(actual, ElementsAre(3, 4, 1, 2, 5));
}

template <typename T>
T RandomlySeededPrng() {
std::random_device rdev;
std::seed_seq::result_type data[T::state_size];
std::generate_n(data, T::state_size, std::ref(rdev));
std::seed_seq prng_seed(data, data + T::state_size);
return T(prng_seed);
}

TEST(MutatingTest, Shuffle) {
std::vector<int> actual = {1, 2, 3, 4, 5};
absl::c_shuffle(actual, std::random_device());
absl::c_shuffle(actual, RandomlySeededPrng<std::mt19937_64>());
EXPECT_THAT(actual, UnorderedElementsAre(1, 2, 3, 4, 5));
}

TEST(MutatingTest, Sample) {
std::vector<int> actual;
absl::c_sample(std::vector<int>{1, 2, 3, 4, 5}, std::back_inserter(actual), 3,
RandomlySeededPrng<std::mt19937_64>());
EXPECT_THAT(actual, IsSubsetOf({1, 2, 3, 4, 5}));
EXPECT_THAT(actual, SizeIs(3));
}

TEST(MutatingTest, PartialSort) {
std::vector<int> sequence{5, 3, 42, 0};
absl::c_partial_sort(sequence, sequence.begin() + 2);
Expand Down

0 comments on commit 258e5a1

Please sign in to comment.