Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Enabed Complex data type for Gemm #462

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ if(IMGDNN_DIR)
endif()

option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
option(BLAS_ENABLE_COMPLEX "Whether to enable complex data type for supported operators" ON)

# CmakeFunctionHelper has to be included after any options that it depends on are declared.
# These include:
Expand All @@ -115,6 +116,7 @@ option(BLAS_ENABLE_EXTENSIONS "Whether to enable portBLAS extensions" ON)
# * BLAS_DATA_TYPES
# * BLAS_INDEX_TYPES
# * NAIVE_GEMM
# * BLAS_ENABLE_COMPLEX
include(CmakeFunctionHelper)

if (INSTALL_HEADER_ONLY)
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ Some of the supported options are:
| `BLAS_ENABLE_EXTENSIONS` | `ON`/`OFF` | Determines whether to enable portBLAS extensions (`ON` by default) |
| `BLAS_DATA_TYPES` | `half;float;double` | Determines the floating-point types to instantiate BLAS operations for. Default is `float` |
| `BLAS_INDEX_TYPES` | `int32_t;int64_t` | Determines the type(s) to use for `index_t` and `increment_t`. Default is `int` |

| `BLAS_ENABLE_COMPLEX` | `ON`/`OFF` | Determines whether to enable Complex data type support *(GEMM Kernels only)* (`ON` by default) |

### Cross-Compile (ComputeCpp Only)

Expand Down
131 changes: 128 additions & 3 deletions cmake/CmakeFunctionHelper.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,30 @@ function(cpp_type output data)
if (${data} STREQUAL "half")
set(${output} "cl::sycl::half" PARENT_SCOPE)
return()
elseif(${data} STREQUAL "complex<float>")
set(${output} "cl::sycl::ext::oneapi::experimental::complex<float>" PARENT_SCOPE)
return()
elseif(${data} STREQUAL "complex<double>")
set(${output} "cl::sycl::ext::oneapi::experimental::complex<double>" PARENT_SCOPE)
return()
endif()
set(${output} "${data}" PARENT_SCOPE)
endfunction()

function(set_complex_list output input append)
set(output_temp "")
if(${append} STREQUAL "true")
foreach(data ${input})
list(APPEND output_temp "${data};complex<${data}>")
endforeach(data)
else()
foreach(data ${input})
list(APPEND output_temp "complex<${data}>")
endforeach(data)
endif()
set(${output} ${output_temp} PARENT_SCOPE)
endfunction(set_complex_list)

## represent the list of bolean options
set(boolean_list "true" "false")

Expand All @@ -56,6 +76,9 @@ function(sanitize_file_name output file_name)
set(${output} "${file_name}" PARENT_SCOPE)
endfunction()

#List of operators supporting Complex Data types
set(COMPLEX_OPS "gemm" "gemm_launcher" "scal")

function(set_target_compile_def in_target)
#setting compiler flag for backend
if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
Expand Down Expand Up @@ -84,16 +107,31 @@ function(set_target_compile_def in_target)
message(STATUS "Gemm vectorization support enabled for target ${in_target}")
target_compile_definitions(${in_target} PUBLIC GEMM_VECTORIZATION_SUPPORT=1)
endif()

#setting const data type support
if(BLAS_ENABLE_CONST_INPUT)
target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_CONST_INPUT=1)
endif()
#setting complex support
if(${BLAS_ENABLE_COMPLEX})
if("${in_target}" IN_LIST COMPLEX_OPS)
message(STATUS "Complex Data type support enabled for target ${in_target}")
target_compile_definitions(${in_target} PUBLIC BLAS_ENABLE_COMPLEX=1)
endif()
endif()
endfunction()

