Skip to content

Commit

Permalink
Merge commit '07023209bfc88c06a9f06b655da6d25e6208f9fa'
Browse files Browse the repository at this point in the history
  • Loading branch information
whitneywhtsang committed Jan 13, 2025
2 parents 8b236c8 + 0702320 commit e7e7ed5
Show file tree
Hide file tree
Showing 18 changed files with 432 additions and 345 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
rev: v0.7.1
hooks:
- id: ruff
files: '^(python|^third_party/proton|benchmarks|third_party/intel|scripts)/.*'
files: '(^python|^third_party/proton|^third_party/amd|^benchmarks|^third_party/intel|^scripts)/.*'
args: ["--fix", "--exit-non-zero-on-fix"]
exclude: |
(?x)(
Expand Down
18 changes: 6 additions & 12 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,13 @@ class DialectInferLayoutInterface

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// In the future we'll always use LinearLayouts
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
virtual LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,15 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
if (isa<BlockedEncodingAttr>(srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
30 changes: 16 additions & 14 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand Down Expand Up @@ -702,21 +701,24 @@ LogicalResult ReshapeOp::verify() {
"encodings, or (b) neither does.");
}

if (!srcEnc || getAllowReorder()) {
return success();
if (srcEnc && !getAllowReorder()) {
Attribute inferredDstEnc;
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
dstTy.getShape(), inferredDstEnc,
getLoc())
.failed()) {
return emitError("This reshape is impossible without reordering, but "
"reordering is not allowed. Try choosing a different "
"encoding for the input tensor (or allow reordering).");
}
if (inferredDstEnc != dstEnc) {
return emitError("Expected result encoding ")
<< inferredDstEnc << " but was " << dstEnc;
}
}

// Check that we can infer the dst encoding from the src encoding
// and that the inferred dst encoding is the same as the given dst encoding
Attribute inferredDstEnc;
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
inferredDstEnc, getLoc());
assert(succeeded(result));
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
getLoc());
return success();
}

//-- FpToFpOp --
Expand Down
125 changes: 42 additions & 83 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = 0;
int nonZeroIdx = -1;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1482,6 +1482,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1626,14 +1627,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
// We can relax this assert by calling toLinearLayout rather than
// getLinearLayout
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
auto ll = getLinearLayout();
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
}

// Start of Selection
Expand Down Expand Up @@ -2706,8 +2705,8 @@ struct TritonGPUInferLayoutInterface
// contains elements [a,b,c,d] before the reshape, it contains those same
// elements after the reshape, they're just "renamed".
//
// Using legacy layouts, a dst encoding that satisfies this property may not
// exist. Here are some positive and negative examples.
// A dst encoding that satisfies this property does not exist for all inputs.
// Here are some positive and negative examples.
//
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
// dim 1 is the fastest-changing in the dst, but the src has the opposite
Expand All @@ -2721,19 +2720,17 @@ struct TritonGPUInferLayoutInterface
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
// contain the same elements as before.
//
// With linear layouts, we can always find a dst encoding that satisfies
// this property. See inferReshapeOpEncoding.
//
// Users of this function require that it is symmetrical: if
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
// srcEnc.
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return failure();
return emitOptionalError(
loc, "Non-reordering reshape only supports BlockedEncoding");
}

// Nop reshape; we can always infer an encoding.
Expand Down Expand Up @@ -2766,7 +2763,9 @@ struct TritonGPUInferLayoutInterface
// to handle CTASplitNum.
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return failure();
return emitOptionalError(
loc, "Non-reordering reshape does not currently support multi-CTA "
"layouts other than the default layout.");
}

// Cowardly refuse to handle encodings where shape[dim] is not divisible by
Expand All @@ -2776,7 +2775,12 @@ struct TritonGPUInferLayoutInterface
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return failure();
return emitOptionalError(loc,
"Can't do a non-reordering reshape because "
"the size of dimension ",
dim, " (", srcShape[dim], ")",
" is not divisible by ", name, "[", dim, "]",
" = ", subblock[dim]);
}
}
return success();
Expand All @@ -2801,7 +2805,11 @@ struct TritonGPUInferLayoutInterface
// physical order, with `a` being the most major.
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return failure();
return emitOptionalError(loc,
"Cannot do a non-reordering reshape given "
"this src encoding order. Dimensions [",
join(srcDims),
"] must be physically consecutive.");
}
}

