From 8d29746fa2a33be6d126e177012f208337b3217c Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI <104583441+OuadiElfarouki@users.noreply.github.com> Date: Thu, 26 Oct 2023 16:18:54 +0100 Subject: [PATCH] Added benchmarks for GEMM complex types (#465) --- benchmark/cublas/CMakeLists.txt | 8 + benchmark/cublas/blas3/gemm.cpp | 168 +++++++++++ benchmark/cublas/blas3/gemm_batched.cpp | 205 ++++++++++++- .../cublas/blas3/gemm_batched_strided.cpp | 200 +++++++++++++ benchmark/cublas/utils.hpp | 10 + benchmark/portblas/CMakeLists.txt | 8 + benchmark/portblas/blas3/gemm.cpp | 185 ++++++++++++ benchmark/portblas/blas3/gemm_batched.cpp | 217 +++++++++++++- .../portblas/blas3/gemm_batched_strided.cpp | 230 ++++++++++++++- benchmark/rocblas/CMakeLists.txt | 9 +- benchmark/rocblas/blas3/gemm.cpp | 183 ++++++++++++ benchmark/rocblas/blas3/gemm_batched.cpp | 200 +++++++++++++ .../rocblas/blas3/gemm_batched_strided.cpp | 217 ++++++++++++++ .../include/common/blas3_state_counters.hpp | 60 ++++ common/include/common/common_utils.hpp | 273 +++++++++++++++++- common/include/common/set_benchmark_label.hpp | 18 ++ 16 files changed, 2177 insertions(+), 14 deletions(-) diff --git a/benchmark/cublas/CMakeLists.txt b/benchmark/cublas/CMakeLists.txt index 250278fac..ad3b4ed05 100644 --- a/benchmark/cublas/CMakeLists.txt +++ b/benchmark/cublas/CMakeLists.txt @@ -74,12 +74,20 @@ set(sources extension/omatadd.cpp ) +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(cublas_bench ${sources}) get_filename_component(bench_cublas_exec ${cublas_bench} NAME_WE) add_executable(bench_cublas_${bench_cublas_exec} ${cublas_bench} main.cpp) target_link_libraries(bench_cublas_${bench_cublas_exec} PRIVATE benchmark CUDA::toolkit CUDA::cublas CUDA::cudart portblas Clara::Clara bench_info) target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) + if(${BLAS_ENABLE_COMPLEX}) + if("${bench_cublas_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_cublas_${bench_cublas_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() add_sycl_to_target( TARGET bench_cublas_${bench_cublas_exec} SOURCES ${cublas_bench} diff --git a/benchmark/cublas/blas3/gemm.cpp b/benchmark/cublas/blas3/gemm.cpp index 5a103d032..c74c9e98e 100644 --- a/benchmark/cublas/blas3/gemm.cpp +++ b/benchmark/cublas/blas3/gemm.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemm(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemm(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta, @@ -168,6 +180,162 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; + +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k, + static_cast(1)); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n, 0); + + blas_benchmark::utils::CUDAVector> a_gpu( + m * k, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVector> b_gpu( + k * n, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVector> c_gpu( + n * m, reinterpret_cast*>(c.data())); + + cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + + reference_blas::cgemm(t_a, t_b, m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + std::vector> c_temp = c; + { + blas_benchmark::utils::CUDAVector, true> c_temp_gpu( + m * n, reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_temp_gpu, + ldc); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + auto blas_warmup = [&]() -> void { + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_gpu, ldc); + return; + }; + + cudaEvent_t start; + cudaEvent_t stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, b_gpu, ldb, &cuBeta, c_gpu, ldc); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + for (auto p : gemm_params) { + std::string t1s, t2s; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/blas3/gemm_batched.cpp b/benchmark/cublas/blas3/gemm_batched.cpp index 4cce28ff5..c0c50631f 100644 --- a/benchmark/cublas/blas3/gemm_batched.cpp +++ b/benchmark/cublas/blas3/gemm_batched.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemmBatched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemmBatched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, index_t t2, index_t m, index_t k, index_t n, scalar_t alpha, @@ -164,7 +176,7 @@ void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, state.counters["bytes_processed"]); blas_benchmark::utils::calc_avg_counters(state); - + CUDA_CHECK(cudaEventDestroy(start)); CUDA_CHECK(cudaEventDestroy(stop)); }; @@ -209,6 +221,197 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, index_t t1, + index_t t2, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_count, int batch_type_i, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard setup + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + auto batch_type = static_cast(batch_type_i); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_count); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + const index_t size_a = m * k; + const index_t size_b = k * n; + const index_t size_c = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a * batch_count); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b * batch_count); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c * batch_count, 0); + + blas_benchmark::utils::CUDAVectorBatched> d_A_array( + size_a, batch_count, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVectorBatched> d_B_array( + size_b, batch_count, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVectorBatched> d_C_array( + size_c, batch_count); + + cublasOperation_t c_t_a = (*t_a == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = (*t_b == 'n') ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + { + std::vector> c_ref = c; + auto _base = [=](index_t dim0, index_t dim1, index_t idx) { + return dim0 * dim1 * idx; + }; + for (int batch_idx = 0; batch_idx < batch_count; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + _base(m, k, batch_idx)), lda, + reinterpret_cast(b.data() + _base(k, n, batch_idx)), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + _base(m, n, batch_idx)), ldc); + } + + std::vector> c_temp(size_c * batch_count); + + { + blas_benchmark::utils::CUDAVectorBatched, true> + c_temp_gpu(n * m, batch_count, + reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, c_temp_gpu.get_batch_array(), ldc, batch_count); + } + + std::ostringstream err_stream; + for (int i = 0; i < batch_count; ++i) { + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; + } + + } // close scope for verify benchmark +#endif + + auto blas_warmup = [&]() -> void { + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, d_C_array.get_batch_array(), ldc, batch_count); + return; + }; + + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + d_A_array.get_batch_array(), lda, d_B_array.get_batch_array(), ldb, + &cuBeta, d_C_array.get_batch_array(), ldc, batch_count); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + + for (auto p : gemm_batched_params) { + std::string t1s, t2s; + index_t m, n, k, batch_count; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_count, + batch_type) = p; + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + if (batch_type == 1) { + std::cerr << "interleaved memory for gemm_batched operator is not " + "supported by cuBLAS\n"; + continue; + } + + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_count, + int batch_type, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_count, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_count, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_count, + batch_type, success) + ->UseRealTime(); + } +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/blas3/gemm_batched_strided.cpp b/benchmark/cublas/blas3/gemm_batched_strided.cpp index d96b7adfe..beb81fb4c 100644 --- a/benchmark/cublas/blas3/gemm_batched_strided.cpp +++ b/benchmark/cublas/blas3/gemm_batched_strided.cpp @@ -38,6 +38,18 @@ static inline void cublas_routine(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void cublas_cplx_routine(args_t&&... args) { + if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasCgemmStridedBatched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CUBLAS_CHECK(cublasZgemmStridedBatched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, int t2, index_t m, index_t k, index_t n, scalar_t alpha, scalar_t beta, @@ -208,6 +220,194 @@ void register_benchmark(blas_benchmark::Args& args, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using cudaComplex = typename std::conditional::type; + +template +void run(benchmark::State& state, cublasHandle_t* cuda_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, index_t stride_a_mul, + index_t stride_b_mul, index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + const bool trA = t_a[0] == 'n'; + const bool trB = t_b[0] == 'n'; + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + cublasHandle_t& cuda_handle = *cuda_handle_ptr; + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices (Total size is equal to matrix size x batch_size since we're using + // default striding values) + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, 0); + + blas_benchmark::utils::CUDAVector> a_gpu( + size_a_batch, reinterpret_cast*>(a.data())); + blas_benchmark::utils::CUDAVector> b_gpu( + size_b_batch, reinterpret_cast*>(b.data())); + blas_benchmark::utils::CUDAVector> c_gpu( + size_c_batch, reinterpret_cast*>(c.data())); + + cublasOperation_t c_t_a = trA ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t c_t_b = trB ? CUBLAS_OP_N : CUBLAS_OP_T; + + cudaComplex cuBeta{beta.real(), beta.imag()}; + cudaComplex cuAlpha{alpha.real(), alpha.imag()}; + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch_idx * stride_a), lda, + reinterpret_cast(b.data() + batch_idx * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch_idx * stride_c), ldc); + } + + std::vector> c_temp = c; + { + blas_benchmark::utils::CUDAVector, true> c_temp_gpu( + size_c_batch, reinterpret_cast*>(c_temp.data())); + cublas_cplx_routine( + cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, a_gpu, lda, stride_a, + b_gpu, ldb, stride_b, &cuBeta, c_temp_gpu, ldc, stride_c, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors_strided(c_temp, c_ref, stride_c, c_size, + err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, stride_a, b_gpu, ldb, stride_b, + &cuBeta, c_gpu, ldc, stride_c, batch_size); + return; + }; + + cudaEvent_t start, stop; + CUDA_CHECK(cudaEventCreate(&start)); + CUDA_CHECK(cudaEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CUDA_CHECK(cudaEventRecord(start)); + cublas_cplx_routine(cuda_handle, c_t_a, c_t_b, m, n, k, &cuAlpha, + a_gpu, lda, stride_a, b_gpu, ldb, stride_b, + &cuBeta, c_gpu, ldc, stride_c, batch_size); + CUDA_CHECK(cudaEventRecord(stop)); + CUDA_CHECK(cudaEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CUDA_CHECK(cudaStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_cuda(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CUDA_CHECK(cudaEventDestroy(start)); + CUDA_CHECK(cudaEventDestroy(stop)); +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + cublasHandle_t* cuda_handle_ptr, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + + for (auto p : gemm_batched_strided_params) { + std::string t1s, t2s; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, cublasHandle_t* cuda_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t strd_a_mul, index_t strd_b_mul, + index_t strd_c_mul, bool* success) { + run(st, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, strd_a_mul, strd_b_mul, strd_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, cuda_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, cublasHandle_t* cuda_handle_ptr, bool* success) { diff --git a/benchmark/cublas/utils.hpp b/benchmark/cublas/utils.hpp index eeaee7371..362fdce51 100644 --- a/benchmark/cublas/utils.hpp +++ b/benchmark/cublas/utils.hpp @@ -33,6 +33,7 @@ #include "portblas.h" #include +#include #include #include #include @@ -179,6 +180,15 @@ class CUDAVectorBatched : private CUDADeviceMemory { } } + CUDAVectorBatched(size_t matrix_size, size_t batch_count, T* h_v) + : CUDAVectorBatched(matrix_size, batch_count) { + if constexpr (CopyToHost) h_data = h_v; + for (int i = 0; i < batch_count; ++i) { + CUDA_CHECK(cudaMemcpy(d_data[i], &h_v[matrix_size * i], + sizeof(T) * c_matrix_size, cudaMemcpyHostToDevice)); + } + } + ~CUDAVectorBatched() { if constexpr (CopyToHost) { for (int i = 0; i < c_batch_count; ++i) { diff --git a/benchmark/portblas/CMakeLists.txt b/benchmark/portblas/CMakeLists.txt index 785996422..4ac3fdeaa 100644 --- a/benchmark/portblas/CMakeLists.txt +++ b/benchmark/portblas/CMakeLists.txt @@ -75,12 +75,20 @@ if(${BLAS_ENABLE_EXTENSIONS}) list(APPEND sources extension/reduction.cpp) endif() +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(portblas_bench ${sources}) get_filename_component(bench_exec ${portblas_bench} NAME_WE) add_executable(bench_${bench_exec} ${portblas_bench} main.cpp) target_link_libraries(bench_${bench_exec} PRIVATE benchmark Clara::Clara portblas bench_info) target_compile_definitions(bench_${bench_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) + if(${BLAS_ENABLE_COMPLEX}) + if("${bench_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_${bench_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() add_sycl_to_target( TARGET bench_${bench_exec} SOURCES ${portblas_bench} diff --git a/benchmark/portblas/blas3/gemm.cpp b/benchmark/portblas/blas3/gemm.cpp index 51d4869a8..27bb90650 100644 --- a/benchmark/portblas/blas3/gemm.cpp +++ b/benchmark/portblas/blas3/gemm.cpp @@ -177,6 +177,191 @@ void register_benchmark(blas_benchmark::Args& args, #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k, + static_cast(1)); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n, 0); + + auto a_gpu = + blas::helper::allocate>(m * k, q); + auto b_gpu = + blas::helper::allocate>(k * n, q); + auto c_gpu = + blas::helper::allocate>(m * n, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + m * k); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + n * k); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + m * n); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + reference_blas::cgemm(t_a, t_b, m, n, k, + reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + + std::vector> c_temp = c; + + { + auto c_temp_gpu = + blas::helper::allocate>(m * n, + q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, m * n); + sb_handle.wait(copy_temp); + auto gemm_event = _gemm(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, + lda, b_gpu, ldb, beta_sycl, c_temp_gpu, ldc); + sb_handle.wait(gemm_event); + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), m * n); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, + b_gpu, ldb, beta_sycl, c_gpu, ldc); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, + std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_params); +#endif +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/portblas/blas3/gemm_batched.cpp b/benchmark/portblas/blas3/gemm_batched.cpp index 959f9eae7..aabd9449a 100644 --- a/benchmark/portblas/blas3/gemm_batched.cpp +++ b/benchmark/portblas/blas3/gemm_batched.cpp @@ -225,8 +225,8 @@ void register_benchmark(blas::SB_Handle* sb_handle_ptr, bool* success, }; benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( - t1s, t2s, m, k, n, batch_size, batch_type, - mem_type).c_str(), + t1s, t2s, m, k, n, batch_size, batch_type, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, batch_type, success) ->UseRealTime(); @@ -239,13 +239,222 @@ void register_benchmark(blas_benchmark::Args& args, auto gemm_batched_params = blas_benchmark::utils::get_gemm_batched_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, gemm_batched_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_params); #ifdef SB_ENABLE_USM register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_batched_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_params); #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, int batch_type_i, + bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + auto batch_type = static_cast(batch_type_i); + + index_t lda = t_a[0] == 'n' ? m : k; + index_t ldb = t_b[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_size); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(m * k * batch_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(k * n * batch_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(m * n * batch_size, + scalar_t(0)); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + auto _base = [=](index_t dim0, index_t dim1, index_t idx) { + return dim0 * dim1 * idx; + }; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + _base(m, k, batch_idx)), lda, + reinterpret_cast(b.data() + _base(k, n, batch_idx)), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + _base(m, n, batch_idx)), ldc); + } + +#endif // BLAS_VERIFY_BENCHMARK + + auto a_gpu = blas::helper::allocate>( + m * k * batch_size, q); + auto b_gpu = blas::helper::allocate>( + k * n * batch_size, q); + auto c_gpu = blas::helper::allocate>( + m * n * batch_size, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + m * k * batch_size); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + n * k * batch_size); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + m * n * batch_size); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + std::vector> c_temp = c; + { + auto c_temp_gpu = + blas::helper::allocate>( + m * n * batch_size, q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, m * n * batch_size); + sb_handle.wait(copy_temp); + auto gemm_batched_event = _gemm_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, b_gpu, ldb, + beta_sycl, c_temp_gpu, ldc, batch_size, batch_type); + sb_handle.wait(gemm_batched_event); + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), + m * n * batch_size); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif // BLAS_VERIFY_BENCHMARK + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm_batched(sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, + a_gpu, lda, b_gpu, ldb, beta_sycl, c_gpu, ldc, + batch_size, batch_type); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm batched. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark( + blas::SB_Handle* sb_handle_ptr, bool* success, std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k, batch_size; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + batch_type) = p; + // Only batch_type == strided is supported with Complex data + if (batch_type == 1) { + std::cerr << "Interleaved memory for gemm_batched operator is not " + "supported whith complex data type\n"; + continue; + } + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + int batch_type, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, batch_type, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + batch_type, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_params); +#endif +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/portblas/blas3/gemm_batched_strided.cpp b/benchmark/portblas/blas3/gemm_batched_strided.cpp index 0fdb29db9..a24a2a188 100644 --- a/benchmark/portblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/portblas/blas3/gemm_batched_strided.cpp @@ -195,7 +195,8 @@ void register_benchmark( benchmark::RegisterBenchmark( blas_benchmark::utils::get_name( t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, - stride_c_mul, mem_type).c_str(), + stride_c_mul, mem_type) + .c_str(), BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, stride_a_mul, stride_b_mul, stride_c_mul, success) ->UseRealTime(); @@ -208,13 +209,236 @@ void register_benchmark(blas_benchmark::Args& args, auto gemm_batched_strided_params = blas_benchmark::utils::get_gemm_batched_strided_params(args); register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, gemm_batched_strided_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_strided_params); #ifdef SB_ENABLE_USM register_benchmark( - sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, gemm_batched_strided_params); + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_strided_params); #endif } +#ifdef BLAS_ENABLE_COMPLEX +template +void run(benchmark::State& state, blas::SB_Handle* sb_handle_ptr, int t1, + int t2, index_t m, index_t k, index_t n, std::complex alpha, + std::complex beta, index_t batch_size, index_t stride_a_mul, + index_t stride_b_mul, index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>( + state, sb_handle_ptr->get_queue()); + + // Standard test setup. + std::string t1s = blas_benchmark::utils::from_transpose_enum( + static_cast(t1)); + std::string t2s = blas_benchmark::utils::from_transpose_enum( + static_cast(t2)); + const char* t_a = t1s.c_str(); + const char* t_b = t2s.c_str(); + + const bool trA = t_a[0] != 'n'; + const bool trB = t_b[0] != 'n'; + + index_t lda = trA ? k : m; + index_t ldb = trB ? n : k; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + blas::SB_Handle& sb_handle = *sb_handle_ptr; + auto q = sb_handle.get_queue(); + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, + scalar_t(0)); + +#ifdef BLAS_VERIFY_BENCHMARK + // Run a first time with a verification of the results + std::vector> c_ref = c; + for (int batch_idx = 0; batch_idx < batch_size; batch_idx++) { + reference_blas::cgemm( + t_a, t_b, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch_idx * stride_a), lda, + reinterpret_cast(b.data() + batch_idx * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch_idx * stride_c), ldc); + } + +#endif + + auto a_gpu = blas::helper::allocate>( + size_a_batch, q); + auto b_gpu = blas::helper::allocate>( + size_b_batch, q); + auto c_gpu = blas::helper::allocate>( + size_c_batch, q); + + auto copy_a = blas::helper::copy_to_device( + q, reinterpret_cast*>(a.data()), a_gpu, + size_a_batch); + auto copy_b = blas::helper::copy_to_device( + q, reinterpret_cast*>(b.data()), b_gpu, + size_b_batch); + auto copy_c = blas::helper::copy_to_device( + q, reinterpret_cast*>(c.data()), c_gpu, + size_c_batch); + + sb_handle.wait({copy_a, copy_b, copy_c}); + + // Kernel expects sycl::complex and not std::complex data + blas::complex_sycl alpha_sycl(alpha); + blas::complex_sycl beta_sycl(beta); + +#ifdef BLAS_VERIFY_BENCHMARK + std::vector> c_temp = c; + { + auto c_temp_gpu = + blas::helper::allocate>( + size_c_batch, q); + auto copy_temp = blas::helper::copy_to_device( + q, reinterpret_cast*>(c_temp.data()), + c_temp_gpu, size_c_batch); + sb_handle.wait(copy_temp); + auto gemm_batched_strided_event = _gemm_strided_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, stride_a, b_gpu, + ldb, stride_b, beta_sycl, c_temp_gpu, ldc, stride_c, batch_size); + sb_handle.wait(gemm_batched_strided_event); + + auto copy_out = blas::helper::copy_to_host( + q, c_temp_gpu, + reinterpret_cast*>(c_temp.data()), + size_c_batch); + sb_handle.wait(copy_out); + + blas::helper::deallocate(c_temp_gpu, q); + } + + std::ostringstream err_stream; + if (!::utils::compare_vectors_strided(c_temp, c_ref, stride_c, + c_size, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_method_def = [&]() -> std::vector { + auto event = _gemm_strided_batched( + sb_handle, *t_a, *t_b, m, n, k, alpha_sycl, a_gpu, lda, stride_a, b_gpu, + ldb, stride_b, beta_sycl, c_gpu, ldc, stride_c, batch_size); + sb_handle.wait(event); + return event; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_method_def); + sb_handle.wait(); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + + blas_benchmark::utils::calc_avg_counters(state); + + blas::helper::deallocate(a_gpu, q); + blas::helper::deallocate(b_gpu, q); + blas::helper::deallocate(c_gpu, q); +}; + +/*! @brief Register & run benchmark of complex data types gemm batched strided. + * Function is similar to register_benchmark + * + * @tparam scalar_t element data type of underlying complex (float or double) + * @tparam mem_alloc USM or Buffer memory allocation + */ +template +void register_cplx_benchmark( + blas::SB_Handle* sb_handle_ptr, bool* success, std::string mem_type, + std::vector> params) { + for (auto p : params) { + std::string t1s, t2s; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t1s, t2s, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t1 = static_cast(blas_benchmark::utils::to_transpose_enum(t1s)); + int t2 = static_cast(blas_benchmark::utils::to_transpose_enum(t2s)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, blas::SB_Handle* sb_handle_ptr, + int t1, int t2, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t stride_a_mul, index_t stride_b_mul, + index_t stride_c_mul, bool* success) { + run(st, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, + batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t1s, t2s, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, mem_type) + .c_str(), + BM_lambda, sb_handle_ptr, t1, t2, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + blas::SB_Handle* sb_handle_ptr, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_BUFFER, + gemm_batched_strided_params); +#ifdef SB_ENABLE_USM + register_cplx_benchmark( + sb_handle_ptr, success, blas_benchmark::utils::MEM_TYPE_USM, + gemm_batched_strided_params); +#endif +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, blas::SB_Handle* sb_handle_ptr, bool* success) { diff --git a/benchmark/rocblas/CMakeLists.txt b/benchmark/rocblas/CMakeLists.txt index caa884725..64a559931 100644 --- a/benchmark/rocblas/CMakeLists.txt +++ b/benchmark/rocblas/CMakeLists.txt @@ -77,6 +77,9 @@ set(sources ) +# Operators supporting COMPLEX types benchmarking +set(CPLX_OPS "gemm" "gemm_batched" "gemm_batched_strided") + # Add individual benchmarks for each method foreach(rocblas_benchmark ${sources}) get_filename_component(rocblas_bench_exec ${rocblas_benchmark} NAME_WE) @@ -84,7 +87,11 @@ foreach(rocblas_benchmark ${sources}) target_link_libraries(bench_rocblas_${rocblas_bench_exec} PRIVATE benchmark Clara::Clara roc::rocblas bench_info) target_compile_definitions(bench_rocblas_${rocblas_bench_exec} PRIVATE -DBLAS_INDEX_T=${BLAS_BENCHMARK_INDEX_TYPE}) target_include_directories(bench_rocblas_${rocblas_bench_exec} PRIVATE ${PORTBLAS_INCLUDE} ${rocblas_INCLUDE_DIRS} ${CBLAS_INCLUDE} ${BLAS_BENCH} ${PORTBLAS_COMMON_INCLUDE_DIR}) - + if(${BLAS_ENABLE_COMPLEX}) + if("${rocblas_bench_exec}" IN_LIST CPLX_OPS) + target_compile_definitions(bench_rocblas_${rocblas_bench_exec} PRIVATE BLAS_ENABLE_COMPLEX=1) + endif() + endif() # Even though rocblas does not use sycl, the common tools indirectly include sycl headers. add_sycl_to_target( TARGET bench_rocblas_${rocblas_bench_exec} diff --git a/benchmark/rocblas/blas3/gemm.cpp b/benchmark/rocblas/blas3/gemm.cpp index b403bafec..ca07ba2ba 100644 --- a/benchmark/rocblas/blas3/gemm.cpp +++ b/benchmark/rocblas/blas3/gemm.cpp @@ -38,6 +38,18 @@ static inline void rocblas_gemm_f(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_f(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_cgemm(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_zgemm(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -183,6 +195,177 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; + +template +void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, + int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + + index_t lda = t_a_str[0] == 'n' ? m : k; + index_t ldb = t_b_str[0] == 'n' ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm, scalar_t>(state, beta, m, n, k); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + t_a_str[0] == 'n' ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + t_b_str[0] == 'n' ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + const int a_size = m * k; + const int b_size = k * n; + const int c_size = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(a_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(b_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(c_size, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVector> a_gpu( + a_size, reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVector> b_gpu( + b_size, reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVector> c_gpu( + c_size, reinterpret_cast*>(c.data())); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference gemm + std::vector> c_ref = c; + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data()), lda, + reinterpret_cast(b.data()), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data()), ldc); + + // Rocblas verification gemm + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVector, true> c_temp_gpu( + c_size, reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_temp_gpu, ldc); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_gpu, ldc); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_f(rb_handle, trans_a_rb, trans_b_rb, m, n, k, + &rocAlpha, a_gpu, lda, b_gpu, ldb, &rocBeta, + c_gpu, ldc); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_params = + blas_benchmark::utils::get_blas3_cplx_params(args); + + for (auto p : gemm_params) { + std::string t_a, t_b; + index_t m, n, k; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i) = p; + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t1i, int t2i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, bool* success) { + run(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/benchmark/rocblas/blas3/gemm_batched.cpp b/benchmark/rocblas/blas3/gemm_batched.cpp index 4cfb1418d..40147d5ff 100644 --- a/benchmark/rocblas/blas3/gemm_batched.cpp +++ b/benchmark/rocblas/blas3/gemm_batched.cpp @@ -38,6 +38,18 @@ static inline void rocblas_gemm_batched_f(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_batched_f(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_cgemm_batched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS(rocblas_zgemm_batched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i, index_t t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -209,6 +221,194 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; +template +void run(benchmark::State& state, rocblas_handle& rb_handle, index_t t_a_i, + index_t t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_size, int batch_type_i, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard setup + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + auto batch_type = static_cast(batch_type_i); + + const bool trA = (t_a_str[0] == 'n'); + const bool trB = (t_b_str[0] == 'n'); + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched, scalar_t>( + state, beta, m, n, k, batch_size); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + trA ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + trB ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + const int a_size = m * k; + const int b_size = k * n; + const int c_size = m * n; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(a_size * batch_size); + std::vector> b = + blas_benchmark::utils::random_cplx_data(b_size * batch_size); + std::vector> c = + blas_benchmark::utils::const_cplx_data(c_size * batch_size, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVectorBatched> a_batched_gpu( + a_size, batch_size, reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVectorBatched> b_batched_gpu( + b_size, batch_size, reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVectorBatched> c_batched_gpu( + c_size, batch_size); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference batched gemm + std::vector> c_ref = c; + for (int batch = 0; batch < batch_size; batch++) { + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch * a_size), lda, + reinterpret_cast(b.data() + batch * b_size), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch * c_size), ldc); + } + + // Rocblas verification gemm_batched + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVectorBatched, true> + c_temp_gpu(c_size, batch_size, + reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_temp_gpu, ldc, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors(c_temp, c_ref, err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_batched_gpu, ldc, batch_size); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_batched_f( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, b_batched_gpu, ldb, &rocBeta, c_batched_gpu, ldc, batch_size); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_batched_params = + blas_benchmark::utils::get_gemm_cplx_batched_params(args); + + for (auto p : gemm_batched_params) { + std::string t_a, t_b; + index_t m, n, k, batch_size; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + int batch_type; + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + batch_type) = p; + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + if (batch_type == 1) { + std::cerr << "interleaved memory for gemm_batched operator is not " + "supported by rocBLAS\n"; + continue; + } + + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t_a_i, int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + int batch_type, bool* success) { + run(st, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, + batch_size, batch_type, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, batch_size, batch_type, + blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, + batch_type, success) + ->UseRealTime(); + } +} +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/benchmark/rocblas/blas3/gemm_batched_strided.cpp b/benchmark/rocblas/blas3/gemm_batched_strided.cpp index 15dac9896..3ecbff82c 100644 --- a/benchmark/rocblas/blas3/gemm_batched_strided.cpp +++ b/benchmark/rocblas/blas3/gemm_batched_strided.cpp @@ -40,6 +40,20 @@ static inline void rocblas_gemm_strided_batched(args_t&&... args) { return; } +#ifdef BLAS_ENABLE_COMPLEX +template +static inline void rocblas_cplx_gemm_strided_batched(args_t&&... args) { + if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS( + rocblas_cgemm_strided_batched(std::forward(args)...)); + } else if constexpr (std::is_same_v) { + CHECK_ROCBLAS_STATUS( + rocblas_zgemm_strided_batched(std::forward(args)...)); + } + return; +} +#endif + template void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, int t_b_i, index_t m, index_t k, index_t n, scalar_t alpha, @@ -219,6 +233,209 @@ void register_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, } } +#ifdef BLAS_ENABLE_COMPLEX +template +using rocComplex = + typename std::conditional::type; + +template +void run(benchmark::State& state, rocblas_handle& rb_handle, int t_a_i, + int t_b_i, index_t m, index_t k, index_t n, + std::complex alpha, std::complex beta, + index_t batch_size, index_t stride_a_mul, index_t stride_b_mul, + index_t stride_c_mul, bool* success) { + // initialize the state label + blas_benchmark::utils::set_benchmark_label>(state); + + // Standard test setup. + std::string t_a = blas_benchmark::utils::from_transpose_enum( + static_cast(t_a_i)); + std::string t_b = blas_benchmark::utils::from_transpose_enum( + static_cast(t_b_i)); + const char* t_a_str = t_a.c_str(); + const char* t_b_str = t_b.c_str(); + + const bool trA = (t_a_str[0] == 'n'); + const bool trB = (t_b_str[0] == 'n'); + + index_t lda = trA ? m : k; + index_t ldb = trB ? k : n; + index_t ldc = m; + + blas_benchmark::utils::init_level_3_cplx_counters< + blas_benchmark::utils::Level3Op::gemm_batched_strided, scalar_t>( + state, beta, m, n, k, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul); + + // Matrix options (rocBLAS) + const rocblas_operation trans_a_rb = + trA ? rocblas_operation_none : rocblas_operation_transpose; + const rocblas_operation trans_b_rb = + trB ? rocblas_operation_none : rocblas_operation_transpose; + + // rocBLAS complex alpha & beta + rocComplex rocBeta{beta.real(), beta.imag()}; + rocComplex rocAlpha{alpha.real(), alpha.imag()}; + + // Data sizes + // Elementary matrices + const index_t a_size = m * k; + const index_t b_size = k * n; + const index_t c_size = m * n; + // Strides + const index_t stride_a = stride_a_mul * a_size; + const index_t stride_b = stride_b_mul * b_size; + const index_t stride_c = stride_c_mul * c_size; + // Batched matrices + const int size_a_batch = a_size + (batch_size - 1) * stride_a; + const int size_b_batch = b_size + (batch_size - 1) * stride_b; + const int size_c_batch = c_size + (batch_size - 1) * stride_c; + + // Matrices + std::vector> a = + blas_benchmark::utils::random_cplx_data(size_a_batch); + std::vector> b = + blas_benchmark::utils::random_cplx_data(size_b_batch); + std::vector> c = + blas_benchmark::utils::const_cplx_data(size_c_batch, 0); + + { + // Device memory allocation & H2D copy + blas_benchmark::utils::HIPVectorBatchedStrided> + a_batched_gpu(a_size, batch_size, stride_a, + reinterpret_cast*>(a.data())); + blas_benchmark::utils::HIPVectorBatchedStrided> + b_batched_gpu(b_size, batch_size, stride_b, + reinterpret_cast*>(b.data())); + blas_benchmark::utils::HIPVectorBatchedStrided> + c_batched_gpu(c_size, batch_size, stride_c, + reinterpret_cast*>(c.data())); + +#ifdef BLAS_VERIFY_BENCHMARK + // Reference gemm batched strided (strided loop of gemm) + std::vector> c_ref = c; + for (int batch = 0; batch < batch_size; batch++) { + reference_blas::cgemm( + t_a_str, t_b_str, m, n, k, reinterpret_cast(&alpha), + reinterpret_cast(a.data() + batch * stride_a), lda, + reinterpret_cast(b.data() + batch * stride_b), ldb, + reinterpret_cast(&beta), + reinterpret_cast(c_ref.data() + batch * stride_c), ldc); + } + + // Rocblas verification gemm_batched_strided + std::vector> c_temp = c; + { + blas_benchmark::utils::HIPVectorBatchedStrided, true> + c_temp_gpu(c_size, batch_size, stride_c, + reinterpret_cast*>(c_temp.data())); + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_temp_gpu, + ldc, stride_c, batch_size); + } + + std::ostringstream err_stream; + if (!utils::compare_vectors_strided(c_temp, c_ref, stride_c, c_size, + err_stream, "")) { + const std::string& err_str = err_stream.str(); + state.SkipWithError(err_str.c_str()); + *success = false; + }; +#endif + + auto blas_warmup = [&]() -> void { + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_batched_gpu, + ldc, stride_c, batch_size); + return; + }; + + hipEvent_t start, stop; + CHECK_HIP_ERROR(hipEventCreate(&start)); + CHECK_HIP_ERROR(hipEventCreate(&stop)); + + auto blas_method_def = [&]() -> std::vector { + CHECK_HIP_ERROR(hipEventRecord(start, NULL)); + rocblas_cplx_gemm_strided_batched( + rb_handle, trans_a_rb, trans_b_rb, m, n, k, &rocAlpha, a_batched_gpu, + lda, stride_a, b_batched_gpu, ldb, stride_b, &rocBeta, c_batched_gpu, + ldc, stride_c, batch_size); + CHECK_HIP_ERROR(hipEventRecord(stop, NULL)); + CHECK_HIP_ERROR(hipEventSynchronize(stop)); + return std::vector{start, stop}; + }; + + // Warmup + blas_benchmark::utils::warmup(blas_warmup); + CHECK_HIP_ERROR(hipStreamSynchronize(NULL)); + + blas_benchmark::utils::init_counters(state); + + // Measure + for (auto _ : state) { + // Run + std::tuple times = + blas_benchmark::utils::timef_hip(blas_method_def); + + // Report + blas_benchmark::utils::update_counters(state, times); + } + + state.SetBytesProcessed(state.iterations() * + state.counters["bytes_processed"]); + state.SetItemsProcessed(state.iterations() * state.counters["n_fl_ops"]); + + blas_benchmark::utils::calc_avg_counters(state); + + CHECK_HIP_ERROR(hipEventDestroy(start)); + CHECK_HIP_ERROR(hipEventDestroy(stop)); + } // release device memory via utils::DeviceVector destructors +}; + +template +void register_cplx_benchmark(blas_benchmark::Args& args, + rocblas_handle& rb_handle, bool* success) { + auto gemm_batched_strided_params = + blas_benchmark::utils::get_gemm_batched_strided_cplx_params( + args); + + for (auto p : gemm_batched_strided_params) { + std::string t_a, t_b; + index_t m, n, k, batch_size, stride_a_mul, stride_b_mul, stride_c_mul; + scalar_t alpha_r, alpha_i, beta_r, beta_i; + + std::tie(t_a, t_b, m, k, n, alpha_r, alpha_i, beta_r, beta_i, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul) = p; + int t_a_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_a)); + int t_b_i = static_cast(blas_benchmark::utils::to_transpose_enum(t_b)); + std::complex alpha{alpha_r, alpha_i}; + std::complex beta{beta_r, beta_i}; + + auto BM_lambda = [&](benchmark::State& st, rocblas_handle rb_handle, + int t1i, int t2i, index_t m, index_t k, index_t n, + std::complex alpha, + std::complex beta, index_t batch_size, + index_t strd_a_mul, index_t strd_b_mul, + index_t strd_c_mul, bool* success) { + run(st, rb_handle, t1i, t2i, m, k, n, alpha, beta, batch_size, + strd_a_mul, strd_b_mul, strd_c_mul, success); + }; + benchmark::RegisterBenchmark( + blas_benchmark::utils::get_name>( + t_a, t_b, m, k, n, batch_size, stride_a_mul, stride_b_mul, + stride_c_mul, blas_benchmark::utils::MEM_TYPE_USM) + .c_str(), + BM_lambda, rb_handle, t_a_i, t_b_i, m, k, n, alpha, beta, batch_size, + stride_a_mul, stride_b_mul, stride_c_mul, success) + ->UseRealTime(); + } +} + +#endif + namespace blas_benchmark { void create_benchmark(blas_benchmark::Args& args, rocblas_handle& rb_handle, bool* success) { diff --git a/common/include/common/blas3_state_counters.hpp b/common/include/common/blas3_state_counters.hpp index c7515eb07..68e332773 100644 --- a/common/include/common/blas3_state_counters.hpp +++ b/common/include/common/blas3_state_counters.hpp @@ -76,6 +76,66 @@ init_level_3_counters(benchmark::State& state, scalar_t beta = 0, index_t m = 0, return; } +#ifdef BLAS_ENABLE_COMPLEX +template +inline typename std::enable_if::type +init_level_3_cplx_counters( + benchmark::State& state, + std::complex beta = std::complex(0, 0), index_t m = 0, + index_t n = 0, index_t k = 0, index_t batch_size = 1, + index_t stride_a_mul = 1, index_t stride_b_mul = 1, + index_t stride_c_mul = 1) { + // Google-benchmark counters are double. + double beta_real_d = static_cast(beta.real()); + double beta_imag_d = static_cast(beta.imag()); + double m_d = static_cast(m); + double n_d = static_cast(n); + double k_d = static_cast(k); + double batch_size_d = static_cast(batch_size); + state.counters["beta_real"] = beta_real_d; + state.counters["beta_imag"] = beta_real_d; + state.counters["m"] = m_d; + state.counters["n"] = n_d; + state.counters["k"] = k_d; + state.counters["batch_size"] = batch_size_d; + if constexpr (op == Level3Op::gemm_batched_strided) { + double stride_a_mul_d = static_cast(stride_a_mul); + double stride_b_mul_d = static_cast(stride_b_mul); + double stride_c_mul_d = static_cast(stride_c_mul); + + state.counters["stride_a_mul"] = stride_a_mul_d; + state.counters["stride_b_mul"] = stride_b_mul_d; + state.counters["stride_c_mul"] = stride_c_mul_d; + } + + // Counters here should be reviewed as pure real/imaginary cases result in + // less flops + + bool beta_zero = (beta.real() == scalar_t{0}) && (beta.imag() == scalar_t{0}); + + const double nflops_AtimesB = + k_d * m_d * n_d * 6 + k_d * m_d * n_d * 2; // MulFlops + AddFlops + double nflops_timesAlpha = m_d * n_d * 6; + const double nflops_addBetaC = + beta_zero ? 0 : 6 * m_d * n_d + 2 * m_d * n_d; // MulFlops + AddFlops + const double nflops_tot = + (nflops_AtimesB + nflops_timesAlpha + nflops_addBetaC) * batch_size_d; + state.counters["n_fl_ops"] = nflops_tot; + + const double mem_readA = m_d * k_d; + const double mem_readB = k_d * n_d; + const double mem_writeC = m_d * n_d; + const double mem_readC = beta_zero ? 0 : m_d * n_d; + const double total_mem = (mem_readA + mem_readB + mem_readC + mem_writeC) * + batch_size_d * sizeof(scalar_t) * 2; + state.counters["bytes_processed"] = total_mem; + return; +} + +#endif + template inline typename std::enable_if::type init_level_3_counters(benchmark::State& state, scalar_t beta = 0, index_t m = 0, diff --git a/common/include/common/common_utils.hpp b/common/include/common/common_utils.hpp index a569ed2ff..251ee9b7f 100644 --- a/common/include/common/common_utils.hpp +++ b/common/include/common/common_utils.hpp @@ -53,6 +53,24 @@ using gemm_batched_strided_param_t = std::tuple; +#ifdef BLAS_ENABLE_COMPLEX +template +using blas3_cplx_param_t = + std::tuple; + +template +using gemm_batched_strided_cplx_param_t = + std::tuple; + +template +using gemm_batched_cplx_param_t = + std::tuple; +#endif + using reduction_param_t = std::tuple; template @@ -485,6 +503,157 @@ static inline std::vector> get_blas3_params( } } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @fn get_blas3_cplx_params for complex data type + * @brief Returns a vector containing the blas 3 benchmark cplx parameters, + * either read from a file according to the command-line args, or the default + * ones. So far only used/supported for GEMM & its batched extensions. + */ +template +static inline std::vector> get_blas3_cplx_params( + Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> blas3_default; + constexpr index_t dmin = 32, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t m = dmin; m <= dmax; m *= 8) { + for (index_t k = dmin; k <= dmax; k *= 8) { + for (index_t n = dmin; n <= dmax; n *= 8) { + blas3_default.push_back( + std::make_tuple(t1, t2, m, k, n, alpha.real(), alpha.imag(), + beta.real(), beta.imag())); + } + } + } + } + } + return blas3_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 9) { + throw std::runtime_error( + "invalid number of parameters (9 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8])); + } catch (...) { + throw std::runtime_error("invalid parameter"); + } + }); + } +} + +/** + * @fn get_gemm_batched_strided_cplx_params for complex data type + * @brief Returns a vector containing the gemm_batched_strided cplx benchmark + * parameters, either read from a file according to the command-line args, or + * the default ones. + */ +template +inline std::vector> +get_gemm_batched_strided_cplx_params(Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> + gemm_batched_strided_default; + constexpr index_t dmin = 128, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + index_t batch_size = 8; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t m = dmin; m <= dmax; m *= 8) { + gemm_batched_strided_default.push_back( + std::make_tuple(t1, t2, m, m, m, alpha.real(), alpha.imag(), + beta.real(), beta.imag(), batch_size, 2, 2, 2)); + } + } + } + return gemm_batched_strided_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 13) { + throw std::runtime_error( + "invalid number of parameters (13 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8]), + str_to_int(v[9]), str_to_int(v[10]), + str_to_int(v[11]), str_to_int(v[12])); + } catch (...) { + std::throw_with_nested(std::runtime_error("invalid parameter")); + } + }); + } +} + +/** + * @fn get_gemm_cplx_batched_params + * @brief Returns a vector containing the gemm_batched cplx benchmark + * parameters, either read from a file according to the command-line args, or + * the default ones. + */ +template +inline std::vector> +get_gemm_cplx_batched_params(Args& args) { + if (args.csv_param.empty()) { + warning_no_csv(); + std::vector> gemm_batched_default; + constexpr index_t dmin = 128, dmax = 8192; + std::vector dtranspose = {"n", "t"}; + std::complex alpha{1, 1}; + std::complex beta{1, 1}; + index_t batch_size = 8; + int batch_type = 0; + for (std::string& t1 : dtranspose) { + for (std::string& t2 : dtranspose) { + for (index_t n = dmin; n <= dmax; n *= 8) { + gemm_batched_default.push_back(std::make_tuple( + t1, t2, n, n, n, alpha.real(), alpha.imag(), beta.real(), + beta.imag(), batch_size, batch_type)); + } + } + } + return gemm_batched_default; + } else { + return parse_csv_file>( + args.csv_param, [&](std::vector& v) { + if (v.size() != 11) { + throw std::runtime_error( + "invalid number of parameters (11 expected)"); + } + try { + return std::make_tuple( + v[0].c_str(), v[1].c_str(), str_to_int(v[2]), + str_to_int(v[3]), str_to_int(v[4]), + str_to_scalar(v[5]), str_to_scalar(v[6]), + str_to_scalar(v[7]), str_to_scalar(v[8]), + str_to_int(v[9]), str_to_batch_type(v[10])); + } catch (...) { + std::throw_with_nested(std::runtime_error("invalid parameter")); + } + }); + } +} +#endif + /** * @fn get_gemm_batched_params * @brief Returns a vector containing the gemm_batched benchmark parameters, @@ -1334,6 +1503,17 @@ inline std::string get_type_name() { return "double"; } +#ifdef BLAS_ENABLE_COMPLEX +template <> +inline std::string get_type_name>() { + return "complex"; +} +template <> +inline std::string get_type_name>() { + return "complex"; +} +#endif + /** * @fn random_scalar * @brief Generates a random scalar value, using an arbitrary low quality @@ -1372,6 +1552,67 @@ static inline std::vector random_data(size_t size) { return v; } +#ifdef BLAS_ENABLE_COMPLEX +/** + * @fn random_cplx_scalar + * @brief Generates a random complex value, using an arbitrary low quality + * algorithm. + */ +template +static inline std::complex random_cplx_scalar() { + scalar_t rl = 1e-3 * ((rand() % 2000) - 1000); + scalar_t im = 1e-3 * ((rand() % 2000) - 1000); + return std::complex(rl, im); +} + +/** + * @brief Generates a random complex in the specified range of its underlying + * data elements (real & imag) + * @param rangeMin range minimum + * @param rangeMax range maximum + */ +template +static inline std::complex random_cplx_scalar(scalar_t rangeMin, + scalar_t rangeMax) { + static std::random_device rd; + static std::default_random_engine gen(rd()); + std::uniform_real_distribution disRl(rangeMin, rangeMax); + std::uniform_real_distribution disIm(rangeMin, rangeMax); + + return std::complex(disRl(gen), disIm(gen)); +} + +/** + * @fn random_cplx_data + * @brief Generates a random vector of complex values, using a uniform + * distribution of the underlying data elements (real & imag). + */ +template +static inline std::vector> random_cplx_data( + size_t size) { + std::vector> v(size); + + for (std::complex& e : v) { + e = random_cplx_scalar(scalar_t{-2}, scalar_t{5}); + } + return v; +} + +/** + * @fn const_cplx_data + * @brief Generates a vector of constant complex values, of a given length. + */ +template +static inline std::vector> const_cplx_data( + size_t size, scalar_t const_value = 0) { + std::vector> v(size); + std::complex const_cplx_value{const_value, const_value}; + std::fill(v.begin(), v.end(), const_cplx_value); + return v; +} + +#endif // BLAS_ENABLE_COMPLEX + /** * @breif Fills a lower or upper triangular matrix suitable for TRSM testing * @param A The matrix to fill. Size must be at least m * lda @@ -1575,17 +1816,39 @@ static inline void calc_avg_counters(benchmark::State& state) { #define BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success) #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +/** Registers benchmark for the float complex data type + * @see BLAS_REGISTER_BENCHMARK + */ +#define BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success) \ + register_cplx_benchmark(args, sb_handle_ptr, success) +#else +#define BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success) +#endif + +#if defined(BLAS_ENABLE_COMPLEX) & defined(BLAS_DATA_TYPE_DOUBLE) +/** Registers benchmark for the double complex data type + * @see BLAS_REGISTER_BENCHMARK + */ +#define BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success) \ + register_cplx_benchmark(args, sb_handle_ptr, success) +#else +#define BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success) +#endif + /** Registers benchmark for all supported data types. * Expects register_benchmark to exist. * @param args Reference to blas_benchmark::Args * @param sb_handle_ptr Pointer to blas::SB_Handle * @param[out] success Pointer to boolean indicating success */ -#define BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success) \ - do { \ - BLAS_REGISTER_BENCHMARK_FLOAT(args, sb_handle_ptr, success); \ - BLAS_REGISTER_BENCHMARK_DOUBLE(args, sb_handle_ptr, success); \ - BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success); \ +#define BLAS_REGISTER_BENCHMARK(args, sb_handle_ptr, success) \ + do { \ + BLAS_REGISTER_BENCHMARK_FLOAT(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_DOUBLE(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_HALF(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_CPLX_FLOAT(args, sb_handle_ptr, success); \ + BLAS_REGISTER_BENCHMARK_CPLX_DOUBLE(args, sb_handle_ptr, success); \ } while (false) #endif diff --git a/common/include/common/set_benchmark_label.hpp b/common/include/common/set_benchmark_label.hpp index b1d4c3ca7..9495a3195 100644 --- a/common/include/common/set_benchmark_label.hpp +++ b/common/include/common/set_benchmark_label.hpp @@ -28,6 +28,10 @@ #include #include +#ifdef BLAS_ENABLE_COMPLEX +#define SYCL_EXT_ONEAPI_COMPLEX +#include +#endif #ifdef BUILD_CUBLAS_BENCHMARKS #include @@ -178,6 +182,20 @@ inline void add_datatype_info( } #endif // BLAS_DATA_TYPE_HALF +#ifdef BLAS_ENABLE_COMPLEX +template <> +inline void add_datatype_info>( + std::map& key_value_map) { + key_value_map["@datatype"] = "complex"; +} + +template <> +inline void add_datatype_info>( + std::map& key_value_map) { + key_value_map["@datatype"] = "complex"; +} +#endif // BLAS_ENABLE_COMPLEX + } // namespace datatype_info inline void set_label(benchmark::State& state,