Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Added static asserts on vector size when using cplx data
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Sep 20, 2023
1 parent 4721c4f commit 369c21a
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 24 deletions.
20 changes: 12 additions & 8 deletions src/operations/blas3/gemm_interleaved.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
static_assert(item_batchs % VectorSize == 0,
"Item batch must be divisible by vector size");

#ifdef BLAS_ENABLE_COMPLEX
static_assert(!is_complex_sycl<element_t>::value,
"Interleaved GEMM is not supported for Complex Data types");
#endif

input_t a_;
input_t b_;
output_t c_;
Expand All @@ -159,10 +164,9 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
const index_t ldc_;
const index_t batch_size_;
PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha,
element_t beta, index_t batch_size,
index_t /*unused stride_a*/,
index_t /*unused stride_b*/,
index_t /*unused stride_c*/)
element_t beta, index_t batch_size,
index_t /*unused stride_a*/, index_t /*unused stride_b*/,
index_t /*unused stride_c*/)
: a_(A),
b_(B),
c_(C),
Expand Down Expand Up @@ -280,9 +284,9 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
template <bool need_check_boundary, typename check_t, typename in_ptr_t,
typename out_ptr_t>
PORTBLAS_INLINE void compute_panel(check_t boundary_check, index_t m_stride,
index_t n_stride, index_t mb_start,
index_t m_start, index_t n_start,
in_ptr_t A, in_ptr_t B, out_ptr_t C) {
index_t n_stride, index_t mb_start,
index_t m_start, index_t n_start,
in_ptr_t A, in_ptr_t B, out_ptr_t C) {
packet_type reg_a[item_rows * item_batchs / VectorSize];
packet_type reg_b[item_cols * item_batchs / VectorSize];
packet_type reg_res[item_rows * item_cols * item_batchs / VectorSize];
Expand Down Expand Up @@ -482,7 +486,7 @@ class Gemm<input_t, output_t, /* DoubleBuffer = */ false, /* NbcA = */ false,
* @param reg_res 2D register array used to store the result C
*/
PORTBLAS_INLINE void compute_block(packet_type *reg_a, packet_type *reg_b,
packet_type *reg_res) noexcept {
packet_type *reg_res) noexcept {
#pragma unroll
for (int i = 0; i < item_cols; ++i) {
#pragma unroll
Expand Down
26 changes: 16 additions & 10 deletions src/operations/blas3/gemm_local.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,12 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
static_assert(cl_elems % packetize_t::packet_size == 0,
"Cache line size must be a multiple of packet_size");

#ifdef BLAS_ENABLE_COMPLEX
static_assert((VectorSize == 1 && is_complex_sycl<element_t>::value) ||
is_sycl_scalar<element_t>::value,
"Vector size should be equal to 1 for Complex Data types");
#endif

//! @brief leading dimension of block of A in local
static constexpr index_t ldsa = block_rows + nbc_a;
//! @brief leading dimension of block of B in local
Expand All @@ -166,8 +172,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
index_t stridec_;

PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha,
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
: a_(A),
b_(B),
c_(C),
Expand Down Expand Up @@ -252,7 +258,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
*/
template <typename local_memory_t>
PORTBLAS_INLINE void eval(local_memory_t scratch_acc,
const cl::sycl::nd_item<1> &id) noexcept {
const cl::sycl::nd_item<1> &id) noexcept {
index_t m = a_.get_size_row();
index_t n = b_.get_size_col();
const index_t k = a_.get_size_col();
Expand Down Expand Up @@ -552,9 +558,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,

template <bool check_m_limit, bool check_n_limit, typename OutputPointerType>
PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc,
OutputPointerType C, index_t ldc,
element_t *reg_res,
const bool out_of_range) noexcept {
OutputPointerType C, index_t ldc,
element_t *reg_res,
const bool out_of_range) noexcept {
if (out_of_range) {
return;
}
Expand Down Expand Up @@ -734,9 +740,9 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
*/
template <bool check_m_limit, bool check_n_limit, typename InputPointerType>
PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B,
InputPointerType A, element_t *reg_a,
element_t &reg_b,
element_t *reg_res) noexcept {
InputPointerType A, element_t *reg_a,
element_t &reg_b,
element_t *reg_res) noexcept {
// NOTE: Adding "#pragma unroll" here reduces performance on AMD R9
// Nano.
// Seems that the small reduction of arithmetic operations does
Expand Down Expand Up @@ -786,7 +792,7 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, TileType,
template <bool db, index_t o, index_t... os, typename P, typename... Ps>
static PORTBLAS_INLINE typename std::enable_if<db>::type sync_smem(
const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s,
Ps &... ss) noexcept {
Ps &...ss) noexcept {
s += ofs_sign * o;
sync_smem<db, os...>(id, ofs_sign, ss...);
}
Expand Down
10 changes: 8 additions & 2 deletions src/operations/blas3/gemm_no_local_full_vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,12 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
packetize_t::template check_size<item_cols>(),
"If vectorization is enabled item_cols must equal the packet_size");

#ifdef BLAS_ENABLE_COMPLEX
static_assert((VectorSize == 1 && is_complex_sycl<element_t>::value) ||
is_sycl_scalar<element_t>::value,
"Vector size should be equal to 1 for Complex Data types");
#endif

input_t a_;
input_t b_;
output_t c_;
Expand All @@ -115,8 +121,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
index_t stridec_;

PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha,
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
: a_(A),
b_(B),
c_(C),
Expand Down
14 changes: 10 additions & 4 deletions src/operations/blas3/gemm_no_local_partial_vec.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,12 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
static_assert(item_cols % packetize_t::packet_size == 0,
"Item cols must be a multiple of the vector packet size");

#ifdef BLAS_ENABLE_COMPLEX
static_assert((VectorSize == 1 && is_complex_sycl<element_t>::value) ||
is_sycl_scalar<element_t>::value,
"Vector size should be equal to 1 for Complex Data types");
#endif

input_t a_;
input_t b_;
output_t c_;
Expand All @@ -111,8 +117,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
index_t stridec_;

PORTBLAS_INLINE Gemm(input_t A, input_t B, output_t C, element_t alpha,
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
element_t beta, index_t batch_size, index_t stride_a,
index_t stride_b, index_t stride_c)
: a_(A),
b_(B),
c_(C),
Expand Down Expand Up @@ -446,8 +452,8 @@ class Gemm<input_t, output_t, DoubleBuffer, NbcA, NbcB, ClSize, tile_type,
index_t work_per_load, typename PointerType,
typename check_boundary>
PORTBLAS_INLINE void load(PointerType ptr, element_t *reg, const index_t &ld,
index_t index, const check_boundary &chk_boundary,
const bool out_of_range) noexcept {
index_t index, const check_boundary &chk_boundary,
const bool out_of_range) noexcept {
if (out_of_range) {
return;
}
Expand Down

0 comments on commit 369c21a

Please sign in to comment.