From 693e1244f9560f197add77fe2b8f81ef91bd3d6f Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Oct 2023 13:20:05 +0100 Subject: [PATCH] Tuned gemm complex for cpu --- cmake/CmakeFunctionHelper.cmake | 7 ++----- src/interface/blas3/backend/default_cpu.hpp | 19 ++++--------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 2af1db750..8c1303f2b 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -623,13 +623,10 @@ else() # default cpu backend foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + 64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") - add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false" "false") + 64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX endif() diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 1b7dfd680..14c0cd337 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -116,10 +116,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -127,10 +127,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (!s_a && !s_b) { + } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -138,17 +138,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); } } #endif