Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Removed symm kernels generation from complex data types
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Oct 23, 2023
1 parent 2dc363d commit 6a0e010
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 17 deletions.
7 changes: 5 additions & 2 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,9 @@ function(add_gemm_configuration
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
foreach(is_beta_zero ${boolean_list})
Expand Down Expand Up @@ -591,8 +594,8 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 64 "false" "false" "true"
64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
"${data}" 256 "false" "false" "true"
64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
endif() # BLAS_ENABLE_COMPLEX
else() # default cpu backend
Expand Down
8 changes: 4 additions & 4 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
static constexpr int tileWgSize = ClSize / sizeof(element_t);
/* Tall & Skinny matrices. */
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) {
if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, true, true, true,
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
Expand All @@ -177,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if (_M * _N <= 65536) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false, false,
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
Expand All @@ -188,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, false, false,
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b,
ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
Expand Down
4 changes: 2 additions & 2 deletions src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if (_M <= 256 && _N <= 256 && _K <= 256) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b,
64, Tile<2, 2, 4, 4>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
Expand All @@ -130,7 +130,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b,
64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
Expand Down
12 changes: 6 additions & 6 deletions src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,11 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (!s_a && !s_b && batch_size == 1) {
if (batch_size == 1) {
constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8;
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, true, true, true, 64,
Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b,
Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
Expand All @@ -239,18 +239,18 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
if (_M <= 128 && _N <= 128) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, true, false, false, 64,
Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b,
Tile<4, 4, 8, 8>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided)>::
template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea,
_b, _ldb, _strideb, _beta, _c, _ldc, _stridec,
batch_size, _dependencies);
} else if (_t_b && !_t_a && !s_a && !s_b) {
} else if (_t_b && !_t_a) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b,
64, Tile<8, 8, 8, 8>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::no_local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::partial), is_beta_zero, 1,
Expand All @@ -261,7 +261,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
} else {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, false,
64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b,
64, Tile<4, 8, 16, 8>, _t_a, _t_b, false, false,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
Expand Down
6 changes: 3 additions & 3 deletions src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -183,9 +183,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
gemm_batch_type_t batch_type,
const typename sb_handle_t::event_t& _dependencies) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 64, false, false, true, 64,
Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
s_a, s_b, static_cast<int>(gemm_memory_t::local),
container_0_t, container_1_t, container_2_t, 256, false, false, true, 64,
Tile<2, 2, 16, 16, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b,
false, false, static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::standard),
static_cast<int>(gemm_vectorization_t::full), is_beta_zero, 1,
static_cast<int>(gemm_batch_type_t::strided),
Expand Down

0 comments on commit 6a0e010

Please sign in to comment.