Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Update configuration for gemm on AMD GPUs #494

Merged
merged 4 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -495,13 +495,12 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
set(twr "${workgroup_${data}}")
set(twc "${workgroup_${data}}")

add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# General configuration
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 2 "strided" "false")

# configuration for tall_skinny
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")
Expand All @@ -518,9 +517,37 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
"${data}" 256 "true" "true" "true"
64 4 1 ${twr} ${twc} 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false")

# configuration for batch
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")

# Configurations for gemm

# low arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 1 1 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# highest arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
32 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# high arithmetic intensity
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid high 162 < a < 240
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
128 4 4 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
# mid low 100 < a < 162
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 2 2 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")

endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
Expand Down
86 changes: 83 additions & 3 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if constexpr (s_a && s_b || ((s_a && _t_b) || (s_b && _t_a))) {
return _dependencies;
} else {
// computing arithmetic ratio with combination of input to use it as
// heuristic numerator is the number of fma, denominator is the number of
// bytes access.
const auto n_fma = (static_cast<int64_t>(_M) * static_cast<int64_t>(_K) *
static_cast<int64_t>(_N));
const auto n_elem_access = (_M * _K + _K * _N + _M * _N);
const auto arith_ratio = n_fma / n_elem_access;
static constexpr int ClSize = 64;
static constexpr int tileWgSize = ClSize / sizeof(element_t);
if (batch_type == gemm_batch_type_t::interleaved) {
Expand All @@ -59,6 +66,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
}

/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 &&
Expand Down Expand Up @@ -123,18 +131,90 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
}
} else
#endif // GEMM_TALL_SKINNY_SUPPORT
if (_M * _N <= 65536) {
// Following configurations are taken using the auto tuner on amd-mi210
// and divided following their arith_ratio or another ratio between _N
// and _K input size
if ((_N >> 4) > _K) {
if (arith_ratio <= 100) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c,
_ldc, _stridec, batch_size, _dependencies);
}
} else if (arith_ratio >= 360) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
s_b, static_cast<int>(gemm_memory_t::local),
true, 32, Tile<8, 8, 16, 16>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio >= 240) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 64, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio > 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<4, 4, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio >= 100 && arith_ratio <= 162) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true,
128, Tile<2, 2, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else if (arith_ratio <= 100) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
true, 128, Tile<1, 1, 16, 8>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda,
_stridea, _b, _ldb, _strideb, _beta, _c, _ldc,
_stridec, batch_size, _dependencies);
} else {
// this branch is a safe net just in case no other branch is taken
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false,
false, ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a,
Expand Down
Loading