From 81bf6b774ca2cddc31ea09b8feeacbaaabcb8599 Mon Sep 17 00:00:00 2001 From: pgorlani <92453485+pgorlani@users.noreply.github.com> Date: Fri, 5 Jan 2024 11:17:55 +0000 Subject: [PATCH] Fix unit tests for GEMM batch interleaved (#490) --- test/unittest/blas3/blas3_gemm_batched_test.cpp | 16 ++++++++-------- test/unittest/blas3/blas3_gemm_common.hpp | 2 ++ 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/test/unittest/blas3/blas3_gemm_batched_test.cpp b/test/unittest/blas3/blas3_gemm_batched_test.cpp index 824bf656b..74257b5cf 100644 --- a/test/unittest/blas3/blas3_gemm_batched_test.cpp +++ b/test/unittest/blas3/blas3_gemm_batched_test.cpp @@ -29,7 +29,7 @@ template const auto BetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(5), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n @@ -48,7 +48,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMatch); template const auto BetaNonZeroLDMultiplied = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(1, 5), // batch ::testing::Values(63, 128, 129), // m ::testing::Values(63, 128, 129), // n @@ -87,7 +87,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMatchAlpha0); template const auto BetaNonZeroLDMultipliedAlpha0 = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(5), // batch ::testing::Values(63), // m ::testing::Values(63), // n @@ -107,7 +107,7 @@ GENERATE_GEMM_TEST(BatchGemm, BetaNonZeroLDMultipliedAlpha0); template const auto DefaultGemmAndGemmBatched = ::testing::Combine(::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(1, 5), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n @@ -128,7 +128,7 @@ GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, DefaultGemmAndGemmBatched); template const auto AllStridedBatched = ::testing::Combine(::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(5), // batch ::testing::Values(128), // m ::testing::Values(128), // n @@ -150,7 +150,7 @@ GENERATE_GEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, AllStridedBatched); template const auto CplxBetaNonZeroLDMatch = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(3), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n @@ -169,7 +169,7 @@ GENERATE_CPLX_GEMM_TEST(BatchGemm, CplxBetaNonZeroLDMatch); template const auto CplxDefaultGemmAndGemmBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(1, 4), // batch ::testing::Values(63, 128), // m ::testing::Values(63, 128), // n @@ -191,7 +191,7 @@ GENERATE_CPLXGEMM_STRIDED_BATCHED_TEST(BatchStridedGemm, template const auto CplxAllStridedBatched = ::testing::Combine( ::testing::Values("usm", "buf"), // allocation type - ::testing::Values(0), // offset + ::testing::Values(0, 33), // offset ::testing::Values(3), // batch ::testing::Values(128), // m ::testing::Values(128), // n diff --git a/test/unittest/blas3/blas3_gemm_common.hpp b/test/unittest/blas3/blas3_gemm_common.hpp index b9bd04e04..2cd832a99 100644 --- a/test/unittest/blas3/blas3_gemm_common.hpp +++ b/test/unittest/blas3/blas3_gemm_common.hpp @@ -56,6 +56,7 @@ inline std::vector strided_to_interleaved( const std::vector& input, int offset, int ld_rows, int ld_cols, int batchs) { std::vector output(input.size()); + for (int o = 0; o < offset; ++o) output[o] = input[o]; for (int c = 0; c < ld_cols; ++c) { for (int r = 0; r < ld_rows; ++r) { for (int b = 0; b < batchs; ++b) { @@ -73,6 +74,7 @@ inline std::vector interleaved_to_strided( const std::vector& input, int offset, int ld_rows, int ld_cols, int batchs) { std::vector output(input.size()); + for (int o = 0; o < offset; ++o) output[o] = input[o]; for (int b = 0; b < batchs; ++b) { for (int c = 0; c < ld_cols; ++c) { for (int r = 0; r < ld_rows; ++r) {