diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a6a18b254a..c171b060d2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -22,7 +22,7 @@ repos: rev: v0.7.1 hooks: - id: ruff - files: '^(python|^third_party/proton|benchmarks|third_party/intel|scripts)/.*' + files: '(^python|^third_party/proton|^third_party/amd|^benchmarks|^third_party/intel|^scripts)/.*' args: ["--fix", "--exit-non-zero-on-fix"] exclude: | (?x)( diff --git a/include/triton/Dialect/Triton/IR/Dialect.h b/include/triton/Dialect/Triton/IR/Dialect.h index 39d006cc65..56a1aa7032 100644 --- a/include/triton/Dialect/Triton/IR/Dialect.h +++ b/include/triton/Dialect/Triton/IR/Dialect.h @@ -55,19 +55,13 @@ class DialectInferLayoutInterface // Tries to compute the encoding for the result of a reshape operation that // makes the reshape a "nop", i.e. the same GPU threads contain the same - // elements as before the reshape using legacy layouts. This is not always - // possible (in which case we fallback to using LinearLayouts) - // In the future we'll always use LinearLayouts + // elements as before the reshape. Note that this is not always possible (in + // which case you'd need to choose a different layout for the input to the + // reshape). virtual LogicalResult - inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const = 0; - - // Check if two layouts are structurally the same, even if their names are - // different - virtual LogicalResult verifyLayoutsAreEqual(ArrayRef shape, - Attribute expected, Attribute got, - Location loc) const = 0; + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const = 0; virtual LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, diff --git a/lib/Analysis/Utility.cpp b/lib/Analysis/Utility.cpp index 8f76285249..2b10095fe7 100644 --- a/lib/Analysis/Utility.cpp +++ b/lib/Analysis/Utility.cpp @@ -225,14 +225,15 @@ bool ReduceOpHelper::isSupportedLayout() { } auto srcLayout = getSrcLayout(); - if (isa( - srcLayout)) { + if (isa(srcLayout)) { return true; } - if (auto mmaLayout = dyn_cast(srcLayout)) { return mmaLayout.supportReduction(); } + if (auto sliceLayout = dyn_cast(srcLayout)) { + return true; + } return false; } diff --git a/lib/Dialect/Triton/IR/Ops.cpp b/lib/Dialect/Triton/IR/Ops.cpp index 876d04e56b..b0ce2287d1 100644 --- a/lib/Dialect/Triton/IR/Ops.cpp +++ b/lib/Dialect/Triton/IR/Ops.cpp @@ -8,7 +8,6 @@ #include "triton/Dialect/Triton/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Types.h" #include "triton/Dialect/Triton/IR/Utility.h" -#include "triton/Tools/LinearLayout.h" #include "llvm/Support/ErrorHandling.h" namespace mlir { @@ -702,21 +701,24 @@ LogicalResult ReshapeOp::verify() { "encodings, or (b) neither does."); } - if (!srcEnc || getAllowReorder()) { - return success(); + if (srcEnc && !getAllowReorder()) { + Attribute inferredDstEnc; + if (cast(&srcEnc.getDialect()) + ->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc, + dstTy.getShape(), inferredDstEnc, + getLoc()) + .failed()) { + return emitError("This reshape is impossible without reordering, but " + "reordering is not allowed. Try choosing a different " + "encoding for the input tensor (or allow reordering)."); + } + if (inferredDstEnc != dstEnc) { + return emitError("Expected result encoding ") + << inferredDstEnc << " but was " << dstEnc; + } } - // Check that we can infer the dst encoding from the src encoding - // and that the inferred dst encoding is the same as the given dst encoding - Attribute inferredDstEnc; - auto result = - cast(&srcEnc.getDialect()) - ->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(), - inferredDstEnc, getLoc()); - assert(succeeded(result)); - return cast(&srcEnc.getDialect()) - ->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc, - getLoc()); + return success(); } //-- FpToFpOp -- diff --git a/lib/Dialect/TritonGPU/IR/Dialect.cpp b/lib/Dialect/TritonGPU/IR/Dialect.cpp index 143a36ba05..ad54ff0c93 100644 --- a/lib/Dialect/TritonGPU/IR/Dialect.cpp +++ b/lib/Dialect/TritonGPU/IR/Dialect.cpp @@ -1470,7 +1470,7 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, SmallVector ret(rank, 1); auto nonZero = [](auto val) { return val != 0; }; - int nonZeroIdx = 0; + int nonZeroIdx = -1; for (const auto &basis : bases) { auto it = std::find_if(basis.begin(), basis.end(), nonZero); // Bases can have one or zero non-zero elements @@ -1482,6 +1482,7 @@ SmallVector basesPerDim(const LinearLayout::BasesT &namedBases, } else if (!skipBroadcast) { // If we've seen a non-zero basis, we double the size of the previous dim // This is just needed to count the CTAsPerCGA + assert(nonZeroIdx != -1); ret[nonZeroIdx] *= 2; } } @@ -1626,14 +1627,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef shape) const { SmallVector LinearEncodingAttr::getElemsPerThread(ArrayRef shape, Type) const { - // When broadcasting the layout the shape changes, otherwise the shape is - // the same as the shape of the tensor - // We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep - // the invariant that the shape of the LL is that of the tensor - // We choose the former for BC - auto ll = *toLinearLayout(shape); - return basesPerDim(ll, StringAttr::get(getContext(), "register"), - /*skipBroadcast=*/false); + // We can relax this assert by calling toLinearLayout rather than + // getLinearLayout + SmallVector shapeVec(shape.begin(), shape.end()); + assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes())); + auto ll = getLinearLayout(); + return basesPerDim(ll, StringAttr::get(getContext(), "register")); } // Start of Selection @@ -2706,8 +2705,8 @@ struct TritonGPUInferLayoutInterface // contains elements [a,b,c,d] before the reshape, it contains those same // elements after the reshape, they're just "renamed". // - // Using legacy layouts, a dst encoding that satisfies this property may not - // exist. Here are some positive and negative examples. + // A dst encoding that satisfies this property does not exist for all inputs. + // Here are some positive and negative examples. // // - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so // dim 1 is the fastest-changing in the dst, but the src has the opposite @@ -2721,19 +2720,17 @@ struct TritonGPUInferLayoutInterface // - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will // contain the same elements as before. // - // With linear layouts, we can always find a dst encoding that satisfies - // this property. See inferReshapeOpEncoding. - // // Users of this function require that it is symmetrical: if // (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) => // srcEnc. - LogicalResult inferReshapeOpLegacyEncoding(ArrayRef srcShape, - Attribute srcEnc, - ArrayRef dstShape, - Attribute &dstEnc) const { + LogicalResult + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { auto src = mlir::dyn_cast(srcEnc); if (!src) { - return failure(); + return emitOptionalError( + loc, "Non-reordering reshape only supports BlockedEncoding"); } // Nop reshape; we can always infer an encoding. @@ -2766,7 +2763,9 @@ struct TritonGPUInferLayoutInterface // to handle CTASplitNum. if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) || !all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) { - return failure(); + return emitOptionalError( + loc, "Non-reordering reshape does not currently support multi-CTA " + "layouts other than the default layout."); } // Cowardly refuse to handle encodings where shape[dim] is not divisible by @@ -2776,7 +2775,12 @@ struct TritonGPUInferLayoutInterface for (int dim = 0; dim < srcShape.size(); dim++) { if (srcShape[dim] >= subblock[dim] && srcShape[dim] % subblock[dim] != 0) { - return failure(); + return emitOptionalError(loc, + "Can't do a non-reordering reshape because " + "the size of dimension ", + dim, " (", srcShape[dim], ")", + " is not divisible by ", name, "[", dim, "]", + " = ", subblock[dim]); } } return success(); @@ -2801,7 +2805,11 @@ struct TritonGPUInferLayoutInterface // physical order, with `a` being the most major. for (const auto &[srcDims, dstDims] : decomp) { if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) { - return failure(); + return emitOptionalError(loc, + "Cannot do a non-reordering reshape given " + "this src encoding order. Dimensions [", + join(srcDims), + "] must be physically consecutive."); } } @@ -2848,7 +2856,11 @@ struct TritonGPUInferLayoutInterface // Check that more-minor dims all have 1 in shapeRemaining. for (int j = i + 1; j < srcDims.size(); j++) { if (shapeRemaining[j] != 1) { - return failure(); + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Must use " + "up sizePerThread / threadsPerWarp / warpsPerCTA for " + "more-minor dimensions before more major-dims can use them."); } } @@ -2863,7 +2875,13 @@ struct TritonGPUInferLayoutInterface // only if we're the most-major dimension of the chunk and in all // future chunks, only this most-major dim has a non-1 size. if (shapeRemaining[i] == 0 && i != 0) { - return failure(); + return emitOptionalError( + loc, + "Invalid src encoding for non-reordering reshape. Block " + "size in dimension ", + dim, + " is larger than the shape that dimension, but this is only " + "allowed for the most-major dimension of a reshape chunk"); } } return success(); @@ -2953,65 +2971,6 @@ struct TritonGPUInferLayoutInterface return success(); } - LogicalResult verifyLayoutsAreEqual(ArrayRef shape, - Attribute expected, Attribute got, - Location loc) const override { - if (expected == got) { - return success(); - } - // Check whether the encodings are structurally the same. - auto expectedLL = triton::gpu::toLinearLayout(shape, expected); - auto gotLL = triton::gpu::toLinearLayout(shape, got); - if (expectedLL != gotLL) { - return emitError(loc, "Expected result encoding ") - << expected << " but was " << got; - } - return success(); - } - - LogicalResult - inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { - auto result = - inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc); - if (succeeded(result)) { - return result; - } - - // If the legacy encoding failed use LinearLayouts. - // Once LinearLayouts are more widely used, we can remove - // inferReshapeOpLegacyEncoding and simply use LLs. - auto *ctx = getContext(); - auto src = triton::gpu::toLinearLayout(srcShape, srcEnc); - if (!src) { - return emitOptionalError(loc, - "src encoding does not support linear layout"); - } - - if (product(srcShape) != product(dstShape)) { - return emitOptionalError(loc, "numel of dst shape does not match " - "numel of src shape"); - } - - auto newRank = dstShape.size(); - SmallVector> newOutDims; - for (auto [dim, size] : - llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) { - newOutDims.emplace_back(dim, size); - } - auto srcOutDims = llvm::to_vector(src->getOutDimNames()); - // reshapeOp assumes minor-to-major, so we need to transpose the out dims - // before the reshape - std::reverse(srcOutDims.begin(), srcOutDims.end()); - std::reverse(newOutDims.begin(), newOutDims.end()); - auto dst = src->transposeOuts(srcOutDims) - .reshapeOuts(newOutDims) - .transposeOuts(standardOutDimNames(ctx, newRank)); - dstEnc = LinearEncodingAttr::get(ctx, dst); - return success(); - } - LogicalResult inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc, std::optional loc) const override { diff --git a/lib/Dialect/TritonGPU/IR/Ops.cpp b/lib/Dialect/TritonGPU/IR/Ops.cpp index 895a7e769e..bcd348e809 100644 --- a/lib/Dialect/TritonGPU/IR/Ops.cpp +++ b/lib/Dialect/TritonGPU/IR/Ops.cpp @@ -43,16 +43,6 @@ struct CanonicalizeConvertFromReshape auto convert = op.getSrc().getDefiningOp(); if (!convert) return failure(); - // If the layouts are structurally the same, the convert is trivial - auto srcType = convert.getSrc().getType(); - auto dstType = convert.getType(); - auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding()); - auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding()); - if (srcLL && dstLL && *srcLL == *dstLL) { - rewriter.replaceOpWithNewOp( - op, op.getType(), convert.getSrc(), op.getAllowReorder()); - return mlir::success(); - } if (isExpensiveView(convert.getSrc().getType(), op.getType())) return failure(); if (!op.getAllowReorder() || op.getEfficientLayout()) diff --git a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp index 1f93d894b5..2b9e12c3ac 100644 --- a/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp @@ -1025,7 +1025,9 @@ void LayoutRematerialization::backwardRematerialization( // we don't handle conversions to DotOperandEncodingAttr // this is a heuristic to accommodate fused attention RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf + if (isa(targetType.getEncoding())) return; Value oldV = convertOp.getSrc(); LDBG("check backward remat with source " << oldV << " encoding " @@ -1067,8 +1069,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast( ConvertLayoutOp convertOp) { // we don't handle conversions to DotOperandEncodingAttr // this is a heuristics to accommodate fused attention + // We stop the rematerialization of linear layouts as we have to be a bit more + // careful with the heuristics for both correctness and perf RankedTensorType targetType = convertOp.getType(); - if (isa(targetType.getEncoding())) + if (mlir::isa( + targetType.getEncoding())) return; auto isExtOrBroadcastOp = [](Operation *op) { diff --git a/lib/Dialect/TritonGPU/Transforms/Utility.cpp b/lib/Dialect/TritonGPU/Transforms/Utility.cpp index 27cb71638f..46dfce695c 100644 --- a/lib/Dialect/TritonGPU/Transforms/Utility.cpp +++ b/lib/Dialect/TritonGPU/Transforms/Utility.cpp @@ -407,13 +407,14 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef srcShape, return {}; Attribute dstEnc; - auto result = - srcEnc.getDialect() - .getRegisteredInterface() - ->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc, - /*loc=*/std::nullopt); - assert(succeeded(result)); - return dstEnc; + if (succeeded( + srcEnc.getDialect() + .getRegisteredInterface() + ->inferReshapeOpNoReorderEncoding( + srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) { + return dstEnc; + } + return {}; } static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) { diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index ee32b62fdd..bcf4cd8c0d 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -993,7 +993,7 @@ def _set_attr(input, values, name): lang.static_assert = _new_static_assert lang.static_print = print lang.dtype.to_ir = _new_to_ir - lang.multiple_of = partial(_set_attr, name="tt.divisiblity") + lang.multiple_of = partial(_set_attr, name="tt.divisibility") lang.max_contiguous = partial(_set_attr, name="tt.contiguity") lang.max_constancy = partial(_set_attr, name="tt.constancy") diff --git a/test/Conversion/reduce_to_llvm.mlir b/test/Conversion/reduce_to_llvm.mlir deleted file mode 100644 index 0bbcecbd93..0000000000 --- a/test/Conversion/reduce_to_llvm.mlir +++ /dev/null @@ -1,70 +0,0 @@ -// RUN: triton-opt %s --allocate-shared-memory --convert-triton-gpu-to-llvm --convert-nv-gpu-to-llvm | mlir-translate -mlir-to-llvmir | opt -S -O1 | FileCheck %s - -#linear = #ttg.linear<{register = [[0, 2], [2, 0]], lane = [[0, 8], [8, 0], [1, 0], [4, 0], [16, 0]], warp = [[0, 1], [0, 4]], block = []}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: @reduce_linear_layout -tt.func private @reduce_linear_layout(%arg0: tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> { - // CHECK-NEXT: [[SRC0:%.*]] = extractvalue {{.*}} %0, 0 - // CHECK-NEXT: [[SRC1:%.*]] = extractvalue {{.*}} %0, 1 - // CHECK-NEXT: [[SRC2:%.*]] = extractvalue {{.*}} %0, 2 - // CHECK-NEXT: [[SRC3:%.*]] = extractvalue {{.*}} %0, 3 - - // The layout looks lke - // [[ T0:0, T32:0, T0:1, T32:1, ... - // [ T4:0, T36:0, T4:1, T36:1, ... - // [ T0:2, T32:2, T0:3, T32:3, ... - // [ T4:2, T36:2, T4:3, T36:3, - // ... - // - // A reduction along axis=0 consists of adding registers (0, 2) and (1, 3) - // before shuffling. - // - // Columns along axis=0 are contained within a warp, so reduction arcoss warps - // is not needed. - - // Reduce within threads - // CHECK-NEXT: [[SUM0:%.*]] = add i32 [[SRC0]], [[SRC2]] - // CHECK-NEXT: [[SUM1:%.*]] = add i32 [[SRC1]], [[SRC3]] - - // Reduce within warp. - // CHECK-NEXT: [[W0:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM0]], i32 16, i32 31) - // CHECK-NEXT: [[WSUM0:%.*]] = add i32 [[W0]], [[SUM0]] - // CHECK-NEXT: [[W1:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM0]], i32 8, i32 31) - // CHECK-NEXT: [[WSUM1:%.*]] = add i32 [[WSUM0]], [[W1]] - // CHECK-NEXT: [[W2:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM1]], i32 4, i32 31) - // CHECK-NEXT: [[WSUM2:%.*]] = add i32 [[WSUM1]], [[W2]] - // CHECK-NEXT: [[W3:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM2]], i32 2, i32 31) - // CHECK-NEXT: [[WSUM3:%.*]] = add i32 [[WSUM2]], [[W3]] - - // CHECK-NEXT: [[W4:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[SUM1]], i32 16, i32 31) - // CHECK-NEXT: [[WSUM4:%.*]] = add i32 [[W4]], [[SUM1]] - // CHECK-NEXT: [[W5:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM4]], i32 8, i32 31) - // CHECK-NEXT: [[WSUM5:%.*]] = add i32 [[WSUM4]], [[W5]] - // CHECK-NEXT: [[W6:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM5]], i32 4, i32 31) - // CHECK-NEXT: [[WSUM6:%.*]] = add i32 [[WSUM5]], [[W6]] - // CHECK-NEXT: [[W7:%.*]] = tail call i32 @llvm.nvvm.shfl.sync.bfly.i32(i32 -1, i32 [[WSUM6]], i32 2, i32 31) - // CHECK-NEXT: [[WSUM7:%.*]] = add i32 [[WSUM6]], [[W7]] - - // CHECK-NEXT: [[DST0:%.*]] = insertvalue { i32, i32 } undef, i32 [[WSUM3]], 0 - // CHECK-NEXT: [[DST1:%.*]] = insertvalue { i32, i32 } [[DST0]], i32 [[WSUM7]], 1 - - %0 = "tt.reduce"(%arg0) ({ - ^bb0(%arg1: i32, %arg2: i32): - %1 = arith.addi %arg1, %arg2 : i32 - tt.reduce.return %1 : i32 - }) {axis = 0 : i32} : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> - - // CHECK-NEXT: ret { i32, i32 } [[DST1]] - tt.return %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> -} - -tt.func @anchor(%ptr: !llvm.ptr, %arg0: tensor<32x16xi32, #linear>) { - %0 = tt.call @reduce_linear_layout(%arg0) : (tensor<32x16xi32, #linear>) -> tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> - %1 = builtin.unrealized_conversion_cast %0 : tensor<16xi32, #ttg.slice<{dim = 0, parent = #linear}>> to !llvm.struct<(i32, i32)> - llvm.store volatile %1, %ptr : !llvm.struct<(i32, i32)>, !llvm.ptr - tt.return -} - -} diff --git a/test/Conversion/tritongpu_to_llvm.mlir b/test/Conversion/tritongpu_to_llvm.mlir index db8b664e76..a16fdd81a5 100644 --- a/test/Conversion/tritongpu_to_llvm.mlir +++ b/test/Conversion/tritongpu_to_llvm.mlir @@ -2123,28 +2123,3 @@ tt.func @upcast_mxfp(%arg0: tensor<32x32xi8, #ttg.dot_op<{opIdx = 0, parent = #m } } - -// ----- - -#blocked = #ttg.blocked<{sizePerThread = [1, 1, 16], threadsPerWarp = [4, 4, 2], warpsPerCTA = [8, 1, 1], order = [2, 1, 0]}> -#linear = #ttg.linear<{register = [[0, 0], [0, 0], [0, 0], [0, 0]], lane = [[0, 0], [0, 1], [0, 2], [1, 0], [2, 0]], warp = [[4, 0], [8, 0], [16, 0]], block = []}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: expand_dims_linear_layout -tt.func private @expand_dims_linear_layout() -> tensor<1x4xi32, #linear> { - %0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> - %1 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32, #ttg.slice<{dim = 0, parent = #linear}>> -> tensor<1x4xi32, #linear> - // CHECK: return %{{.*}} : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32, i32)> - tt.return %1 : tensor<1x4xi32, #linear> -} - -// CHECK-LABEL: reshape_linear_layout_broadcasting -tt.func private @reshape_linear_layout_broadcasting(%arg0: tensor<32x4xbf16, #linear>) -> tensor<32x4x1xbf16, #blocked> { - // CHECK-COUNT-16: extractvalue - // CHECK-COUNT-16: insertvalue - %0 = tt.reshape %arg0 : tensor<32x4xbf16, #linear> -> tensor<32x4x1xbf16, #blocked> - tt.return %0 : tensor<32x4x1xbf16, #blocked> -} - -} diff --git a/test/TritonGPU/combine.mlir b/test/TritonGPU/combine.mlir index ab124f82ab..cd45d1ee05 100644 --- a/test/TritonGPU/combine.mlir +++ b/test/TritonGPU/combine.mlir @@ -2829,26 +2829,3 @@ tt.func @remat_across_regions(%arg0: i1, %arg1: tensor<8x8xf32, #blocked>) { } } - -// ----- - -#linear = #ttg.linear<{register = [[1, 0], [0, 8], [0, 16]], lane = [[2, 0], [4, 0], [8, 0], [16, 0], [0, 1]], warp = [[0, 2], [0, 4]], block = []}> -#blocked = #ttg.blocked<{sizePerThread = [2, 4], threadsPerWarp = [16, 2], warpsPerCTA = [1, 4], order = [1, 0]}> - -module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "cuda:90", "ttg.threads-per-warp" = 32 : i32} { - -// CHECK-LABEL: reduce_linear_layouts -tt.func @reduce_linear_layouts(%arg0: tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> { - // CHECK-NOT: convert_layout - %0 = ttg.convert_layout %arg0 : tensor<32x32xi32, #linear> -> tensor<32x32xi32, #blocked> - // CHECK-NEXT: tt.reduce - %1 = "tt.reduce" (%0) ({ - ^bb0(%arg1: i32, %arg2: i32): - tt.reduce.return %arg1 : i32 - // CHECK: (tensor<32x32xi32, #linear>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}> - }) {axis = 1 : i32} : (tensor<32x32xi32, #blocked>) -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> - %2 = ttg.convert_layout %1 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #blocked}>> -> tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> - tt.return %2 : tensor<32xi32, #ttg.slice<{dim = 1, parent = #linear}>> -} - -} diff --git a/third_party/amd/backend/driver.py b/third_party/amd/backend/driver.py index fd02d53270..ca712f9040 100644 --- a/third_party/amd/backend/driver.py +++ b/third_party/amd/backend/driver.py @@ -37,7 +37,7 @@ class DlPhdrInfo(ctypes.Structure): # Load libc and get the dl_iterate_phdr symbol. try: dl_iterate_phdr = ctypes.CDLL('libc.so.6').dl_iterate_phdr - except: + except Exception: return None # argtypes must use c_char_p to accept create_string_buffer. dl_iterate_phdr.argtypes = [callback_t, c_char_p] diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp index 5efa81c912..678ac0c54b 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp +++ b/third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp @@ -337,6 +337,23 @@ class BlockedToMFMA : public OpRewritePattern { : OpRewritePattern(context, benefit), mfmaVersion(mfmaVersion), nonKDim(nonKDim), kPack(kPack) {} + bool isChainDot(tt::DotOp &dotOp) const { + auto filter = [&dotOp](Operation *op) { + return op->getParentRegion() == dotOp->getParentRegion(); + }; + ForwardSliceOptions fwdOpt; + fwdOpt.filter = filter; + BackwardSliceOptions bwdOpt; + bwdOpt.omitBlockArguments = true; + bwdOpt.filter = filter; + auto slices = getSlice(dotOp, bwdOpt, fwdOpt); + for (Operation *op : slices) { + if (isa(op) && (op != dotOp)) + return true; + } + return false; + } + bool isSecondDot(tt::DotOp &dotOp) const { auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); @@ -391,12 +408,16 @@ class BlockedToMFMA : public OpRewritePattern { auto warpsPerTile = warpsPerTileMFMA(dotOp, retShape, numWarps, {mDim, nDim}); - // Always use transposed mfma layout. This enables larger vectorization - // for global store instructions + // Use transposed mfma layout to enable larger vectorization for global + // store instructions, except for fp8 matmul kernels due to regression + // TODO (lixun): investigate the regression and enable this feature again + auto aElemTy = mfmaInstr.getElementTypeA(); + bool isFP8 = aElemTy.isFloat8E5M2FNUZ() || aElemTy.isFloat8E4M3FNUZ(); + bool isTransposed = isChainDot(dotOp) || !isFP8; mfmaEnc = ttg::AMDMfmaEncodingAttr::get( oldRetType.getContext(), /*versionMajor*/ mfmaVersion, /*versionMinor*/ 0, warpsPerTile, - /*instrShape*/ mDim, nDim, /*isTransposed*/ true, CTALayout); + /*instrShape*/ mDim, nDim, isTransposed, CTALayout); Type mfmaAccType; if (oldRetType.getElementType().isIntOrIndex()) diff --git a/third_party/amd/python/test/test_extract_slice.py b/third_party/amd/python/test/test_extract_slice.py index a9c7df4754..5d24080861 100644 --- a/third_party/amd/python/test/test_extract_slice.py +++ b/third_party/amd/python/test/test_extract_slice.py @@ -1,11 +1,7 @@ -import tempfile - -import numpy as np import pytest import torch import triton -import triton.language as tl from triton._internal_testing import is_hip diff --git a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp index de2f019fd8..c6ffa50aa1 100644 --- a/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp +++ b/third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp @@ -646,26 +646,10 @@ struct TritonIntelGPUInferLayoutInterface return success(); } - LogicalResult verifyLayoutsAreEqual(ArrayRef shape, - Attribute expected, Attribute got, - Location loc) const override { - if (expected == got) { - return success(); - } - // Check whether the encodings are structurally the same. - auto expectedLL = triton::gpu::toLinearLayout(shape, expected); - auto gotLL = triton::gpu::toLinearLayout(shape, got); - if (expectedLL != gotLL) { - return emitError(loc, "Expected result encoding ") - << expected << " but was " << got; - } - return success(); - } - LogicalResult - inferReshapeOpEncoding(ArrayRef srcShape, Attribute srcEnc, - ArrayRef dstShape, Attribute &dstEnc, - std::optional loc) const override { + inferReshapeOpNoReorderEncoding(ArrayRef srcShape, Attribute srcEnc, + ArrayRef dstShape, Attribute &dstEnc, + std::optional loc) const override { // TODO return failure(); } diff --git a/third_party/proton/test/test_cmd.py b/third_party/proton/test/test_cmd.py index 620dcd5691..75d364fecd 100644 --- a/third_party/proton/test/test_cmd.py +++ b/third_party/proton/test/test_cmd.py @@ -7,8 +7,7 @@ def test_help(): # Only check if the viewer can be invoked - ret = subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) - assert ret == 0 + subprocess.check_call(["proton", "-h"], stdout=subprocess.DEVNULL) def is_hip(): @@ -22,14 +21,13 @@ def test_exec(mode, tmp_path: pathlib.Path): temp_file = tmp_path / "test_exec.hatchet" name = str(temp_file.with_suffix("")) if mode == "script": - ret = subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) + subprocess.check_call(["proton", "-n", name, helper_file, "test"], stdout=subprocess.DEVNULL) elif mode == "python": - ret = subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], - stdout=subprocess.DEVNULL) + subprocess.check_call(["python3", "-m", "triton.profiler.proton", "-n", name, helper_file, "test"], + stdout=subprocess.DEVNULL) elif mode == "pytest": - ret = subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], - stdout=subprocess.DEVNULL) - assert ret == 0 + subprocess.check_call(["proton", "-n", name, "pytest", "-k", "test_main", helper_file], + stdout=subprocess.DEVNULL) with temp_file.open() as f: data = json.load(f, ) kernels = data[0]["children"] diff --git a/unittest/Dialect/TritonGPU/DialectTest.cpp b/unittest/Dialect/TritonGPU/DialectTest.cpp index a3cc65605f..3fab64c0e8 100644 --- a/unittest/Dialect/TritonGPU/DialectTest.cpp +++ b/unittest/Dialect/TritonGPU/DialectTest.cpp @@ -77,6 +77,135 @@ int64_t getFlatIdx(ArrayRef idx, ArrayRef shape, return flatIdx; } +// Represents the many indices of one element of a tensor with a +// BlockedEncoding. +// +// The purpose of this class is we can say, if two MultiIdx's have the same +// flatFoo values before and after a reshape, then the same GPU thread contains +// the same element (and the reshape is a nop, at least for that element). +struct MultiIdx { + using Vec = SmallVector; + + // Logical index into the tensor. + Vec idx; + + // If the tensor's encoding has e.g. numPerThread = [2,2], then idxInThread + // tells us which of the four elements per thread this is. Same for idxInWarp + // and idxInCTA. + Vec idxInThread; + Vec idxInWarp; + Vec idxInCTA; + + // If the tensor's encoding defines a block of size [x,y,z], the tensor itself + // may be larger than this, comprising multiple blocks. This tells us which + // block we're in. + Vec idxOuter; + + // flatIdx is flattened according to the tensor's logical order (i.e. ignoring + // the encoding). The others are flattened according to the tensor's physical + // encoding. + int64_t flatIdx; + int64_t flatIdxInThread; + int64_t flatIdxInWarp; + int64_t flatIdxInCTA; + int64_t flatIdxOuter; +}; + +bool sameFlatIdxs(const MultiIdx &a, const MultiIdx &b) { + return a.flatIdx == b.flatIdx && // + a.flatIdxInThread == b.flatIdxInThread && + a.flatIdxInWarp == b.flatIdxInWarp && + a.flatIdxInCTA == b.flatIdxInCTA && // + a.flatIdxOuter == b.flatIdxOuter; +} + +std::string multiIdxsToString(ArrayRef> idxs) { + std::stringstream ss; + for (const auto &idxPtr : idxs) { + const MultiIdx &idx = *idxPtr; + ss // + << " [" << triton::join(idx.idx, ",") << "] (" << idx.flatIdx << ") " + << "elem=[" << triton::join(idx.idxInThread, ",") << "] (" + << idx.flatIdxInThread << ") " + << "thread=[" << triton::join(idx.idxInWarp, ",") << "] (" + << idx.flatIdxInWarp << ") " + << "warp=[" << triton::join(idx.idxInCTA, ",") << "] (" + << idx.flatIdxInCTA << ") " + << "outer=[" << triton::join(idx.idxOuter, ",") << "] (" + << idx.flatIdxOuter << ")\n"; + } + return ss.str(); +} + +std::vector> getMultiIdxs(ArrayRef shape, + BlockedEncodingAttr enc) { + using Vec = MultiIdx::Vec; + + const unsigned rank = shape.size(); + auto sizePerThread = enc.getSizePerThread(); + auto threadsPerWarp = enc.getThreadsPerWarp(); + auto warpsPerCTA = enc.getWarpsPerCTA(); + auto order = enc.getOrder(); + + Vec numBlocks; + for (int i = 0; i < rank; i++) { + numBlocks.push_back(ceil( + shape[i], sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i])); + } + + Vec idxInThread(rank, 0); + Vec idxInWarp(rank, 0); + Vec idxInCTA(rank, 0); + Vec idxOuter(rank, 0); + + int64_t nElems = product(sizePerThread) * product(threadsPerWarp) * + product(warpsPerCTA) * product(numBlocks); + + // We eventually sort this array, and if the elements are plain MultiIdx + // elements rather than pointers, we have to swap them, which ends up being + // expensive. + std::vector> elems; + elems.reserve(nElems); + + for (int64_t i = 0; i < nElems; i++) { + auto e = std::make_unique(); + e->idxInThread = idxInThread; + e->idxInWarp = idxInWarp; + e->idxInCTA = idxInCTA; + e->idxOuter = idxOuter; + + for (int i = 0; i < rank; i++) { + e->idx.push_back( // + idxInThread[i] + // + idxInWarp[i] * sizePerThread[i] + + idxInCTA[i] * sizePerThread[i] * threadsPerWarp[i] + + idxOuter[i] * sizePerThread[i] * threadsPerWarp[i] * warpsPerCTA[i]); + } + + e->flatIdxInThread = getFlatIdx(e->idxInThread, sizePerThread, order); + e->flatIdxInWarp = getFlatIdx(e->idxInWarp, threadsPerWarp, order); + e->flatIdxInCTA = getFlatIdx(e->idxInCTA, warpsPerCTA, order); + e->flatIdxOuter = getFlatIdx(e->idxOuter, numBlocks, order); + e->flatIdx = getFlatIdx(e->idx, shape, + llvm::to_vector(llvm::reverse(llvm::seq(rank)))); + + elems.push_back(std::move(e)); + + if (advance(idxInThread, sizePerThread, order)) { + if (advance(idxInWarp, threadsPerWarp, order)) { + if (advance(idxInCTA, warpsPerCTA, order)) { + advance(idxOuter, numBlocks, order); + } + } + } + } + llvm::sort(elems, [](const std::unique_ptr &a, + const std::unique_ptr &b) { + return a->flatIdx < b->flatIdx; + }); + return elems; +} + class InferLayoutTest : public ::testing::Test { public: InferLayoutTest() @@ -92,12 +221,25 @@ class InferLayoutTest : public ::testing::Test { /*static*/ MLIRContext InferLayoutTest::ctx; +// The optional outparam couldReshape tells the caller whether the reshape +// worked. You might want this to be a return value instead, but gtest ASSERT +// and FAIL have an implicit `return`, so only work in fns that return void. void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, std::optional expectedDstEnc, + std::optional expectSuccess, DialectInferLayoutInterface *inferLayout, - bool longErrors = true) { + bool longErrors = true, bool *couldReshape = nullptr) { + std::unique_ptr couldReshapeStorage; + if (!couldReshape) { + couldReshapeStorage = std::make_unique(); + couldReshape = couldReshapeStorage.get(); + } + *couldReshape = false; MLIRContext *ctx = srcTy.getContext(); + ASSERT_TRUE(expectSuccess || !dstTy.getEncoding()) + << "dstTy shouldn't have an expected encoding if we're expecting the " + "reshape to be impossible!"; // Capture any errors from calling inferReshapeNoOpReorderEncoding, so we can // print them if we expected the reshape to succeed but it failed. @@ -107,17 +249,29 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, { ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); - result = inferLayout->inferReshapeOpEncoding( + result = inferLayout->inferReshapeOpNoReorderEncoding( srcTy.getShape(), srcTy.getEncoding(), dstTy.getShape(), inferredEnc, UnknownLoc::get(ctx)); } - // We expect the reshape to succeed as long as the inputs have the same - // number of elements - EXPECT_TRUE(succeeded(result)) - << "Expected reshape to succeed, but it didn't! Error(s):\n" - << join(diags, "\n"); + if (!expectSuccess.has_value() && !succeeded(result)) { + // We didn't know whether or not it was supposed to succeed, and it didn't. + // Test passes! + return; + } + + if (expectSuccess.has_value() && !*expectSuccess) { + EXPECT_FALSE(succeeded(result)) + << "Expected reshape to be impossible, but got dst encoding: " + << stringifyLLVMType(inferredEnc); + *couldReshape = true; + return; + } + if (!succeeded(result)) { + FAIL() << "Expected reshape to succeed, but it didn't! Error(s):\n" + << join(diags, "\n"); + } if (auto expectedEnc = dstTy.getEncoding()) { EXPECT_EQ(inferredEnc, expectedEnc); } @@ -125,14 +279,12 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, // We know that infer(srcShape, srcEnc, dstShape) => dstEnc. Check that it // works the other way around too: infer(dstShape, dstEnc, srcShape) => // srcEnc. (This is an invariant of the inference function.) - // Even more, we check that the inferred encoding is structurally the same as - // the src encoding, showing that the inference is consistent. { std::vector diags; ScopedDiagnosticHandler scopedHandler( ctx, [&](Diagnostic &diag) { diags.push_back(" - " + diag.str()); }); Attribute inferredSrcEnc; - auto result = inferLayout->inferReshapeOpEncoding( + auto result = inferLayout->inferReshapeOpNoReorderEncoding( dstTy.getShape(), inferredEnc, srcTy.getShape(), inferredSrcEnc, UnknownLoc::get(ctx)); EXPECT_TRUE(succeeded(result)) @@ -140,40 +292,56 @@ void testReshape(RankedTensorType srcTy, RankedTensorType dstTy, << " " << stringifyLLVMType(inferredEnc) << " -> " << triton::join(srcTy.getShape(), "x") << "failed:\n" << join(diags, "\n"); - auto srcLinear = toLinearLayout(srcTy.getShape(), srcTy.getEncoding()); - auto inferredSrcLinear = toLinearLayout(srcTy.getShape(), inferredSrcEnc); - EXPECT_EQ(inferredSrcLinear, srcLinear) - << "Inverse encoding inference (" << triton::join(dstTy.getShape(), "x") - << " " << stringifyLLVMType(inferredEnc) << " -> " - << triton::join(srcTy.getShape(), "x") - << " gave the wrong result. Expected " << srcLinear->toString() - << " but " - << "got " << inferredSrcLinear->toString() << ".\n"; + if (succeeded(result)) { + EXPECT_EQ(inferredSrcEnc, srcTy.getEncoding()) + << "Inverse encoding inference (" + << triton::join(dstTy.getShape(), "x") << " " + << stringifyLLVMType(inferredEnc) << " -> " + << triton::join(srcTy.getShape(), "x") + << " gave the wrong result. Expected " + << stringifyLLVMType(srcTy.getEncoding()) << " but got " + << stringifyLLVMType(inferredSrcEnc) << ".\n"; + } } - // The funtional characterisation of resize is that, if we have a srcLayout - // and a dstLayout, then the flattened layouts are views of the same data - // when considered as C-contiguous. - auto makeFlattenedCContig = [](ArrayRef shape, Attribute layout) { - auto ctx = layout.getContext(); - auto linear = *toLinearLayout(shape, layout); - auto dims = standardOutDimNames(ctx, shape.size()); - std::reverse(dims.begin(), dims.end()); - return linear.transposeOuts(dims).reshapeOuts( - {{dims.back(), linear.getTotalOutDimSize()}}); - }; - EXPECT_EQ(makeFlattenedCContig(srcTy.getShape(), srcTy.getEncoding()), - makeFlattenedCContig(dstTy.getShape(), inferredEnc)); + std::vector> srcMultiIdxs = + getMultiIdxs(SmallVector(srcTy.getShape()), + mlir::cast(srcTy.getEncoding())); + + std::vector> dstMultiIdxs = + getMultiIdxs(SmallVector(dstTy.getShape()), + mlir::cast(inferredEnc)); + + if (srcMultiIdxs.size() != dstMultiIdxs.size() || + !llvm::all_of(llvm::zip_equal(srcMultiIdxs, dstMultiIdxs), + [](const auto &pair) { + const auto &[a, b] = pair; + return sameFlatIdxs(*a, *b); + })) { + SCOPED_TRACE(longErrors ? "dst indices:\n" + multiIdxsToString(dstMultiIdxs) + : ""); + SCOPED_TRACE(longErrors ? "src indices:\n" + multiIdxsToString(srcMultiIdxs) + : ""); + ADD_FAILURE() << "Reified indices do not match for encodings:\n" + << " src: [" << triton::join(srcTy.getShape(), "x") << "] " + << stringifyLLVMType(srcTy.getEncoding()) << "\n" + << " dst: [" << triton::join(dstTy.getShape(), "x") << "] " + << stringifyLLVMType(inferredEnc); + } else { + *couldReshape = true; + } } -class InferReshapeOpEncodingTest +class InferReshapeOpNoReorderEncodingTest : public InferLayoutTest, public ::testing::WithParamInterface< - std::tuple> {}; + std::tuple> {}; -TEST_P(InferReshapeOpEncodingTest, DoIt) { +TEST_P(InferReshapeOpNoReorderEncodingTest, DoIt) { std::string srcTyStr = expandTyStr(std::get<0>(GetParam())); std::string dstTyStr = expandTyStr(std::get<1>(GetParam())); + bool expectSuccess = std::get<2>(GetParam()); auto src = mlir::parseType(srcTyStr, &ctx); if (!src) @@ -189,7 +357,7 @@ TEST_P(InferReshapeOpEncodingTest, DoIt) { } testReshape(cast(src), cast(dst), - expectedDstEnc, inferLayout, /*longErrors=*/true); + expectedDstEnc, expectSuccess, inferLayout, /*longErrors=*/true); } // A testcase of {a, b, c} means: @@ -200,72 +368,158 @@ TEST_P(InferReshapeOpEncodingTest, DoIt) { // encoding that makes the reshape a nop, and // - if b has an encoding, check that the inferred encoding matches b's. INSTANTIATE_TEST_SUITE_P( - Reshapes, InferReshapeOpEncodingTest, - ::testing::ValuesIn(std::vector>({ + Reshapes, InferReshapeOpNoReorderEncodingTest, + ::testing::ValuesIn(std::vector< + std::tuple>({ // Use raw strings in here so clang-format doesn't try to wrap them. {R"(T<128x64xf32, #B<{spt=[1,1], tpw=[1,32], wpc=[1,1], ord=[1,0]}>>)", - R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)"}, + R"(T<8192xf32, #B<{spt=[1], tpw=[32], wpc=[1], ord=[0]}>>)", + true}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)"}, + R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", + true}, {R"(T<128xf32, #B<{spt=[4], tpw=[32], wpc=[1], ord=[0]}>>)", - R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)"}, + R"(T<16x8xf32, #B<{spt=[1,4], tpw=[16,2], wpc=[1,1], ord=[1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[2,2], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - "T<1024xf32>"}, + "T<128xf32>", false}, {R"(T<32x4xf32, #B<{spt=[1,4], tpw=[32,1], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<2x16x2x2xf32, #B<{spt=[1,1,2,2], tpw=[2,16,1,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<4x32xf32, #B<{spt=[4,1], tpw=[1,32], wpc=[1,1], ord=[0,1]}>>)", - R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)"}, + R"(T<2x2x2x16xf32, #B<{spt=[2,2,1,1], tpw=[1,1,2,16], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<2x16x2x16xf32, #B<{spt=[1,4,1,4], tpw=[1,4,2,4], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)"}, + R"(T<16x2x16x2xf32, #B<{spt=[2,2,2,2], tpw=[4,1,8,1], wpc=[1,1,1,1], ord=[3,2,1,0]}>>)", + true}, {R"(T<32x32xf32, #B<{spt=[4,4], tpw=[4,8], wpc=[1,1], ord=[0,1]}>>)", - R"(T<16x2x16x2xf32>)"}, + R"(T<16x2x16x2xf32>)", true}, // nop reshape, but the block size is 2x larger than the tensor. {R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", - R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)"}, + R"(T<4x2x2x4xf32, #B<{spt=[2,1,1,2], tpw=[2,1,1,2], wpc=[2,2,1,1], ord=[0,3,1,2]}>>)", + true}, {R"(T<2x4x2x4xf32, #B<{spt=[1,2,2,1], tpw=[1,2,1,2], wpc=[1,2,2,1], ord=[2,1,0,3]}>>)", - R"(T<4x2x2x4xf32>)"}, + R"(T<4x2x2x4xf32>)", false}, {R"(T<1x2x2x4xf32, #B<{spt=[1,32,4,4], tpw=[4,4,16,16], wpc=[8,8,8,1], ord=[0,1,2,3]}>>)", - R"(T<2x2x4x1xf32>)"}, + R"(T<2x2x4x1xf32>)", false}, {R"(T<2x2x2x2xf32, #B<{spt=[2,2,2,2], tpw=[1,1,1,1], wpc=[1,1,1,1], ord=[1,0,3,2]}>>)", - R"(T<4x4xf32>)"}, + R"(T<4x4xf32>)", true}, {R"(T<16x8xf32, #B<{spt=[1,2], tpw=[2,4], wpc=[2,1], ord=[1,0]}>>)", - R"(T<128xf32>)"}, + R"(T<128xf32>)", true}, {R"(T<16x1x8xf32, #B<{spt=[8,1,1], tpw=[2,1,1], wpc=[1,1,8], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)"}, + R"(T<128x1xf32>)", false}, {R"(T<16x1x8xf32, #B<{spt=[1,1,8], tpw=[2,1,1], wpc=[8,1,1], ord=[2,1,0]}>>)", - R"(T<128x1xf32>)"}, + R"(T<128x1xf32>)", true}, {R"(T<32x32xf32, #B<{spt=[1,2], tpw=[1,8], wpc=[1,1], ord=[1,0]}>>)", - R"(T<1024xf32>)"}, + R"(T<1024xf32>)", true}, {R"(T<4x4xf32, #B<{spt=[1,1], tpw=[2,4], wpc=[2,1], ord=[0,1]}>>)", - R"(T<16xf32>)"}, + R"(T<16xf32>)", false}, {R"(T<32xf32, #B<{spt=[2], tpw=[32], wpc=[2], ord=[0]}>>)", - R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)"}, + R"(T<16x2xf32, #B<{spt=[1,2], tpw=[32,1], wpc=[2,1], ord=[1,0]}>>)", + true}, {R"(T<2x1x2xf32, #B<{spt=[2,1,1], tpw=[2,1,2], wpc=[4,1,8], ord=[2,1,0]}>>)", - R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)"}, + R"(T<2x2xf32, #B<{spt=[2,1], tpw=[2,2], wpc=[4,8], ord=[1,0]}>>)", + true}, }))); +TEST_F(InferLayoutTest, FuzzReshape) { + const int numTests = 1000; // Increase to get more coverage. + + std::minstd_rand rng(/*seed=*/0); + auto randPow2Vec = [&](int rank, int maxPow2) { + SmallVector ret; + for (int i = 0; i < rank; i++) { + int pow2 = std::uniform_int_distribution(0, maxPow2)(rng); + if (pow2 == maxPow2 && maxPow2 > 0) { + maxPow2--; + } + ret.push_back(1 << pow2); + } + return ret; + }; + + int numSuccess = 0; + for (int i = 0; i < numTests; i++) { + SCOPED_TRACE("iteration " + std::to_string(i)); + int rank = std::uniform_int_distribution(1, 4)(rng); + + SmallVector srcShape( + convertType(randPow2Vec(rank, /*maxPow2=*/4))); + SmallVector dstShape = srcShape; + std::shuffle(dstShape.begin(), dstShape.end(), rng); + + // Optionally merge some dimensions in dst. + for (int i = 1; i < dstShape.size(); i++) { + if (std::uniform_real_distribution(0, 1)(rng) > 1.0 / rank) { + dstShape[i - 1] *= dstShape[i]; + dstShape.erase(dstShape.begin() + i); + i--; + } + } + + SmallVector sizePerThread = randPow2Vec(rank, /*maxPow2=*/3); + SmallVector threadsPerWarp = randPow2Vec(rank, /*maxPow2=*/3); + SmallVector warpsPerCTA = randPow2Vec(rank, /*maxPow2=*/3); + + SmallVector order(llvm::to_vector(llvm::seq(rank))); + std::shuffle(order.begin(), order.end(), rng); + + auto ctaLayout = CTALayoutAttr::get( + &ctx, SmallVector(rank, 1), SmallVector(rank, 1), + llvm::to_vector(llvm::reverse(llvm::seq(rank)))); + + auto srcTy = RankedTensorType::get( + srcShape, FloatType::getF32(&ctx), + BlockedEncodingAttr::get(&ctx, sizePerThread, threadsPerWarp, + warpsPerCTA, order, ctaLayout)); + auto dstTy = RankedTensorType::get(dstShape, FloatType::getF32(&ctx)); + + bool couldReshape = false; + testReshape(srcTy, dstTy, /*expectedDstEnc=*/std::nullopt, + /*expectSuccess=*/std::nullopt, inferLayout, + /*longErrors=*/false, &couldReshape); + if (couldReshape) + numSuccess++; + } + + // We don't expect or want 100% success, but if only a tiny fraction of tests + // actually exercise the successful reshape logic, then that gives us bad + // coverage. I'm currently getting 35% success, which seems good enough, + // especially since the successful cases take a lot longer to run because of + // the MultiIdx checks (so we're spending most of our time on successful + // cases, even if they're only 1/3 of the iterations). + // + // Run ctest with --verbose to see this output. For example: + // $ cd python/build/cmake.blah.blah + // $ ninja + // $ $(git rev-parse --show-toplevel)/.venv/bin/ctest --verbose + printf("Fuzz success rate: %d/%d = %.2f%%\n", numSuccess, numTests, + 100.0 * numSuccess / numTests); +} + class AMDLayoutTest : public ::testing::Test { public: AMDLayoutTest() {