Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Fix errors for double datatype + sm_80
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorlani committed Oct 20, 2023
1 parent eae0f8b commit c3915fb
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 17 deletions.
28 changes: 12 additions & 16 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -457,23 +457,19 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
string(SUBSTRING ${DPCPP_SYCL_ARCH} ${start_idx} "2" sm_val)
endif()
endif()
# Joint Matrix specific GEMM configurations (only for float)
if(${start_idx} AND ${sm_val} GREATER_EQUAL "80")
add_gemm_configuration(
"float" 128 "false" "true" "true"
128 2 4 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 128 "false" "true" "true"
128 4 8 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 256 "false" "true" "true"
128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
endif()
foreach(data ${supported_types})
# Joint Matrix specific GEMM configurations (only for float)
if(${start_idx} AND ${sm_val} GREATER_EQUAL "80")
add_gemm_configuration(
"float" 128 "false" "true" "true"
128 2 4 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 128 "false" "true" "true"
128 4 8 16 8 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"float" 256 "false" "true" "true"
128 8 8 16 16 16 2 1 1 1 1 16 16 16 cl::sycl::half float "local" "standard" "none" 1 "strided" "true")
add_gemm_configuration(
"${data}" 64 "false" "false" "true"
64 8 8 8 8 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endif()

# Non-Joint Matrix specific GEMM Configurations
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
Expand Down
5 changes: 4 additions & 1 deletion src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,10 @@ typename sb_handle_t::event_t _gemm(

#ifdef SB_ENABLE_JOINT_MATRIX
const char* en_joint_matrix = std::getenv("SB_ENABLE_JOINT_MATRIX");
if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b) {
if (en_joint_matrix != NULL && *en_joint_matrix == '1' && !s_a && !s_b &&
std::is_same<typename ValueType<container_0_t>::type, float>::value &&
std::is_same<typename ValueType<container_1_t>::type, float>::value &&
std::is_same<typename ValueType<container_2_t>::type, float>::value) {
if (_M > 1024 && _N > 1024) {
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 256, false, true, true,
Expand Down

0 comments on commit c3915fb

Please sign in to comment.