Skip to content

Commit

Permalink
Optimize the rescaling code:
Browse files Browse the repository at this point in the history
- Replace division with multiplicaitons;
- Find maximum value before adding performing multiplication.
  • Loading branch information
hummingtree committed Jan 3, 2025
1 parent ea867a3 commit 53c8e31
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 28 deletions.
5 changes: 4 additions & 1 deletion include/kernels/restrictor_mma.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ namespace quda
constexpr int elements_per_thread = 16 / (sizeof(typename gmem_obj_t::store_t) * 2);
static_assert(contiguous_dim % elements_per_thread == 0, "contiguous_dim %% elements_per_thread == 0");
float block_rescale_factor = 1.0f;
float block_rescale_factor_inv = 1.0f;

using store_t = typename gmem_obj_t::store_t;

Expand Down Expand Up @@ -186,6 +187,8 @@ namespace quda
__syncthreads();

block_rescale_factor = mma::numeric_limits<mma::half>::max() / block_max_all;
constexpr float inv_max_half = 1.0f / mma::numeric_limits<mma::half>::max();
block_rescale_factor_inv = block_max_all * inv_max_half;
}

auto write_to_smem = [&](int smem_m, int smem_k, complex<store_t> a[elements_per_thread], float scale_inv_) {
Expand Down Expand Up @@ -236,7 +239,7 @@ namespace quda
loop_over<contiguous_dim, contiguous_limit, elements_per_thread>(
gmem, x_coarse, coarse_spin, contiguous_dim_offset, aggregate_k_offset, coarse_to_fine, arg, write_to_smem);

return 1.0f / block_rescale_factor;
return block_rescale_factor_inv;
}

template <typename Arg>
Expand Down
52 changes: 25 additions & 27 deletions include/targets/cuda/mma_tensor_op/gmem_loader.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ namespace quda

inline __device__ float abs_max(float a, float max) { return fmaxf(fabsf(a), max); }

inline __device__ float abs_max(short a, short max) { return ::max(::abs(a), max); }

inline __device__ float abs_max(float2 a, float max) { return fmaxf(fabsf(a.y), fmaxf(fabsf(a.x), max)); }

template <class T, int batch> struct batch_load_t {
Expand Down Expand Up @@ -316,41 +318,31 @@ namespace quda
template <bool x, bool fixed, bool dagger, int ld, class T>
inline __device__ float find_abs_max(half2, complex<T> *p, int m_idx, int n_idx, float scale_inv)
{
float this_max = 0.0f;
T this_max = 0;

if constexpr (x) {
auto xx = p[(m_idx + 0) * ld + n_idx];
auto yy = p[(m_idx + 1) * ld + n_idx];

if constexpr (fixed) {
this_max = abs_max(scale_inv * xx.real(), this_max);
this_max = abs_max(scale_inv * xx.imag(), this_max);
this_max = abs_max(scale_inv * yy.real(), this_max);
this_max = abs_max(scale_inv * yy.imag(), this_max);
} else {
this_max = abs_max(xx.real(), this_max);
this_max = abs_max(xx.imag(), this_max);
this_max = abs_max(yy.real(), this_max);
this_max = abs_max(yy.imag(), this_max);
}
this_max = abs_max(xx.real(), this_max);
this_max = abs_max(xx.imag(), this_max);
this_max = abs_max(yy.real(), this_max);
this_max = abs_max(yy.imag(), this_max);
} else {
complex<T> v[2];
batch_load_t<complex<T>, 2>::load(v, &p[n_idx * ld + m_idx]);

if constexpr (fixed) {
this_max = abs_max(scale_inv * v[0].real(), this_max);
this_max = abs_max(scale_inv * v[0].imag(), this_max);
this_max = abs_max(scale_inv * v[1].real(), this_max);
this_max = abs_max(scale_inv * v[1].imag(), this_max);
} else {
this_max = abs_max(v[0].real(), this_max);
this_max = abs_max(v[0].imag(), this_max);
this_max = abs_max(v[1].real(), this_max);
this_max = abs_max(v[1].imag(), this_max);
}
this_max = abs_max(v[0].real(), this_max);
this_max = abs_max(v[0].imag(), this_max);
this_max = abs_max(v[1].real(), this_max);
this_max = abs_max(v[1].imag(), this_max);
}

return this_max;
if constexpr (fixed) {
return scale_inv * this_max;
} else {
return this_max;
}
}

/**
Expand Down Expand Up @@ -479,9 +471,13 @@ namespace quda
constexpr int warp_cycle = (total_tiles + n_warp - 1) / n_warp;

float block_rescale_factor = 1.0f;
float block_rescale_factor_inv = 1.0f;
if constexpr (rescale) {
if constexpr (fixed) {
block_rescale_factor = scale_inv > 0 ? numeric_limits<half>::max() / (scale_inv * fixedMaxValue<T>::value) : 1.0f;
if (fixed && scale_inv > 0) {
float f = scale_inv * fixedMaxValue<T>::value;
block_rescale_factor = numeric_limits<half>::max() / f;

This comment has been minimized.

Copy link
@maddyscientist

maddyscientist Jan 3, 2025

Member

I think 1/f division could yet be removed:

  • precompute 1.0 / fixedMaxValue<T>::value
  • and then multiply by scale

This comment has been minimized.

Copy link
@hummingtree

hummingtree Jan 6, 2025

Author Member

But then scale would need to be loaded - not sure if that would be a win overall or not.

This comment has been minimized.

Copy link
@hummingtree

hummingtree Jan 7, 2025

Author Member
constexpr float inv_max_half = 1.0f / numeric_limits<half>::max();
block_rescale_factor_inv = f * inv_max_half;
} else {
float thread_max = 0.0f;
#pragma unroll
Expand Down Expand Up @@ -521,6 +517,8 @@ namespace quda
__syncthreads();

block_rescale_factor = numeric_limits<half>::max() / block_max_all;
constexpr float inv_max_half = 1.0f / numeric_limits<half>::max();
block_rescale_factor_inv = block_max_all * inv_max_half;
}
}

Expand Down Expand Up @@ -549,7 +547,7 @@ namespace quda
}
}

return 1.0f / block_rescale_factor;
return block_rescale_factor_inv;
}

template <int ld, bool dagger, bool fixed, class T, class smem_accessor_t>
Expand Down

0 comments on commit 53c8e31

Please sign in to comment.