From 821800fb720a82403a4488d90ea8233cca45b918 Mon Sep 17 00:00:00 2001 From: tomflinda Date: Thu, 9 Jan 2025 16:05:58 +0800 Subject: [PATCH] [SYCLomatic][PTX] Refine migration of PTX asm instruction "lop3.b32" (#2592) Signed-off-by: chenwei.sun --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 50 +++++----- clang/runtime/dpct-rt/include/dpct/util.hpp | 100 ++++++++++++++++++++ clang/test/dpct/asm/lop3.cu | 13 +++ 3 files changed, 140 insertions(+), 23 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index ac0104aa1de9..b4c560b2fc6c 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -943,24 +943,27 @@ class SYCLGen : public SYCLGenBase { return SYCLGenError(); OS() << " = "; - std::string Op[3]; - for (auto Idx : llvm::seq(0, 3)) { + std::string Op[4]; + for (auto Idx : llvm::seq(0, 4)) { if (tryEmitStmt(Op[Idx], I->getInputOperand(Idx))) return SYCLGenError(); } - if (!isa(I->getInputOperand(3))) - return SYCLGenError(); - unsigned Imm = dyn_cast(I->getInputOperand(3)) - ->getValue() - .getZExtValue(); + if (!isa(I->getInputOperand(3))) { + OS() << MapNames::getDpctNamespace() << "ternary_logic_op(" << Op[0] + << ", " << Op[1] << ", " << Op[2] << ", " << Op[3] << ")"; + + } else { + unsigned Imm = dyn_cast(I->getInputOperand(3)) + ->getValue() + .getZExtValue(); #define EMPTY nullptr #define EMPTY4 EMPTY, EMPTY, EMPTY, EMPTY #define EMPTY16 EMPTY4, EMPTY4, EMPTY4, EMPTY4 - constexpr const char *FastMap[256] = { - /*0x00*/ "0", - // clang-format off + constexpr const char *FastMap[256] = { + /*0x00*/ "0", + // clang-format off EMPTY16, EMPTY4, EMPTY4, EMPTY, /*0x1a*/ "({0} & {1} | {2}) ^ {0}", EMPTY, EMPTY, EMPTY, @@ -988,12 +991,12 @@ class SYCLGen : public SYCLGenBase { EMPTY16, EMPTY, EMPTY, EMPTY, /*0xfe*/ "{0} | {1} | {2}", /*0xff*/ "uint32_t(-1)"}; - // clang-format on + // clang-format on #undef EMPTY16 #undef EMPTY4 #undef EMPTY - // clang-format off + // clang-format off constexpr const char *SlowMap[8] = { /* 0x01*/ "(~{0} & ~{1} & ~{2})", /* 0x02*/ "(~{0} & ~{1} & {2})", @@ -1004,20 +1007,21 @@ class SYCLGen : public SYCLGenBase { /* 0x40*/ "({0} & {1} & ~{2})", /* 0x80*/ "({0} & {1} & {2})" }; - // clang-format on + // clang-format on - if (FastMap[Imm]) { - OS() << llvm::formatv(FastMap[Imm], Op[0], Op[1], Op[2]); - } else { - SmallVector Templates; - for (auto Bit : llvm::seq(0, 8)) { - if (Imm & (1U << Bit)) { - Templates.push_back( - llvm::formatv(SlowMap[Bit], Op[0], Op[1], Op[2]).str()); + if (FastMap[Imm]) { + OS() << llvm::formatv(FastMap[Imm], Op[0], Op[1], Op[2]); + } else { + SmallVector Templates; + for (auto Bit : llvm::seq(0, 8)) { + if (Imm & (1U << Bit)) { + Templates.push_back( + llvm::formatv(SlowMap[Bit], Op[0], Op[1], Op[2]).str()); + } } - } - OS() << llvm::join(Templates, " | "); + OS() << llvm::join(Templates, " | "); + } } endstmt(); diff --git a/clang/runtime/dpct-rt/include/dpct/util.hpp b/clang/runtime/dpct-rt/include/dpct/util.hpp index 91f40c479976..26697704ae38 100644 --- a/clang/runtime/dpct-rt/include/dpct/util.hpp +++ b/clang/runtime/dpct-rt/include/dpct/util.hpp @@ -1205,6 +1205,106 @@ template struct nth_argument_type { using type = decltype(helper(std::declval())); }; +/// \brief The function performs bitwise logical operations on three input +/// values of \p a, \p b and \p c based on the specified 8-bit truth table \p +/// lut and return the result +/// +/// \param [in] a Input value +/// \param [in] b Input value +/// \param [in] c Input value +/// \param [in] lut truth table for looking up +/// \returns The result +inline uint32_t ternary_logic_op(uint32_t a, uint32_t b, uint32_t c, + uint8_t lut) { + uint32_t result = 0; + + switch (lut) { + case 0x0: + result = 0; + break; + case 0x1: + result = ~a & ~b & ~c; + break; + case 0x2: + result = ~a & ~b & c; + case 0x4: + result = ~a & b & ~c; + break; + case 0x8: + result = ~a & b & c; + break; + case 0x10: + result = a & ~b & ~c; + break; + case 0x20: + result = a & ~b & c; + break; + case 0x40: + result = a & b & ~c; + break; + case 0x80: + result = a & b & c; + break; + case 0x1a: + result = (a & b | c) ^ a; + break; + case 0x1e: + result = a ^ (b | c); + break; + case 0x2d: + result = ~a ^ (~b & c); + break; + case 0x78: + result = a ^ (b & c); + break; + case 0x96: + result = a ^ b ^ c; + break; + case 0xb4: + result = a ^ (b & ~c); + break; + case 0xb8: + result = a ^ (b & (c ^ a)); + break; + case 0xd2: + result = a ^ (~b & c); + break; + case 0xe8: + result = a & (b | c) | (b & c); + break; + case 0xea: + result = a & b | c; + break; + case 0xfe: + result = a | b | c; + break; + case 0xff: + result = -1; + break; + default: { + if (lut & 0x01) + result |= ~a & ~b & ~c; + if (lut & 0x02) + result |= ~a & ~b & c; + if (lut & 0x04) + result |= ~a & b & ~c; + if (lut & 0x08) + result |= ~a & b & c; + if (lut & 0x10) + result |= a & ~b & ~c; + if (lut & 0x20) + result |= a & ~b & c; + if (lut & 0x40) + result |= a & b & ~c; + if (lut & 0x80) + result |= a & b & c; + break; + } + } + + return result; +} + #ifdef _WIN32 #define DPCT_EXPORT __declspec(dllexport) #else diff --git a/clang/test/dpct/asm/lop3.cu b/clang/test/dpct/asm/lop3.cu index a793632f7853..c6bca1eb50c1 100644 --- a/clang/test/dpct/asm/lop3.cu +++ b/clang/test/dpct/asm/lop3.cu @@ -35,4 +35,17 @@ __device__ int hard(int a) { asm("lop3.b32 %0, %1, %2, 3, 0x1C;" : "=r"(d4) : "r"(a + B), "r"(B)); return d4; } + +// CHECK: template inline T lop3(T a, T b, T c) { +// CHECK-NEXT: T res; +// CHECK-NEXT: res = dpct::ternary_logic_op(a, b, c, lut); +// CHECK-NEXT: return res; +// CHECK-NEXT:} +template __device__ inline T lop3(T a, T b, T c) { + T res; + asm volatile("lop3.b32 %0, %1, %2, %3, %4;\n" + : "=r"(res) + : "r"(a), "r"(b), "r"(c), "n"(lut)); + return res; +} // clang-format on