From ac26cfe2e2d2a883a926b64c2c6f7dbe889a0a52 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 13:16:40 +0100 Subject: [PATCH 01/16] Added complex type support to gemm kernels --- CMakeLists.txt | 1 + cmake/CmakeFunctionHelper.cmake | 119 ++++++++++++++- common/include/common/common_utils.hpp | 31 ++++ include/blas_meta.h | 28 ++++ include/operations/blas_constants.h | 10 ++ src/interface/blas3/backend/amd_gpu.hpp | 75 ++++++++- src/interface/blas3/backend/default_cpu.hpp | 70 ++++++++- src/interface/blas3/backend/intel_gpu.hpp | 84 +++++++++- src/interface/blas3/backend/nvidia_gpu.hpp | 46 +++++- src/interface/gemm_interface.hpp | 29 ++-- src/operations/blas1_trees.hpp | 11 ++ src/operations/blas3/gemm_common.hpp | 22 +++ src/operations/blas3/gemm_load_store.hpp | 144 ++++++++++++++++++ src/operations/blas3/gemm_local.hpp | 2 +- .../blas3/gemm_no_local_full_vec.hpp | 45 ++++-- .../blas3/gemm_no_local_partial_vec.hpp | 25 +-- src/operations/blas3/gemm_partial_local.hpp | 4 +- 17 files changed, 677 insertions(+), 69 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a6b85f570..1037b1098 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -220,6 +220,7 @@ option(BUILD_CUBLAS_BENCHMARKS "Whether to build cuBLAS benchmarks" OFF) option(BUILD_ROCBLAS_BENCHMARKS "Whether to build rocBLAS benchmarks" OFF) option(BUILD_ACL_BENCHMARKS "Whether to build ARM Compute Library benchmarks" OFF) option(BLAS_BUILD_SAMPLES "Whether to build portBLAS samples" ON) +option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_BENCHMARK) message(STATUS "Benchmarks are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_BENCHMARK OFF) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 1be6ab8a8..8b411d5e6 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -36,10 +36,30 @@ function(cpp_type output data) if (${data} STREQUAL "half") set(${output} "cl::sycl::half" PARENT_SCOPE) return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() + elseif(${data} STREQUAL "complex") + set(${output} "cl::sycl::ext::oneapi::experimental::complex" PARENT_SCOPE) + return() endif() set(${output} "${data}" PARENT_SCOPE) endfunction() +function(set_complex_list output input append) + set(output_temp "") + if(${append} STREQUAL "true") + foreach(data ${input}) + list(APPEND output_temp "${data};complex<${data}>") + endforeach(data) + else() + foreach(data ${input}) + list(APPEND output_temp "complex<${data}>") + endforeach(data) + endif() + set(${output} ${output_temp} PARENT_SCOPE) +endfunction(set_complex_list) + ## represent the list of bolean options set(boolean_list "true" "false") @@ -56,6 +76,9 @@ function(sanitize_file_name output file_name) set(${output} "${file_name}" PARENT_SCOPE) endfunction() +#List of operators supporting Complex Data types +set(COMPLEX_OPS "gemm" "gemm_launcher" "scal") + function(set_target_compile_def in_target) #setting compiler flag for backend if(${TUNING_TARGET} STREQUAL "INTEL_GPU") @@ -84,16 +107,31 @@ function(set_target_compile_def in_target) message(STATUS "Gemm vectorization support enabled for target ${in_target}") target_compile_definitions(${in_target} PUBLIC GEMM_VECTORIZATION_SUPPORT=1) endif() - + #setting const data type support if(BLAS_ENABLE_CONST_INPUT) target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_CONST_INPUT=1) endif() + #setting complex support + if(${BLAS_ENABLE_COMPLEX}) + if("${in_target}" IN_LIST COMPLEX_OPS) + message(STATUS "Complex Data type support enabled for target ${in_target}") + target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_COMPLEX=1) + endif() + endif() endfunction() # blas unary function for generating source code function(generate_blas_objects blas_level func) set(LOCATION "${PORTBLAS_GENERATED_SRC}/${blas_level}/${func}/") - foreach(data ${data_list}) + set(data_list_c ${data_list}) + # Extend data_list to complex for each data in list + # if target function is in COMPLEX_OPS + if(BLAS_ENABLE_COMPLEX) + if("${func}" IN_LIST COMPLEX_OPS) + set_complex_list(data_list_c "${data_list}" "true") + endif() + endif() + foreach(data ${data_list_c}) cpp_type(cpp_data ${data}) foreach(index ${index_list}) foreach(increment ${index_list}) @@ -234,7 +272,11 @@ function(add_gemm_configuration batch_type use_joint_matrix ) - if(NOT ("${data}" IN_LIST data_list)) + set(data_list_c ${data_list}) + if(BLAS_ENABLE_COMPLEX) + set_complex_list(data_list_c "${data_list}" "true") + endif() + if(NOT ("${data}" IN_LIST data_list_c)) # Data type not enabled, skip configuration return() endif() @@ -380,12 +422,36 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU") "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "true" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 4 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false") + add_gemm_configuration( + "${data}" 32 "true" "true" "true" + 64 2 1 8 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) set(supported_types "float" "half" ) - foreach(data ${supported_types}) + set(data_list_c ${supported_types}) + if(BLAS_ENABLE_COMPLEX) + set_complex_list(data_list_c "${supported_types}" "false") + endif() + foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 96 "true" "false" "false" 16 4 6 12 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") @@ -445,6 +511,23 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation "${data}" 64 "false" "false" "false" 64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set(supported_types "float" @@ -486,7 +569,18 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") add_gemm_configuration( "${data}" 256 "false" "true" "true" 128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "false" "false" "true" + 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() + endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend set(supported_types "float" @@ -513,6 +607,23 @@ else() # default cpu backend "${data}" 64 "false" "false" "false" 64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false") endforeach() + if(BLAS_ENABLE_COMPLEX) + # Extract list of complex for each data in supported_types + # list for complex specific gemm configurations + set(data_list_c) + set_complex_list(data_list_c "${supported_types}" "false") + foreach(data ${data_list_c}) + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") + add_gemm_configuration( + "${data}" 64 "false" "false" "false" + 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false" "false") + endforeach() + endif() # BLAS_ENABLE_COMPLEX endif() add_library(${func} OBJECT ${gemm_sources}) set_target_compile_def(${func}) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index a569ed2ff..df14ac062 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1372,6 +1372,37 @@ static inline std::vector random_data(size_t size) { return v; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline complex_std random_scalar() { + scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); + scalar_t im = 1e-3 * ((rand() % 2000) - 1000); + return complex_std({rl, im}); +} + +template +static inline complex_std random_scalar(scalar_t rangeMin, + scalar_t rangeMax) { + static std::random_device rd; + static std::default_random_engine gen(rd()); + std::uniform_real_distribution disRl(rangeMin, rangeMax); + std::uniform_real_distribution disIm(rangeMin, rangeMax); + + return complex_std({disRl(gen), disIm(gen)}); +} + +template +static inline std::vector> random_data(size_t size) { + std::vector> v = + std::vector>(size); + + for (scalar_t& e : v) { + e = random_scalar(scalar_t{-2}, scalar_t{5}); + } + return v; +} +#endif + /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda diff --git a/include/blas_meta.h b/include/blas_meta.h index 6bad4be98..a7634dbca 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -29,6 +29,11 @@ #include #include #include +#ifdef BLAS_ENABLE_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#include +#include +#endif namespace blas { @@ -190,6 +195,29 @@ struct is_sycl_scalar : std::false_type {}; template <> struct is_sycl_scalar : std::false_type {}; +#ifdef BLAS_ENABLE_COMPLEX +// SYCL Complex type alias +template +using complex_sycl = typename cl::sycl::ext::oneapi::experimental::complex; + +template +struct is_complex_sycl + : std::integral_constant> || + std::is_same_v>> {}; + +// STD Complex type alias +template +using complex_std = typename std::complex; + +template +struct is_complex_std + : std::integral_constant> || + std::is_same_v>> {}; + +#endif + } // namespace blas #endif // BLAS_META_H diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 103c78152..5fc4afb82 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -210,6 +210,16 @@ struct constant, Indicator> { } }; +#ifdef BLAS_ENABLE_COMPLEX +template +struct constant, Indicator> { + constexpr static PORTBLAS_INLINE complex_sycl value() { + return complex_sycl(constant::value(), + constant::value()); + } +}; +#endif + #ifdef BLAS_DATA_TYPE_HALF template <> struct constant diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index be864ae76..3aff8dd46 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { static constexpr int ClSize = 64; static constexpr int tileWgSize = ClSize / sizeof(element_t); if (batch_type == gemm_batch_type_t::interleaved) { @@ -142,6 +147,62 @@ typename sb_handle_t::event_t _gemm( batch_size, _dependencies); } } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + static constexpr int ClSize = 64; +/* Tall & Skinny matrices. */ +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, true, true, true, + ClSize, Tile<1, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<1, 1, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 17868991e..44a99d1fe 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -101,6 +106,57 @@ typename sb_handle_t::event_t _gemm( #endif } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else if (!s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index 8fcb3e3a8..e22274008 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -32,13 +32,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -206,6 +211,71 @@ typename sb_handle_t::event_t _gemm( batch_size, _dependencies); } } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { +#ifdef GEMM_TALL_SKINNY_SUPPORT + if (!s_a && !s_b && batch_size == 1) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 32, true, true, true, 64, + Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::tall_skinny), + static_cast(gemm_vectorization_t::none), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +#endif + if (_M <= 128 && _N <= 128) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, + Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else if (_t_b && !_t_a && !s_a && !s_b) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::no_local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, false, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, + _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, + batch_size, _dependencies); + } +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index aeb678704..f13a95d2e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -33,13 +33,18 @@ namespace backend { template -typename sb_handle_t::event_t _gemm( - sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, - element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, - container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, - container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, - gemm_batch_type_t batch_type, - const typename sb_handle_t::event_t& _dependencies) { +#ifdef BLAS_ENABLE_COMPLEX +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +#else +typename sb_handle_t::event_t +#endif +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { if (batch_type == gemm_batch_type_t::interleaved) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, @@ -167,6 +172,33 @@ typename sb_handle_t::event_t _gemm( _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); } + +// Complex Configurations +#ifdef BLAS_ENABLE_COMPLEX +template +typename std::enable_if::value, + typename sb_handle_t::event_t>::type +_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, + element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, + container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, + container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, + gemm_batch_type_t batch_type, + const typename sb_handle_t::event_t& _dependencies) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, + Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + s_a, s_b, static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided), + false>::template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, + _ldc, _stridec, batch_size, _dependencies); +} +#endif + } // namespace backend } // namespace gemm } // namespace blas diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index a5c2c7bb3..f5b7383e6 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -48,6 +48,18 @@ namespace blas { */ namespace internal { +// Check whether value is zero (complex & float/double) +template +inline bool isZero(const T& value) { +#ifdef BLAS_ENABLE_COMPLEX + if constexpr (is_complex_sycl::value) { + using value_t = typename T::value_type; + return (value == T(value_t(0), value_t(0))); + } +#endif + return (value == static_cast(0)); +} + template @@ -73,15 +85,14 @@ typename sb_handle_t::event_t _gemm_is_beta_zero( container_2_t _C, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - return ((_beta == static_cast(0)) - ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies) - : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( - sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, b_, _ldb, - _strideb, _beta, _C, _ldc, _stridec, batch_size, batch_type, - _dependencies)); + return isZero(_beta) ? _gemm_platform_specific<_t_a, _t_b, s_a, s_b, true>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies) + : _gemm_platform_specific<_t_a, _t_b, s_a, s_b, false>( + sb_handle, _M, _N, _K, _alpha, a_, _lda, _stridea, + b_, _ldb, _strideb, _beta, _C, _ldc, _stridec, + batch_size, batch_type, _dependencies); } template > { static element_t get_scalar(element_t &scalar) { return scalar; } }; +#ifdef BLAS_ENABLE_COMPLEX +/*! DetectScalar (for sycl::complex) + * @brief See Detect Scalar. + */ +template +struct DetectScalar> { + using element_t = complex_sycl; + static element_t get_scalar(element_t &scalar) { return scalar; } +}; +#endif + /*! get_scalar. * @brief Template autodecuction function for DetectScalar. */ diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 4966b9f13..6923f492b 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -33,6 +33,28 @@ namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (a * b + c); +} + +template +static PORTBLAS_INLINE T +mul_add(T a, T b, T c, + typename std::enable_if::value>::type * = 0) { + return (sycl::mad(a, b, c)); +} +#else + +template +static PORTBLAS_INLINE T mul_add(T a, T b, T c) { + return (sycl::mad(a, b, c)); +} +#endif + template struct type_string { static const char *get_value() { return "unknown"; } diff --git a/src/operations/blas3/gemm_load_store.hpp b/src/operations/blas3/gemm_load_store.hpp index ef44cbfe6..7ae45ce5d 100644 --- a/src/operations/blas3/gemm_load_store.hpp +++ b/src/operations/blas3/gemm_load_store.hpp @@ -125,5 +125,149 @@ struct Packetize { } }; +#ifdef BLAS_ENABLE_COMPLEX +/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in + * Packetize. It serves as a temporary workaround to the upcoming + * sycl::vec container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif + } // namespace blas #endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 9b1c1c98b..db0fe6f14 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -754,7 +754,7 @@ class Gemm(reg_a[l], reg_b, reg_res[j * item_rows + l]); } } A = A + ldsa; diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index a5dc683f3..732cc9568 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -69,6 +69,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -114,8 +115,8 @@ class Gemm(check_boundary( dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( @@ -552,7 +555,9 @@ class Gemm(is_valid_row(j * ptr_next + work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -630,7 +635,9 @@ class Gemm(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // if in range perform a vectorised load in_vec.template load( @@ -705,7 +712,9 @@ class Gemm(is_valid_row(work_per_load - 1)) && do_check(is_valid_col(col_ofs)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -768,7 +777,9 @@ class Gemm(is_valid_row(row_ofs)) && do_check(is_valid_col(work_per_load - 1)); - cl::sycl::vec in_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{}; if (in_range) { // If in range perform a vectorised load. in_vec.template load( @@ -808,7 +819,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -860,7 +871,7 @@ class Gemm(reg_a[j], *reg_b, reg_res[j]); } } @@ -887,11 +898,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -901,7 +912,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index eb3d19473..189de963b 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -69,6 +69,7 @@ class Gemm::type; using address_t = cl::sycl::access::address_space; using packetize_t = Packetize; + using vector_t = typename packetize_t::PacketType; static constexpr int local_memory_size = 0; /*! @brief The number of rows processed by each work item */ static constexpr index_t item_rows = tile_type::item_rows; @@ -458,7 +459,9 @@ class Gemm(chk_boundary(index + (work_per_load - 1))); - cl::sycl::vec in_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t in_vec{0}; if (in_range) { in_vec.template load( 0, @@ -488,7 +491,7 @@ class Gemm(reg_a[j], reg_b[i], reg_res[i * item_rows + j]); } } } @@ -502,7 +505,9 @@ class Gemm PORTBLAS_INLINE typename std::enable_if::type store_packet( element_t *reg, OutputPointerType out_ptr) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr(reg)); @@ -531,11 +536,11 @@ class Gemm PORTBLAS_INLINE void store(PointerType C, element_t *reg_res, - const index_t &dim_m_c_start, - const index_t &dim_n_c_start, - const check_boundary &chk_boundary, - const bool out_of_range, - const index_t &ldc) noexcept { + const index_t &dim_m_c_start, + const index_t &dim_n_c_start, + const check_boundary &chk_boundary, + const bool out_of_range, + const index_t &ldc) noexcept { if (out_of_range) { return; } @@ -545,7 +550,9 @@ class Gemm(chk_boundary(dim_m_c_start + j * wg_rows, dim_n_c_start + i * wg_cols))) { - cl::sycl::vec out_vec{0}; + using l_vector_t = + typename Packetize::PacketType; + l_vector_t out_vec{0}; out_vec.template load( 0, cl::sycl::multi_ptr( diff --git a/src/operations/blas3/gemm_partial_local.hpp b/src/operations/blas3/gemm_partial_local.hpp index a9de19fb8..a6f8bf30a 100644 --- a/src/operations/blas3/gemm_partial_local.hpp +++ b/src/operations/blas3/gemm_partial_local.hpp @@ -309,8 +309,8 @@ class GemmPartial( + privateLhs, privateRhs, private_res[wLPTM + idx]); lhs_index += tile_type::wg_rows; } From a4d6b8fac9d6543088ec318ad41cdfe8b06e6b38 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 13:17:45 +0100 Subject: [PATCH 02/16] Added unit tests for complex type gemm operators --- common/include/common/float_comparison.hpp | 98 ++++- .../include/common/system_reference_blas.hpp | 15 + test/blas_test.hpp | 48 ++- test/blas_test_macros.hpp | 42 +++ test/unittest/CMakeLists.txt | 5 + .../blas3/blas3_gemm_batched_test.cpp | 64 ++++ test/unittest/blas3/blas3_gemm_common.hpp | 338 +++++++++++++++++- .../blas3/blas3_gemm_tall_skinny_test.cpp | 78 ++++ test/unittest/blas3/blas3_gemm_test.cpp | 118 ++++++ 9 files changed, 803 insertions(+), 3 deletions(-) diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index 43f8f578b..e244f0d5a 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -28,6 +28,9 @@ #include #include +#ifdef BLAS_ENABLE_COMPLEX +#include +#endif #ifdef BLAS_DATA_TYPE_HALF #if SYCL_LANGUAGE_VERSION < 202000 @@ -65,6 +68,23 @@ scalar_t abs(scalar_t value) noexcept { return std::abs(value); } +#ifdef BLAS_ENABLE_COMPLEX +template +bool isnan(std::complex value) noexcept { + return (isnan(value.imag()) || isnan(value.imag())); +} + +template +bool isinf(std::complex value) noexcept { + return (isinf(value.imag()) || isinf(value.imag())); +} + +template +scalar_t abs(std::complex value) noexcept { + return std::abs(value); +} +#endif + #ifdef BLAS_DATA_TYPE_HALF template <> inline bool isnan(cl::sycl::half value) noexcept { @@ -172,7 +192,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) { return true; } - const scalar_t absolute_diff = utils::abs(scalar1 - scalar2); + const auto absolute_diff = utils::abs(scalar1 - scalar2); // Close to zero, the relative error doesn't work, use absolute error if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} || @@ -212,6 +232,37 @@ inline bool compare_vectors(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data and returns false if the difference is + * not acceptable. The second vector is considered the reference. + * @tparam scalar_t the type of complex underying data present in the input + * vectors + * @tparam epilon_t the type used as tolerance. + */ +template +inline bool compare_vectors(std::vector> const& vec, + std::vector> const& ref, + std::ostream& err_stream = std::cerr, + std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + for (int i = 0; i < vec.size(); ++i) { + if (!almost_equal, epsilon_t>(vec[i], ref[i])) { + err_stream << "Value mismatch at index " << i << ": (" << vec[i].real() + << "," << vec[i].imag() << "); expected (" << ref[i].real() + << "," << ref[i].imag() << ")" << end_line; + return false; + } + } + return true; +} +#endif + /** * Compare two vectors at a given stride and window (unit_vec_size) and returns * false if the difference is not acceptable. The second vector is considered @@ -253,6 +304,51 @@ inline bool compare_vectors_strided(std::vector const& vec, return true; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * Compare two vectors of complex data at a given stride and window and returns + * false if the difference is not acceptable. The second vector is considered + * the reference. + * @tparam scalar_t the type of the complex underying data present in the input + * vectors + * @tparam epsilon_t the type used as tolerance. + * @param stride is the stride between two consecutive 'windows' + * @param window is the size of a comparison window + */ +template +inline bool compare_vectors_strided( + std::vector> const& vec, + std::vector> const& ref, int stride, int window, + std::ostream& err_stream = std::cerr, std::string end_line = "\n") { + if (vec.size() != ref.size()) { + err_stream << "Error: tried to compare vectors of different sizes" + << std::endl; + return false; + } + + int k = 0; + + // Loop over windows + while (window + (k + 1) * stride < vec.size()) { + // Loop within a window + for (int i = 0; i < window; ++i) { + auto index = i + k * stride; + if (!almost_equal, epsilon_t>(vec[index], + ref[index])) { + err_stream << "Value mismatch at index " << index << ": (" + << vec[index].real() << "," << vec[index].imag() + << "); expected (" << ref[index].real() << "," + << ref[index].imag() << ")" << end_line; + return false; + } + } + k += 1; + } + + return true; +} +#endif + } // namespace utils #endif // UTILS_FLOAT_COMPARISON_H_ diff --git a/common/include/common/system_reference_blas.hpp b/common/include/common/system_reference_blas.hpp index afcb4f5e4..cd07e27cf 100644 --- a/common/include/common/system_reference_blas.hpp +++ b/common/include/common/system_reference_blas.hpp @@ -133,6 +133,12 @@ auto blas_system_function(floatfn_t ffn, doublefn_t dfn) return BlasSystemFunction::get(ffn, dfn); } +template +auto blas_cplx_system_function(floatfn_t ffn, doublefn_t dfn) + -> decltype(BlasSystemFunction::get(ffn, dfn)) { + return BlasSystemFunction::get(ffn, dfn); +} + // ======= // Level 1 // ======= @@ -378,6 +384,15 @@ void gemm(const char *transA, const char *transB, int m, int n, int k, lda, b, ldb, beta, c, ldc); } +template +void cgemm(const char *transA, const char *transB, int m, int n, int k, + const void *alpha, const void *a, int lda, const void *b, int ldb, + const void *beta, void *c, int ldc) { + auto func = blas_cplx_system_function(&cblas_cgemm, &cblas_zgemm); + func(CblasColMajor, c_trans(*transA), c_trans(*transB), m, n, k, alpha, a, + lda, b, ldb, beta, c, ldc); +} + template void trsm(const char *side, const char *uplo, const char *trans, const char *diag, int m, int n, scalar_t alpha, const scalar_t A[], diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 1d0f39de3..70a32d61c 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -149,6 +149,34 @@ static inline void fill_random(std::vector &vec) { fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + * @param vec Input vector to fill + * @param rangeMin Minimum value for the uniform distribution (real & imag) + * @param rangeMax Maximum value for the uniform distribution (real & imag) + */ +template +static inline void fill_random_with_range( + std::vector> &vec, scalar_t rangeMin, + scalar_t rangeMax) { + for (complex_std &e : vec) { + e = complex_std{random_scalar(rangeMin, rangeMax), + random_scalar(rangeMin, rangeMax)}; + } +} + +/** + * @brief Generates a random vector of std::complex values, using a + * uniform distribution. + */ +template +static inline void fill_random(std::vector> &vec) { + fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); +} +#endif + /** * @brief Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda @@ -165,7 +193,7 @@ static inline void fill_random(std::vector &vec) { * @param unused Value to put in the unused parts of the matrix */ template -static inline void fill_trsm_matrix(std::vector& A, size_t k, +static inline void fill_trsm_matrix(std::vector &A, size_t k, size_t lda, char uplo, char unit_diag, scalar_t diag = scalar_t{1}, scalar_t unused = scalar_t{0}) { @@ -262,6 +290,24 @@ struct dump_arg_helper { } }; +#ifdef BLAS_ENABLE_COMPLEX +/** Specialization of dump_arg_helper for std::complex types. + * This is required to split the real & imag parts properly and avoid + * by-default parentheses format. + **/ +template +struct dump_arg_helper< + T, typename std::enable_if::value>::type> { + inline void operator()(std::ostream &ss, T f) { + using scalar_t = typename T::value_type; + dump_arg_helper{}(ss, f.real()); + ss << "r"; + dump_arg_helper{}(ss, f.imag()); + ss << "i"; + } +}; +#endif + /** * Type of the tested api */ diff --git a/test/blas_test_macros.hpp b/test/blas_test_macros.hpp index 5b4cf979c..89e733e60 100644 --- a/test/blas_test_macros.hpp +++ b/test/blas_test_macros.hpp @@ -93,6 +93,36 @@ combination, name_generator) #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxFloat \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxFloat, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxFloat, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX + +#if defined(BLAS_DATA_TYPE_DOUBLE) & defined(BLAS_ENABLE_COMPLEX) +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + class class_name##CplxDouble \ + : public ::testing::TestWithParam> {}; \ + TEST_P(class_name##CplxDouble, test) { test_function(GetParam()); }; \ + INSTANTIATE_TEST_SUITE_P(test_suite, class_name##CplxDouble, \ + combination, name_generator); +#else +#define BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) +#endif // BLAS_ENABLE_COMPLEX & BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @param test_suite Name of the test suite * @param class_name Base name of the test class @@ -115,6 +145,18 @@ combination_t, combination, \ name_generator); +#ifdef BLAS_ENABLE_COMPLEX +#define BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, class_name, \ + test_function, combination_t, \ + combination, name_generator) \ + BLAS_REGISTER_TEST_CPLX_S_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); \ + BLAS_REGISTER_TEST_CPLX_D_CUSTOM_NAME(test_suite, class_name, test_function, \ + combination_t, combination, \ + name_generator); +#endif // BLAS_ENABLE_COMPLEX + /** Registers test for all supported data types * @see BLAS_REGISTER_TEST_CUSTOM_NAME */ diff --git a/test/unittest/CMakeLists.txt b/test/unittest/CMakeLists.txt index 4f824238d..b4d2b0a3b 100644 --- a/test/unittest/CMakeLists.txt +++ b/test/unittest/CMakeLists.txt @@ -116,6 +116,11 @@ foreach(blas_test ${SYCL_UNITTEST_SRCS}) if(STRESS_TESTING) target_compile_definitions(${test_exec} PRIVATE STRESS_TESTING) endif() + if(${BLAS_ENABLE_COMPLEX}) + if(${test_exec} MATCHES "gemm") + target_compile_definitions(${test_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() target_compile_definitions(${test_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_TEST_INDEX_TYPE}) target_link_libraries(${test_exec} PRIVATE gtest_main Clara::Clara blas::blas portblas) target_include_directories(${test_exec} PRIVATE ${CBLAS_INCLUDE} ${PORTBLAS_COMMON_INCLUDE_DIR}) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 1ce9413bd..6794ff56c 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -145,3 +145,67 @@ const auto AllStridedBatched = ::testing::Values(1, 2, 3) // stride_c_mul ); GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, AllStridedBatched); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(5), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(BatchGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1, 5), // batch + ::testing::Values(63, 128), // m + ::testing::Values(63, 128), // n + ::testing::Values(63, 128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(1), // stride_a_mul + ::testing::Values(1), // stride_b_mul + ::testing::Values(1) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, + CplxDefaultGemmAndGemmBatched); + +template +const auto CplxAllStridedBatched = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(5), // batch + ::testing::Values(128), // m + ::testing::Values(128), // n + ::testing::Values(128), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({2.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(0, 1, 2), // stride_a_mul + ::testing::Values(0, 1, 2), // stride_b_mul + ::testing::Values(1, 2, 3) // stride_c_mul +); +GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, CplxAllStridedBatched); +#endif diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 48bd28128..d28baa99a 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -37,6 +37,19 @@ using gemm_batched_strided_arguments_t = std::tuple; +#ifdef BLAS_ENABLE_COMPLEX +template +using gemm_cplx_arguments_t = + std::tuple, std::complex, int, int, int, + gemm_batch_type_t>; + +template +using gemm_cplx_batched_strided_arguments_t = + std::tuple, std::complex, int, int, int, int, int, int>; +#endif + // Convert batch_type=strided to interleaved on the host template inline std::vector strided_to_interleaved( @@ -383,4 +396,327 @@ static std::string generate_batched_strided_name( BLAS_REGISTER_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ verify_gemm, \ gemm_batched_strided_arguments_t, \ - combination, generate_batched_strided_name); \ No newline at end of file + combination, generate_batched_strided_name); + +#ifdef BLAS_ENABLE_COMPLEX + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t buffer_size_a = batch * size_a + offset; + const index_t buffer_size_b = batch * size_b + offset; + const index_t buffer_size_c = batch * size_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * size_a + offset), lda, + reinterpret_cast(b_m.data() + i * size_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); + } + + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported + GTEST_SKIP(); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM implementation + typename blas::SB_Handle::event_t gemm_event; + if (batch == index_t(1)) { + gemm_event = _gemm(sb_handle, transa, transb, m, n, k, alpha_sycl, + m_a_gpu + offset, lda, m_b_gpu + offset, ldb, beta_sycl, + m_c_gpu + offset, ldc, {copy_a, copy_b, copy_c}); + } else { + return; + _gemm_batched(sb_handle, transa, transb, m, n, k, alpha, m_a_gpu + offset, + lda, m_b_gpu + offset, ldb, beta, m_c_gpu + offset, ldc, + batch, batch_type, {copy_a, copy_b, copy_c}); + } + sb_handle.wait(gemm_event); + + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = utils::compare_vectors(c_m_gpu, c_m_cpu); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm(const gemm_cplx_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + gemm_batch_type_t batch_type; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, batch_type) = arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#else + GTEST_SKIP(); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_name( + const ::testing::TestParamInfo>& info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; + char transa, transb; + complex_std alpha, beta; + gemm_batch_type_t batchType; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, batchType); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + const char ta_str[2] = {transa, '\0'}; + const char tb_str[2] = {transb, '\0'}; + + auto q = make_queue(); + blas::SB_Handle sb_handle(q); + + const index_t lda = ((transa != 'n') ? k : m) * lda_mul; + const index_t ldb = ((transb != 'n') ? n : k) * ldb_mul; + const index_t ldc = m * ldc_mul; + + const index_t size_a = m * k * lda_mul; + const index_t size_b = k * n * ldb_mul; + const index_t size_c = m * n * ldc_mul; + + const index_t stride_a = stride_a_mul * size_a; + const index_t stride_b = stride_b_mul * size_b; + const index_t stride_c = stride_c_mul * size_c; + + const index_t buffer_size_a = size_a + (batch - 1) * stride_a + offset; + const index_t buffer_size_b = size_b + (batch - 1) * stride_b + offset; + const index_t buffer_size_c = size_c + (batch - 1) * stride_c + offset; + + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); + + fill_random(a_m); + fill_random(b_m); + fill_random(c_m_gpu); + std::vector> c_m_cpu = c_m_gpu; + + // Use system blas to create a reference output + for (int i = 0; i < batch; ++i) { + reference_blas::cgemm( + ta_str, tb_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a_m.data() + i * stride_a + offset), lda, + reinterpret_cast(b_m.data() + i * stride_b + offset), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_m_cpu.data() + i * stride_c + offset), ldc); + } + + auto m_a_gpu = blas::helper::allocate>( + buffer_size_a, q); + auto m_b_gpu = blas::helper::allocate>( + buffer_size_b, q); + auto m_c_gpu = blas::helper::allocate>( + buffer_size_c, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a_m.data()), m_a_gpu, + buffer_size_a); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b_m.data()), m_b_gpu, + buffer_size_b); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_m_gpu.data()), m_c_gpu, + buffer_size_c); + + complex_sycl alpha_sycl(alpha); + complex_sycl beta_sycl(beta); + + // portBLAS GEMM STRIDED BATCHED implementation + auto gemm_batched_event = _gemm_strided_batched( + sb_handle, transa, transb, m, n, k, alpha_sycl, m_a_gpu + offset, lda, + stride_a, m_b_gpu + offset, ldb, stride_b, beta_sycl, m_c_gpu + offset, + ldc, stride_c, batch, {copy_a, copy_b, copy_c}); + + sb_handle.wait({gemm_batched_event}); + auto event = blas::helper::copy_to_host( + q, m_c_gpu, reinterpret_cast*>(c_m_gpu.data()), + buffer_size_c); + sb_handle.wait(event); + + const bool isAlmostEqual = + (stride_c_mul == 1) + ? utils::compare_vectors(c_m_gpu, c_m_cpu) + : utils::compare_vectors_strided(c_m_gpu, c_m_cpu, stride_c, size_c); + ASSERT_TRUE(isAlmostEqual); + + helper::deallocate(m_a_gpu, q); + helper::deallocate(m_b_gpu, q); + helper::deallocate(m_c_gpu, q); +} + +template +inline void verify_gemm( + const gemm_cplx_batched_strided_arguments_t arguments) { + std::string alloc; + index_t offset; + index_t batch; + index_t m; + index_t n; + index_t k; + char transa; + char transb; + complex_std alpha; + complex_std beta; + index_t lda_mul; + index_t ldb_mul; + index_t ldc_mul; + index_t stride_a_mul; + index_t stride_b_mul; + index_t stride_c_mul; + std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, + ldb_mul, ldc_mul, stride_a_mul, stride_b_mul, stride_c_mul) = + arguments; + + if (alloc == "usm") { +#ifdef SB_ENABLE_USM + verify_gemm(arguments); +#endif + } else { + verify_gemm(arguments); + } +} + +template +static std::string generate_cplx_batched_strided_name( + const ::testing::TestParamInfo>& + info) { + std::string alloc; + int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul; + char transa, transb; + complex_std alpha, beta; + BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, + alpha, beta, ldaMul, ldbMul, ldcMul, stride_a_mul, + stride_b_mul, stride_c_mul); +} + +/** Registers GEMM test for all supported complex data types + * @param test_suite Name of the test suite + * @param combination Combinations object + * @see BLAS_REGISTER_TEST_CUSTOM_NAME + */ +#define GENERATE_CPLX_GEMM_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME(test_suite, test_suite##combination, \ + verify_gemm, gemm_cplx_arguments_t, \ + combination, generate_cplx_name); + +#define GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(test_suite, combination) \ + BLAS_REGISTER_CPLX_TEST_CUSTOM_NAME( \ + test_suite, test_suite##combination, verify_gemm, \ + gemm_cplx_batched_strided_arguments_t, combination, \ + generate_cplx_batched_strided_name); + +#endif diff --git a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp index 5e156b7c5..4eeee3cde 100644 --- a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp +++ b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp @@ -101,3 +101,81 @@ const auto OffsetNonZero = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(TallSkinnyGemm, OffsetNonZero); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 65), // m + ::testing::Values(9, 126), // n + ::testing::Values(2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.5}), // alpha + ::testing::Values>({0.5, 0.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMatch); + +template +const auto CplxBetaNonZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7, 65), // m + ::testing::Values(9, 126), // n + ::testing::Values(2049), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 0.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaNonZeroLDMultiplied); + +template +const auto CplxBetaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxBetaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(10), // offset + ::testing::Values(1), // batch + ::testing::Values(7), // m + ::testing::Values(9), // n + ::testing::Values(1026), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 2.5}), // alpha + ::testing::Values>({0.5, 1.5}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(TallSkinnyGemm, CplxOffsetNonZero); +#endif diff --git a/test/unittest/blas3/blas3_gemm_test.cpp b/test/unittest/blas3/blas3_gemm_test.cpp index e5d4a4122..acf4c85d8 100644 --- a/test/unittest/blas3/blas3_gemm_test.cpp +++ b/test/unittest/blas3/blas3_gemm_test.cpp @@ -139,3 +139,121 @@ const auto LargeBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_GEMM_TEST(Gemm, LargeBetaNonZeroLDMatch); + +#ifdef BLAS_ENABLE_COMPLEX +template +const auto CplxSmallBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 16, 32), // m + ::testing::Values(11, 16, 32), // n + ::testing::Values(16, 17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaNonZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 32), // m + ::testing::Values(11, 32), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 1.0}), // alpha + ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMatch); + +template +const auto CplxSmallBetaZeroLDMultiplied = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(11, 32), // m + ::testing::Values(11, 32), // n + ::testing::Values(17), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.5, 3.0}), // alpha + ::testing::Values>({0.0, 0.0}), // beta + ::testing::Values(2), // lda_mul + ::testing::Values(3), // ldb_mul + ::testing::Values(4), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMultiplied); + +template +const auto CplxAlphaZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16), // m + ::testing::Values(16), // n + ::testing::Values(17), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({0.0, 0.0}), // alpha + ::testing::Values(std::complex{0.0, 0.0}, + std::complex{1.0, 0.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxAlphaZero); + +template +const auto CplxOffsetNonZero = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(1, 10), // offset + ::testing::Values(1), // batch + ::testing::Values(16, 63), // m + ::testing::Values(16, 63), // n + ::testing::Values(17, 63), // k + ::testing::Values('n'), // transa + ::testing::Values('n'), // transb + ::testing::Values>({1.0, 1.0}), // alpha + ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values(1, 2), // lda_mul + ::testing::Values(1, 2), // ldb_mul + ::testing::Values(1, 2), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxOffsetNonZero); + +template +const auto CplxLargeBetaNonZeroLDMatch = ::testing::Combine( + ::testing::Values("usm", "buf"), // allocation type + ::testing::Values(0), // offset + ::testing::Values(1), // batch + ::testing::Values(253, 511), // m + ::testing::Values(257, 511), // n + ::testing::Values(253, 511), // k + ::testing::Values('n', 't'), // transa + ::testing::Values('n', 't'), // transb + ::testing::Values>({1.0, 1.0}), // alpha + ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values(1), // lda_mul + ::testing::Values(1), // ldb_mul + ::testing::Values(1), // ldc_mul + ::testing::Values(gemm_batch_type_t::strided) // batch_type +); +GENERATE_CPLX_GEMM_TEST(Gemm, CplxLargeBetaNonZeroLDMatch); + +#endif From 58b8c350a5853244b18291e341724602096eed23 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 14:20:43 +0100 Subject: [PATCH 03/16] Minor fixes --- cmake/CmakeFunctionHelper.cmake | 6 +----- src/operations/blas3/gemm_common.hpp | 7 ++++--- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 8b411d5e6..fff0f923f 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -447,11 +447,7 @@ elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) "float" "half" ) - set(data_list_c ${supported_types}) - if(BLAS_ENABLE_COMPLEX) - set_complex_list(data_list_c "${supported_types}" "false") - endif() - foreach(data ${data_list_c}) + foreach(data ${supported_types}) add_gemm_configuration( "${data}" 96 "true" "false" "false" 16 4 6 12 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index 6923f492b..f1817e70a 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -45,13 +45,13 @@ template static PORTBLAS_INLINE T mul_add(T a, T b, T c, typename std::enable_if::value>::type * = 0) { - return (sycl::mad(a, b, c)); + return (cl::sycl::mad(a, b, c)); } #else template static PORTBLAS_INLINE T mul_add(T a, T b, T c) { - return (sycl::mad(a, b, c)); + return (cl::sycl::mad(a, b, c)); } #endif @@ -84,7 +84,8 @@ template PORTBLAS_INLINE std::string Tile::get_type_string() noexcept { + ItemBatchs, WgBatchs, jm_M, jm_N, jm_K, inp_jmT, + out_jmT>::get_type_string() noexcept { std::ostringstream str{}; str << "Tile<" << item_rows << ", " << item_cols << ", " << wg_rows << ", " << wg_cols << ", " << sg_rows << ", " << sg_cols << ", " << tl_rows From bfc56ba9393166519bc98f9054c487ce940e8d1e Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 11 Sep 2023 16:38:05 +0100 Subject: [PATCH 04/16] Typo fix --- common/include/common/float_comparison.hpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/common/include/common/float_comparison.hpp b/common/include/common/float_comparison.hpp index e244f0d5a..1222ccc41 100644 --- a/common/include/common/float_comparison.hpp +++ b/common/include/common/float_comparison.hpp @@ -71,12 +71,12 @@ scalar_t abs(scalar_t value) noexcept { #ifdef BLAS_ENABLE_COMPLEX template bool isnan(std::complex value) noexcept { - return (isnan(value.imag()) || isnan(value.imag())); + return (isnan(value.real()) || isnan(value.imag())); } template bool isinf(std::complex value) noexcept { - return (isinf(value.imag()) || isinf(value.imag())); + return (isinf(value.real()) || isinf(value.imag())); } template From 3f80316c9cd53af9ab777d55eea4736a565cd6f1 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 13 Sep 2023 22:46:47 +0100 Subject: [PATCH 05/16] amd gpu config --- cmake/CmakeFunctionHelper.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index fff0f923f..c50688a6d 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -515,7 +515,7 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 256 "true" "true" "true" - 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 2 "strided" "false") + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") add_gemm_configuration( "${data}" 256 "false" "false" "false" 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") From 2eeb03f05d607bf4582c19eab9e3d64d534fe65a Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 20 Sep 2023 10:53:10 +0100 Subject: [PATCH 06/16] De-coupling complex & scalar enable_if statements --- cmake/CmakeFunctionHelper.cmake | 12 +++++++++--- src/interface/blas3/backend/amd_gpu.hpp | 6 +----- src/interface/blas3/backend/default_cpu.hpp | 6 +----- src/interface/blas3/backend/intel_gpu.hpp | 11 ++++------- src/interface/blas3/backend/nvidia_gpu.hpp | 6 +----- src/interface/gemm_interface.hpp | 18 +++++++++++------- src/operations/blas3/gemm_common.hpp | 10 ++-------- 7 files changed, 29 insertions(+), 40 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index c50688a6d..0c30a48d5 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -437,9 +437,15 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU") add_gemm_configuration( "${data}" 64 "false" "false" "false" 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false") - add_gemm_configuration( - "${data}" 32 "true" "true" "true" - 64 2 1 8 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 64 "true" "true" "true" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + endif() endforeach() endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR) diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 3aff8dd46..3ec103620 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 44a99d1fe..1b7dfd680 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index e22274008..a0ce6f52a 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -32,12 +32,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, @@ -227,9 +223,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT if (!s_a && !s_b && batch_size == 1) { + constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 32, true, true, true, 64, - Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b, + container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index f13a95d2e..7d555d902 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -33,12 +33,8 @@ namespace backend { template -#ifdef BLAS_ENABLE_COMPLEX -typename std::enable_if::value, +typename std::enable_if::value, typename sb_handle_t::event_t>::type -#else -typename sb_handle_t::event_t -#endif _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea, container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta, diff --git a/src/interface/gemm_interface.hpp b/src/interface/gemm_interface.hpp index f5b7383e6..8e90a4b82 100644 --- a/src/interface/gemm_interface.hpp +++ b/src/interface/gemm_interface.hpp @@ -50,16 +50,20 @@ namespace internal { // Check whether value is zero (complex & float/double) template -inline bool isZero(const T& value) { -#ifdef BLAS_ENABLE_COMPLEX - if constexpr (is_complex_sycl::value) { - using value_t = typename T::value_type; - return (value == T(value_t(0), value_t(0))); - } -#endif +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { return (value == static_cast(0)); } +#ifdef BLAS_ENABLE_COMPLEX +template +inline typename std::enable_if::value, bool>::type isZero( + const T& value) { + using value_t = typename T::value_type; + return (value == T(value_t(0), value_t(0))); +} +#endif + template diff --git a/src/operations/blas3/gemm_common.hpp b/src/operations/blas3/gemm_common.hpp index f1817e70a..670dc340d 100644 --- a/src/operations/blas3/gemm_common.hpp +++ b/src/operations/blas3/gemm_common.hpp @@ -40,20 +40,14 @@ mul_add(T a, T b, T c, typename std::enable_if::value>::type * = 0) { return (a * b + c); } +#endif template static PORTBLAS_INLINE T mul_add(T a, T b, T c, - typename std::enable_if::value>::type * = 0) { - return (cl::sycl::mad(a, b, c)); -} -#else - -template -static PORTBLAS_INLINE T mul_add(T a, T b, T c) { + typename std::enable_if::value>::type * = 0) { return (cl::sycl::mad(a, b, c)); } -#endif template struct type_string { From 06705b31a176409abb00952f68b6583745b4c09a Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 20 Sep 2023 18:00:18 +0100 Subject: [PATCH 07/16] Added static asserts on vector size when using cplx data --- src/operations/blas3/gemm_interleaved.hpp | 20 ++++++++------ src/operations/blas3/gemm_local.hpp | 26 ++++++++++++------- .../blas3/gemm_no_local_full_vec.hpp | 6 +++++ .../blas3/gemm_no_local_partial_vec.hpp | 14 +++++++--- 4 files changed, 44 insertions(+), 22 deletions(-) diff --git a/src/operations/blas3/gemm_interleaved.hpp b/src/operations/blas3/gemm_interleaved.hpp index 551bb465a..66629033e 100644 --- a/src/operations/blas3/gemm_interleaved.hpp +++ b/src/operations/blas3/gemm_interleaved.hpp @@ -146,6 +146,11 @@ class Gemm::value, + "Interleaved GEMM is not supported for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -159,10 +164,9 @@ class Gemm PORTBLAS_INLINE void compute_panel(check_t boundary_check, index_t m_stride, - index_t n_stride, index_t mb_start, - index_t m_start, index_t n_start, - in_ptr_t A, in_ptr_t B, out_ptr_t C) { + index_t n_stride, index_t mb_start, + index_t m_start, index_t n_start, + in_ptr_t A, in_ptr_t B, out_ptr_t C) { packet_type reg_a[item_rows * item_batchs / VectorSize]; packet_type reg_b[item_cols * item_batchs / VectorSize]; packet_type reg_res[item_rows * item_cols * item_batchs / VectorSize]; @@ -482,7 +486,7 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = block_rows + nbc_a; //! @brief leading dimension of block of B in local @@ -162,8 +168,8 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); const index_t k = a_.get_size_col(); @@ -546,9 +552,9 @@ class Gemm PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc, - OutputPointerType C, index_t ldc, - element_t *reg_res, - const bool out_of_range) noexcept { + OutputPointerType C, index_t ldc, + element_t *reg_res, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -726,9 +732,9 @@ class Gemm PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B, - InputPointerType A, element_t *reg_a, - element_t ®_b, - element_t *reg_res) noexcept { + InputPointerType A, element_t *reg_a, + element_t ®_b, + element_t *reg_res) noexcept { // NOTE: Adding "#pragma unroll" here reduces performance on AMD R9 // Nano. // Seems that the small reduction of arithmetic operations does @@ -781,7 +787,7 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index 732cc9568..df1ce6bd7 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -104,6 +104,12 @@ class Gemm(), "If vectorization is enabled item_cols must equal the packet_size"); +#ifdef BLAS_ENABLE_COMPLEX + static_assert((VectorSize == 1 && is_complex_sycl::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index 189de963b..02a42e938 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -100,6 +100,12 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -111,8 +117,8 @@ class Gemm PORTBLAS_INLINE void load(PointerType ptr, element_t *reg, const index_t &ld, - index_t index, const check_boundary &chk_boundary, - const bool out_of_range) noexcept { + index_t index, const check_boundary &chk_boundary, + const bool out_of_range) noexcept { if (out_of_range) { return; } From f7179c51a5e6edf93b623f4c70f07dc44eeb5874 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Thu, 21 Sep 2023 16:33:12 +0100 Subject: [PATCH 08/16] Fixes to amd gpu configs --- cmake/CmakeFunctionHelper.cmake | 30 +++++++++----- src/interface/blas3/backend/amd_gpu.hpp | 53 +++++++++++++------------ 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 0c30a48d5..8dedc9857 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -519,15 +519,27 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation set(data_list_c) set_complex_list(data_list_c "${supported_types}" "false") foreach(data ${data_list_c}) - add_gemm_configuration( - "${data}" 256 "true" "true" "true" - 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") - add_gemm_configuration( - "${data}" 256 "false" "false" "false" - 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + if (${data} STREQUAL "complex") + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + else() + add_gemm_configuration( + "${data}" 256 "true" "true" "true" + 64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + add_gemm_configuration( + "${data}" 256 "false" "false" "false" + 64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + endif() endforeach() endif() # BLAS_ENABLE_COMPLEX elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index 3ec103620..a425b2f2a 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -119,29 +119,29 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } } else #endif // GEMM_TALL_SKINNY_SUPPORT - if (_M * _N <= 65536) { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 2, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); - } + if (_M * _N <= 65536) { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 1, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } else { + return blas::Gemm_Launcher< + container_0_t, container_1_t, container_2_t, 256, false, false, false, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + static_cast(gemm_memory_t::local), + static_cast(gemm_algorithm_t::standard), + static_cast(gemm_vectorization_t::full), is_beta_zero, 2, + static_cast(gemm_batch_type_t::strided)>:: + template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, + _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, + _stridec, batch_size, _dependencies); + } } // Complex Configurations @@ -158,12 +158,13 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { static constexpr int ClSize = 64; + static constexpr int tileWgSize = ClSize / sizeof(element_t); /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -176,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -187,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, From 9ba82fe8190f06ea443c8fe7dbd16abba4e2ec85 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 25 Sep 2023 10:07:42 +0100 Subject: [PATCH 09/16] Addressed PR comments --- common/include/common/common_utils.hpp | 16 +++++----- include/blas_meta.h | 10 ++----- test/blas_test.hpp | 8 ++--- test/unittest/blas3/blas3_gemm_common.hpp | 36 +++++++++++------------ 4 files changed, 33 insertions(+), 37 deletions(-) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index df14ac062..26916483b 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1374,27 +1374,27 @@ static inline std::vector random_data(size_t size) { #ifdef BLAS_ENABLE_COMPLEX template -static inline complex_std random_scalar() { +static inline std::complex random_scalar() { scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); scalar_t im = 1e-3 * ((rand() % 2000) - 1000); - return complex_std({rl, im}); + return std::complex(rl, im); } template -static inline complex_std random_scalar(scalar_t rangeMin, - scalar_t rangeMax) { +static inline std::complex random_scalar(scalar_t rangeMin, + scalar_t rangeMax) { static std::random_device rd; static std::default_random_engine gen(rd()); std::uniform_real_distribution disRl(rangeMin, rangeMax); std::uniform_real_distribution disIm(rangeMin, rangeMax); - return complex_std({disRl(gen), disIm(gen)}); + return std::complex(disRl(gen), disIm(gen)); } template -static inline std::vector> random_data(size_t size) { - std::vector> v = - std::vector>(size); +static inline std::vector> random_data(size_t size) { + std::vector> v = + std::vector>(size); for (scalar_t& e : v) { e = random_scalar(scalar_t{-2}, scalar_t{5}); diff --git a/include/blas_meta.h b/include/blas_meta.h index a7634dbca..d39a395f5 100644 --- a/include/blas_meta.h +++ b/include/blas_meta.h @@ -167,7 +167,7 @@ int append_vector(vector_t &lhs_vector, vector_t const &rhs_vector) { template first_vector_t concatenate_vectors(first_vector_t first_vector, - other_vector_t &&... other_vectors) { + other_vector_t &&...other_vectors) { int first_Vector_size = static_cast(first_vector.size()); int s[] = {vec_total_size(first_Vector_size, other_vectors)..., 0}; first_vector.reserve(first_Vector_size); @@ -206,15 +206,11 @@ struct is_complex_sycl std::is_same_v> || std::is_same_v>> {}; -// STD Complex type alias -template -using complex_std = typename std::complex; - template struct is_complex_std : std::integral_constant> || - std::is_same_v>> {}; + std::is_same_v> || + std::is_same_v>> {}; #endif diff --git a/test/blas_test.hpp b/test/blas_test.hpp index 70a32d61c..d159109db 100644 --- a/test/blas_test.hpp +++ b/test/blas_test.hpp @@ -161,9 +161,9 @@ template static inline void fill_random_with_range( std::vector> &vec, scalar_t rangeMin, scalar_t rangeMax) { - for (complex_std &e : vec) { - e = complex_std{random_scalar(rangeMin, rangeMax), - random_scalar(rangeMin, rangeMax)}; + for (std::complex &e : vec) { + e = std::complex{random_scalar(rangeMin, rangeMax), + random_scalar(rangeMin, rangeMax)}; } } @@ -172,7 +172,7 @@ static inline void fill_random_with_range( * uniform distribution. */ template -static inline void fill_random(std::vector> &vec) { +static inline void fill_random(std::vector> &vec) { fill_random_with_range(vec, scalar_t{-2}, scalar_t{5}); } #endif diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index d28baa99a..3aacf4244 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -410,8 +410,8 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -437,14 +437,14 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { const index_t buffer_size_b = batch * size_b + offset; const index_t buffer_size_c = batch * size_c + offset; - std::vector> a_m(buffer_size_a); - std::vector> b_m(buffer_size_b); - std::vector> c_m_gpu(buffer_size_c); + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); - std::vector> c_m_cpu = c_m_gpu; + std::vector> c_m_cpu = c_m_gpu; // Use system blas to create a reference output for (int i = 0; i < batch; ++i) { @@ -518,8 +518,8 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -544,7 +544,7 @@ static std::string generate_cplx_name( std::string alloc; int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul; char transa, transb; - complex_std alpha, beta; + std::complex alpha, beta; gemm_batch_type_t batchType; BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, alpha, beta, ldaMul, ldbMul, ldcMul, batchType); @@ -561,8 +561,8 @@ inline void verify_gemm( index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -595,14 +595,14 @@ inline void verify_gemm( const index_t buffer_size_b = size_b + (batch - 1) * stride_b + offset; const index_t buffer_size_c = size_c + (batch - 1) * stride_c + offset; - std::vector> a_m(buffer_size_a); - std::vector> b_m(buffer_size_b); - std::vector> c_m_gpu(buffer_size_c); + std::vector> a_m(buffer_size_a); + std::vector> b_m(buffer_size_b); + std::vector> c_m_gpu(buffer_size_c); fill_random(a_m); fill_random(b_m); fill_random(c_m_gpu); - std::vector> c_m_cpu = c_m_gpu; + std::vector> c_m_cpu = c_m_gpu; // Use system blas to create a reference output for (int i = 0; i < batch; ++i) { @@ -668,8 +668,8 @@ inline void verify_gemm( index_t k; char transa; char transb; - complex_std alpha; - complex_std beta; + std::complex alpha; + std::complex beta; index_t lda_mul; index_t ldb_mul; index_t ldc_mul; @@ -697,7 +697,7 @@ static std::string generate_cplx_batched_strided_name( int offset, batch, m, n, k, ldaMul, ldbMul, ldcMul, stride_a_mul, stride_b_mul, stride_c_mul; char transa, transb; - complex_std alpha, beta; + std::complex alpha, beta; BLAS_GENERATE_NAME(info.param, alloc, offset, batch, m, n, k, transa, transb, alpha, beta, ldaMul, ldbMul, ldcMul, stride_a_mul, stride_b_mul, stride_c_mul); From 007727c24cf6d88a64d0f79f429652aca3027984 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 25 Sep 2023 20:24:12 +0100 Subject: [PATCH 10/16] minor fixes --- common/include/common/common_utils.hpp | 31 -------------------------- 1 file changed, 31 deletions(-) diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index 26916483b..a569ed2ff 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -1372,37 +1372,6 @@ static inline std::vector random_data(size_t size) { return v; } -#ifdef BLAS_ENABLE_COMPLEX -template -static inline std::complex random_scalar() { - scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); - scalar_t im = 1e-3 * ((rand() % 2000) - 1000); - return std::complex(rl, im); -} - -template -static inline std::complex random_scalar(scalar_t rangeMin, - scalar_t rangeMax) { - static std::random_device rd; - static std::default_random_engine gen(rd()); - std::uniform_real_distribution disRl(rangeMin, rangeMax); - std::uniform_real_distribution disIm(rangeMin, rangeMax); - - return std::complex(disRl(gen), disIm(gen)); -} - -template -static inline std::vector> random_data(size_t size) { - std::vector> v = - std::vector>(size); - - for (scalar_t& e : v) { - e = random_scalar(scalar_t{-2}, scalar_t{5}); - } - return v; -} -#endif - /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda From 7f76dfd2ad3b081512758b9f2fccc52496efe9c5 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 27 Sep 2023 22:28:22 +0100 Subject: [PATCH 11/16] fixed bug in cmake & added readme description to complex --- CMakeLists.txt | 3 ++- README.md | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 1037b1098..09785078f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -106,6 +106,7 @@ if(IMGDNN_DIR) endif() option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) +option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) # CmakeFunctionHelper has to be included after any options that it depends on are declared. # These include: @@ -115,6 +116,7 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON) # * BLAS_DATA_TYPES # * BLAS_INDEX_TYPES # * NAIVE_GEMM +# * BLAS_ENABLE_COMPLEX include(CmakeFunctionHelper) if (INSTALL_HEADER_ONLY) @@ -220,7 +222,6 @@ option(BUILD_CUBLAS_BENCHMARKS "Whether to build cuBLAS benchmarks" OFF) option(BUILD_ROCBLAS_BENCHMARKS "Whether to build rocBLAS benchmarks" OFF) option(BUILD_ACL_BENCHMARKS "Whether to build ARM Compute Library benchmarks" OFF) option(BLAS_BUILD_SAMPLES "Whether to build portBLAS samples" ON) -option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON) if (INSTALL_HEADER_ONLY AND BLAS_ENABLE_BENCHMARK) message(STATUS "Benchmarks are disabled when installing portBLAS in header only mode") set(BLAS_ENABLE_BENCHMARK OFF) diff --git a/README.md b/README.md index 5720ae145..c5383b73f 100644 --- a/README.md +++ b/README.md @@ -463,7 +463,7 @@ Some of the supported options are: | `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) | | `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` | | `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` | - +| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Kernels only)* (`ON` by default) | ### Cross-Compile (ComputeCpp Only) From 148a2eaa9ddbfc4e6fc86149396fa12998b058ed Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Mon, 2 Oct 2023 13:42:59 +0100 Subject: [PATCH 12/16] Reduced complex gemm tests cases sizes --- .../blas3/blas3_gemm_batched_test.cpp | 6 ++--- .../blas3/blas3_gemm_tall_skinny_test.cpp | 6 ++--- test/unittest/blas3/blas3_gemm_test.cpp | 24 +++++++++---------- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 6794ff56c..824bf656b 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -151,7 +151,7 @@ template const auto CplxBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(5), // batch + ::testing::Values(3), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n ::testing::Values(63, 128), // k @@ -170,7 +170,7 @@ template const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(1, 5), // batch + ::testing::Values(1, 4), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n ::testing::Values(63, 128), // k @@ -192,7 +192,7 @@ template const auto CplxAllStridedBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset - ::testing::Values(5), // batch + ::testing::Values(3), // batch ::testing::Values(128), // m ::testing::Values(128), // n ::testing::Values(128), // k diff --git a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp index 4eeee3cde..95abb271a 100644 --- a/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp +++ b/test/unittest/blas3/blas3_gemm_tall_skinny_test.cpp @@ -127,9 +127,9 @@ const auto CplxBetaNonZeroLDMultiplied = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(7, 65), // m - ::testing::Values(9, 126), // n - ::testing::Values(2049), // k + ::testing::Values(7, 33), // m + ::testing::Values(9, 63), // n + ::testing::Values(1026), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 0.5}), // alpha diff --git a/test/unittest/blas3/blas3_gemm_test.cpp b/test/unittest/blas3/blas3_gemm_test.cpp index acf4c85d8..f7cae4630 100644 --- a/test/unittest/blas3/blas3_gemm_test.cpp +++ b/test/unittest/blas3/blas3_gemm_test.cpp @@ -146,8 +146,8 @@ const auto CplxSmallBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(11, 16, 32), // m - ::testing::Values(11, 16, 32), // n + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n ::testing::Values(16, 17), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb @@ -171,7 +171,7 @@ const auto CplxSmallBetaZeroLDMatch = ::testing::Combine( ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 1.0}), // alpha - ::testing::Values>({1.5, 3.0}), // beta + ::testing::Values>({0.0, 0.0}), // beta ::testing::Values(1), // lda_mul ::testing::Values(1), // ldb_mul ::testing::Values(1), // ldc_mul @@ -184,16 +184,16 @@ const auto CplxSmallBetaZeroLDMultiplied = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(11, 32), // m - ::testing::Values(11, 32), // n + ::testing::Values(11, 33), // m + ::testing::Values(11, 33), // n ::testing::Values(17), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb ::testing::Values>({1.5, 3.0}), // alpha ::testing::Values>({0.0, 0.0}), // beta ::testing::Values(2), // lda_mul - ::testing::Values(3), // ldb_mul - ::testing::Values(4), // ldc_mul + ::testing::Values(2), // ldb_mul + ::testing::Values(3), // ldc_mul ::testing::Values(gemm_batch_type_t::strided) // batch_type ); GENERATE_CPLX_GEMM_TEST(Gemm, CplxSmallBetaZeroLDMultiplied); @@ -242,13 +242,13 @@ const auto CplxLargeBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type ::testing::Values(0), // offset ::testing::Values(1), // batch - ::testing::Values(253, 511), // m - ::testing::Values(257, 511), // n - ::testing::Values(253, 511), // k + ::testing::Values(63, 253), // m + ::testing::Values(63, 253), // n + ::testing::Values(63, 253), // k ::testing::Values('n', 't'), // transa ::testing::Values('n', 't'), // transb - ::testing::Values>({1.0, 1.0}), // alpha - ::testing::Values>({1.0, 1.0}), // beta + ::testing::Values>({1.0, 1.5}), // alpha + ::testing::Values>({1.5, 1.0}), // beta ::testing::Values(1), // lda_mul ::testing::Values(1), // ldb_mul ::testing::Values(1), // ldc_mul From 49e0e01594d510755af1171a46137e5562c707d7 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Oct 2023 11:56:25 +0100 Subject: [PATCH 13/16] removed unused legacy complex data utils --- include/operations/blas_constants.h | 8 -------- src/operations/blas1_trees.hpp | 18 ------------------ 2 files changed, 26 deletions(-) diff --git a/include/operations/blas_constants.h b/include/operations/blas_constants.h index 5fc4afb82..637f23f95 100644 --- a/include/operations/blas_constants.h +++ b/include/operations/blas_constants.h @@ -202,14 +202,6 @@ struct constant, const_val::collapse> { } }; -template -struct constant, Indicator> { - constexpr static PORTBLAS_INLINE std::complex value() { - return std::complex(constant::value(), - constant::value()); - } -}; - #ifdef BLAS_ENABLE_COMPLEX template struct constant, Indicator> { diff --git a/src/operations/blas1_trees.hpp b/src/operations/blas1_trees.hpp index ff51a7915..1b079c98b 100644 --- a/src/operations/blas1_trees.hpp +++ b/src/operations/blas1_trees.hpp @@ -90,24 +90,6 @@ struct DetectScalar { }; #endif // BLAS_DATA_TYPE_HALF -/*! DetectScalar. - * @brief See Detect Scalar. - */ -template <> -struct DetectScalar> { - using element_t = std::complex; - static element_t get_scalar(element_t &scalar) { return scalar; } -}; - -/*! DetectScalar. - * @brief See Detect Scalar. - */ -template <> -struct DetectScalar> { - using element_t = std::complex; - static element_t get_scalar(element_t &scalar) { return scalar; } -}; - #ifdef BLAS_ENABLE_COMPLEX /*! DetectScalar (for sycl::complex) * @brief See Detect Scalar. From d86cfc8598101f657890f6798f69637fef60be9c Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Wed, 4 Oct 2023 13:20:05 +0100 Subject: [PATCH 14/16] Tuned gemm complex for cpu --- cmake/CmakeFunctionHelper.cmake | 7 ++----- src/interface/blas3/backend/default_cpu.hpp | 19 ++++--------------- 2 files changed, 6 insertions(+), 20 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index 8dedc9857..ef7fc22d6 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -629,13 +629,10 @@ else() # default cpu backend foreach(data ${data_list_c}) add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") + 64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false") add_gemm_configuration( "${data}" 64 "false" "false" "false" - 64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") - add_gemm_configuration( - "${data}" 64 "false" "false" "false" - 64 2 2 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false" "false") + 64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX endif() diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 1b7dfd680..14c0cd337 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -116,10 +116,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, container_2_t _c, index_t _ldc, index_t _stridec, index_t batch_size, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { - if (_M <= 128 && _N <= 128 && _K <= 128 && !s_a && !s_b) { + if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -127,10 +127,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (!s_a && !s_b) { + } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -138,17 +138,6 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else { - return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 8, 8>, _t_a, _t_b, s_a, s_b, - static_cast(gemm_memory_t::local), - static_cast(gemm_algorithm_t::standard), - static_cast(gemm_vectorization_t::full), is_beta_zero, 1, - static_cast(gemm_batch_type_t::strided)>:: - template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, - _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, - batch_size, _dependencies); } } #endif From 2dc363db653413af39acd416f09d57c101de3004 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 13 Oct 2023 17:06:21 +0100 Subject: [PATCH 15/16] Separated complex gemm load store & addressed PR comments --- doc/Gemm.md | 2 +- src/operations/blas3/gemm_load_store.hpp | 144 --------------- .../blas3/gemm_load_store_complex.hpp | 174 ++++++++++++++++++ src/operations/blas3/gemm_local.hpp | 3 + .../blas3/gemm_no_local_full_vec.hpp | 3 + .../blas3/gemm_no_local_partial_vec.hpp | 3 + test/unittest/blas3/blas3_gemm_common.hpp | 10 +- 7 files changed, 189 insertions(+), 150 deletions(-) create mode 100644 src/operations/blas3/gemm_load_store_complex.hpp diff --git a/doc/Gemm.md b/doc/Gemm.md index 0264e3d4c..653549212 100644 --- a/doc/Gemm.md +++ b/doc/Gemm.md @@ -100,7 +100,7 @@ The core of the `GEMM` computation is as follows: ## Vectorized Loading/Storing -Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/` . +Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/`*(this feature is limited to non-complex data types)*. These functions are pretty simple but there are some special considerations for how they are used, particularly around whether the matrices are transposed or not. If a matrix is transposed this changes the data layout such that elements are no longer contiguous in memory. diff --git a/src/operations/blas3/gemm_load_store.hpp b/src/operations/blas3/gemm_load_store.hpp index 7ae45ce5d..ef44cbfe6 100644 --- a/src/operations/blas3/gemm_load_store.hpp +++ b/src/operations/blas3/gemm_load_store.hpp @@ -125,149 +125,5 @@ struct Packetize { } }; -#ifdef BLAS_ENABLE_COMPLEX -/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in - * Packetize. It serves as a temporary workaround to the upcoming - * sycl::vec container - * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc - * and only supports size = 1. - * @tparam DataT Complex type of the vector's data - * @tparam NumElements Elements count of the vector (only 1 is supported) - */ -template -class vec_complex { - static_assert(NumElements == 1, - "Vector wrapper arround sycl::complex of size>1 unsupported."); - using address_t = cl::sycl::access::address_space; - using decorated_t = cl::sycl::access::decorated; - using DataType = DataT; - static constexpr int getNumElements() { return NumElements; } - size_t size() const noexcept { return NumElements; } - - private: - DataType m_Data; - - public: - vec_complex() = default; - - constexpr vec_complex(const vec_complex &rhs) = default; - constexpr vec_complex(vec_complex &&rhs) = default; - constexpr vec_complex &operator=(const vec_complex &rhs) = default; - - vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} - - // Conversion operator (valid with NumElements==1) - operator DataT() const { return m_Data; } - - // Subscript operators - DataT &operator[](int i) { - assert(i < NumElements); - return (m_Data); - } - const DataT &operator[](int i) const { - assert(i < NumElements); - return (m_Data); - } - - // Binary Ops - // Multiply - vec_complex operator*(const vec_complex &rhs) { - return (vec_complex{m_Data * static_cast(rhs)}); - } - - vec_complex operator*(const DataType &rhs) { - return (vec_complex{m_Data * rhs}); - } - - // Compound Multiply - vec_complex &operator*=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator*=(const vec_complex &rhs) { - this->m_Data = this->m_Data * static_cast(rhs); - return (*this); - } - - // Add - vec_complex operator+(const vec_complex &rhs) { - return (vec_complex{m_Data + static_cast(rhs)}); - } - - vec_complex operator+(const DataType &rhs) { - return (vec_complex{m_Data + rhs}); - } - - // Compound Add - vec_complex &operator+=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator+=(const vec_complex &rhs) { - this->m_Data = this->m_Data + static_cast(rhs); - return (*this); - } - - // Load - template - void load(size_t Offset, - cl::sycl::multi_ptr Ptr) { - m_Data = *(Ptr + Offset * NumElements); - } - - // Store - template - void store(size_t Offset, - cl::sycl::multi_ptr Ptr) const { - *(Ptr + Offset * NumElements) = m_Data; - } -}; - -/*! @brief Partial specialization of the Packetize class dedicated to -sycl::complex types. It contains static methods for loading and storing size=1 -complex packets from/to memory. -* @tparam vector_size The desired vector size to be used. Only size = 1 is -supported so far. -* @tparam value_t The complex type of the matrix data. -*/ -template -struct Packetize, index_t> { - // Vectorization is not enabled for complex, always set to 1 - using value_t = complex_sycl; - using PacketType = vec_complex; - static constexpr int packet_size = 1; - template - static PORTBLAS_INLINE constexpr bool check_size() { - return true; - } - - /*! @brief Performs a non-vectorised load of sycl::complex data element while - * whether block is internal or not since vectorization is not enabled for - * complex types yet. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam internal True if the current block is internal and no bounds - * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template - static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, - DestPointerType dest, - EdgePredicate edge_in_range) { - *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; - } - - /*! @brief Store a size = 1 vector packet of sycl::complex data into local - * memory (whether source is transposed or not since it's only 1 element). - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { - *dest = packet[0]; - } -}; -#endif - } // namespace blas #endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP diff --git a/src/operations/blas3/gemm_load_store_complex.hpp b/src/operations/blas3/gemm_load_store_complex.hpp new file mode 100644 index 000000000..7b1eb769b --- /dev/null +++ b/src/operations/blas3/gemm_load_store_complex.hpp @@ -0,0 +1,174 @@ +/*************************************************************************** + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename gemm_load_store_complex.hpp + * + **************************************************************************/ + +#ifndef PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP +#define PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP + +namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in + * Packetize. It serves as a temporary workaround to the upcoming + * sycl::vec container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif +} // namespace blas + +#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 870349c48..0ca182918 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index df1ce6bd7..77cbafbbf 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index 02a42e938..ba26ef67f 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 3aacf4244..b9bd04e04 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -419,6 +419,11 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, batch_type) = arguments; + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported with complex data types + GTEST_SKIP(); + } + const char ta_str[2] = {transa, '\0'}; const char tb_str[2] = {transb, '\0'}; @@ -456,11 +461,6 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); } - if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { - // Interleaved batched gemm unsupported - GTEST_SKIP(); - } - auto m_a_gpu = blas::helper::allocate>( buffer_size_a, q); auto m_b_gpu = blas::helper::allocate>( From 6a0e010b7011bcd810a038c15ed3712ab76b4101 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 13 Oct 2023 17:16:24 +0100 Subject: [PATCH 16/16] Removed symm kernels generation from complex data types --- cmake/CmakeFunctionHelper.cmake | 7 +++++-- src/interface/blas3/backend/amd_gpu.hpp | 8 ++++---- src/interface/blas3/backend/default_cpu.hpp | 4 ++-- src/interface/blas3/backend/intel_gpu.hpp | 12 ++++++------ src/interface/blas3/backend/nvidia_gpu.hpp | 6 +++--- 5 files changed, 20 insertions(+), 17 deletions(-) diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index ef7fc22d6..2ae71bc5e 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -291,6 +291,9 @@ function(add_gemm_configuration cpp_type(cpp_data ${data}) foreach(symm_a ${boolean_list}) foreach(symm_b ${boolean_list}) + if ((${data} MATCHES "complex") AND (symm_a OR symm_b)) + continue() + endif() foreach(trans_a ${boolean_list}) foreach(trans_b ${boolean_list}) foreach(is_beta_zero ${boolean_list}) @@ -591,8 +594,8 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") set_complex_list(data_list_c "${supported_types}" "false") foreach(data ${data_list_c}) add_gemm_configuration( - "${data}" 64 "false" "false" "true" - 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") + "${data}" 256 "false" "false" "true" + 64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") endforeach() endif() # BLAS_ENABLE_COMPLEX else() # default cpu backend diff --git a/src/interface/blas3/backend/amd_gpu.hpp b/src/interface/blas3/backend/amd_gpu.hpp index a425b2f2a..f494f25b9 100644 --- a/src/interface/blas3/backend/amd_gpu.hpp +++ b/src/interface/blas3/backend/amd_gpu.hpp @@ -161,10 +161,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, static constexpr int tileWgSize = ClSize / sizeof(element_t); /* Tall & Skinny matrices. */ #ifdef GEMM_TALL_SKINNY_SUPPORT - if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8) && (!s_a && !s_b)) { + if (batch_size == 1 && (_M / _N > 8 || _N / _M > 8)) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, true, true, true, - ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -177,7 +177,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M * _N <= 65536) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<1, 1, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -188,7 +188,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, false, false, - ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, s_a, s_b, + ClSize, Tile<4, 4, tileWgSize, tileWgSize>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/default_cpu.hpp b/src/interface/blas3/backend/default_cpu.hpp index 14c0cd337..e62348363 100644 --- a/src/interface/blas3/backend/default_cpu.hpp +++ b/src/interface/blas3/backend/default_cpu.hpp @@ -119,7 +119,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 256 && _N <= 256 && _K <= 256) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<2, 2, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<2, 2, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -130,7 +130,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 4, 4>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 4, 4>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/intel_gpu.hpp b/src/interface/blas3/backend/intel_gpu.hpp index a0ce6f52a..8d788c9b5 100644 --- a/src/interface/blas3/backend/intel_gpu.hpp +++ b/src/interface/blas3/backend/intel_gpu.hpp @@ -222,11 +222,11 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { #ifdef GEMM_TALL_SKINNY_SUPPORT - if (!s_a && !s_b && batch_size == 1) { + if (batch_size == 1) { constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8; return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, true, true, 64, - Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::tall_skinny), static_cast(gemm_vectorization_t::none), is_beta_zero, 1, @@ -239,7 +239,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, if (_M <= 128 && _N <= 128) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, true, false, false, 64, - Tile<4, 4, 8, 8>, _t_a, _t_b, s_a, s_b, + Tile<4, 4, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, @@ -247,10 +247,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, template _select_gemm(sb_handle, _M, _N, _K, _alpha, _a, _lda, _stridea, _b, _ldb, _strideb, _beta, _c, _ldc, _stridec, batch_size, _dependencies); - } else if (_t_b && !_t_a && !s_a && !s_b) { + } else if (_t_b && !_t_a) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<8, 8, 8, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<8, 8, 8, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::no_local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::partial), is_beta_zero, 1, @@ -261,7 +261,7 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, } else { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 64, false, false, false, - 64, Tile<4, 8, 16, 8>, _t_a, _t_b, s_a, s_b, + 64, Tile<4, 8, 16, 8>, _t_a, _t_b, false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index 7d555d902..13966172e 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -183,9 +183,9 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K, gemm_batch_type_t batch_type, const typename sb_handle_t::event_t& _dependencies) { return blas::Gemm_Launcher< - container_0_t, container_1_t, container_2_t, 64, false, false, true, 64, - Tile<8, 8, 8, 8, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, - s_a, s_b, static_cast(gemm_memory_t::local), + container_0_t, container_1_t, container_2_t, 256, false, false, true, 64, + Tile<2, 2, 16, 16, 1, 1, 2, 2, 1, 1, 1, 1, 1, float, float>, _t_a, _t_b, + false, false, static_cast(gemm_memory_t::local), static_cast(gemm_algorithm_t::standard), static_cast(gemm_vectorization_t::full), is_beta_zero, 1, static_cast(gemm_batch_type_t::strided),