Skip to content

Commit

Permalink
Avoid using thread_local to work-around mingw bug
Browse files Browse the repository at this point in the history
  • Loading branch information
peterbell10 committed Dec 6, 2021
1 parent daa8bb1 commit b637fc6
Showing 1 changed file with 11 additions and 27 deletions.
38 changes: 11 additions & 27 deletions pocketfft_hdronly.h
Original file line number Diff line number Diff line change
Expand Up @@ -515,25 +515,12 @@ namespace threading {

#ifdef POCKETFFT_NO_MULTITHREADING

constexpr inline size_t thread_id() { return 0; }
constexpr inline size_t num_threads() { return 1; }

template <typename Func>
void thread_map(size_t /* nthreads */, Func f)
{ f(); }
{ f(0, 1); }

#else

inline size_t &thread_id()
{
static thread_local size_t thread_id_=0;
return thread_id_;
}
inline size_t &num_threads()
{
static thread_local size_t num_threads_=1;
return num_threads_;
}
static const size_t max_threads = std::max(1u, std::thread::hardware_concurrency());

class latch
Expand Down Expand Up @@ -786,7 +773,7 @@ void thread_map(size_t nthreads, Func f)
nthreads = max_threads;

if (nthreads == 1)
{ f(); return; }
{ f(0, 1); return; }

auto & pool = get_pool();
latch counter(nthreads);
Expand All @@ -796,9 +783,7 @@ void thread_map(size_t nthreads, Func f)
{
pool.submit(
[&f, &counter, &ex, &ex_mut, i, nthreads] {
thread_id() = i;
num_threads() = nthreads;
try { f(); }
try { f(i, nthreads); }
catch (...)
{
std::lock_guard<std::mutex> lock(ex_mut);
Expand Down Expand Up @@ -2881,15 +2866,14 @@ template<size_t N> class multi_iter
}

public:
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_)
multi_iter(const arr_info &iarr_, const arr_info &oarr_, size_t idim_,
size_t nshares, size_t myshare)
: pos(iarr_.ndim(), 0), iarr(iarr_), oarr(oarr_), p_ii(0),
str_i(iarr.stride(idim_)), p_oi(0), str_o(oarr.stride(idim_)),
idim(idim_), rem(iarr.size()/iarr.shape(idim))
{
auto nshares = threading::num_threads();
if (nshares==1) return;
if (nshares==0) throw std::runtime_error("can't run with zero threads");
auto myshare = threading::thread_id();
if (myshare>=nshares) throw std::runtime_error("impossible share requested");
size_t nbase = rem/nshares;
size_t additional = rem%nshares;
Expand Down Expand Up @@ -3134,11 +3118,11 @@ POCKETFFT_NOINLINE void general_nd(const cndarr<T> &in, ndarr<T> &out,

threading::thread_map(
util::thread_count(nthreads, in.shape(), axes[iax], VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T0>::val;
auto storage = alloc_tmp<T0>(in.shape(), len, sizeof(T));
const auto &tin(iax==0? in : out);
multi_iter<vlen> it(tin, out, axes[iax]);
multi_iter<vlen> it(tin, out, axes[iax], nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down Expand Up @@ -3241,10 +3225,10 @@ template<typename T> POCKETFFT_NOINLINE void general_r2c(
size_t len=in.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(in.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
multi_iter<vlen> it(in, out, axis, nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down Expand Up @@ -3296,10 +3280,10 @@ template<typename T> POCKETFFT_NOINLINE void general_c2r(
size_t len=out.shape(axis);
threading::thread_map(
util::thread_count(nthreads, in.shape(), axis, VLEN<T>::val),
[&] {
[&] (size_t tid, size_t nthreads) {
constexpr auto vlen = VLEN<T>::val;
auto storage = alloc_tmp<T>(out.shape(), len, sizeof(T));
multi_iter<vlen> it(in, out, axis);
multi_iter<vlen> it(in, out, axis, nthreads, tid);
#ifndef POCKETFFT_NO_VECTORS
if (vlen>1)
while (it.remaining()>=vlen)
Expand Down

0 comments on commit b637fc6

Please sign in to comment.