diff --git a/include/kernels/restrictor_mma.cuh b/include/kernels/restrictor_mma.cuh index c49a3b0209..b8516c20b7 100644 --- a/include/kernels/restrictor_mma.cuh +++ b/include/kernels/restrictor_mma.cuh @@ -149,6 +149,7 @@ namespace quda constexpr int elements_per_thread = 16 / (sizeof(typename gmem_obj_t::store_t) * 2); static_assert(contiguous_dim % elements_per_thread == 0, "contiguous_dim %% elements_per_thread == 0"); float block_rescale_factor = 1.0f; + float block_rescale_factor_inv = 1.0f; using store_t = typename gmem_obj_t::store_t; @@ -186,6 +187,8 @@ namespace quda __syncthreads(); block_rescale_factor = mma::numeric_limits::max() / block_max_all; + constexpr float inv_max_half = 1.0f / mma::numeric_limits::max(); + block_rescale_factor_inv = block_max_all * inv_max_half; } auto write_to_smem = [&](int smem_m, int smem_k, complex a[elements_per_thread], float scale_inv_) { @@ -236,7 +239,7 @@ namespace quda loop_over( gmem, x_coarse, coarse_spin, contiguous_dim_offset, aggregate_k_offset, coarse_to_fine, arg, write_to_smem); - return 1.0f / block_rescale_factor; + return block_rescale_factor_inv; } template diff --git a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh index d070af66e3..3214858d10 100644 --- a/include/targets/cuda/mma_tensor_op/gmem_loader.cuh +++ b/include/targets/cuda/mma_tensor_op/gmem_loader.cuh @@ -59,6 +59,8 @@ namespace quda inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); } + inline __device__ float abs_max(short a, short max) { return ::max(::abs(a), max); } + inline __device__ float abs_max(float2 a, float max) { return fmaxf(fabsf(a.y), fmaxf(fabsf(a.x), max)); } template struct batch_load_t { @@ -316,41 +318,31 @@ namespace quda template inline __device__ float find_abs_max(half2, complex *p, int m_idx, int n_idx, float scale_inv) { - float this_max = 0.0f; + T this_max = 0; if constexpr (x) { auto xx = p[(m_idx + 0) * ld + n_idx]; auto yy = p[(m_idx + 1) * ld + n_idx]; - if constexpr (fixed) { - this_max = abs_max(scale_inv * xx.real(), this_max); - this_max = abs_max(scale_inv * xx.imag(), this_max); - this_max = abs_max(scale_inv * yy.real(), this_max); - this_max = abs_max(scale_inv * yy.imag(), this_max); - } else { - this_max = abs_max(xx.real(), this_max); - this_max = abs_max(xx.imag(), this_max); - this_max = abs_max(yy.real(), this_max); - this_max = abs_max(yy.imag(), this_max); - } + this_max = abs_max(xx.real(), this_max); + this_max = abs_max(xx.imag(), this_max); + this_max = abs_max(yy.real(), this_max); + this_max = abs_max(yy.imag(), this_max); } else { complex v[2]; batch_load_t, 2>::load(v, &p[n_idx * ld + m_idx]); - if constexpr (fixed) { - this_max = abs_max(scale_inv * v[0].real(), this_max); - this_max = abs_max(scale_inv * v[0].imag(), this_max); - this_max = abs_max(scale_inv * v[1].real(), this_max); - this_max = abs_max(scale_inv * v[1].imag(), this_max); - } else { - this_max = abs_max(v[0].real(), this_max); - this_max = abs_max(v[0].imag(), this_max); - this_max = abs_max(v[1].real(), this_max); - this_max = abs_max(v[1].imag(), this_max); - } + this_max = abs_max(v[0].real(), this_max); + this_max = abs_max(v[0].imag(), this_max); + this_max = abs_max(v[1].real(), this_max); + this_max = abs_max(v[1].imag(), this_max); } - return this_max; + if constexpr (fixed) { + return scale_inv * this_max; + } else { + return this_max; + } } /** @@ -479,9 +471,13 @@ namespace quda constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp; float block_rescale_factor = 1.0f; + float block_rescale_factor_inv = 1.0f; if constexpr (rescale) { - if constexpr (fixed) { - block_rescale_factor = scale_inv > 0 ? numeric_limits::max() / (scale_inv * fixedMaxValue::value) : 1.0f; + if (fixed && scale_inv > 0) { + float f = scale_inv * fixedMaxValue::value; + block_rescale_factor = numeric_limits::max() / f; + constexpr float inv_max_half = 1.0f / numeric_limits::max(); + block_rescale_factor_inv = f * inv_max_half; } else { float thread_max = 0.0f; #pragma unroll @@ -521,6 +517,8 @@ namespace quda __syncthreads(); block_rescale_factor = numeric_limits::max() / block_max_all; + constexpr float inv_max_half = 1.0f / numeric_limits::max(); + block_rescale_factor_inv = block_max_all * inv_max_half; } } @@ -549,7 +547,7 @@ namespace quda } } - return 1.0f / block_rescale_factor; + return block_rescale_factor_inv; } template