diff --git a/src/operations/blas3/gemm_interleaved.hpp b/src/operations/blas3/gemm_interleaved.hpp index 551bb465a..66629033e 100644 --- a/src/operations/blas3/gemm_interleaved.hpp +++ b/src/operations/blas3/gemm_interleaved.hpp @@ -146,6 +146,11 @@ class Gemm::value, + "Interleaved GEMM is not supported for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -159,10 +164,9 @@ class Gemm PORTBLAS_INLINE void compute_panel(check_t boundary_check, index_t m_stride, - index_t n_stride, index_t mb_start, - index_t m_start, index_t n_start, - in_ptr_t A, in_ptr_t B, out_ptr_t C) { + index_t n_stride, index_t mb_start, + index_t m_start, index_t n_start, + in_ptr_t A, in_ptr_t B, out_ptr_t C) { packet_type reg_a[item_rows * item_batchs / VectorSize]; packet_type reg_b[item_cols * item_batchs / VectorSize]; packet_type reg_res[item_rows * item_cols * item_batchs / VectorSize]; @@ -482,7 +486,7 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + //! @brief leading dimension of block of A in local static constexpr index_t ldsa = block_rows + nbc_a; //! @brief leading dimension of block of B in local @@ -166,8 +172,8 @@ class Gemm PORTBLAS_INLINE void eval(local_memory_t scratch_acc, - const cl::sycl::nd_item<1> &id) noexcept { + const cl::sycl::nd_item<1> &id) noexcept { index_t m = a_.get_size_row(); index_t n = b_.get_size_col(); const index_t k = a_.get_size_col(); @@ -552,9 +558,9 @@ class Gemm PORTBLAS_INLINE void store_output_block(index_t, index_t mc, index_t nc, - OutputPointerType C, index_t ldc, - element_t *reg_res, - const bool out_of_range) noexcept { + OutputPointerType C, index_t ldc, + element_t *reg_res, + const bool out_of_range) noexcept { if (out_of_range) { return; } @@ -734,9 +740,9 @@ class Gemm PORTBLAS_INLINE void compute_block_gemm(index_t, InputPointerType B, - InputPointerType A, element_t *reg_a, - element_t ®_b, - element_t *reg_res) noexcept { + InputPointerType A, element_t *reg_a, + element_t ®_b, + element_t *reg_res) noexcept { // NOTE: Adding "#pragma unroll" here reduces performance on AMD R9 // Nano. // Seems that the small reduction of arithmetic operations does @@ -786,7 +792,7 @@ class Gemm static PORTBLAS_INLINE typename std::enable_if::type sync_smem( const cl::sycl::nd_item<1> &id, index_t &ofs_sign, P &s, - Ps &... ss) noexcept { + Ps &...ss) noexcept { s += ofs_sign * o; sync_smem(id, ofs_sign, ss...); } diff --git a/src/operations/blas3/gemm_no_local_full_vec.hpp b/src/operations/blas3/gemm_no_local_full_vec.hpp index 7e686ca04..df1ce6bd7 100644 --- a/src/operations/blas3/gemm_no_local_full_vec.hpp +++ b/src/operations/blas3/gemm_no_local_full_vec.hpp @@ -104,6 +104,12 @@ class Gemm(), "If vectorization is enabled item_cols must equal the packet_size"); +#ifdef BLAS_ENABLE_COMPLEX + static_assert((VectorSize == 1 && is_complex_sycl::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -115,8 +121,8 @@ class Gemm::value) || + is_sycl_scalar::value, + "Vector size should be equal to 1 for Complex Data types"); +#endif + input_t a_; input_t b_; output_t c_; @@ -111,8 +117,8 @@ class Gemm PORTBLAS_INLINE void load(PointerType ptr, element_t *reg, const index_t &ld, - index_t index, const check_boundary &chk_boundary, - const bool out_of_range) noexcept { + index_t index, const check_boundary &chk_boundary, + const bool out_of_range) noexcept { if (out_of_range) { return; }