Skip to content
This repository has been archived by the owner on Jan 13, 2025. It is now read-only.

Commit

Permalink
Fix unit tests for GEMM batch interleaved (#490)
Browse files Browse the repository at this point in the history
  • Loading branch information
pgorlani authored Jan 5, 2024
1 parent 29e2c93 commit 81bf6b7
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
16 changes: 8 additions & 8 deletions test/unittest/blas3/blas3_gemm_batched_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
template <typename scalar_t>
const auto BetaNonZeroLDMatch = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(5), // batch
::testing::Values(63, 128), // m
::testing::Values(63, 128), // n
Expand All @@ -48,7 +48,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMatch);
template <typename scalar_t>
const auto BetaNonZeroLDMultiplied = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(1, 5), // batch
::testing::Values(63, 128, 129), // m
::testing::Values(63, 128, 129), // n
Expand Down Expand Up @@ -87,7 +87,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMatchAlpha0);
template <typename scalar_t>
const auto BetaNonZeroLDMultipliedAlpha0 = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(5), // batch
::testing::Values(63), // m
::testing::Values(63), // n
Expand All @@ -107,7 +107,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMultipliedAlpha0);
template <typename scalar_t>
const auto DefaultGemmAndGemmBatched =
::testing::Combine(::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(1, 5), // batch
::testing::Values(63, 128), // m
::testing::Values(63, 128), // n
Expand All @@ -128,7 +128,7 @@ GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, DefaultGemmAndGemmBatched);
template <typename scalar_t>
const auto AllStridedBatched =
::testing::Combine(::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(5), // batch
::testing::Values(128), // m
::testing::Values(128), // n
Expand All @@ -150,7 +150,7 @@ GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, AllStridedBatched);
template <typename scalar_t>
const auto CplxBetaNonZeroLDMatch = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(3), // batch
::testing::Values(63, 128), // m
::testing::Values(63, 128), // n
Expand All @@ -169,7 +169,7 @@ GENERATE_CPLX_GEMM_TEST(BatchGemm, CplxBetaNonZeroLDMatch);
template <typename scalar_t>
const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(1, 4), // batch
::testing::Values(63, 128), // m
::testing::Values(63, 128), // n
Expand All @@ -191,7 +191,7 @@ GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm,
template <typename scalar_t>
const auto CplxAllStridedBatched = ::testing::Combine(
::testing::Values("usm", "buf"), // allocation type
::testing::Values(0), // offset
::testing::Values(0, 33), // offset
::testing::Values(3), // batch
::testing::Values(128), // m
::testing::Values(128), // n
Expand Down
2 changes: 2 additions & 0 deletions test/unittest/blas3/blas3_gemm_common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ inline std::vector<scalar_t> strided_to_interleaved(
const std::vector<scalar_t>& input, int offset, int ld_rows, int ld_cols,
int batchs) {
std::vector<scalar_t> output(input.size());
for (int o = 0; o < offset; ++o) output[o] = input[o];
for (int c = 0; c < ld_cols; ++c) {
for (int r = 0; r < ld_rows; ++r) {
for (int b = 0; b < batchs; ++b) {
Expand All @@ -73,6 +74,7 @@ inline std::vector<scalar_t> interleaved_to_strided(
const std::vector<scalar_t>& input, int offset, int ld_rows, int ld_cols,
int batchs) {
std::vector<scalar_t> output(input.size());
for (int o = 0; o < offset; ++o) output[o] = input[o];
for (int b = 0; b < batchs; ++b) {
for (int c = 0; c < ld_cols; ++c) {
for (int r = 0; r < ld_rows; ++r) {
Expand Down

0 comments on commit 81bf6b7

Please sign in to comment.