From 2dc363db653413af39acd416f09d57c101de3004 Mon Sep 17 00:00:00 2001 From: Ouadie EL FAROUKI Date: Fri, 13 Oct 2023 17:06:21 +0100 Subject: [PATCH] Separated complex gemm load store & addressed PR comments --- doc/Gemm.md | 2 +- src/operations/blas3/gemm_load_store.hpp | 144 --------------- .../blas3/gemm_load_store_complex.hpp | 174 ++++++++++++++++++ src/operations/blas3/gemm_local.hpp | 3 + .../blas3/gemm_no_local_full_vec.hpp | 3 + .../blas3/gemm_no_local_partial_vec.hpp | 3 + test/unittest/blas3/blas3_gemm_common.hpp | 10 +- 7 files changed, 189 insertions(+), 150 deletions(-) create mode 100644 src/operations/blas3/gemm_load_store_complex.hpp diff --git a/doc/Gemm.md b/doc/Gemm.md index 0264e3d4c..653549212 100644 --- a/doc/Gemm.md +++ b/doc/Gemm.md @@ -100,7 +100,7 @@ The core of the `GEMM` computation is as follows: ## Vectorized Loading/Storing -Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/` . +Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/`*(this feature is limited to non-complex data types)*. These functions are pretty simple but there are some special considerations for how they are used, particularly around whether the matrices are transposed or not. If a matrix is transposed this changes the data layout such that elements are no longer contiguous in memory. diff --git a/src/operations/blas3/gemm_load_store.hpp b/src/operations/blas3/gemm_load_store.hpp index 7ae45ce5d..ef44cbfe6 100644 --- a/src/operations/blas3/gemm_load_store.hpp +++ b/src/operations/blas3/gemm_load_store.hpp @@ -125,149 +125,5 @@ struct Packetize { } }; -#ifdef BLAS_ENABLE_COMPLEX -/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in - * Packetize. It serves as a temporary workaround to the upcoming - * sycl::vec container - * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc - * and only supports size = 1. - * @tparam DataT Complex type of the vector's data - * @tparam NumElements Elements count of the vector (only 1 is supported) - */ -template -class vec_complex { - static_assert(NumElements == 1, - "Vector wrapper arround sycl::complex of size>1 unsupported."); - using address_t = cl::sycl::access::address_space; - using decorated_t = cl::sycl::access::decorated; - using DataType = DataT; - static constexpr int getNumElements() { return NumElements; } - size_t size() const noexcept { return NumElements; } - - private: - DataType m_Data; - - public: - vec_complex() = default; - - constexpr vec_complex(const vec_complex &rhs) = default; - constexpr vec_complex(vec_complex &&rhs) = default; - constexpr vec_complex &operator=(const vec_complex &rhs) = default; - - vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} - - // Conversion operator (valid with NumElements==1) - operator DataT() const { return m_Data; } - - // Subscript operators - DataT &operator[](int i) { - assert(i < NumElements); - return (m_Data); - } - const DataT &operator[](int i) const { - assert(i < NumElements); - return (m_Data); - } - - // Binary Ops - // Multiply - vec_complex operator*(const vec_complex &rhs) { - return (vec_complex{m_Data * static_cast(rhs)}); - } - - vec_complex operator*(const DataType &rhs) { - return (vec_complex{m_Data * rhs}); - } - - // Compound Multiply - vec_complex &operator*=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator*=(const vec_complex &rhs) { - this->m_Data = this->m_Data * static_cast(rhs); - return (*this); - } - - // Add - vec_complex operator+(const vec_complex &rhs) { - return (vec_complex{m_Data + static_cast(rhs)}); - } - - vec_complex operator+(const DataType &rhs) { - return (vec_complex{m_Data + rhs}); - } - - // Compound Add - vec_complex &operator+=(const DataType &rhs) { - this->m_Data = this->m_Data * rhs; - return (*this); - } - - vec_complex &operator+=(const vec_complex &rhs) { - this->m_Data = this->m_Data + static_cast(rhs); - return (*this); - } - - // Load - template - void load(size_t Offset, - cl::sycl::multi_ptr Ptr) { - m_Data = *(Ptr + Offset * NumElements); - } - - // Store - template - void store(size_t Offset, - cl::sycl::multi_ptr Ptr) const { - *(Ptr + Offset * NumElements) = m_Data; - } -}; - -/*! @brief Partial specialization of the Packetize class dedicated to -sycl::complex types. It contains static methods for loading and storing size=1 -complex packets from/to memory. -* @tparam vector_size The desired vector size to be used. Only size = 1 is -supported so far. -* @tparam value_t The complex type of the matrix data. -*/ -template -struct Packetize, index_t> { - // Vectorization is not enabled for complex, always set to 1 - using value_t = complex_sycl; - using PacketType = vec_complex; - static constexpr int packet_size = 1; - template - static PORTBLAS_INLINE constexpr bool check_size() { - return true; - } - - /*! @brief Performs a non-vectorised load of sycl::complex data element while - * whether block is internal or not since vectorization is not enabled for - * complex types yet. - * @tparam trans Whether the source matrix is transposed or not. - * @tparam internal True if the current block is internal and no bounds - * checking is required. - * @tparam ld The leading dimension of the destination memory. */ - template - static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, - DestPointerType dest, - EdgePredicate edge_in_range) { - *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; - } - - /*! @brief Store a size = 1 vector packet of sycl::complex data into local - * memory (whether source is transposed or not since it's only 1 element). - * @tparam trans Whether the source matrix is transposed or not. - * @tparam ld The leading dimension of the destination memory.*/ - template - static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { - *dest = packet[0]; - } -}; -#endif - } // namespace blas #endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP diff --git a/src/operations/blas3/gemm_load_store_complex.hpp b/src/operations/blas3/gemm_load_store_complex.hpp new file mode 100644 index 000000000..7b1eb769b --- /dev/null +++ b/src/operations/blas3/gemm_load_store_complex.hpp @@ -0,0 +1,174 @@ +/*************************************************************************** + * @license + * Copyright (C) Codeplay Software Limited + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * For your convenience, a copy of the License has been included in this + * repository. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * portBLAS: BLAS implementation using SYCL + * + * @filename gemm_load_store_complex.hpp + * + **************************************************************************/ + +#ifndef PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP +#define PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP + +namespace blas { +#ifdef BLAS_ENABLE_COMPLEX +/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in + * Packetize. It serves as a temporary workaround to the upcoming + * sycl::vec container + * github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc + * and only supports size = 1. + * @tparam DataT Complex type of the vector's data + * @tparam NumElements Elements count of the vector (only 1 is supported) + */ +template +class vec_complex { + static_assert(NumElements == 1, + "Vector wrapper arround sycl::complex of size>1 unsupported."); + using address_t = cl::sycl::access::address_space; + using decorated_t = cl::sycl::access::decorated; + using DataType = DataT; + static constexpr int getNumElements() { return NumElements; } + size_t size() const noexcept { return NumElements; } + + private: + DataType m_Data; + + public: + vec_complex() = default; + + constexpr vec_complex(const vec_complex &rhs) = default; + constexpr vec_complex(vec_complex &&rhs) = default; + constexpr vec_complex &operator=(const vec_complex &rhs) = default; + + vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {} + + // Conversion operator (valid with NumElements==1) + operator DataT() const { return m_Data; } + + // Subscript operators + DataT &operator[](int i) { + assert(i < NumElements); + return (m_Data); + } + const DataT &operator[](int i) const { + assert(i < NumElements); + return (m_Data); + } + + // Binary Ops + // Multiply + vec_complex operator*(const vec_complex &rhs) { + return (vec_complex{m_Data * static_cast(rhs)}); + } + + vec_complex operator*(const DataType &rhs) { + return (vec_complex{m_Data * rhs}); + } + + // Compound Multiply + vec_complex &operator*=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator*=(const vec_complex &rhs) { + this->m_Data = this->m_Data * static_cast(rhs); + return (*this); + } + + // Add + vec_complex operator+(const vec_complex &rhs) { + return (vec_complex{m_Data + static_cast(rhs)}); + } + + vec_complex operator+(const DataType &rhs) { + return (vec_complex{m_Data + rhs}); + } + + // Compound Add + vec_complex &operator+=(const DataType &rhs) { + this->m_Data = this->m_Data * rhs; + return (*this); + } + + vec_complex &operator+=(const vec_complex &rhs) { + this->m_Data = this->m_Data + static_cast(rhs); + return (*this); + } + + // Load + template + void load(size_t Offset, + cl::sycl::multi_ptr Ptr) { + m_Data = *(Ptr + Offset * NumElements); + } + + // Store + template + void store(size_t Offset, + cl::sycl::multi_ptr Ptr) const { + *(Ptr + Offset * NumElements) = m_Data; + } +}; + +/*! @brief Partial specialization of the Packetize class dedicated to +sycl::complex types. It contains static methods for loading and storing size=1 +complex packets from/to memory. +* @tparam vector_size The desired vector size to be used. Only size = 1 is +supported so far. +* @tparam value_t The complex type of the matrix data. +*/ +template +struct Packetize, index_t> { + // Vectorization is not enabled for complex, always set to 1 + using value_t = complex_sycl; + using PacketType = vec_complex; + static constexpr int packet_size = 1; + template + static PORTBLAS_INLINE constexpr bool check_size() { + return true; + } + + /*! @brief Performs a non-vectorised load of sycl::complex data element while + * whether block is internal or not since vectorization is not enabled for + * complex types yet. + * @tparam trans Whether the source matrix is transposed or not. + * @tparam internal True if the current block is internal and no bounds + * checking is required. + * @tparam ld The leading dimension of the destination memory. */ + template + static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src, + DestPointerType dest, + EdgePredicate edge_in_range) { + *(dest) = in_range ? *(src) : value_t{(T)0, (T)0}; + } + + /*! @brief Store a size = 1 vector packet of sycl::complex data into local + * memory (whether source is transposed or not since it's only 1 element). + * @tparam trans Whether the source matrix is transposed or not. + * @tparam ld The leading dimension of the destination memory.*/ + template + static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) { + *dest = packet[0]; + } +}; +#endif +} // namespace blas + +#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP diff --git a/src/operations/blas3/gemm_local.hpp b/src/operations/blas3/gemm_local.hpp index 870349c48..0ca182918 100644 --- a/src/operations/blas3/gemm_local.hpp +++ b/src/operations/blas3/gemm_local.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index df1ce6bd7..77cbafbbf 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/src/operations/blas3/gemm_no_local_partial_vec.hpp b/src/operations/blas3/gemm_no_local_partial_vec.hpp index 02a42e938..ba26ef67f 100644 --- a/src/operations/blas3/gemm_no_local_partial_vec.hpp +++ b/src/operations/blas3/gemm_no_local_partial_vec.hpp @@ -27,6 +27,9 @@ #include "gemm_common.hpp" #include "gemm_load_store.hpp" +#ifdef BLAS_ENABLE_COMPLEX +#include "gemm_load_store_complex.hpp" +#endif namespace blas { diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index 3aacf4244..b9bd04e04 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -419,6 +419,11 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul, ldb_mul, ldc_mul, batch_type) = arguments; + if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { + // Interleaved batched gemm unsupported with complex data types + GTEST_SKIP(); + } + const char ta_str[2] = {transa, '\0'}; const char tb_str[2] = {transb, '\0'}; @@ -456,11 +461,6 @@ inline void verify_gemm(const gemm_cplx_arguments_t arguments) { reinterpret_cast(c_m_cpu.data() + i * size_c + offset), ldc); } - if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) { - // Interleaved batched gemm unsupported - GTEST_SKIP(); - } - auto m_a_gpu = blas::helper::allocate>( buffer_size_a, q); auto m_b_gpu = blas::helper::allocate>(