From 199fd8a239068318e94d39843c4c676f44883bd3 Mon Sep 17 00:00:00 2001 From: Keren Zhou Date: Sat, 11 Jan 2025 11:28:16 -0500 Subject: [PATCH 01/15] [FRONTEND][NFC] Remove unused strings (#5578) --- python/triton/__init__.py | 2 -- python/triton/compiler/__init__.py | 5 +---- python/triton/language/__init__.py | 2 -- 3 files changed, 1 insertion(+), 8 deletions(-) diff --git a/python/triton/__init__.py b/python/triton/__init__.py index 1c7058567a..c9ad47de26 100644 --- a/python/triton/__init__.py +++ b/python/triton/__init__.py @@ -33,7 +33,6 @@ "compile", "Config", "heuristics", - "impl", "InterpreterError", "jit", "JITFunction", @@ -41,7 +40,6 @@ "language", "MockTensor", "next_power_of_2", - "ops", "OutOfResources", "reinterpret", "runtime", diff --git a/python/triton/compiler/__init__.py b/python/triton/compiler/__init__.py index a05efd7e08..f055926fa8 100644 --- a/python/triton/compiler/__init__.py +++ b/python/triton/compiler/__init__.py @@ -1,7 +1,4 @@ from .compiler import CompiledKernel, ASTSource, IRSource, compile, make_backend, LazyDict from .errors import CompilationError -__all__ = [ - "compile", "make_backend", "ASTSource", "IRSource", "AttrsDescriptor", "CompiledKernel", "CompilationError", - "LazyDict" -] +__all__ = ["compile", "make_backend", "ASTSource", "IRSource", "CompiledKernel", "CompilationError", "LazyDict"] diff --git a/python/triton/language/__init__.py b/python/triton/language/__init__.py index 1a8ddf9f92..b4cb1df50e 100644 --- a/python/triton/language/__init__.py +++ b/python/triton/language/__init__.py @@ -199,7 +199,6 @@ "int32", "int64", "int8", - "ir", "join", "load", "log", @@ -249,7 +248,6 @@ "swizzle2d", "tensor", "trans", - "triton", "tuple", "uint16", "uint32", From 22ac44735eab5d2d0d5485db2cc6237cb15cf967 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sat, 11 Jan 2025 23:33:33 -0800 Subject: [PATCH 02/15] [FRONTEND][BACKEND] plumb `fast_math` attribute from scaled_dot frontend to LLVM codegen. Ignore NaN when set. (#5582) --- .../Conversion/TritonGPUToLLVM/Utility.h | 3 ++- include/triton/Dialect/Triton/IR/TritonOps.td | 3 ++- .../Dialect/TritonGPU/IR/TritonGPUOps.td | 11 +++++++---- lib/Conversion/TritonGPUToLLVM/Utility.cpp | 10 ++++++---- .../TritonGPU/Transforms/AccelerateMatmul.cpp | 19 ++++++++++++------- python/src/ir.cc | 10 ++++++---- python/triton/language/core.py | 6 ++++-- python/triton/language/semantic.py | 5 +++-- test/Conversion/tritongpu_to_llvm.mlir | 10 +++++++++- test/TritonGPU/accelerate-matmul.mlir | 12 ++++++------ .../AccelerateAMDMatmul.cpp | 10 ++++++---- .../UpcastMXFPToLLVM.cpp | 4 ++-- 12 files changed, 65 insertions(+), 38 deletions(-) diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 58da0311ad..e69964ebca 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -480,7 +480,8 @@ SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, ArrayRef values); // Scale a mxfp4 value by a given scale. -Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale); +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath); } // namespace LLVM diff --git a/include/triton/Dialect/Triton/IR/TritonOps.td b/include/triton/Dialect/Triton/IR/TritonOps.td index c0a6887b64..baaa037a33 100644 --- a/include/triton/Dialect/Triton/IR/TritonOps.td +++ b/include/triton/Dialect/Triton/IR/TritonOps.td @@ -690,7 +690,8 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure, Optional>:$lhs_scale, Optional>:$rhs_scale, TT_ScaleDotElemTypeAttr:$lhs_type, - TT_ScaleDotElemTypeAttr:$rhs_type + TT_ScaleDotElemTypeAttr:$rhs_type, + BoolAttr:$fastMath ); let results = (outs TT_FloatTensor:$d); diff --git a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td index 546de144f0..ea5ed593df 100644 --- a/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td +++ b/include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td @@ -292,10 +292,13 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> { Compute the bf16 encoded in the given mxfp number as per https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf }]; - let arguments = (ins - TT_Tensor:$src, - TT_Tensor:$scale, - TT_ScaleDotElemTypeAttr:$fp_type); + let arguments = ( + ins + TT_Tensor:$src, + TT_Tensor:$scale, + TT_ScaleDotElemTypeAttr:$fp_type, + BoolAttr:$fastMath + ); let results = (outs TT_Tensor:$result); let assemblyFormat = [{ diff --git a/lib/Conversion/TritonGPUToLLVM/Utility.cpp b/lib/Conversion/TritonGPUToLLVM/Utility.cpp index 05dcbe1c2d..9e82a21ba7 100644 --- a/lib/Conversion/TritonGPUToLLVM/Utility.cpp +++ b/lib/Conversion/TritonGPUToLLVM/Utility.cpp @@ -904,13 +904,15 @@ SmallVector convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc, return results; } -Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, - Value scale) { +Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath) { Value vBf16 = bitcast(v, bf16_ty); - Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); - Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty); Value scaledBf16 = fmul(vBf16, scaleBf16); + if (fastMath) + return scaledBf16; + Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); + Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); // Account for NaN in the scale as per the mxfp specification. return select(scaleIsNan, nanBf16, scaledBf16); }; diff --git a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp index 5f3dbf3cbe..09df997782 100644 --- a/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp +++ b/lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp @@ -398,6 +398,7 @@ class DecomposeScaledBlocked auto scale = scaledDotOp.getLhsScale(); auto aType = scaledDotOp.getLhsType(); auto bType = scaledDotOp.getRhsType(); + bool fastMath = scaledDotOp.getFastMath(); auto rank = oldRetType.getShape().size(); if (rank != 2) @@ -510,7 +511,8 @@ class DecomposeScaledBlocked newScaleEncoding = LinearEncodingAttr::get(ctx, std::move(newLL)); } - a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding); + a = createArg(rewriter, a, 0, aType, newAEncoding, scale, newScaleEncoding, + fastMath); Operation *newDot = nullptr; if (versionMajor == 2) { @@ -518,7 +520,8 @@ class DecomposeScaledBlocked assert(bType != ScaleDotElemType::E2M1 && "NYI: rhs scale for fp4"); auto newBEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaEnc, bKWidth); b = createArg(rewriter, b, 1, bType, newBEncoding, - /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt); + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt, + fastMath); newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, a, b, newAcc); } else { @@ -541,7 +544,7 @@ class DecomposeScaledBlocked createArg(mlir::PatternRewriter &rewriter, TypedValue v, int idx, ScaleDotElemType type, std::optional vEncoding, std::optional> opt_scale, - std::optional scaleEncoding) const { + std::optional scaleEncoding, bool fastMath) const { auto ctx = rewriter.getContext(); // Create a new tensor with a given encoding or remove the encoding auto maybeWithEncoding = @@ -576,7 +579,7 @@ class DecomposeScaledBlocked auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType( ret, type, Builder(v.getContext()).getBF16Type()); ret = rewriter.create(v.getLoc(), retTy, ret, - scale, type); + scale, type, fastMath); } return ret; } @@ -589,6 +592,7 @@ class DecomposeScaledBlocked auto scale = scaledDotOp.getLhsScale(); auto aType = scaledDotOp.getLhsType(); auto bType = scaledDotOp.getRhsType(); + bool fastMath = scaledDotOp.getFastMath(); // create a DotOp to be passed in to getMMAVersionSafe // We don't pass encodings as we just want to get the type and shape @@ -597,7 +601,7 @@ class DecomposeScaledBlocked // end up in the graph RankedTensorType aTType = createArg(rewriter, a, 0, aType, /*vEncoding=*/std::nullopt, scale, - /*scaleEncoding=*/std::nullopt) + /*scaleEncoding=*/std::nullopt, fastMath) .getType(); auto aTypeNoEnc = RankedTensorType::get(aTType.getShape(), aTType.getElementType()); @@ -605,7 +609,8 @@ class DecomposeScaledBlocked RankedTensorType bTType = createArg(rewriter, b, 1, bType, /*vEncoding=*/std::nullopt, - /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt) + /*scale=*/std::nullopt, /*scaleEncoding=*/std::nullopt, + fastMath) .getType(); auto bTypeNoEnc = RankedTensorType::get(bTType.getShape(), bTType.getElementType()); @@ -752,7 +757,7 @@ static Operation *transposeDotOp(DotScaledOp dotOp) { Value result = builder.create( dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed, cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(), - dotOp.getLhsType()); + dotOp.getLhsType(), dotOp.getFastMath()); Operation *transposedResult = builder.create(result.getLoc(), result, transOrder); dotOp.replaceAllUsesWith(transposedResult); diff --git a/python/src/ir.cc b/python/src/ir.cc index b6b0c846fe..6c31946d66 100644 --- a/python/src/ir.cc +++ b/python/src/ir.cc @@ -1513,10 +1513,12 @@ void init_triton_ir(py::module &&m) { std::optional &lhs_scale, ScaleDotElemType lhs_format, mlir::Value &rhs, std::optional &rhs_scale, - ScaleDotElemType rhs_format, mlir::Value &c) -> mlir::Value { - return self.create( - c.getType(), lhs, rhs, c, lhs_scale.value_or(Value()), - rhs_scale.value_or(Value()), lhs_format, rhs_format); + ScaleDotElemType rhs_format, bool fast_math, + mlir::Value &c) -> mlir::Value { + return self.create(c.getType(), lhs, rhs, c, + lhs_scale.value_or(Value()), + rhs_scale.value_or(Value()), + lhs_format, rhs_format, fast_math); }) .def("create_floor", [](TritonOpBuilder &self, Value &val) -> Value { diff --git a/python/triton/language/core.py b/python/triton/language/core.py index a07cf5dc68..b22cd86a15 100644 --- a/python/triton/language/core.py +++ b/python/triton/language/core.py @@ -1733,7 +1733,8 @@ def dot(input, other, acc=None, input_precision=None, allow_tf32=None, max_num_i @builtin -def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, out_dtype=float32, _builder=None): +def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, fast_math=False, acc=None, out_dtype=float32, + _builder=None): """ Returns the matrix product of two blocks in microscaling format. @@ -1763,7 +1764,8 @@ def dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc=None, """ out_dtype = _constexpr_to_value(out_dtype) assert out_dtype == float32, "Only float32 is supported for out_dtype at the moment" - return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, acc, out_dtype, _builder) + return semantic.dot_scaled(lhs, lhs_scale, lhs_format, rhs, rhs_scale, rhs_format, fast_math, acc, out_dtype, + _builder) # ----------------------- diff --git a/python/triton/language/semantic.py b/python/triton/language/semantic.py index 62e2538acc..7939d7f5af 100644 --- a/python/triton/language/semantic.py +++ b/python/triton/language/semantic.py @@ -1562,7 +1562,8 @@ def _bitcast_to_fp_type(val: tl.tensor, float_format: str, builder: ir.builder): def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.tensor, rhs_scale: Optional[tl.tensor], - rhs_format: str, acc: tl.tensor | None, out_dtype: tl.dtype, builder: ir.builder) -> tl.tensor: + rhs_format: str, fast_math: bool, acc: tl.tensor | None, out_dtype: tl.dtype, + builder: ir.builder) -> tl.tensor: assert lhs.type.is_block() and rhs.type.is_block() #TODO: validate types. lhs_rank = len(lhs.shape) @@ -1601,7 +1602,7 @@ def dot_scaled(lhs: tl.tensor, lhs_scale: tl.tensor, lhs_format: str, rhs: tl.te lhs_scale_handle = None if lhs_scale_is_none else lhs_scale.handle return tl.tensor( builder.create_dot_scaled(lhs.handle, lhs_scale_handle, lhs_format_enum, rhs.handle, rhs_scale_handle, - rhs_format_enum, acc_handle), ret_ty) + rhs_format_enum, fast_math, acc_handle), ret_ty) # ===----------------------------------------------------------------------===// diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index 382e4eb9c6..db8b664e76 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2110,7 +2110,15 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m // CHECK-COUNT-4: llvm.inline_asm // CHECK-COUNT-2: nvvm.shfl.sync // CHECK-COUNT-32: llvm.fmul - %0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + // CHECK: llvm.icmp + // CHECK: llvm.select + %0 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> + // CHECK-COUNT-4: llvm.inline_asm + // CHECK-COUNT-2: nvvm.shfl.sync + // CHECK-COUNT-32: llvm.fmul + // CHECK-NOT: llvm.icmp + // CHECK-NOT: llvm.select + %1 = ttg.upcast_mxfp %arg0, %arg1 fp_type = e2m1 {fastMath = true} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>>, tensor<32x2xi8, #linear> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>> tt.return } diff --git a/test/TritonGPU/accelerate-matmul.mlir b/test/TritonGPU/accelerate-matmul.mlir index 17180a3924..e07813af7b 100644 --- a/test/TritonGPU/accelerate-matmul.mlir +++ b/test/TritonGPU/accelerate-matmul.mlir @@ -204,10 +204,10 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- %b_bf16: tensor<64x128xbf16, #blocked> ) -> tensor<128x128xf32, #blocked> { // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, {{.*}}> - // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> - // CHECK: ttng.warp_group_dot + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 {fastMath = false} : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, {{.*}}> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> + // CHECK-NEXT: ttng.warp_group_dot {{.*}} %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b_bf16, %cst lhs = e2m1 rhs = bf16 {fastMath = false} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } @@ -220,9 +220,9 @@ module attributes {"ttg.target" = "cuda:90", "ttg.num-ctas" = 1 : i32, "ttg.num- ) -> tensor<128x128xf32, #blocked> { %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> // CHECK: ttg.convert_layout {{.*}} : tensor<128x2xi8, #blocked1> -> tensor<128x2xi8, #[[LINEAR]]> - // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> + // CHECK: ttg.upcast_mxfp {{.*}}, {{.*}} fp_type = e2m1 {fastMath = true} : tensor<128x32xi8, #ttg.dot_op<{{.*}}>>, tensor<128x2xi8, #[[LINEAR]]> -> tensor<128x64xbf16, #ttg.dot_op<{{.*}}>> // CHECK: tt.dot - %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b_fp8, %cst lhs = e2m1 rhs = e4m3 {fastMath = true} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } @@ -246,7 +246,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ %0 = scf.for %arg4 = %c0_i32 to %c100_i32 step %c1_i32 iter_args(%arg5 = %cst) -> (tensor<128x32xf32, #blocked1>) : i32 { // CHECK-DAG: tt.trans %{{.*}} {order = array} : tensor<128x64xf8E4M3FN, #{{.*}}> -> tensor<64x128xf8E4M3FN, #{{.*}}> // CHECK-DAG: tt.trans %a{{.*}} {order = array} : tensor<32x32xi8, #{{.*}}> -> tensor<32x32xi8, #{{.*}}> - %3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1> + %3 = tt.dot_scaled %arg0, %arg1 scale %arg2, %arg5 lhs = e4m3 rhs = e2m1 {fastMath = false}: tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1> // CHECK: tt.dot // CHECK-NOT: tt.trans // CHECK: scf.yield diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 4e50775511..5efa81c912 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -613,7 +613,7 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { ctx, {1, 1}, threadsPerWarp, blockWarpsPerCTA, {1, 0}, ctaLayout); auto upcastMXFP = [&](TensorValue v, TensorValue scale, - ScaleDotElemType elemType) -> Value { + ScaleDotElemType elemType, bool fastMath) -> Value { if (!scale) return v; @@ -629,11 +629,13 @@ class ScaledBlockedToMFMA final : public OpRewritePattern { auto outputType = ttg::UpcastMXFPOp::deduceOutputType(v, elemType, outputElemType); return rewriter.create(dotOp.getLoc(), outputType, v, - convOp, elemType); + convOp, elemType, fastMath); }; - Value scaledA = upcastMXFP(a, aScale, dotOp.getLhsType()); - Value scaledB = upcastMXFP(b, bScale, dotOp.getRhsType()); + Value scaledA = + upcastMXFP(a, aScale, dotOp.getLhsType(), dotOp.getFastMath()); + Value scaledB = + upcastMXFP(b, bScale, dotOp.getRhsType(), dotOp.getFastMath()); auto newDot = rewriter.create(dotOp.getLoc(), newRetType, scaledA, scaledB, newAcc); rewriter.replaceOpWithNewOp(dotOp, oldRetType, diff --git a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp index 47c7fcc063..96648179c4 100644 --- a/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -149,8 +149,8 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int k = 0; k < kWidth; ++k) { auto idx = 32 * i + 16 * mxfp + rep * 2 * kWidth + subTile * kWidth + k; - xVals[idx] = - LLVM::mxfpScaleBf16(rewriter, loc, xVals[idx], si[subTile]); + xVals[idx] = LLVM::mxfpScaleBf16(rewriter, loc, xVals[idx], + si[subTile], op.getFastMath()); } } } From a3095b3b3addcaaabc036f03fe94ac8aaa8dbe6d Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Sun, 12 Jan 2025 18:20:09 +0100 Subject: [PATCH 03/15] Remove `examples` folder (#5574) Signed-off-by: Anatoly Myachev --- python/examples/copy_strided.py | 25 ------------------------- python/examples/empty.py | 13 ------------- 2 files changed, 38 deletions(-) delete mode 100644 python/examples/copy_strided.py delete mode 100644 python/examples/empty.py diff --git a/python/examples/copy_strided.py b/python/examples/copy_strided.py deleted file mode 100644 index 8d21c42d9f..0000000000 --- a/python/examples/copy_strided.py +++ /dev/null @@ -1,25 +0,0 @@ -import triton -import triton.language as tl -import triton.compiler as tc - - -# triton kernel -@triton.jit -def kernel(X, stride_xm, # - Z, stride_zn, # - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr): - off_m = tl.arange(0, BLOCK_M) - off_n = tl.arange(0, BLOCK_N) - Xs = X + off_m[:, None] * stride_xm + off_n[None, :] * 1 - Zs = Z + off_m[:, None] * 1 + off_n[None, :] * stride_zn - tl.store(Zs, tl.load(Xs)) - - -src = tc.ASTSource( - fn=kernel, - constants={"BLOCK_M": 64, "BLOCK_N": 64}, - signature="*fp32,i32,*fp32,i32", -) - -ret = triton.compile(src) -print(ret.asm["ttgir"]) diff --git a/python/examples/empty.py b/python/examples/empty.py deleted file mode 100644 index bff6d1e949..0000000000 --- a/python/examples/empty.py +++ /dev/null @@ -1,13 +0,0 @@ -import torch - -import triton -import triton.language as tl - - -@triton.jit -def kernel(X, stride_xm, stride_xn, BLOCK: tl.constexpr): - pass - - -X = torch.randn(1, device="cuda") -pgm = kernel[(1, )](X, 1, 1, BLOCK=1024) From f7e6775a7757fcbb28c5c7403de9841e74a77d99 Mon Sep 17 00:00:00 2001 From: Samuel Ginzburg Date: Sun, 12 Jan 2025 12:28:52 -0500 Subject: [PATCH 04/15] [AMD] Pass down atomics memscope through lowering (#5580) # Overview Atomics in triton have two optional attributes: 1) `sem` -- describing the memory semantics of the operation 2) `scope` -- describing which threads will see the effect of a memory operation (e.g., GPU, CTA) Presently, the `scope` is ignored by the AMD backend and defaults to `agent`-scope in the emitted LLVM (which roughly corresponds to `gpu` memscope in triton). This is correct (in most cases? maybe not all?), as this is a "stricter" scope than CTA (and I'm guessing it is rare that system scope is needed for AMD kernels, so no bugs have shown up). That being said, emitting atomics at CTA scope can be more efficient since there can be fewer cache invalidations/barriers. I think that this is fixable by just passing through the attribute to the generated `llvm.atomicrmw` op. There are some additional optimizations potentially possible (e.g., !amdgpu.no.remote.memory, since Triton doesn't support this today), but it isn't clear to me if those would have any real impact on end-to-end performance and those optimizations would be specific to the `sys`-scope that doesn't appear to be frequently used. # Testing I added a lit test to ensure that the generated LLVM instructions have the correct sem/scope attributes for atomicrmw, but I also ran the following 386 unit tests locally on an MI300x: ```bash pytest test/unit/language/test_core.py -k test_atomic_ ``` I then locally ran some kernels with the scope set to CTA/SYSTEM to make sure that they worked. --- test/Conversion/amd/tritongpu_to_llvm.mlir | 43 ++++++++++++++++++ .../TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp | 44 ++++++++++++++++--- 2 files changed, 81 insertions(+), 6 deletions(-) diff --git a/test/Conversion/amd/tritongpu_to_llvm.mlir b/test/Conversion/amd/tritongpu_to_llvm.mlir index 8f4fbee399..ecc12a94c0 100644 --- a/test/Conversion/amd/tritongpu_to_llvm.mlir +++ b/test/Conversion/amd/tritongpu_to_llvm.mlir @@ -209,3 +209,46 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr tt.return } } + + +// ----- + +#blocked0 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} { + // CHECK-LABEL: atomicrmw_scope_memsemantics + tt.func @atomicrmw_scope_memsemantics(%arg0 : tensor<128x!tt.ptr, #blocked0>, %arg1 : tensor<128xi1, #blocked0>, %arg2 : tensor<128xf32, #blocked0>) { + // relaxed + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} monotonic + %0 = tt.atomic_rmw fadd, relaxed, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) monotonic + %1 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) monotonic + %2 = tt.atomic_rmw fadd, relaxed, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // acquire + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} acquire + %3 = tt.atomic_rmw fadd, acquire, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acquire + %4 = tt.atomic_rmw fadd, acquire, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acquire + %5 = tt.atomic_rmw fadd, acquire, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // release + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} release + %6 = tt.atomic_rmw fadd, release, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) release + %7 = tt.atomic_rmw fadd, release, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) release + %8 = tt.atomic_rmw fadd, release, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + // acq_rel + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} acq_rel + %9 = tt.atomic_rmw fadd, acq_rel, sys, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"agent"}}) acq_rel + %10 = tt.atomic_rmw fadd, acq_rel, gpu, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + // CHECK: llvm.atomicrmw {{.*}}, {{.*}} syncscope({{"workgroup"}}) acq_rel + %11 = tt.atomic_rmw fadd, acq_rel, cta, %arg0, %arg2, %arg1 : (tensor<128x!tt.ptr, #blocked0>, tensor<128xf32, #blocked0>, tensor<128xi1, #blocked0>) -> tensor<128xf32, #blocked0> + + tt.return + } +} diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp index 825697e0e9..53dac1e96f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -227,6 +227,29 @@ struct LoadStoreConversionBase { return axisAnalysisPass.getPtrAlignment(ptr); } + std::optional + getAMDGPUMemScopeStr(MemSyncScope scope) const { + // See: https://llvm.org/docs/AMDGPUUsage.html#memory-scopes + auto scopeStr = ""; + switch (scope) { + case MemSyncScope::SYSTEM: + // The default AMDHSA LLVM Sync Scope is "system", so no string is + // provided here + scopeStr = ""; + break; + case MemSyncScope::GPU: + scopeStr = "agent"; + break; + case MemSyncScope::CTA: + scopeStr = "workgroup"; + break; + default: + return std::nullopt; + } + + return scopeStr; + } + protected: const AMD::TargetInfo &targetInfo; ModuleAxisInfoAnalysis &axisAnalysisPass; @@ -601,6 +624,10 @@ struct AtomicCASOpConversion auto memOrdering = op.getSem(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto scope = op.getScope(); + auto scopeStr = getAMDGPUMemScopeStr(scope); + if (!scopeStr) + return failure(); // deal with tensor or scalar auto valueTy = op.getResult().getType(); @@ -643,7 +670,7 @@ struct AtomicCASOpConversion auto failureOrdering = LLVM::AtomicOrdering::monotonic; auto cmpxchg = rewriter.create( loc, casPtr, casCmp, casVal, successOrdering, failureOrdering, - StringRef("agent")); + StringRef(scopeStr.value())); // Extract the new_loaded value from the pair. Value ret = extract_val(valueElemTy, cmpxchg, i); @@ -852,8 +879,13 @@ struct AtomicRMWOpConversion mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0))); auto memOrdering = op.getSem(); + auto scope = op.getScope(); auto atomicMemOrdering = getMemoryOrdering(memOrdering); + auto scopeStr = getAMDGPUMemScopeStr(scope); + if (!scopeStr) + return failure(); + auto vecTy = vec_ty(valueElemTy, vec); auto retType = vec == 1 ? valueElemTy : vecTy; retType = useDppForPackedF16 ? packF16Ty : retType; @@ -907,11 +939,11 @@ struct AtomicRMWOpConversion auto maybeKind = matchAtomicOp(atomicRmwAttr); // TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient // atomics for MI-* series of AMD GPU. - Value atom = - rewriter - .create(loc, *maybeKind, rmwPtr, operand, - atomicMemOrdering, StringRef("agent")) - .getResult(); + Value atom = rewriter + .create(loc, *maybeKind, rmwPtr, + operand, atomicMemOrdering, + StringRef(scopeStr.value())) + .getResult(); if (!tensorTy) { if (atomicNeedsSharedMemory(op.getResult())) { Value atomPtr = From 9649f7108cf967ae55df7765323afd86b1db1a03 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Sun, 12 Jan 2025 10:35:57 -0800 Subject: [PATCH 05/15] [AMD] Bypass NaN check for fast math scaled dot (#5584) Following https://github.com/triton-lang/triton/pull/5582. --- .../TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp index cd02807db7..8f058ffe6f 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/UpcastMXFPToLLVM.cpp @@ -53,12 +53,14 @@ SmallVector convertMxfp4x2ToFp16x2(RewriterBase &rewriter, Location loc, return results; } -Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, - Value scale) { +Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, Value scale, + bool fastMath) { Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); Value scaleF16 = LLVM::AMD::cvtFp32ToFp16(loc, rewriter, scaleF32, RoundingMode::RTNE); Value mulF16 = fmul(v, scaleF16); + if (fastMath) + return mulF16; // Account for NaN in the scale as per the mxfp specification. Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); Value nanF16 = bitcast(i16_val(0x7c01), f16_ty); @@ -72,16 +74,19 @@ Value mxfpScaleFp16(RewriterBase &rewriter, Location loc, Value v, // handles v * scale multiplication using fp32 VALU ops. LLVM backend can do it // for us, just with unnecessary overheads. Value mxfpScaleBf16ViaF32(RewriterBase &rewriter, Location loc, Value v, - Value scale) { + Value scale, bool fastMath) { Value c16 = i32_val(16); Value vF32 = bitcast(shl(zext(i32_ty, bitcast(v, i16_ty)), c16), f32_ty); Value scaleF32 = bitcast(shl(zext(i32_ty, scale), i32_val(23)), f32_ty); Value mulF32 = fmul(vF32, scaleF32); Value mulI16 = trunc(i16_ty, lshr(bitcast(mulF32, i32_ty), c16)); + Value mulBf16 = bitcast(mulI16, bf16_ty); + if (fastMath) + return mulBf16; // Account for NaN in the scale as per the mxfp specification. Value scaleIsNan = icmp_eq(scale, i8_val(0xff)); Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty); - return select(scaleIsNan, nanBf16, bitcast(mulI16, bf16_ty)); + return select(scaleIsNan, nanBf16, mulBf16); }; class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { @@ -166,9 +171,10 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; xVals[index] = - useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16]) + useFp16 ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16], + op.getFastMath()) : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], - si[j / 16]); + si[j / 16], op.getFastMath()); } } } else { @@ -190,10 +196,11 @@ class UpcastMXFPOpPattern : public ConvertOpToLLVMPattern { for (int j = 0; j < 32; ++j) { int index = 32 * i + j; - xVals[index] = - useFp16 - ? mxfpScaleFp16(rewriter, loc, xVals[index], si[j / 16]) - : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], si[j / 8]); + xVals[index] = useFp16 + ? mxfpScaleFp16(rewriter, loc, xVals[index], + si[j / 16], op.getFastMath()) + : mxfpScaleBf16ViaF32(rewriter, loc, xVals[index], + si[j / 8], op.getFastMath()); } } } From 6b41bcfceeb1c7e813bf5e63a3a8f006cf43c9c2 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 13 Jan 2025 01:41:24 +0100 Subject: [PATCH 06/15] Enable ruff-pre-commit for `third_party/proton` (#5586) Signed-off-by: Anatoly Myachev --- .pre-commit-config.yaml | 2 +- third_party/proton/proton/language.py | 1 - third_party/proton/test/test_record.py | 1 - 3 files changed, 1 insertion(+), 3 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a85e54d05d..be8ab74f46 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: rev: v0.7.1 hooks: - id: ruff - files: '^python/.*' + files: '(^python|^third_party/proton)/.*' args: ["--fix", "--exit-non-zero-on-fix"] exclude: | (?x)( diff --git a/third_party/proton/proton/language.py b/third_party/proton/proton/language.py index d923f60c6a..b88934b216 100644 --- a/third_party/proton/proton/language.py +++ b/third_party/proton/proton/language.py @@ -1,4 +1,3 @@ -from triton._C.libtriton import ir from triton.language import core as tl from triton.language.core import builtin import warnings diff --git a/third_party/proton/test/test_record.py b/third_party/proton/test/test_record.py index 0c623c3784..57a2337908 100644 --- a/third_party/proton/test/test_record.py +++ b/third_party/proton/test/test_record.py @@ -1,5 +1,4 @@ import torch -import pytest import pathlib import triton From 7cc6799ddb76a18830874259bcaf2da59484c684 Mon Sep 17 00:00:00 2001 From: Philippe Tillet Date: Sun, 12 Jan 2025 18:11:57 -0800 Subject: [PATCH 07/15] [FRONTEND] capability override bugfix (#5590) there is currently a weird bug causing capability overrides to persist when users pass `arch=None`. Rather than making `CUDABackend.sw_capability` stateful, we now retrieve capability lazily from compilation options also fix an amd bug encountered in the wild --- python/triton/compiler/compiler.py | 2 +- third_party/amd/backend/compiler.py | 6 ++-- third_party/nvidia/backend/compiler.py | 42 +++++++++++--------------- 3 files changed, 22 insertions(+), 28 deletions(-) diff --git a/python/triton/compiler/compiler.py b/python/triton/compiler/compiler.py index 406107a6c4..59423b7bba 100644 --- a/python/triton/compiler/compiler.py +++ b/python/triton/compiler/compiler.py @@ -269,7 +269,7 @@ def compile(src, target=None, options=None): ir.load_dialects(context) backend.load_dialects(context) - codegen_fns = backend.get_codegen_implementation() + codegen_fns = backend.get_codegen_implementation(options) module_map = backend.get_module_map() # try: module = src.make_ir(options, codegen_fns, module_map, context) diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index c3ccded47a..6019e01e8c 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -110,7 +110,7 @@ def __init__(self, target: GPUTarget) -> None: self.binary_ext = "hsaco" def parse_options(self, opts) -> Any: - args = {'arch': self.target.arch} + args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", self.target.arch)} if "supported_fp8_dtypes" not in opts: supported_fp8_dtypes = set(HIPOptions.supported_fp8_dtypes) @@ -120,7 +120,7 @@ def parse_options(self, opts) -> Any: if "enable_fp_fusion" not in opts: args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" - args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts}) + args.update({k: opts[k] for k in HIPOptions.__dataclass_fields__.keys() if k in opts and opts[k] is not None}) return HIPOptions(**args) def pack_metadata(self, metadata): @@ -133,7 +133,7 @@ def pack_metadata(self, metadata): metadata.cluster_dims[2], ) - def get_codegen_implementation(self): + def get_codegen_implementation(self, options): codegen_fns = {"min_dot_size": min_dot_size(self.target)} return codegen_fns diff --git a/third_party/nvidia/backend/compiler.py b/third_party/nvidia/backend/compiler.py index e498d6019e..a1b5fc31ef 100644 --- a/third_party/nvidia/backend/compiler.py +++ b/third_party/nvidia/backend/compiler.py @@ -155,36 +155,28 @@ def _parse_arch(self, arch): def __init__(self, target: GPUTarget) -> None: super().__init__(target) - # Capability can be overrided to limit feature set to a specific version - self.hw_capability = target.arch - self.sw_capability = self.hw_capability - arch = os.getenv("TRITON_OVERRIDE_ARCH") - if arch is not None: - self.sw_capability = self._parse_arch(arch) - # HW Capability is used to determine the binary format - self.hw_capability = target.arch - assert isinstance(self.hw_capability, int) - assert isinstance(self.sw_capability, int) self.binary_ext = "cubin" def parse_options(self, opts) -> Any: - args = {k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts} + args = {'arch': os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")} + args.update({k: opts[k] for k in CUDAOptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None}) + capability = int(self._parse_arch(args["arch"])) + if "supported_fp8_dtypes" not in args: supported_fp8_dtypes = set(CUDAOptions.supported_fp8_dtypes) - if self.sw_capability >= 89: + if capability >= 89: supported_fp8_dtypes.add("fp8e4nv") args["supported_fp8_dtypes"] = tuple(sorted(supported_fp8_dtypes)) if "deprecated_fp8_dtypes" not in args: - if self.sw_capability >= 90: + if capability >= 90: args["deprecated_fp8_dtypes"] = ("fp8e4b15", ) if "enable_fp_fusion" not in args: args["enable_fp_fusion"] = os.getenv("TRITON_DEFAULT_FP_FUSION", "1") == "1" - if args.get("arch", None) is not None: - self.sw_capability = self._parse_arch(args["arch"]) - args["max_num_imprecise_acc_default"] = 2**30 if self.sw_capability == 90 else 0 + args["max_num_imprecise_acc_default"] = 2**30 if capability == 90 else 0 + return CUDAOptions(**args) def pack_metadata(self, metadata): @@ -197,12 +189,13 @@ def pack_metadata(self, metadata): metadata.cluster_dims[2], ) - def get_codegen_implementation(self): + def get_codegen_implementation(self, options): import triton.language.extra.cuda as cuda + capability = int(self._parse_arch(options.arch)) codegen_fns = { "convert_custom_types": - cuda.convert_custom_float8_sm80 if self.sw_capability >= 80 else cuda.convert_custom_float8_sm70, - "min_dot_size": min_dot_size(self.target) + cuda.convert_custom_float8_sm80 if capability >= 80 else cuda.convert_custom_float8_sm70, "min_dot_size": + min_dot_size(self.target) } return codegen_fns @@ -411,13 +404,14 @@ def make_cubin(src, metadata, opt, capability): return cubin def add_stages(self, stages, options): + capability = self._parse_arch(options.arch) stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options) - stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, self.sw_capability) - stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, self.sw_capability) - stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.hw_capability) - stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.hw_capability) + stages["ttgir"] = lambda src, metadata: self.make_ttgir(src, metadata, options, capability) + stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options, capability) + stages["ptx"] = lambda src, metadata: self.make_ptx(src, metadata, options, self.target.arch) + stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, self.target.arch) @functools.lru_cache() def hash(self): version = get_ptxas_version() - return f'{version}-{self.sw_capability}-{self.hw_capability}' + return f'{version}-{self.target.arch}' From 7db39a91dfaedcf5e333a33d801c8fee2a5204a7 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 13 Jan 2025 11:33:17 +0100 Subject: [PATCH 08/15] Proper use of `subprocess.check_call` in `third_party/proton/test/test_cmd.py` (#5588) Relates to https://github.com/triton-lang/triton/pull/5537 --------- Signed-off-by: Anatoly Myachev --- third_party/proton/test/test_cmd.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index 620dcd5691..75d364fecd 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -7,8 +7,7 @@ def test_help(): # Only check if the viewer can be invoked - ret = subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) - assert ret == 0 + subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) def is_hip(): @@ -22,14 +21,13 @@ def test_exec(mode, tmp_path: pathlib.Path): temp_file = tmp_path / "test_exec.hatchet" name = str(temp_file.with_suffix("")) if mode == "script": - ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) elif mode == "python": - ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], - stdout=subprocess.DEVNULL) + subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], - stdout=subprocess.DEVNULL) - assert ret == 0 + subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) with temp_file.open() as f: data = json.load(f, ) kernels = data[0]["children"] From 3ed479f2f91a1d94dacb547115d357f5ce3219d8 Mon Sep 17 00:00:00 2001 From: apgoucher Date: Mon, 13 Jan 2025 14:53:36 +0000 Subject: [PATCH 09/15] Revert "Revert "Reverting #5389 (#5528)" (#5555)" (#5592) This reverts commit 70359fa26ca916139db22e230944cd968e8ef399 which was causing some of our internal tests to fail. Co-authored-by: Adam P. Goucher --- include/triton/Dialect/Triton/IR/Dialect.h | 18 +- lib/Analysis/Utility.cpp | 7 +- lib/Dialect/Triton/IR/Ops.cpp | 30 +- lib/Dialect/TritonGPU/IR/Dialect.cpp | 125 ++---- lib/Dialect/TritonGPU/IR/Ops.cpp | 10 - .../Transforms/RemoveLayoutConversions.cpp | 9 +- lib/Dialect/TritonGPU/Transforms/Utility.cpp | 15 +- test/Conversion/reduce_to_llvm.mlir | 70 ---- test/Conversion/tritongpu_to_llvm.mlir | 25 -- test/TritonGPU/combine.mlir | 23 -- unittest/Dialect/TritonGPU/DialectTest.cpp | 372 +++++++++++++++--- 11 files changed, 396 insertions(+), 308 deletions(-) delete mode 100644 test/Conversion/reduce_to_llvm.mlir diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 39d006cc65..56a1aa7032 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -55,19 +55,13 @@ class DialectInferLayoutInterface // Tries to compute the encoding for the result of a reshape operation that // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape using legacy layouts. This is not always - // possible (in which case we fallback to using LinearLayouts) - // In the future we'll always use LinearLayouts + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). virtual LogicalResult - inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; - - // Check if two layouts are structurally the same, even if their names are - // different - virtual LogicalResult verifyLayoutsAreEqual(ArrayRef shape, - Attribute expected, Attribute got, - Location loc) const = 0; + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8dfbcd4adf..0d534761a4 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -219,14 +219,15 @@ bool ReduceOpHelper::isSupportedLayout() { } auto srcLayout = getSrcLayout(); - if (isa( - srcLayout)) { + if (isa(srcLayout)) { return true; } - if (auto mmaLayout = dyn_cast(srcLayout)) { return mmaLayout.supportReduction(); } + if (auto sliceLayout = dyn_cast(srcLayout)) { + return true; + } return false; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 876d04e56b..b0ce2287d1 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,7 +8,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Tools/LinearLayout.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { @@ -702,21 +701,24 @@ LogicalResult ReshapeOp::verify() { "encodings, or (b) neither does."); } - if (!srcEnc || getAllowReorder()) { - return success(); + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } } - // Check that we can infer the dst encoding from the src encoding - // and that the inferred dst encoding is the same as the given dst encoding - Attribute inferredDstEnc; - auto result = - cast(&srcEnc.getDialect()) - ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(), - inferredDstEnc, getLoc()); - assert(succeeded(result)); - return cast(&srcEnc.getDialect()) - ->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc, - getLoc()); + return success(); } //-- FpToFpOp -- diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 9c1b0a4145..d24cb73260 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1441,7 +1441,7 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, SmallVector ret(rank, 1); auto nonZero = [](auto val) { return val != 0; }; - int nonZeroIdx = 0; + int nonZeroIdx = -1; for (const auto &basis : bases) { auto it = std::find_if(basis.begin(), basis.end(), nonZero); // Bases can have one or zero non-zero elements @@ -1453,6 +1453,7 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, } else if (!skipBroadcast) { // If we've seen a non-zero basis, we double the size of the previous dim // This is just needed to count the CTAsPerCGA + assert(nonZeroIdx != -1); ret[nonZeroIdx] *= 2; } } @@ -1597,14 +1598,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { - // When broadcasting the layout the shape changes, otherwise the shape is - // the same as the shape of the tensor - // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep - // the invariant that the shape of the LL is that of the tensor - // We choose the former for BC - auto ll = *toLinearLayout(shape); - return basesPerDim(ll, StringAttr::get(getContext(), "register"), - /*skipBroadcast=*/false); + // We can relax this assert by calling toLinearLayout rather than + // getLinearLayout + SmallVector shapeVec(shape.begin(), shape.end()); + assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); + auto ll = getLinearLayout(); + return basesPerDim(ll, StringAttr::get(getContext(), "register")); } // Start of Selection @@ -2674,8 +2673,8 @@ struct TritonGPUInferLayoutInterface // contains elements [a,b,c,d] before the reshape, it contains those same // elements after the reshape, they're just "renamed". // - // Using legacy layouts, a dst encoding that satisfies this property may not - // exist. Here are some positive and negative examples. + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. // // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so // dim 1 is the fastest-changing in the dst, but the src has the opposite @@ -2689,19 +2688,17 @@ struct TritonGPUInferLayoutInterface // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will // contain the same elements as before. // - // With linear layouts, we can always find a dst encoding that satisfies - // this property. See inferReshapeOpEncoding. - // // Users of this function require that it is symmetrical: if // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => // srcEnc. - LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, - Attribute srcEnc, - ArrayRef dstShape, - Attribute &dstEnc) const { + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { auto src = mlir::dyn_cast(srcEnc); if (!src) { - return failure(); + return emitOptionalError( + loc, "Non-reordering reshape only supports BlockedEncoding"); } // Nop reshape; we can always infer an encoding. @@ -2734,7 +2731,9 @@ struct TritonGPUInferLayoutInterface // to handle CTASplitNum. if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { - return failure(); + return emitOptionalError( + loc, "Non-reordering reshape does not currently support multi-CTA " + "layouts other than the default layout."); } // Cowardly refuse to handle encodings where shape[dim] is not divisible by @@ -2744,7 +2743,12 @@ struct TritonGPUInferLayoutInterface for (int dim = 0; dim < srcShape.size(); dim++) { if (srcShape[dim] >= subblock[dim] && srcShape[dim] % subblock[dim] != 0) { - return failure(); + return emitOptionalError(loc, + "Can't do a non-reordering reshape because " + "the size of dimension ", + dim, " (", srcShape[dim], ")", + " is not divisible by ", name, "[", dim, "]", + " = ", subblock[dim]); } } return success(); @@ -2769,7 +2773,11 @@ struct TritonGPUInferLayoutInterface // physical order, with `a` being the most major. for (const auto &[srcDims, dstDims] : decomp) { if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { - return failure(); + return emitOptionalError(loc, + "Cannot do a non-reordering reshape given " + "this src encoding order. Dimensions [", + join(srcDims), + "] must be physically consecutive."); } } @@ -2816,7 +2824,11 @@ struct TritonGPUInferLayoutInterface // Check that more-minor dims all have 1 in shapeRemaining. for (int j = i + 1; j < srcDims.size(); j++) { if (shapeRemaining[j] != 1) { - return failure(); + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Must use " + "up sizePerThread / threadsPerWarp / warpsPerCTA for " + "more-minor dimensions before more major-dims can use them."); } } @@ -2831,7 +2843,13 @@ struct TritonGPUInferLayoutInterface // only if we're the most-major dimension of the chunk and in all // future chunks, only this most-major dim has a non-1 size. if (shapeRemaining[i] == 0 && i != 0) { - return failure(); + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Block " + "size in dimension ", + dim, + " is larger than the shape that dimension, but this is only " + "allowed for the most-major dimension of a reshape chunk"); } } return success(); @@ -2921,65 +2939,6 @@ struct TritonGPUInferLayoutInterface return success(); } - LogicalResult verifyLayoutsAreEqual(ArrayRef shape, - Attribute expected, Attribute got, - Location loc) const override { - if (expected == got) { - return success(); - } - // Check whether the encodings are structurally the same. - auto expectedLL = triton::gpu::toLinearLayout(shape, expected); - auto gotLL = triton::gpu::toLinearLayout(shape, got); - if (expectedLL != gotLL) { - return emitError(loc, "Expected result encoding ") - << expected << " but was " << got; - } - return success(); - } - - LogicalResult - inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { - auto result = - inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); - if (succeeded(result)) { - return result; - } - - // If the legacy encoding failed use LinearLayouts. - // Once LinearLayouts are more widely used, we can remove - // inferReshapeOpLegacyEncoding and simply use LLs. - auto *ctx = getContext(); - auto src = triton::gpu::toLinearLayout(srcShape, srcEnc); - if (!src) { - return emitOptionalError(loc, - "src encoding does not support linear layout"); - } - - if (product(srcShape) != product(dstShape)) { - return emitOptionalError(loc, "numel of dst shape does not match " - "numel of src shape"); - } - - auto newRank = dstShape.size(); - SmallVector> newOutDims; - for (auto [dim, size] : - llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { - newOutDims.emplace_back(dim, size); - } - auto srcOutDims = llvm::to_vector(src->getOutDimNames()); - // reshapeOp assumes minor-to-major, so we need to transpose the out dims - // before the reshape - std::reverse(srcOutDims.begin(), srcOutDims.end()); - std::reverse(newOutDims.begin(), newOutDims.end()); - auto dst = src->transposeOuts(srcOutDims) - .reshapeOuts(newOutDims) - .transposeOuts(standardOutDimNames(ctx, newRank)); - dstEnc = LinearEncodingAttr::get(ctx, dst); - return success(); - } - LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, std::optional loc) const override { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index f2088f3a84..39c5f31d98 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -42,16 +42,6 @@ struct CanonicalizeConvertFromReshape auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - // If the layouts are structurally the same, the convert is trivial - auto srcType = convert.getSrc().getType(); - auto dstType = convert.getType(); - auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding()); - auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding()); - if (srcLL && dstLL && *srcLL == *dstLL) { - rewriter.replaceOpWithNewOp( - op, op.getType(), convert.getSrc(), op.getAllowReorder()); - return mlir::success(); - } if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); if (!op.getAllowReorder() || op.getEfficientLayout()) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 1f93d894b5..2b9e12c3ac 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1025,7 +1025,9 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf + if (isa(targetType.getEncoding())) return; Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " @@ -1067,8 +1069,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + if (mlir::isa( + targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 27cb71638f..46dfce695c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -407,13 +407,14 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, return {}; Attribute dstEnc; - auto result = - srcEnc.getDialect() - .getRegisteredInterface() - ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, - /*loc=*/std::nullopt); - assert(succeeded(result)); - return dstEnc; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpNoReorderEncoding( + srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { + return dstEnc; + } + return {}; } static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { diff --git a/test/Conversion/reduce_to_llvm.mlir b/test/Conversion/reduce_to_llvm.mlir deleted file mode 100644 index 0bbcecbd93..0000000000 --- a/test/Conversion/reduce_to_llvm.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s - -#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: @reduce_linear_layout -tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> { - // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 - // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 - // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 - // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 - - // The layout looks lke - // [[ T0:0, T32:0, T0:1, T32:1, ... - // [ T4:0, T36:0, T4:1, T36:1, ... - // [ T0:2, T32:2, T0:3, T32:3, ... - // [ T4:2, T36:2, T4:3, T36:3, - // ... - // - // A reduction along axis=0 consists of adding registers (0, 2) and (1, 3) - // before shuffling. - // - // Columns along axis=0 are contained within a warp, so reduction arcoss warps - // is not needed. - - // Reduce within threads - // CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]] - // CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]] - - // Reduce within warp. - // CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31) - // CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]] - // CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31) - // CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]] - // CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31) - // CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]] - // CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31) - // CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]] - - // CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31) - // CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]] - // CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31) - // CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]] - // CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31) - // CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]] - // CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31) - // CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]] - - // CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0 - // CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1 - - %0 = "tt.reduce"(%arg0) ({ - ^bb0(%arg1: i32, %arg2: i32): - %1 = arith.addi %arg1, %arg2 : i32 - tt.reduce.return %1 : i32 - }) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> - - // CHECK-NEXT: ret { i32, i32 } [[DST1]] - tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> -} - -tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) { - %0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> - %1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)> - llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr - tt.return -} - -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index db8b664e76..a16fdd81a5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2123,28 +2123,3 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}> -#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: expand_dims_linear_layout -tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear> - // CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> - tt.return %1 : tensor<1x4xi32, #linear> -} - -// CHECK-LABEL: reshape_linear_layout_broadcasting -tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> { - // CHECK-COUNT-16: extractvalue - // CHECK-COUNT-16: insertvalue - %0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked> - tt.return %0 : tensor<32x4x1xbf16, #blocked> -} - -} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ab124f82ab..cd45d1ee05 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2829,26 +2829,3 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) { } } - -// ----- - -#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}> -#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: reduce_linear_layouts -tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> { - // CHECK-NOT: convert_layout - %0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked> - // CHECK-NEXT: tt.reduce - %1 = "tt.reduce" (%0) ({ - ^bb0(%arg1: i32, %arg2: i32): - tt.reduce.return %arg1 : i32 - // CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}> - }) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> - tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> -} - -} diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index a3cc65605f..3fab64c0e8 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -77,6 +77,135 @@ int64_t getFlatIdx(ArrayRef idx, ArrayRef shape, return flatIdx; } +// Represents the many indices of one element of a tensor with a +// BlockedEncoding. +// +// The purpose of this class is we can say, if two MultiIdx's have the same +// flatFoo values before and after a reshape, then the same GPU thread contains +// the same element (and the reshape is a nop, at least for that element). +struct MultiIdx { + using Vec = SmallVector; + + // Logical index into the tensor. + Vec idx; + + // If the tensor's encoding has e.g. numPerThread = [2,2], then idxInThread + // tells us which of the four elements per thread this is. Same for idxInWarp + // and idxInCTA. + Vec idxInThread; + Vec idxInWarp; + Vec idxInCTA; + + // If the tensor's encoding defines a block of size [x,y,z], the tensor itself + // may be larger than this, comprising multiple blocks. This tells us which + // block we're in. + Vec idxOuter; + + // flatIdx is flattened according to the tensor's logical order (i.e. ignoring + // the encoding). The others are flattened according to the tensor's physical + // encoding. + int64_t flatIdx; + int64_t flatIdxInThread; + int64_t flatIdxInWarp; + int64_t flatIdxInCTA; + int64_t flatIdxOuter; +}; + +bool sameFlatIdxs(const MultiIdx &a, const MultiIdx &b) { + return a.flatIdx == b.flatIdx && // + a.flatIdxInThread == b.flatIdxInThread && + a.flatIdxInWarp == b.flatIdxInWarp && + a.flatIdxInCTA == b.flatIdxInCTA && // + a.flatIdxOuter == b.flatIdxOuter; +} + +std::string multiIdxsToString(ArrayRef> idxs) { + std::stringstream ss; + for (const auto &idxPtr : idxs) { + const MultiIdx &idx = *idxPtr; + ss // + << " [" << triton::join(idx.idx, ",") << "] (" << idx.flatIdx << ") " + << "elem=[" << triton::join(idx.idxInThread, ",") << "] (" + << idx.flatIdxInThread << ") " + << "thread=[" << triton::join(idx.idxInWarp, ",") << "] (" + << idx.flatIdxInWarp << ") " + << "warp=[" << triton::join(idx.idxInCTA, ",") << "] (" + << idx.flatIdxInCTA << ") " + << "outer=[" << triton::join(idx.idxOuter, ",") << "] (" + << idx.flatIdxOuter << ")\n"; + } + return ss.str(); +} + +std::vector> getMultiIdxs(ArrayRef shape, + BlockedEncodingAttr enc) { + using Vec = MultiIdx::Vec; + + const unsigned rank = shape.size(); + auto sizePerThread = enc.getSizePerThread(); + auto threadsPerWarp = enc.getThreadsPerWarp(); + auto warpsPerCTA = enc.getWarpsPerCTA(); + auto order = enc.getOrder(); + + Vec numBlocks; + for (int i = 0; i < rank; i++) { + numBlocks.push_back(ceil( + shape[i], sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i])); + } + + Vec idxInThread(rank, 0); + Vec idxInWarp(rank, 0); + Vec idxInCTA(rank, 0); + Vec idxOuter(rank, 0); + + int64_t nElems = product(sizePerThread) * product(threadsPerWarp) * + product(warpsPerCTA) * product(numBlocks); + + // We eventually sort this array, and if the elements are plain MultiIdx + // elements rather than pointers, we have to swap them, which ends up being + // expensive. + std::vector> elems; + elems.reserve(nElems); + + for (int64_t i = 0; i < nElems; i++) { + auto e = std::make_unique(); + e->idxInThread = idxInThread; + e->idxInWarp = idxInWarp; + e->idxInCTA = idxInCTA; + e->idxOuter = idxOuter; + + for (int i = 0; i < rank; i++) { + e->idx.push_back( // + idxInThread[i] + // + idxInWarp[i] * sizePerThread[i] + + idxInCTA[i] * sizePerThread[i] * threadsPerWarp[i] + + idxOuter[i] * sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]); + } + + e->flatIdxInThread = getFlatIdx(e->idxInThread, sizePerThread, order); + e->flatIdxInWarp = getFlatIdx(e->idxInWarp, threadsPerWarp, order); + e->flatIdxInCTA = getFlatIdx(e->idxInCTA, warpsPerCTA, order); + e->flatIdxOuter = getFlatIdx(e->idxOuter, numBlocks, order); + e->flatIdx = getFlatIdx(e->idx, shape, + llvm::to_vector(llvm::reverse(llvm::seq(rank)))); + + elems.push_back(std::move(e)); + + if (advance(idxInThread, sizePerThread, order)) { + if (advance(idxInWarp, threadsPerWarp, order)) { + if (advance(idxInCTA, warpsPerCTA, order)) { + advance(idxOuter, numBlocks, order); + } + } + } + } + llvm::sort(elems, [](const std::unique_ptr &a, + const std::unique_ptr &b) { + return a->flatIdx < b->flatIdx; + }); + return elems; +} + class InferLayoutTest : public ::testing::Test { public: InferLayoutTest() @@ -92,12 +221,25 @@ class InferLayoutTest : public ::testing::Test { /*static*/ MLIRContext InferLayoutTest::ctx; +// The optional outparam couldReshape tells the caller whether the reshape +// worked. You might want this to be a return value instead, but gtest ASSERT +// and FAIL have an implicit `return`, so only work in fns that return void. void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, std::optional expectedDstEnc, + std::optional expectSuccess, DialectInferLayoutInterface *inferLayout, - bool longErrors = true) { + bool longErrors = true, bool *couldReshape = nullptr) { + std::unique_ptr couldReshapeStorage; + if (!couldReshape) { + couldReshapeStorage = std::make_unique(); + couldReshape = couldReshapeStorage.get(); + } + *couldReshape = false; MLIRContext *ctx = srcTy.getContext(); + ASSERT_TRUE(expectSuccess || !dstTy.getEncoding()) + << "dstTy shouldn't have an expected encoding if we're expecting the " + "reshape to be impossible!"; // Capture any errors from calling inferReshapeNoOpReorderEncoding, so we can // print them if we expected the reshape to succeed but it failed. @@ -107,17 +249,29 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, { ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); - result = inferLayout->inferReshapeOpEncoding( + result = inferLayout->inferReshapeOpNoReorderEncoding( srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, UnknownLoc::get(ctx)); } - // We expect the reshape to succeed as long as the inputs have the same - // number of elements - EXPECT_TRUE(succeeded(result)) - << "Expected reshape to succeed, but it didn't! Error(s):\n" - << join(diags, "\n"); + if (!expectSuccess.has_value() && !succeeded(result)) { + // We didn't know whether or not it was supposed to succeed, and it didn't. + // Test passes! + return; + } + + if (expectSuccess.has_value() && !*expectSuccess) { + EXPECT_FALSE(succeeded(result)) + << "Expected reshape to be impossible, but got dst encoding: " + << stringifyLLVMType(inferredEnc); + *couldReshape = true; + return; + } + if (!succeeded(result)) { + FAIL() << "Expected reshape to succeed, but it didn't! Error(s):\n" + << join(diags, "\n"); + } if (auto expectedEnc = dstTy.getEncoding()) { EXPECT_EQ(inferredEnc, expectedEnc); } @@ -125,14 +279,12 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, // We know that infer(srcShape, srcEnc, dstShape) => dstEnc. Check that it // works the other way around too: infer(dstShape, dstEnc, srcShape) => // srcEnc. (This is an invariant of the inference function.) - // Even more, we check that the inferred encoding is structurally the same as - // the src encoding, showing that the inference is consistent. { std::vector diags; ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); Attribute inferredSrcEnc; - auto result = inferLayout->inferReshapeOpEncoding( + auto result = inferLayout->inferReshapeOpNoReorderEncoding( dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, UnknownLoc::get(ctx)); EXPECT_TRUE(succeeded(result)) @@ -140,40 +292,56 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, << " " << stringifyLLVMType(inferredEnc) << " -> " << triton::join(srcTy.getShape(), "x") << "failed:\n" << join(diags, "\n"); - auto srcLinear = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - auto inferredSrcLinear = toLinearLayout(srcTy.getShape(), inferredSrcEnc); - EXPECT_EQ(inferredSrcLinear, srcLinear) - << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") - << " " << stringifyLLVMType(inferredEnc) << " -> " - << triton::join(srcTy.getShape(), "x") - << " gave the wrong result. Expected " << srcLinear->toString() - << " but " - << "got " << inferredSrcLinear->toString() << ".\n"; + if (succeeded(result)) { + EXPECT_EQ(inferredSrcEnc, srcTy.getEncoding()) + << "Inverse encoding inference (" + << triton::join(dstTy.getShape(), "x") << " " + << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") + << " gave the wrong result. Expected " + << stringifyLLVMType(srcTy.getEncoding()) << " but got " + << stringifyLLVMType(inferredSrcEnc) << ".\n"; + } } - // The funtional characterisation of resize is that, if we have a srcLayout - // and a dstLayout, then the flattened layouts are views of the same data - // when considered as C-contiguous. - auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { - auto ctx = layout.getContext(); - auto linear = *toLinearLayout(shape, layout); - auto dims = standardOutDimNames(ctx, shape.size()); - std::reverse(dims.begin(), dims.end()); - return linear.transposeOuts(dims).reshapeOuts( - {{dims.back(), linear.getTotalOutDimSize()}}); - }; - EXPECT_EQ(makeFlattenedCContig(srcTy.getShape(), srcTy.getEncoding()), - makeFlattenedCContig(dstTy.getShape(), inferredEnc)); + std::vector> srcMultiIdxs = + getMultiIdxs(SmallVector(srcTy.getShape()), + mlir::cast(srcTy.getEncoding())); + + std::vector> dstMultiIdxs = + getMultiIdxs(SmallVector(dstTy.getShape()), + mlir::cast(inferredEnc)); + + if (srcMultiIdxs.size() != dstMultiIdxs.size() || + !llvm::all_of(llvm::zip_equal(srcMultiIdxs, dstMultiIdxs), + [](const auto &pair) { + const auto &[a, b] = pair; + return sameFlatIdxs(*a, *b); + })) { + SCOPED_TRACE(longErrors ? "dst indices:\n" + multiIdxsToString(dstMultiIdxs) + : ""); + SCOPED_TRACE(longErrors ? "src indices:\n" + multiIdxsToString(srcMultiIdxs) + : ""); + ADD_FAILURE() << "Reified indices do not match for encodings:\n" + << " src: [" << triton::join(srcTy.getShape(), "x") << "] " + << stringifyLLVMType(srcTy.getEncoding()) << "\n" + << " dst: [" << triton::join(dstTy.getShape(), "x") << "] " + << stringifyLLVMType(inferredEnc); + } else { + *couldReshape = true; + } } -class InferReshapeOpEncodingTest +class InferReshapeOpNoReorderEncodingTest : public InferLayoutTest, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> {}; -TEST_P(InferReshapeOpEncodingTest, DoIt) { +TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { std::string srcTyStr = expandTyStr(std::get<0>(GetParam())); std::string dstTyStr = expandTyStr(std::get<1>(GetParam())); + bool expectSuccess = std::get<2>(GetParam()); auto src = mlir::parseType(srcTyStr, &ctx); if (!src) @@ -189,7 +357,7 @@ TEST_P(InferReshapeOpEncodingTest, DoIt) { } testReshape(cast(src), cast(dst), - expectedDstEnc, inferLayout, /*longErrors=*/true); + expectedDstEnc, expectSuccess, inferLayout, /*longErrors=*/true); } // A testcase of {a, b, c} means: @@ -200,72 +368,158 @@ TEST_P(InferReshapeOpEncodingTest, DoIt) { // encoding that makes the reshape a nop, and // - if b has an encoding, check that the inferred encoding matches b's. INSTANTIATE_TEST_SUITE_P( - Reshapes, InferReshapeOpEncodingTest, - ::testing::ValuesIn(std::vector>({ + Reshapes, InferReshapeOpNoReorderEncodingTest, + ::testing::ValuesIn(std::vector< + std::tuple>({ // Use raw strings in here so clang-format doesn't try to wrap them. {R"(T<128x64xf32, #B<{spt=[1,1], tpw=[1,32], wpc=[1,1], ord=[1,0]}>>)", - R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)"}, + R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)", + true}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)"}, + R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", + true}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)"}, + R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[2,2], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - "T<1024xf32>"}, + "T<128xf32>", false}, {R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<4x32xf32, #B<{spt=[4,1], tpw=[1,32], wpc=[1,1], ord=[0,1]}>>)", - R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)"}, + R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[0,1]}>>)", - R"(T<16x2x16x2xf32>)"}, + R"(T<16x2x16x2xf32>)", true}, // nop reshape, but the block size is 2x larger than the tensor. {R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)"}, + R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", + true}, {R"(T<2x4x2x4xf32, #B<{spt=[1,2,2,1], tpw=[1,2,1,2], wpc=[1,2,2,1], ord=[2,1,0,3]}>>)", - R"(T<4x2x2x4xf32>)"}, + R"(T<4x2x2x4xf32>)", false}, {R"(T<1x2x2x4xf32, #B<{spt=[1,32,4,4], tpw=[4,4,16,16], wpc=[8,8,8,1], ord=[0,1,2,3]}>>)", - R"(T<2x2x4x1xf32>)"}, + R"(T<2x2x4x1xf32>)", false}, {R"(T<2x2x2x2xf32, #B<{spt=[2,2,2,2], tpw=[1,1,1,1], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - R"(T<4x4xf32>)"}, + R"(T<4x4xf32>)", true}, {R"(T<16x8xf32, #B<{spt=[1,2], tpw=[2,4], wpc=[2,1], ord=[1,0]}>>)", - R"(T<128xf32>)"}, + R"(T<128xf32>)", true}, {R"(T<16x1x8xf32, #B<{spt=[8,1,1], tpw=[2,1,1], wpc=[1,1,8], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)"}, + R"(T<128x1xf32>)", false}, {R"(T<16x1x8xf32, #B<{spt=[1,1,8], tpw=[2,1,1], wpc=[8,1,1], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)"}, + R"(T<128x1xf32>)", true}, {R"(T<32x32xf32, #B<{spt=[1,2], tpw=[1,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<1024xf32>)"}, + R"(T<1024xf32>)", true}, {R"(T<4x4xf32, #B<{spt=[1,1], tpw=[2,4], wpc=[2,1], ord=[0,1]}>>)", - R"(T<16xf32>)"}, + R"(T<16xf32>)", false}, {R"(T<32xf32, #B<{spt=[2], tpw=[32], wpc=[2], ord=[0]}>>)", - R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)"}, + R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)", + true}, {R"(T<2x1x2xf32, #B<{spt=[2,1,1], tpw=[2,1,2], wpc=[4,1,8], ord=[2,1,0]}>>)", - R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"}, + R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)", + true}, }))); +TEST_F(InferLayoutTest, FuzzReshape) { + const int numTests = 1000; // Increase to get more coverage. + + std::minstd_rand rng(/*seed=*/0); + auto randPow2Vec = [&](int rank, int maxPow2) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + int pow2 = std::uniform_int_distribution(0, maxPow2)(rng); + if (pow2 == maxPow2 && maxPow2 > 0) { + maxPow2--; + } + ret.push_back(1 << pow2); + } + return ret; + }; + + int numSuccess = 0; + for (int i = 0; i < numTests; i++) { + SCOPED_TRACE("iteration " + std::to_string(i)); + int rank = std::uniform_int_distribution(1, 4)(rng); + + SmallVector srcShape( + convertType(randPow2Vec(rank, /*maxPow2=*/4))); + SmallVector dstShape = srcShape; + std::shuffle(dstShape.begin(), dstShape.end(), rng); + + // Optionally merge some dimensions in dst. + for (int i = 1; i < dstShape.size(); i++) { + if (std::uniform_real_distribution(0, 1)(rng) > 1.0 / rank) { + dstShape[i - 1] *= dstShape[i]; + dstShape.erase(dstShape.begin() + i); + i--; + } + } + + SmallVector sizePerThread = randPow2Vec(rank, /*maxPow2=*/3); + SmallVector threadsPerWarp = randPow2Vec(rank, /*maxPow2=*/3); + SmallVector warpsPerCTA = randPow2Vec(rank, /*maxPow2=*/3); + + SmallVector order(llvm::to_vector(llvm::seq(rank))); + std::shuffle(order.begin(), order.end(), rng); + + auto ctaLayout = CTALayoutAttr::get( + &ctx, SmallVector(rank, 1), SmallVector(rank, 1), + llvm::to_vector(llvm::reverse(llvm::seq(rank)))); + + auto srcTy = RankedTensorType::get( + srcShape, FloatType::getF32(&ctx), + BlockedEncodingAttr::get(&ctx, sizePerThread, threadsPerWarp, + warpsPerCTA, order, ctaLayout)); + auto dstTy = RankedTensorType::get(dstShape, FloatType::getF32(&ctx)); + + bool couldReshape = false; + testReshape(srcTy, dstTy, /*expectedDstEnc=*/std::nullopt, + /*expectSuccess=*/std::nullopt, inferLayout, + /*longErrors=*/false, &couldReshape); + if (couldReshape) + numSuccess++; + } + + // We don't expect or want 100% success, but if only a tiny fraction of tests + // actually exercise the successful reshape logic, then that gives us bad + // coverage. I'm currently getting 35% success, which seems good enough, + // especially since the successful cases take a lot longer to run because of + // the MultiIdx checks (so we're spending most of our time on successful + // cases, even if they're only 1/3 of the iterations). + // + // Run ctest with --verbose to see this output. For example: + // $ cd python/build/cmake.blah.blah + // $ ninja + // $ $(git rev-parse --show-toplevel)/.venv/bin/ctest --verbose + printf("Fuzz success rate: %d/%d = %.2f%%\n", numSuccess, numTests, + 100.0 * numSuccess / numTests); +} + class AMDLayoutTest : public ::testing::Test { public: AMDLayoutTest() { From 4523d38fe75dd67fc39b6c370ae6b2be6c97e387 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 13 Jan 2025 15:56:47 +0100 Subject: [PATCH 10/15] Don't use designated initializers in `MatmulLoopPipeline.cpp` as it relates to c++20 (#5585) Signed-off-by: Anatoly Myachev --- .../TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp index 722e27d8bf..eeeef753be 100644 --- a/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Pipeliner/MatmulLoopPipeline.cpp @@ -879,7 +879,8 @@ createAsyncOps(scf::ForOp &forOp, llvm::MapVector stageGroups; for (auto &[loadOp, info] : loadToInfo) { - AsyncLoad asyncLoad = {.loadOp = loadOp}; + AsyncLoad asyncLoad; + asyncLoad.loadOp = loadOp; bool isTMALoad = false; int numBuffers = info.distToUse; // For MMAv3, we need an extra buffer as this is assumed in the wgmma From 194a21f2b7a20c18635da16fd0b71d9584719062 Mon Sep 17 00:00:00 2001 From: Anatoly Myachev Date: Mon, 13 Jan 2025 16:07:29 +0100 Subject: [PATCH 11/15] Enable ruff-pre-commit for `third_party/amd` (#5589) Signed-off-by: Anatoly Myachev --- .pre-commit-config.yaml | 2 +- third_party/amd/backend/driver.py | 2 +- third_party/amd/python/test/test_extract_slice.py | 4 ---- 3 files changed, 2 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index be8ab74f46..690f1c282a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: rev: v0.7.1 hooks: - id: ruff - files: '(^python|^third_party/proton)/.*' + files: '(^python|^third_party/proton|^third_party/amd)/.*' args: ["--fix", "--exit-non-zero-on-fix"] exclude: | (?x)( diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index fd02d53270..ca712f9040 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -37,7 +37,7 @@ class DlPhdrInfo(ctypes.Structure): # Load libc and get the dl_iterate_phdr symbol. try: dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr - except: + except Exception: return None # argtypes must use c_char_p to accept create_string_buffer. dl_iterate_phdr.argtypes = [callback_t, c_char_p] diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py index a9c7df4754..5d24080861 100644 --- a/third_party/amd/python/test/test_extract_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -1,11 +1,7 @@ -import tempfile - -import numpy as np import pytest import torch import triton -import triton.language as tl from triton._internal_testing import is_hip From e8ef0bbeb1fa9461b7fc742b08a8588659e3a18f Mon Sep 17 00:00:00 2001 From: Lixun Zhang Date: Mon, 13 Jan 2025 10:31:27 -0600 Subject: [PATCH 12/15] [AMD] Disable swap operands for fp8 matmul (#5577) We found regressions for moe kernel with fp8 inputs. This PR effectively reverts part of #4767 and disables the swap-operand feature for fp8 inputs matmul kernels for now while we investigate the regression. --- .../AccelerateAMDMatmul.cpp | 27 ++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 5efa81c912..678ac0c54b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -337,6 +337,23 @@ class BlockedToMFMA : public OpRewritePattern { : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), nonKDim(nonKDim), kPack(kPack) {} + bool isChainDot(tt::DotOp &dotOp) const { + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + ForwardSliceOptions fwdOpt; + fwdOpt.filter = filter; + BackwardSliceOptions bwdOpt; + bwdOpt.omitBlockArguments = true; + bwdOpt.filter = filter; + auto slices = getSlice(dotOp, bwdOpt, fwdOpt); + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) + return true; + } + return false; + } + bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -391,12 +408,16 @@ class BlockedToMFMA : public OpRewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - // Always use transposed mfma layout. This enables larger vectorization - // for global store instructions + // Use transposed mfma layout to enable larger vectorization for global + // store instructions, except for fp8 matmul kernels due to regression + // TODO (lixun): investigate the regression and enable this feature again + auto aElemTy = mfmaInstr.getElementTypeA(); + bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); + /*instrShape*/ mDim, nDim, isTransposed, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) From 07023209bfc88c06a9f06b655da6d25e6208f9fa Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Mon, 13 Jan 2025 18:04:04 +0100 Subject: [PATCH 13/15] [INTERPRETER] Fix typo in attribute name (#5593) --- python/triton/runtime/interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index ee32b62fdd..bcf4cd8c0d 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -993,7 +993,7 @@ def _set_attr(input, values, name): lang.static_assert = _new_static_assert lang.static_print = print lang.dtype.to_ir = _new_to_ir - lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.multiple_of = partial(_set_attr, name="tt.divisibility") lang.max_contiguous = partial(_set_attr, name="tt.contiguity") lang.max_constancy = partial(_set_attr, name="tt.constancy") From f501e9767ef6e4bd3c18fc6f050adf1a2f893549 Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Mon, 13 Jan 2025 18:26:05 +0000 Subject: [PATCH 14/15] [Intel] Plumb fast_math attribute from scaled_dot frontend to LLVM codegen Signed-off-by: Whitney Tsang --- .../TritonIntelGPU/accelerate-matmul-pvc.mlir | 12 +++++----- .../AccelerateMatmul.cpp | 24 ++++++++++--------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir index aa40fbc1d6..a0ed54dc79 100644 --- a/test/TritonIntelGPU/accelerate-matmul-pvc.mlir +++ b/test/TritonIntelGPU/accelerate-matmul-pvc.mlir @@ -222,14 +222,14 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[ARG0]] : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[ARG1]] : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED3]]> - // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = false} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> // CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[B:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xbf16, [[BLOCKED2]]> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<128x128xf32, [[DPAS]]> // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %dot_res1 = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> + %dot_res1 = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = bf16 {fastMath = false} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xbf16, #blocked> -> tensor<128x128xf32, #blocked> tt.return %dot_res1 : tensor<128x128xf32, #blocked> } @@ -239,7 +239,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp // CHECK: [[C:%.*]] = ttg.convert_layout [[CST]] : tensor<128x128xf32, [[BLOCKED2]]> -> tensor<128x128xf32, [[DPAS]]> // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout %arg0 : tensor<128x32xi8, [[BLOCKED]]> -> tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout %arg1 : tensor<128x2xi8, [[BLOCKED1]]> -> tensor<128x2xi8, [[BLOCKED3]]> - // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG0]], [[CVT_ARG1]] fp_type = e2m1 {fastMath = true} : tensor<128x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<128x2xi8, [[BLOCKED3]]> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> // CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<128x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<64x128xf8E4M3FN, [[BLOCKED2]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG2]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> @@ -247,7 +247,7 @@ module attributes {"ttg.target" = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warp // CHECK: [[RES:%.*]] = ttg.convert_layout {{.*}} : tensor<128x128xf32, [[DPAS]]> -> tensor<128x128xf32, [[BLOCKED2]]> %cst = arith.constant dense<0.000000e+00> : tensor<128x128xf32, #blocked> - %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = e4m3 : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> + %result = tt.dot_scaled %a scale %scale, %b, %cst lhs = e2m1 rhs = e4m3 {fastMath = true} : tensor<128x32xi8, #blocked2>, tensor<128x2xi8, #blocked1> * tensor<64x128xf8E4M3FN, #blocked> -> tensor<128x128xf32, #blocked> tt.return %result : tensor<128x128xf32, #blocked> } } @@ -285,14 +285,14 @@ module attributes {ttg.target = "xpu", "ttg.num-ctas" = 1 : i32, "ttg.num-warps" // CHECK: [[C:%.*]] = ttg.convert_layout [[ARG5]] : tensor<32x128xf32, [[BLOCKED4]]> -> tensor<32x128xf32, [[DPAS]]> // CHECK: [[CVT_ARG1:%.*]] = ttg.convert_layout [[TRANS_B]] : tensor<32x32xi8, [[BLOCKED4]]> -> tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[CVT_ARG2:%.*]] = ttg.convert_layout [[ARG2]] : tensor<32x2xi8, [[BLOCKED2]]> -> tensor<32x2xi8, [[BLOCKED6]]> - // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<32x2xi8, [[BLOCKED6]]> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> + // CHECK: [[UPCAST:%.*]] = ttg.upcast_mxfp [[CVT_ARG1]], [[CVT_ARG2]] fp_type = e2m1 {fastMath = false} : tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>>, tensor<32x2xi8, [[BLOCKED6]]> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> // CHECK: [[A:%.*]] = ttg.convert_layout [[UPCAST]] : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS1]], kWidth = 4}>> -> tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[CVT_ARG0:%.*]] = ttg.convert_layout [[TRANS_A]] : tensor<64x128xf8E4M3FN, [[BLOCKED5]]> -> tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[B:%.*]] = tt.fp_to_fp [[CVT_ARG0]] : tensor<64x128xf8E4M3FN, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> // CHECK: [[D:%.*]] = tt.dot [[A]], [[B]], [[C]] : tensor<32x64xbf16, #ttg.dot_op<{opIdx = 0, parent = [[DPAS]], kWidth = 2}>> * tensor<64x128xbf16, #ttg.dot_op<{opIdx = 1, parent = [[DPAS]], kWidth = 2}>> -> tensor<32x128xf32, [[DPAS]]> // CHECK: [[RES:%.*]] = ttg.convert_layout [[D]] : tensor<32x128xf32, [[DPAS]]> -> tensor<32x128xf32, [[BLOCKED4]]> // CHECK: scf.yield [[RES]] : tensor<32x128xf32, [[BLOCKED4]]> - %3 = tt.dot_scaled %a, %b scale %scale, %arg5 lhs = e4m3 rhs = e2m1 : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1> + %3 = tt.dot_scaled %a, %b scale %scale, %arg5 lhs = e4m3 rhs = e2m1 {fastMath = false} : tensor<128x64xf8E4M3FN, #blocked> * tensor<32x32xi8, #blocked1>, tensor<32x2xi8, #blocked2> -> tensor<128x32xf32, #blocked1> scf.yield %3 : tensor<128x32xf32, #blocked1> } // CHECK: [[TRUNC:%.*]] = arith.truncf [[DOT_RES]] : tensor<32x128xf32, [[BLOCKED4]]> to tensor<32x128xbf16, [[BLOCKED4]]> diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index fd2fab7c07..6e27980603 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -237,9 +237,10 @@ class DecomposeScaledBlocked : public OpRewritePattern { TensorValue newAcc = convertAccumulator(scaledDotOp, dpasEnc, rewriter); RankedTensorType newRetType = newAcc.getType(); - std::tie(a, b) = convertOperands( - {a, aElemType, aScale}, {b, bElemType, bScale}, dpasEnc, newRetType, - scaledDotOp->getParentOfType(), rewriter); + std::tie(a, b) = + convertOperands({a, aElemType, aScale}, {b, bElemType, bScale}, + scaledDotOp.getFastMath(), dpasEnc, newRetType, + scaledDotOp->getParentOfType(), rewriter); auto newDot = rewriter.create(scaledDotOp.getLoc(), newRetType, a, b, newAcc); @@ -256,7 +257,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { }; std::pair - convertOperands(OpDescriptor aDesc, OpDescriptor bDesc, + convertOperands(OpDescriptor aDesc, OpDescriptor bDesc, bool fastMath, ttgi::DpasEncodingAttr dpasEnc, RankedTensorType newRetType, ModuleOp mod, PatternRewriter &rewriter) const { assert((aDesc.scale || bDesc.scale) && "No scale provided"); @@ -265,7 +266,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { if (aDesc.scale) { TensorValue newA = convertScaledOperand( - aDesc, dpasEnc, newRetType, mod, rewriter); + aDesc, fastMath, dpasEnc, newRetType, mod, rewriter); TensorValue newB = convertUnscaledOperand( bDesc, dpasEnc, newRetType, rewriter); @@ -274,7 +275,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { TensorValue newB = convertScaledOperand( - bDesc, dpasEnc, newRetType, mod, rewriter); + bDesc, fastMath, dpasEnc, newRetType, mod, rewriter); TensorValue newA = convertUnscaledOperand( aDesc, dpasEnc, newRetType, rewriter); @@ -282,7 +283,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { } template - TensorValue convertScaledOperand(OpDescriptor opDesc, + TensorValue convertScaledOperand(OpDescriptor opDesc, bool fastMath, ttg::intel::DpasEncodingAttr dpasEnc, RankedTensorType retType, ModuleOp mod, PatternRewriter &rewriter) const { @@ -318,7 +319,8 @@ class DecomposeScaledBlocked : public OpRewritePattern { newOpEncoding.getCTAOrder(), CTALayout); TensorValue scale = createScale(opDesc.scale, newScaleEncoding, rewriter); - auto upcastOp = createUpcastMxfpOp(op, scale, opDesc.elemType, rewriter); + auto upcastOp = + createUpcastMxfpOp(op, scale, opDesc.elemType, fastMath, rewriter); if (opDesc.elemType == tt::ScaleDotElemType::E2M1) { auto resultType = cast(upcastOp.getType()); auto newRetType = RankedTensorType::get( @@ -416,7 +418,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { } TensorValue createUpcastMxfpOp(TensorValue v, TensorValue scale, - tt::ScaleDotElemType elemType, + tt::ScaleDotElemType elemType, bool fastMath, PatternRewriter &rewriter) const { if (!scale) return v; @@ -424,7 +426,7 @@ class DecomposeScaledBlocked : public OpRewritePattern { auto retTy = triton::gpu::UpcastMXFPOp::deduceOutputType( v, elemType, Builder(v.getContext()).getBF16Type()); return rewriter.create(v.getLoc(), retTy, v, scale, - elemType); + elemType, fastMath); } }; @@ -605,7 +607,7 @@ static tt::TransOp transposeDotOp(tt::DotScaledOp dotOp) { auto result = builder.create( dotOp.getLoc(), cTransposed.getType(), rhsTransposed, lhsTransposed, cTransposed, dotOp.getRhsScale(), dotOp.getLhsScale(), dotOp.getRhsType(), - dotOp.getLhsType()); + dotOp.getLhsType(), dotOp.getFastMath()); auto transOp = builder.create(result.getLoc(), result, transOrder); dotOp.replaceAllUsesWith(transOp.getOperation()); From 0dfb74409a01ac631a9138fefad56952346e2ece Mon Sep 17 00:00:00 2001 From: Whitney Tsang Date: Mon, 13 Jan 2025 18:29:23 +0000 Subject: [PATCH 15/15] Fix `TypeError: XPUBackend.get_codegen_implementation() takes 1 positional argument but 2 were given` Signed-off-by: Whitney Tsang --- third_party/intel/backend/compiler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/intel/backend/compiler.py b/third_party/intel/backend/compiler.py index 361f57dc00..761c473c7a 100644 --- a/third_party/intel/backend/compiler.py +++ b/third_party/intel/backend/compiler.py @@ -189,7 +189,7 @@ def parse_options(self, opts) -> Any: def pack_metadata(self, metadata): return metadata - def get_codegen_implementation(self): + def get_codegen_implementation(self, options): from triton.language.extra.intel import convert_custom_float8 codegen_fns = {} codegen_fns["convert_custom_types"] = convert_custom_float8