diff --git a/pocketfft_hdronly.h b/pocketfft_hdronly.h index d75ada6..2f239d7 100644 --- a/pocketfft_hdronly.h +++ b/pocketfft_hdronly.h @@ -515,25 +515,12 @@ namespace threading { #ifdef POCKETFFT_NO_MULTITHREADING -constexpr inline size_t thread_id() { return 0; } -constexpr inline size_t num_threads() { return 1; } - template void thread_map(size_t /* nthreads */, Func f) - { f(); } + { f(0, 1); } #else -inline size_t &thread_id() - { - static thread_local size_t thread_id_=0; - return thread_id_; - } -inline size_t &num_threads() - { - static thread_local size_t num_threads_=1; - return num_threads_; - } static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency()); class latch @@ -786,7 +773,7 @@ void thread_map(size_t nthreads, Func f) nthreads = max_threads; if (nthreads == 1) - { f(); return; } + { f(0, 1); return; } auto & pool = get_pool(); latch counter(nthreads); @@ -796,9 +783,7 @@ void thread_map(size_t nthreads, Func f) { pool.submit( [&f, &counter, &ex, &ex_mut, i, nthreads] { - thread_id() = i; - num_threads() = nthreads; - try { f(); } + try { f(i, nthreads); } catch (...) { std::lock_guard lock(ex_mut); @@ -2881,15 +2866,14 @@ template class multi_iter } public: - multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_) + multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_, + size_t nshares, size_t myshare) : pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0), str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)), idim(idim_), rem(iarr.size()/iarr.shape(idim)) { - auto nshares = threading::num_threads(); if (nshares==1) return; if (nshares==0) throw std::runtime_error("can't run with zero threads"); - auto myshare = threading::thread_id(); if (myshare>=nshares) throw std::runtime_error("impossible share requested"); size_t nbase = rem/nshares; size_t additional = rem%nshares; @@ -3134,11 +3118,11 @@ POCKETFFT_NOINLINE void general_nd(const cndarr &in, ndarr &out, threading::thread_map( util::thread_count(nthreads, in.shape(), axes[iax], VLEN::val), - [&] { + [&] (size_t tid, size_t nthreads) { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); const auto &tin(iax==0? in : out); - multi_iter it(tin, out, axes[iax]); + multi_iter it(tin, out, axes[iax], nthreads, tid); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) @@ -3241,10 +3225,10 @@ template POCKETFFT_NOINLINE void general_r2c( size_t len=in.shape(axis); threading::thread_map( util::thread_count(nthreads, in.shape(), axis, VLEN::val), - [&] { + [&] (size_t tid, size_t nthreads) { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(in.shape(), len, sizeof(T)); - multi_iter it(in, out, axis); + multi_iter it(in, out, axis, nthreads, tid); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen) @@ -3296,10 +3280,10 @@ template POCKETFFT_NOINLINE void general_c2r( size_t len=out.shape(axis); threading::thread_map( util::thread_count(nthreads, in.shape(), axis, VLEN::val), - [&] { + [&] (size_t tid, size_t nthreads) { constexpr auto vlen = VLEN::val; auto storage = alloc_tmp(out.shape(), len, sizeof(T)); - multi_iter it(in, out, axis); + multi_iter it(in, out, axis, nthreads, tid); #ifndef POCKETFFT_NO_VECTORS if (vlen>1) while (it.remaining()>=vlen)