diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index ef7fc22d6..2ae71bc5e 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -291,6 +291,9 @@ function(add_gemm_configuration cpp_type(cpp_data ${data}) foreach(symm_a ${boolean_list}) foreach(symm_b ${boolean_list}) + if ((${data} MATCHES "complex") AND (symm_a OR symm_b)) + continue() + endif() foreach(trans_a ${boolean_list}) foreach(trans_b ${boolean_list}) foreach(is_beta_zero ${boolean_list}) @@ -591,8 +594,8 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set_complex_list(data_list_c "${supported_types}" "false") foreach(data ${data_list_c}) add_gemm_configuration( - "${data}" 64 "false" "false" "true" - 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + "${data}" 256 "false" "false" "true" + 64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index a425b2f2a..f494f25b9 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -161,10 +161,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, static constexpr int tileWgSize = ClSize / sizeof(element_t); /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT - if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -177,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -188,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 14c0cd337..e62348363 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -119,7 +119,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -130,7 +130,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index a0ce6f52a..8d788c9b5 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -222,11 +222,11 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT - if (!s_a && !s_b && batch_size == 1) { + if (batch_size == 1) { constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, - Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -239,7 +239,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 128 && _N <= 128) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, - Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -247,10 +247,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (_t_b && !_t_a && !s_a && !s_b) { + } else if (_t_b && !_t_a) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -261,7 +261,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 7d555d902..13966172e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -183,9 +183,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, - Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, + Tile<2, 2, 16, 16, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided),