diff --git a/cmake/CmakeFunctionHelper.cmake b/cmake/CmakeFunctionHelper.cmake index ddd1e61f3..1be6ab8a8 100644 --- a/cmake/CmakeFunctionHelper.cmake +++ b/cmake/CmakeFunctionHelper.cmake @@ -457,23 +457,19 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU") string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val) endif() endif() + # Joint Matrix specific GEMM configurations (only for float) + if(${start_idx} AND ${sm_val} GREATER_EQUAL "80") + add_gemm_configuration( + "float" 128 "false" "true" "true" + 128 2 4 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") + add_gemm_configuration( + "float" 128 "false" "true" "true" + 128 4 8 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") + add_gemm_configuration( + "float" 256 "false" "true" "true" + 128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") + endif() foreach(data ${supported_types}) - # Joint Matrix specific GEMM configurations (only for float) - if(${start_idx} AND ${sm_val} GREATER_EQUAL "80") - add_gemm_configuration( - "float" 128 "false" "true" "true" - 128 2 4 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") - add_gemm_configuration( - "float" 128 "false" "true" "true" - 128 4 8 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") - add_gemm_configuration( - "float" 256 "false" "true" "true" - 128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true") - add_gemm_configuration( - "${data}" 64 "false" "false" "true" - 64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false") - endif() - # Non-Joint Matrix specific GEMM Configurations add_gemm_configuration( "${data}" 64 "false" "false" "false" diff --git a/src/interface/blas3/backend/nvidia_gpu.hpp b/src/interface/blas3/backend/nvidia_gpu.hpp index f630bf761..aeb678704 100644 --- a/src/interface/blas3/backend/nvidia_gpu.hpp +++ b/src/interface/blas3/backend/nvidia_gpu.hpp @@ -55,7 +55,10 @@ typename sb_handle_t::event_t _gemm( #ifdef SB_ENABLE_JOINT_MATRIX const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX"); - if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b) { + if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b && + std::is_same::type, float>::value && + std::is_same::type, float>::value && + std::is_same::type, float>::value) { if (_M > 1024 && _N > 1024) { return blas::Gemm_Launcher< container_0_t, container_1_t, container_2_t, 256, false, true, true,