# blas unary function for generating source code
function(generate_blas_objects blas_level func)
set(LOCATION "${PORTBLAS_GENERATED_SRC}/${blas_level}/${func}/")
foreach(data ${data_list})
set(data_list_c ${data_list})
# Extend data_list to complex<data> for each data in list
# if target function is in COMPLEX_OPS
if(BLAS_ENABLE_COMPLEX)
if("${func}" IN_LIST COMPLEX_OPS)
set_complex_list(data_list_c "${data_list}" "true")
endif()
endif()
foreach(data ${data_list_c})
cpp_type(cpp_data ${data})
foreach(index ${index_list})
foreach(increment ${index_list})
Expand Down Expand Up @@ -234,7 +272,11 @@ function(add_gemm_configuration
batch_type
use_joint_matrix
)
if(NOT ("${data}" IN_LIST data_list))
set(data_list_c ${data_list})
if(BLAS_ENABLE_COMPLEX)
set_complex_list(data_list_c "${data_list}" "true")
endif()
if(NOT ("${data}" IN_LIST data_list_c))
# Data type not enabled, skip configuration
return()
endif()
Expand All @@ -249,6 +291,9 @@ function(add_gemm_configuration
cpp_type(cpp_data ${data})
foreach(symm_a ${boolean_list})
foreach(symm_b ${boolean_list})
if ((${data} MATCHES "complex") AND (symm_a OR symm_b))
continue()
endif()
foreach(trans_a ${boolean_list})
foreach(trans_b ${boolean_list})
foreach(is_beta_zero ${boolean_list})
Expand Down Expand Up @@ -380,6 +425,32 @@ if(${TUNING_TARGET} STREQUAL "INTEL_GPU")
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 64 "true" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 4 8 16 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 8 8 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false")
if (${data} STREQUAL "complex<double>")
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
else()
add_gemm_configuration(
"${data}" 64 "true" "true" "true"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
endif()
endforeach()
endif() # BLAS_ENABLE_COMPLEX
elseif(${TUNING_TARGET} STREQUAL "POWER_VR" AND NOT IMGDNN_DIR)
set(supported_types
"float"
Expand Down Expand Up @@ -445,6 +516,35 @@ elseif(${TUNING_TARGET} STREQUAL "AMD_GPU") # need investigation
"${data}" 64 "false" "false" "false"
64 4 4 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
if (${data} STREQUAL "complex<double>")
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 4 4 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
else()
add_gemm_configuration(
"${data}" 256 "true" "true" "true"
64 1 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "tall_skinny" "none" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 1 1 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
add_gemm_configuration(
"${data}" 256 "false" "false" "false"
64 4 4 8 8 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endif()
endforeach()
endif() # BLAS_ENABLE_COMPLEX
elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
set(supported_types
"float"
Expand Down Expand Up @@ -486,7 +586,18 @@ elseif(${TUNING_TARGET} STREQUAL "NVIDIA_GPU")
add_gemm_configuration(
"${data}" 256 "false" "true" "true"
128 8 8 16 16 1 1 1 1 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 256 "false" "false" "true"
64 2 2 16 16 1 1 2 2 1 1 1 1 1 float float "local" "standard" "full" 1 "strided" "false")
endforeach()
endif() # BLAS_ENABLE_COMPLEX
else() # default cpu backend
set(supported_types
"float"
Expand All @@ -513,6 +624,20 @@ else() # default cpu backend
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 4 4 1 1 1 float float "no_local" "standard" "full" 4 "interleaved" "false" "false")
endforeach()
if(BLAS_ENABLE_COMPLEX)
# Extract list of complex<data> for each data in supported_types
# list for complex<data> specific gemm configurations
set(data_list_c)
set_complex_list(data_list_c "${supported_types}" "false")
foreach(data ${data_list_c})
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 2 2 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "full" 1 "strided" "false" "false")
add_gemm_configuration(
"${data}" 64 "false" "false" "false"
64 8 8 4 4 1 1 1 1 1 1 1 1 1 float float "no_local" "standard" "partial" 1 "strided" "false" "false")
endforeach()
endif() # BLAS_ENABLE_COMPLEX
endif()
add_library(${func} OBJECT ${gemm_sources})
set_target_compile_def(${func})
Expand Down
98 changes: 97 additions & 1 deletion common/include/common/float_comparison.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

#include <cmath>
#include <iostream>
#ifdef BLAS_ENABLE_COMPLEX
#include <complex>
#endif

#ifdef BLAS_DATA_TYPE_HALF
#if SYCL_LANGUAGE_VERSION < 202000
Expand Down Expand Up @@ -65,6 +68,23 @@ scalar_t abs(scalar_t value) noexcept {
return std::abs(value);
}

#ifdef BLAS_ENABLE_COMPLEX
template <typename scalar_t>
bool isnan(std::complex<scalar_t> value) noexcept {
return (isnan<scalar_t>(value.real()) || isnan<scalar_t>(value.imag()));
}

template <typename scalar_t>
bool isinf(std::complex<scalar_t> value) noexcept {
return (isinf<scalar_t>(value.real()) || isinf<scalar_t>(value.imag()));
}

template <typename scalar_t>
scalar_t abs(std::complex<scalar_t> value) noexcept {
return std::abs(value);
}
#endif

#ifdef BLAS_DATA_TYPE_HALF
template <>
inline bool isnan<cl::sycl::half>(cl::sycl::half value) noexcept {
Expand Down Expand Up @@ -172,7 +192,7 @@ inline bool almost_equal(scalar_t const& scalar1, scalar_t const& scalar2) {
return true;
}

const scalar_t absolute_diff = utils::abs(scalar1 - scalar2);
const auto absolute_diff = utils::abs(scalar1 - scalar2);

// Close to zero, the relative error doesn't work, use absolute error
if (scalar1 == scalar_t{0} || scalar2 == scalar_t{0} ||
Expand Down Expand Up @@ -212,6 +232,37 @@ inline bool compare_vectors(std::vector<scalar_t> const& vec,
return true;
}

#ifdef BLAS_ENABLE_COMPLEX
/**
* Compare two vectors of complex data and returns false if the difference is
* not acceptable. The second vector is considered the reference.
* @tparam scalar_t the type of complex underying data present in the input
* vectors
* @tparam epilon_t the type used as tolerance.
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors(std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref,
std::ostream& err_stream = std::cerr,
std::string end_line = "\n") {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

for (int i = 0; i < vec.size(); ++i) {
if (!almost_equal<std::complex<scalar_t>, epsilon_t>(vec[i], ref[i])) {
err_stream << "Value mismatch at index " << i << ": (" << vec[i].real()
<< "," << vec[i].imag() << "); expected (" << ref[i].real()
<< "," << ref[i].imag() << ")" << end_line;
return false;
}
}
return true;
}
#endif

/**
* Compare two vectors at a given stride and window (unit_vec_size) and returns
* false if the difference is not acceptable. The second vector is considered
Expand Down Expand Up @@ -253,6 +304,51 @@ inline bool compare_vectors_strided(std::vector<scalar_t> const& vec,
return true;
}

#ifdef BLAS_ENABLE_COMPLEX
/**
* Compare two vectors of complex data at a given stride and window and returns
* false if the difference is not acceptable. The second vector is considered
* the reference.
* @tparam scalar_t the type of the complex underying data present in the input
* vectors
* @tparam epsilon_t the type used as tolerance.
* @param stride is the stride between two consecutive 'windows'
* @param window is the size of a comparison window
*/
template <typename scalar_t, typename epsilon_t = scalar_t>
inline bool compare_vectors_strided(
std::vector<std::complex<scalar_t>> const& vec,
std::vector<std::complex<scalar_t>> const& ref, int stride, int window,
std::ostream& err_stream = std::cerr, std::string end_line = "\n") {
if (vec.size() != ref.size()) {
err_stream << "Error: tried to compare vectors of different sizes"
<< std::endl;
return false;
}

int k = 0;

// Loop over windows
while (window + (k + 1) * stride < vec.size()) {
// Loop within a window
for (int i = 0; i < window; ++i) {
auto index = i + k * stride;
if (!almost_equal<std::complex<scalar_t>, epsilon_t>(vec[index],
ref[index])) {
err_stream << "Value mismatch at index " << index << ": ("
<< vec[index].real() << "," << vec[index].imag()
<< "); expected (" << ref[index].real() << ","
<< ref[index].imag() << ")" << end_line;
return false;
}
}
k += 1;
}

return true;
}
#endif

} // namespace utils

#endif // UTILS_FLOAT_COMPARISON_H_
15 changes: 15 additions & 0 deletions common/include/common/system_reference_blas.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,12 @@ auto blas_system_function(floatfn_t ffn, doublefn_t dfn)
return BlasSystemFunction<scalar_t>::get(ffn, dfn);
}

template <typename scalar_t, typename floatfn_t, typename doublefn_t>
auto blas_cplx_system_function(floatfn_t ffn, doublefn_t dfn)
-> decltype(BlasSystemFunction<scalar_t>::get(ffn, dfn)) {
return BlasSystemFunction<scalar_t>::get(ffn, dfn);
}

// =======
// Level 1
// =======
Expand Down Expand Up @@ -378,6 +384,15 @@ void gemm(const char *transA, const char *transB, int m, int n, int k,
lda, b, ldb, beta, c, ldc);
}

template <typename scalar_t>
void cgemm(const char *transA, const char *transB, int m, int n, int k,
const void *alpha, const void *a, int lda, const void *b, int ldb,
const void *beta, void *c, int ldc) {
auto func = blas_cplx_system_function<scalar_t>(&cblas_cgemm, &cblas_zgemm);
func(CblasColMajor, c_trans(*transA), c_trans(*transB), m, n, k, alpha, a,
lda, b, ldb, beta, c, ldc);
}

template <typename scalar_t>
void trsm(const char *side, const char *uplo, const char *trans,
const char *diag, int m, int n, scalar_t alpha, const scalar_t A[],
Expand Down
Loading