Expand Down Expand Up @@ -2848,7 +2856,11 @@ struct TritonGPUInferLayoutInterface
// Check that more-minor dims all have 1 in shapeRemaining.
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return failure();
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Must use "
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
"more-minor dimensions before more major-dims can use them.");
}
}

Expand All @@ -2863,7 +2875,13 @@ struct TritonGPUInferLayoutInterface
// only if we're the most-major dimension of the chunk and in all
// future chunks, only this most-major dim has a non-1 size.
if (shapeRemaining[i] == 0 && i != 0) {
return failure();
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Block "
"size in dimension ",
dim,
" is larger than the shape that dimension, but this is only "
"allowed for the most-major dimension of a reshape chunk");
}
}
return success();
Expand Down Expand Up @@ -2953,65 +2971,6 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
}
return success();
}

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}

// If the legacy encoding failed use LinearLayouts.
// Once LinearLayouts are more widely used, we can remove
// inferReshapeOpLegacyEncoding and simply use LLs.
auto *ctx = getContext();
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
if (!src) {
return emitOptionalError(loc,
"src encoding does not support linear layout");
}

if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src->transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, newRank));
dstEnc = LinearEncodingAttr::get(ctx, dst);
return success();
}

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const override {
Expand Down
10 changes: 0 additions & 10 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ struct CanonicalizeConvertFromReshape
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// If the layouts are structurally the same, the convert is trivial
auto srcType = convert.getSrc().getType();
auto dstType = convert.getType();
auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding());
auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding());
if (srcLL && dstLL && *srcLL == *dstLL) {
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder());
return mlir::success();
}
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout())
Expand Down
9 changes: 7 additions & 2 deletions lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1025,7 +1025,9 @@ void LayoutRematerialization::backwardRematerialization(
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristic to accommodate fused attention
RankedTensorType targetType = convertOp.getType();
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
// We stop the rematerialization of linear layouts as we have to be a bit more
// careful with the heuristics for both correctness and perf
if (isa<DotOperandEncodingAttr, LinearEncodingAttr>(targetType.getEncoding()))
return;
Value oldV = convertOp.getSrc();
LDBG("check backward remat with source " << oldV << " encoding "
Expand Down Expand Up @@ -1067,8 +1069,11 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
ConvertLayoutOp convertOp) {
// we don't handle conversions to DotOperandEncodingAttr
// this is a heuristics to accommodate fused attention
// We stop the rematerialization of linear layouts as we have to be a bit more
// careful with the heuristics for both correctness and perf
RankedTensorType targetType = convertOp.getType();
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
if (mlir::isa<DotOperandEncodingAttr, LinearEncodingAttr>(
targetType.getEncoding()))
return;

auto isExtOrBroadcastOp = [](Operation *op) {
Expand Down
15 changes: 8 additions & 7 deletions lib/Dialect/TritonGPU/Transforms/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -407,13 +407,14 @@ static Attribute inferReshapeOpDstEncoding(ArrayRef<int64_t> srcShape,
return {};

Attribute dstEnc;
auto result =
srcEnc.getDialect()
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
->inferReshapeOpEncoding(srcShape, srcEnc, dstShape, dstEnc,
/*loc=*/std::nullopt);
assert(succeeded(result));
return dstEnc;
if (succeeded(
srcEnc.getDialect()
.getRegisteredInterface<triton::DialectInferLayoutInterface>()
->inferReshapeOpNoReorderEncoding(
srcShape, srcEnc, dstShape, dstEnc, /*loc=*/std::nullopt))) {
return dstEnc;
}
return {};
}

static Attribute inferDstEncoding(triton::ReshapeOp op, Attribute encoding) {
Expand Down
2 changes: 1 addition & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -993,7 +993,7 @@ def _set_attr(input, values, name):
lang.static_assert = _new_static_assert
lang.static_print = print
lang.dtype.to_ir = _new_to_ir
lang.multiple_of = partial(_set_attr, name="tt.divisiblity")
lang.multiple_of = partial(_set_attr, name="tt.divisibility")
lang.max_contiguous = partial(_set_attr, name="tt.contiguity")
lang.max_constancy = partial(_set_attr, name="tt.constancy")

Expand Down
Loading

0 comments on commit e7e7ed5

Please sign in to comment.