Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
removed uncessary wg_size template param from Transpose kernels
Browse files Browse the repository at this point in the history
  • Loading branch information
OuadiElfarouki committed Dec 27, 2023
1 parent b24e4ef commit 19eae59
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 111 deletions.
47 changes: 17 additions & 30 deletions include/operations/extension/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,15 @@ namespace blas {
* @tparam in_place Whether the transpose is in or out of place
* @tparam Tile_size Tiling size used explicitly in the local memory kernel, and
* used to compute work-group size in the non-local memory case.
* @tparam wg_size work group size
* @tparam cl_size cache line size
* @tparam local_memory Whether to use local memory
* @tparam in_t The input matrix type
* @tparam out_t The output matrix type
* @tparam element_t The scaling factor type
*
*/
template <bool in_place, int Tile_size, int wg_size, int cl_size,
bool local_memory, typename in_t, typename out_t, typename element_t>
template <bool in_place, int Tile_size, int cl_size, bool local_memory,
typename in_t, typename out_t, typename element_t>
class Transpose {
public:
using index_t = typename in_t::index_t;
Expand All @@ -76,10 +75,6 @@ class Transpose {
index_t tile_count_n_;
// Total number of tiles used to cover the matrix
index_t tile_count_total_;
// Number of Inner WG Tiles
static constexpr const index_t inner_tile_size_ = wg_size / Tile_size;
static constexpr const index_t inner_tile_count_ =
Tile_size / inner_tile_size_;
// Batch size when using batched transpose
index_t batch_size_;
// Number of contiguous elements to be used in local memory to avoid bank
Expand Down Expand Up @@ -131,17 +126,15 @@ class Transpose {
/*!
@brief Generator/factory for Transpose trees.
*/
template <bool in_place, int Tile_size, int wg_size, int cl_size,
bool local_memory, typename in_t, typename out_t, typename element_t,
typename index_t>
Transpose<in_place, Tile_size, wg_size, cl_size, local_memory, in_t, out_t,
element_t>
template <bool in_place, int Tile_size, int cl_size, bool local_memory,
typename in_t, typename out_t, typename element_t, typename index_t>
Transpose<in_place, Tile_size, cl_size, local_memory, in_t, out_t, element_t>
make_transpose(in_t &A, index_t inc_a, index_t &stride_a, out_t &At,
index_t inc_a_t, index_t &stride_at, element_t &alpha,
index_t &batch_size) {
return Transpose<in_place, Tile_size, wg_size, cl_size, local_memory, in_t,
out_t, element_t>(A, inc_a, stride_a, At, inc_a_t, stride_at,
alpha, batch_size);
return Transpose<in_place, Tile_size, cl_size, local_memory, in_t, out_t,
element_t>(A, inc_a, stride_a, At, inc_a_t, stride_at, alpha,
batch_size);
}

/*!
Expand All @@ -160,7 +153,6 @@ make_transpose(in_t &A, index_t inc_a, index_t &stride_a, out_t &At,
* by the template parameter both_trans.
* @tparam Tile_size Tiling size used explicitly in the local memory kernel, and
* used to compute work-group size in the non-local memory case.
* @tparam wg_size work group size
* @tparam cl_size cache line size
* @tparam local_memory Whether to use local memory
* @tparam in1_t The input matrix A type
Expand All @@ -169,9 +161,8 @@ make_transpose(in_t &A, index_t inc_a, index_t &stride_a, out_t &At,
* @tparam element_t The scaling factor type
*
*/
template <bool both_trans, int Tile_size, int wg_size, int cl_size,
bool local_memory, typename in1_t, typename in2_t, typename out_t,
typename element_t>
template <bool both_trans, int Tile_size, int cl_size, bool local_memory,
typename in1_t, typename in2_t, typename out_t, typename element_t>
class TransposeAdd {
public:
using index_t = typename in1_t::index_t;
Expand All @@ -197,10 +188,6 @@ class TransposeAdd {
index_t tile_count_n_;
// Total number of tiles used to cover the matrix
index_t tile_count_total_;
// Inner WG Tiles
static constexpr const index_t inner_tile_size_ = wg_size / Tile_size;
static constexpr const index_t inner_tile_count_ =
Tile_size / inner_tile_size_;
// Batch size when using batched transpose
index_t batch_size_;
// Number of contiguous elements to be used in local memory to avoid bank
Expand Down Expand Up @@ -254,16 +241,16 @@ class TransposeAdd {
/*!
* @brief Generator/factory for Transpose-Add trees.
*/
template <bool both_trans, int Tile_size, int wg_size, int cl_size,
bool local_memory, typename in1_t, typename in2_t, typename out_t,
typename element_t, typename index_t>
TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory, in1_t,
in2_t, out_t, element_t>
template <bool both_trans, int Tile_size, int cl_size, bool local_memory,
typename in1_t, typename in2_t, typename out_t, typename element_t,
typename index_t>
TransposeAdd<both_trans, Tile_size, cl_size, local_memory, in1_t, in2_t, out_t,
element_t>
make_transpose_add(in1_t &A, index_t stride_a, in2_t &B, index_t stride_b,
out_t &C, index_t stride_c, element_t &alpha,
element_t &beta, index_t batch_size) {
return TransposeAdd<both_trans, Tile_size, wg_size, cl_size, local_memory,
in1_t, in2_t, out_t, element_t>(
return TransposeAdd<both_trans, Tile_size, cl_size, local_memory, in1_t,
in2_t, out_t, element_t>(
A, stride_a, B, stride_b, C, stride_c, alpha, beta, batch_size);
}

Expand Down
4 changes: 2 additions & 2 deletions src/interface/extension_interface.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ typename sb_handle_t::event_t _transpose_outplace_impl(

// Transpose expression Tree
auto trans_scale_tree =
make_transpose<false, Tile_size, wg_size, cl_size, local_memory>(
make_transpose<false, Tile_size, cl_size, local_memory>(
in_view, _inc_in, _stride_in, out_view, _inc_out, _stride_out, _alpha,
_batch_size);

Expand Down Expand Up @@ -234,7 +234,7 @@ typename sb_handle_t::event_t _transpose_add_impl(

// Transpose Add expression Tree
auto trans_scale_tree =
make_transpose_add<both_trans, Tile_size, wg_size, cl_size, local_memory>(
make_transpose_add<both_trans, Tile_size, cl_size, local_memory>(
A_view, _stride_a, B_view, _stride_b, C_view, _stride_c, _alpha,
_beta, _batch_size);

Expand Down
Loading

0 comments on commit 19eae59

Please sign in to comment.