Skip to content

Commit

Permalink
Attempt fix
Browse files Browse the repository at this point in the history
Signed-off-by: Joaquin Anton Guirao <janton@nvidia.com>
  • Loading branch information
jantonguirao committed Jul 23, 2024
1 parent 1238d5d commit 18609f9
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 22 deletions.
3 changes: 3 additions & 0 deletions dali/core/cuda_event_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ CUDAEventPool::CUDAEventPool(unsigned event_flags) {
int num_devices = 0;
CUDA_CALL(cudaGetDeviceCount(&num_devices));
dev_events_.resize(num_devices);
for (int i = 0; i < 20000; i++) {
Put(CUDAEvent::CreateWithFlags(cudaEventDisableTiming));
}
}

CUDAEvent CUDAEventPool::Get(int device_id) {
Expand Down
28 changes: 15 additions & 13 deletions dali/operators/imgcodec/image_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class ImageDecoder : public StatelessOperator<Backend> {
const bool use_cache = cache_ && cache_->IsCacheEnabled() && dtype_ == DALI_UINT8;
auto get_task = [&](int block_idx, int nblocks) {
return [&, block_idx, nblocks](int tid) {
DomainTimeRange tr(make_string("SetupImpl #", block_idx), DomainTimeRange::kBlue1);
int i_start = nsamples * block_idx / nblocks;
int i_end = nsamples * (block_idx + 1) / nblocks;
for (int i = i_start; i < i_end; i++) {
Expand Down Expand Up @@ -572,17 +573,17 @@ class ImageDecoder : public StatelessOperator<Backend> {
};
};

int nblocks = tp_->NumThreads() + 1;
if (nsamples > nblocks * 4) {
if (nsamples < 16) {
get_task(0, 1)(-1); // run all in current thread
} else {
int nblocks = std::max(tp_->NumThreads() + 1, 8);
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
for (; block_idx < nblocks - 1; block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
}

output_descs[0] = {std::move(shapes), dtype_};
Expand Down Expand Up @@ -764,6 +765,7 @@ class ImageDecoder : public StatelessOperator<Backend> {
DomainTimeRange tr(make_string("Prepare descs"), DomainTimeRange::kOrange);
auto get_task = [&](int block_idx, int nblocks) {
return [&, block_idx, nblocks](int tid) {
DomainTimeRange tr(make_string("Prepare desc #", block_idx), DomainTimeRange::kBlue1);
int i_start = decode_nsamples * block_idx / nblocks;
int i_end = decode_nsamples * (block_idx + 1) / nblocks;
for (int i = i_start; i < i_end; i++) {
Expand All @@ -773,17 +775,17 @@ class ImageDecoder : public StatelessOperator<Backend> {
};
};

int nblocks = tp_->NumThreads() + 1;
if (decode_nsamples > nblocks * 4) {
if (decode_nsamples < 16) {
get_task(0, 1)(-1); // run all in current thread
} else {
int nblocks = std::max(tp_->NumThreads() + 1, 8);
int block_idx = 0;
for (; block_idx < tp_->NumThreads(); block_idx++) {
for (; block_idx < nblocks - 1; block_idx++) {
tp_->AddWork(get_task(block_idx, nblocks), -block_idx);
}
tp_->RunAll(false); // start work but not wait
get_task(block_idx, nblocks)(-1); // run last block
tp_->WaitForWork(); // wait for the other threads
} else { // not worth parallelizing
get_task(0, 1)(-1); // run all in current thread
}

for (int orig_idx : decode_sample_idxs_) {
Expand Down
10 changes: 5 additions & 5 deletions dali/pipeline/util/thread_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ ThreadPool::ThreadPool(int num_thread, int device_id, bool set_affinity, const c
ThreadPool::~ThreadPool() {
WaitForWork(false);

std::unique_lock<std::mutex> lock(mutex_);
std::unique_lock<spinlock> lock(lock_);
running_ = false;
condition_.notify_all();
lock.unlock();
Expand All @@ -59,7 +59,7 @@ ThreadPool::~ThreadPool() {
void ThreadPool::AddWork(Work work, int64_t priority, bool start_immediately) {
bool started_before = false;
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<spinlock> lock(lock_);
work_queue_.push({priority, std::move(work)});
work_complete_ = false;
started_before = started_;
Expand All @@ -75,7 +75,7 @@ void ThreadPool::AddWork(Work work, int64_t priority, bool start_immediately) {

// Blocks until all work issued to the thread pool is complete
void ThreadPool::WaitForWork(bool checkForErrors) {
std::unique_lock<std::mutex> lock(mutex_);
std::unique_lock<spinlock> lock(lock_);
completed_.wait(lock, [this] { return this->work_complete_; });
started_ = false;
if (checkForErrors) {
Expand All @@ -93,7 +93,7 @@ void ThreadPool::WaitForWork(bool checkForErrors) {

void ThreadPool::RunAll(bool wait) {
{
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<spinlock> lock(lock_);
started_ = true;
}
condition_.notify_all(); // other threads will be waken up if needed
Expand Down Expand Up @@ -145,7 +145,7 @@ void ThreadPool::ThreadMain(int thread_id, int device_id, bool set_affinity,

while (running_) {
// Block on the condition to wait for work
std::unique_lock<std::mutex> lock(mutex_);
std::unique_lock<spinlock> lock(lock_);
condition_.wait(lock, [this] { return !running_ || (!work_queue_.empty() && started_); });
// If we're no longer running, exit the run loop
if (!running_) break;
Expand Down
8 changes: 4 additions & 4 deletions dali/pipeline/util/thread_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@
#include <utility>
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <thread>
#include <vector>
#include <string>
#include "dali/core/common.h"
#include "dali/core/spinlock.h"
#if NVML_ENABLED
#include "dali/util/nvml.h"
#endif
Expand Down Expand Up @@ -90,9 +90,9 @@ class DLL_PUBLIC ThreadPool {
bool work_complete_;
bool started_;
int active_threads_;
std::mutex mutex_;
std::condition_variable condition_;
std::condition_variable completed_;
spinlock lock_;
std::condition_variable_any condition_;
std::condition_variable_any completed_;

// Stored error strings for each thread
vector<std::queue<string>> tl_errors_;
Expand Down

0 comments on commit 18609f9

Please sign in to comment.