Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
removed extra attributes from transpose kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Dec 21, 2023
1 parent 947f66d commit c6f2d0f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 26 deletions.
16 changes: 0 additions & 16 deletions include/operations/extension/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,11 +80,6 @@ class Transpose {
static constexpr const index_t inner_tile_size_ = wg_size / Tile_size;
static constexpr const index_t inner_tile_count_ =
Tile_size / inner_tile_size_;
// Minimum number of Tile-mutliple rows & columns to cover the matrices
index_t M_pad_;
index_t N_pad_;
// Total size of Tile-mutliple covering matrix
index_t size_pad_;
// Batch size when using batched transpose
index_t batch_size_;
// Number of contiguous elements to be used in local memory to avoid bank
Expand Down Expand Up @@ -115,9 +110,6 @@ class Transpose {
stride_a_(stride_a),
stride_at_(stride_at),
inc_at_(inc_at),
M_pad_(tile_count_m_ * Tile_size),
N_pad_(tile_count_n_ * Tile_size),
size_pad_(M_pad_ * N_pad_),
batch_size_(batch_size) {}

index_t get_size() const;
Expand Down Expand Up @@ -209,11 +201,6 @@ class TransposeAdd {
static constexpr const index_t inner_tile_size_ = wg_size / Tile_size;
static constexpr const index_t inner_tile_count_ =
Tile_size / inner_tile_size_;
// Minimum number of Tile-mutliple rows & columns to cover the output matrix
index_t M_pad_;
index_t N_pad_;
// Total size of Tile-mutliple covering matrix
index_t size_pad_;
// Batch size when using batched transpose
index_t batch_size_;
// Number of contiguous elements to be used in local memory to avoid bank
Expand Down Expand Up @@ -246,9 +233,6 @@ class TransposeAdd {
tile_count_m_((M_ - 1) / Tile_size + 1),
tile_count_n_((N_ - 1) / Tile_size + 1),
tile_count_total_(tile_count_m_ * tile_count_n_),
M_pad_(tile_count_m_ * Tile_size),
N_pad_(tile_count_n_ * Tile_size),
size_pad_(M_pad_ * N_pad_),
batch_size_(batch_size) {}

index_t get_size() const;
Expand Down
22 changes: 12 additions & 10 deletions src/operations/extension/transpose.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ Transpose<in_place, Tile_size, wg_size, cl_size, local_memory, in_t, out_t,
element_t>::get_size() const {
// Smallest TileSize square-multiple containing input/output matrices times
// batch_size
return (size_pad_ * batch_size_);
return (tile_count_total_ * Tile_size * Tile_size * batch_size_);
}

template <bool in_place, int Tile_size, int wg_size, int cl_size,
Expand Down Expand Up @@ -254,7 +254,7 @@ PORTBLAS_INLINE typename in1_t::index_t
TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,
in2_t, out_t, element_t>::get_size() const {
// Smallest TileSize square-multiple containing input/output matrices
return (size_pad_ * batch_size_);
return (tile_count_total_ * Tile_size * Tile_size * batch_size_);
}

template <bool both_trans, int Tile_size, int wg_size, int cl_size,
Expand All @@ -276,10 +276,10 @@ TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,
* @param in_a_idx [output] the input A global-memory index
* @param in_b_idx [output] the input B global-memory index
* @param out_idx [output] the output C global-memory index
* @param i [output] the global row-index (A & B when both_trans -> [0,N_], B &
*C otherwise -> [0,M_])
* @param j [output] the global col-index (A & B when both_trans -> [0,M_], B &
*C otherwise -> [0,N_])
* @param i [output] the global row-index (A & B when both_trans -> [0,N_], B
*& C otherwise -> [0,M_])
* @param j [output] the global col-index (A & B when both_trans -> [0,M_], B
*& C otherwise -> [0,N_])
*/
template <bool both_trans, int Tile_size, int wg_size, int cl_size,
bool local_memory, typename in1_t, typename in2_t, typename out_t,
Expand Down Expand Up @@ -461,7 +461,8 @@ TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,
// Compute & Copy sum/scaled input to local memory (before transpose)
for (index_t l = 0; l < inner_tile_count_; l++) {
if (j_block_start + jl + l * inner_tile_size_ < M_) {
// Compute & Copy sum/scaled input to local memory (before transpose)
// Compute & Copy sum/scaled input to local memory (before
// transpose)
local[in_local_id +
l * (get_non_bank_conflict_line_size() + 1) *
(inner_tile_size_ / get_num_tiles_per_line())] =
Expand Down Expand Up @@ -490,7 +491,8 @@ TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,
if (j_block_start + il < N_) {
for (index_t l = 0; l < inner_tile_count_; l++) {
if (i_block_start + jl + l * inner_tile_size_ < M_) {
// Compute & Copy sum/scaled input to local memory (before transpose)
// Compute & Copy sum/scaled input to local memory (before
// transpose)
local[in_local_id +
l * (get_non_bank_conflict_line_size() + 1) *
(inner_tile_size_ / get_num_tiles_per_line())] =
Expand All @@ -501,8 +503,8 @@ TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,

id.barrier(cl::sycl::access::fence_space::local_space);

// Transposed copy of previous output from local memory and scaled addition
// with 2nd non transposed matrix B
// Transposed copy of previous output from local memory and scaled
// addition with 2nd non transposed matrix B
if (i_block_start + il < M_) {
for (index_t l = 0; l < inner_tile_count_; l++) {
if (j_block_start + jl + l * inner_tile_size_ < N_) {
Expand Down

0 comments on commit c6f2d0f

Please sign in to comment.