Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
De-coupling complex & scalar enable_if statements
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Sep 20, 2023
1 parent 0c8ab21 commit 4721c4f
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 40 deletions.
12 changes: 9 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -437,9 +437,15 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false")
add_gemm_configuration(
"${data}" 32 "true" "true" "true"
64 2 1 8 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
if (${data} STREQUAL "complex<double>")
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
else()
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
endif()
endforeach()
endif() # BLAS_ENABLE_COMPLEX
elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR)
Expand Down
6 changes: 1 addition & 5 deletions src/interface/blas3/backend/amd_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
#ifdef BLAS_ENABLE_COMPLEX
typename std::enable_if<!is_complex_sycl<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#else
typename sb_handle_t::event_t
#endif
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
Expand Down
6 changes: 1 addition & 5 deletions src/interface/blas3/backend/default_cpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
#ifdef BLAS_ENABLE_COMPLEX
typename std::enable_if<!is_complex_sycl<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#else
typename sb_handle_t::event_t
#endif
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
Expand Down
11 changes: 4 additions & 7 deletions src/interface/blas3/backend/intel_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
#ifdef BLAS_ENABLE_COMPLEX
typename std::enable_if<!is_complex_sycl<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#else
typename sb_handle_t::event_t
#endif
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
Expand Down Expand Up @@ -227,9 +223,10 @@ _gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
const typename sb_handle_t::event_t& _dependencies) {
#ifdef GEMM_TALL_SKINNY_SUPPORT
if (!s_a && !s_b && batch_size == 1) {
constexpr int wg_size = sizeof(element_t) == 16 ? 4 : 8;
return blas::Gemm_Launcher<
container_0_t, container_1_t, container_2_t, 32, true, true, true, 64,
Tile<2, 1, 8, 4>, _t_a, _t_b, s_a, s_b,
container_0_t, container_1_t, container_2_t, 64, true, true, true, 64,
Tile<4, 4, wg_size, wg_size>, _t_a, _t_b, s_a, s_b,
static_cast<int>(gemm_memory_t::local),
static_cast<int>(gemm_algorithm_t::tall_skinny),
static_cast<int>(gemm_vectorization_t::none), is_beta_zero, 1,
Expand Down
6 changes: 1 addition & 5 deletions src/interface/blas3/backend/nvidia_gpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,8 @@ namespace backend {
template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
#ifdef BLAS_ENABLE_COMPLEX
typename std::enable_if<!is_complex_sycl<element_t>::value,
typename std::enable_if<is_sycl_scalar<element_t>::value,
typename sb_handle_t::event_t>::type
#else
typename sb_handle_t::event_t
#endif
_gemm(sb_handle_t& sb_handle, index_t _M, index_t _N, index_t _K,
element_t _alpha, container_0_t _a, index_t _lda, index_t _stridea,
container_1_t _b, index_t _ldb, index_t _strideb, element_t _beta,
Expand Down
18 changes: 11 additions & 7 deletions src/interface/gemm_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,20 @@ namespace internal {

// Check whether value is zero (complex & float/double)
template <typename T>
inline bool isZero(const T& value) {
#ifdef BLAS_ENABLE_COMPLEX
if constexpr (is_complex_sycl<T>::value) {
using value_t = typename T::value_type;
return (value == T(value_t(0), value_t(0)));
}
#endif
inline typename std::enable_if<is_sycl_scalar<T>::value, bool>::type isZero(
const T& value) {
return (value == static_cast<T>(0));
}

#ifdef BLAS_ENABLE_COMPLEX
template <typename T>
inline typename std::enable_if<is_complex_sycl<T>::value, bool>::type isZero(
const T& value) {
using value_t = typename T::value_type;
return (value == T(value_t(0), value_t(0)));
}
#endif

template <bool _t_a, bool _t_b, bool s_a, bool s_b, bool is_beta_zero,
typename sb_handle_t, typename container_0_t, typename container_1_t,
typename container_2_t, typename element_t, typename index_t>
Expand Down
10 changes: 2 additions & 8 deletions src/operations/blas3/gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,14 @@ mul_add(T a, T b, T c,
typename std::enable_if<is_complex_sycl<T>::value>::type * = 0) {
return (a * b + c);
}
#endif

template <typename T>
static PORTBLAS_INLINE T
mul_add(T a, T b, T c,
typename std::enable_if<!is_complex_sycl<T>::value>::type * = 0) {
return (cl::sycl::mad(a, b, c));
}
#else

template <typename T>
static PORTBLAS_INLINE T mul_add(T a, T b, T c) {
typename std::enable_if<is_sycl_scalar<T>::value>::type * = 0) {
return (cl::sycl::mad(a, b, c));
}
#endif

template <typename T>
struct type_string {
Expand Down

0 comments on commit 4721c4f

Please sign in to comment.