Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Separated complex gemm load store & addressed PR comments
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Oct 23, 2023
1 parent d86cfc8 commit 2dc363d
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 150 deletions.
2 changes: 1 addition & 1 deletion doc/Gemm.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ The core of the `GEMM` computation is as follows:

## Vectorized Loading/Storing

Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/` .
Many of the `GEMM` kernels support vectorized loads/stores using functions located in `gemm_load_store.hpp` in `src/operations/blas3/`*(this feature is limited to non-complex data types)*.
These functions are pretty simple but there are some special considerations for how they are used, particularly around whether the matrices are transposed or not.
If a matrix is transposed this changes the data layout such that elements are no longer contiguous in memory.

Expand Down
144 changes: 0 additions & 144 deletions src/operations/blas3/gemm_load_store.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,149 +125,5 @@ struct Packetize {
}
};

#ifdef BLAS_ENABLE_COMPLEX
/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in
* Packetize. It serves as a temporary workaround to the upcoming
* sycl::vec<syc::complex> container
* github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc
* and only supports size = 1.
* @tparam DataT Complex type of the vector's data
* @tparam NumElements Elements count of the vector (only 1 is supported)
*/
template <typename DataT, int NumElements = 1>
class vec_complex {
static_assert(NumElements == 1,
"Vector wrapper arround sycl::complex of size>1 unsupported.");
using address_t = cl::sycl::access::address_space;
using decorated_t = cl::sycl::access::decorated;
using DataType = DataT;
static constexpr int getNumElements() { return NumElements; }
size_t size() const noexcept { return NumElements; }

private:
DataType m_Data;

public:
vec_complex() = default;

constexpr vec_complex(const vec_complex &rhs) = default;
constexpr vec_complex(vec_complex &&rhs) = default;
constexpr vec_complex &operator=(const vec_complex &rhs) = default;

vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {}

// Conversion operator (valid with NumElements==1)
operator DataT() const { return m_Data; }

// Subscript operators
DataT &operator[](int i) {
assert(i < NumElements);
return (m_Data);
}
const DataT &operator[](int i) const {
assert(i < NumElements);
return (m_Data);
}

// Binary Ops
// Multiply
vec_complex operator*(const vec_complex &rhs) {
return (vec_complex{m_Data * static_cast<DataT>(rhs)});
}

vec_complex operator*(const DataType &rhs) {
return (vec_complex{m_Data * rhs});
}

// Compound Multiply
vec_complex &operator*=(const DataType &rhs) {
this->m_Data = this->m_Data * rhs;
return (*this);
}

vec_complex &operator*=(const vec_complex &rhs) {
this->m_Data = this->m_Data * static_cast<DataT>(rhs);
return (*this);
}

// Add
vec_complex operator+(const vec_complex &rhs) {
return (vec_complex{m_Data + static_cast<DataT>(rhs)});
}

vec_complex operator+(const DataType &rhs) {
return (vec_complex{m_Data + rhs});
}

// Compound Add
vec_complex &operator+=(const DataType &rhs) {
this->m_Data = this->m_Data * rhs;
return (*this);
}

vec_complex &operator+=(const vec_complex &rhs) {
this->m_Data = this->m_Data + static_cast<DataT>(rhs);
return (*this);
}

// Load
template <address_t Space, decorated_t DecorateAddress>
void load(size_t Offset,
cl::sycl::multi_ptr<const DataT, Space, DecorateAddress> Ptr) {
m_Data = *(Ptr + Offset * NumElements);
}

// Store
template <address_t Space, decorated_t DecorateAddress>
void store(size_t Offset,
cl::sycl::multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
*(Ptr + Offset * NumElements) = m_Data;
}
};

/*! @brief Partial specialization of the Packetize class dedicated to
sycl::complex types. It contains static methods for loading and storing size=1
complex packets from/to memory.
* @tparam vector_size The desired vector size to be used. Only size = 1 is
supported so far.
* @tparam value_t The complex type of the matrix data.
*/
template <int vector_size, typename T, typename index_t>
struct Packetize<vector_size, complex_sycl<T>, index_t> {
// Vectorization is not enabled for complex, always set to 1
using value_t = complex_sycl<T>;
using PacketType = vec_complex<value_t, 1>;
static constexpr int packet_size = 1;
template <index_t dimension>
static PORTBLAS_INLINE constexpr bool check_size() {
return true;
}

/*! @brief Performs a non-vectorised load of sycl::complex data element while
* whether block is internal or not since vectorization is not enabled for
* complex types yet.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory. */
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src,
DestPointerType dest,
EdgePredicate edge_in_range) {
*(dest) = in_range ? *(src) : value_t{(T)0, (T)0};
}

/*! @brief Store a size = 1 vector packet of sycl::complex data into local
* memory (whether source is transposed or not since it's only 1 element).
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, index_t ld, typename DestPointerType>
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
*dest = packet[0];
}
};
#endif

} // namespace blas
#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_HPP
174 changes: 174 additions & 0 deletions src/operations/blas3/gemm_load_store_complex.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
/***************************************************************************
* @license
* Copyright (C) Codeplay Software Limited
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* For your convenience, a copy of the License has been included in this
* repository.
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* portBLAS: BLAS implementation using SYCL
*
* @filename gemm_load_store_complex.hpp
*
**************************************************************************/

#ifndef PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP
#define PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP

namespace blas {
#ifdef BLAS_ENABLE_COMPLEX
/*! @brief vec_complex is an intermediate wrapper of sycl::complex used in
* Packetize. It serves as a temporary workaround to the upcoming
* sycl::vec<syc::complex> container
* github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_complex.asciidoc
* and only supports size = 1.
* @tparam DataT Complex type of the vector's data
* @tparam NumElements Elements count of the vector (only 1 is supported)
*/
template <typename DataT, int NumElements = 1>
class vec_complex {
static_assert(NumElements == 1,
"Vector wrapper arround sycl::complex of size>1 unsupported.");
using address_t = cl::sycl::access::address_space;
using decorated_t = cl::sycl::access::decorated;
using DataType = DataT;
static constexpr int getNumElements() { return NumElements; }
size_t size() const noexcept { return NumElements; }

private:
DataType m_Data;

public:
vec_complex() = default;

constexpr vec_complex(const vec_complex &rhs) = default;
constexpr vec_complex(vec_complex &&rhs) = default;
constexpr vec_complex &operator=(const vec_complex &rhs) = default;

vec_complex(const DataType &rhs_data) : m_Data{rhs_data} {}

// Conversion operator (valid with NumElements==1)
operator DataT() const { return m_Data; }

// Subscript operators
DataT &operator[](int i) {
assert(i < NumElements);
return (m_Data);
}
const DataT &operator[](int i) const {
assert(i < NumElements);
return (m_Data);
}

// Binary Ops
// Multiply
vec_complex operator*(const vec_complex &rhs) {
return (vec_complex{m_Data * static_cast<DataT>(rhs)});
}

vec_complex operator*(const DataType &rhs) {
return (vec_complex{m_Data * rhs});
}

// Compound Multiply
vec_complex &operator*=(const DataType &rhs) {
this->m_Data = this->m_Data * rhs;
return (*this);
}

vec_complex &operator*=(const vec_complex &rhs) {
this->m_Data = this->m_Data * static_cast<DataT>(rhs);
return (*this);
}

// Add
vec_complex operator+(const vec_complex &rhs) {
return (vec_complex{m_Data + static_cast<DataT>(rhs)});
}

vec_complex operator+(const DataType &rhs) {
return (vec_complex{m_Data + rhs});
}

// Compound Add
vec_complex &operator+=(const DataType &rhs) {
this->m_Data = this->m_Data * rhs;
return (*this);
}

vec_complex &operator+=(const vec_complex &rhs) {
this->m_Data = this->m_Data + static_cast<DataT>(rhs);
return (*this);
}

// Load
template <address_t Space, decorated_t DecorateAddress>
void load(size_t Offset,
cl::sycl::multi_ptr<const DataT, Space, DecorateAddress> Ptr) {
m_Data = *(Ptr + Offset * NumElements);
}

// Store
template <address_t Space, decorated_t DecorateAddress>
void store(size_t Offset,
cl::sycl::multi_ptr<DataT, Space, DecorateAddress> Ptr) const {
*(Ptr + Offset * NumElements) = m_Data;
}
};

/*! @brief Partial specialization of the Packetize class dedicated to
sycl::complex types. It contains static methods for loading and storing size=1
complex packets from/to memory.
* @tparam vector_size The desired vector size to be used. Only size = 1 is
supported so far.
* @tparam value_t The complex type of the matrix data.
*/
template <int vector_size, typename T, typename index_t>
struct Packetize<vector_size, complex_sycl<T>, index_t> {
// Vectorization is not enabled for complex, always set to 1
using value_t = complex_sycl<T>;
using PacketType = vec_complex<value_t, 1>;
static constexpr int packet_size = 1;
template <index_t dimension>
static PORTBLAS_INLINE constexpr bool check_size() {
return true;
}

/*! @brief Performs a non-vectorised load of sycl::complex data element while
* whether block is internal or not since vectorization is not enabled for
* complex types yet.
* @tparam trans Whether the source matrix is transposed or not.
* @tparam internal True if the current block is internal and no bounds
* checking is required.
* @tparam ld The leading dimension of the destination memory. */
template <bool trans, bool internal, index_t ld, typename SrcPointerType,
typename DestPointerType, typename EdgePredicate>
static PORTBLAS_INLINE void load(const bool in_range, SrcPointerType src,
DestPointerType dest,
EdgePredicate edge_in_range) {
*(dest) = in_range ? *(src) : value_t{(T)0, (T)0};
}

/*! @brief Store a size = 1 vector packet of sycl::complex data into local
* memory (whether source is transposed or not since it's only 1 element).
* @tparam trans Whether the source matrix is transposed or not.
* @tparam ld The leading dimension of the destination memory.*/
template <bool trans, index_t ld, typename DestPointerType>
static PORTBLAS_INLINE void store(PacketType &packet, DestPointerType dest) {
*dest = packet[0];
}
};
#endif
} // namespace blas

#endif // PORTBLAS_BLAS3_GEMM_LOAD_STORE_CPLX_HPP
3 changes: 3 additions & 0 deletions src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

#include "gemm_common.hpp"
#include "gemm_load_store.hpp"
#ifdef BLAS_ENABLE_COMPLEX
#include "gemm_load_store_complex.hpp"
#endif

namespace blas {

Expand Down
3 changes: 3 additions & 0 deletions src/operations/blas3/gemm_no_local_full_vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

#include "gemm_common.hpp"
#include "gemm_load_store.hpp"
#ifdef BLAS_ENABLE_COMPLEX
#include "gemm_load_store_complex.hpp"
#endif

namespace blas {

Expand Down
3 changes: 3 additions & 0 deletions src/operations/blas3/gemm_no_local_partial_vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

#include "gemm_common.hpp"
#include "gemm_load_store.hpp"
#ifdef BLAS_ENABLE_COMPLEX
#include "gemm_load_store_complex.hpp"
#endif

namespace blas {

Expand Down
10 changes: 5 additions & 5 deletions test/unittest/blas3/blas3_gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,11 @@ inline void verify_gemm(const gemm_cplx_arguments_t<scalar_t> arguments) {
std::tie(alloc, offset, batch, m, n, k, transa, transb, alpha, beta, lda_mul,
ldb_mul, ldc_mul, batch_type) = arguments;

if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) {
// Interleaved batched gemm unsupported with complex data types
GTEST_SKIP();
}

const char ta_str[2] = {transa, '\0'};
const char tb_str[2] = {transb, '\0'};

Expand Down Expand Up @@ -456,11 +461,6 @@ inline void verify_gemm(const gemm_cplx_arguments_t<scalar_t> arguments) {
reinterpret_cast<void*>(c_m_cpu.data() + i * size_c + offset), ldc);
}

if (batch > 1 && batch_type == gemm_batch_type_t::interleaved) {
// Interleaved batched gemm unsupported
GTEST_SKIP();
}

auto m_a_gpu = blas::helper::allocate<mem_alloc, complex_sycl<scalar_t>>(
buffer_size_a, q);
auto m_b_gpu = blas::helper::allocate<mem_alloc, complex_sycl<scalar_t>>(
Expand Down

0 comments on commit 2dc363d

Please sign in to comment.