Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge OpenAI Triton commit 0702320 #3149

Merged
merged 17 commits into from
Jan 13, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
199fd8a
[FRONTEND][NFC] Remove unused strings (#5578)
Jokeren Jan 11, 2025
22ac447
[FRONTEND][BACKEND] plumb `fast_math` attribute from scaled_dot front…
ptillet Jan 12, 2025
a3095b3
Remove `examples` folder (#5574)
anmyachev Jan 12, 2025
f7e6775
[AMD] Pass down atomics memscope through lowering (#5580)
SamGinzburg Jan 12, 2025
9649f71
[AMD] Bypass NaN check for fast math scaled dot (#5584)
antiagainst Jan 12, 2025
6b41bcf
Enable ruff-pre-commit for `third_party/proton` (#5586)
anmyachev Jan 13, 2025
7cc6799
[FRONTEND] capability override bugfix (#5590)
ptillet Jan 13, 2025
7db39a9
Proper use of `subprocess.check_call` in `third_party/proton/test/tes…
anmyachev Jan 13, 2025
3ed479f
Revert "Revert "Reverting #5389 (#5528)" (#5555)" (#5592)
apgoucher Jan 13, 2025
4523d38
Don't use designated initializers in `MatmulLoopPipeline.cpp` as it r…
anmyachev Jan 13, 2025
194a21f
Enable ruff-pre-commit for `third_party/amd` (#5589)
anmyachev Jan 13, 2025
e8ef0bb
[AMD] Disable swap operands for fp8 matmul (#5577)
zhanglx13 Jan 13, 2025
0702320
[INTERPRETER] Fix typo in attribute name (#5593)
matthias-springer Jan 13, 2025
0dfcad6
Merge commit '7cc6799ddb76a18830874259bcaf2da59484c684'
whitneywhtsang Jan 13, 2025
f501e97
[Intel] Plumb fast_math attribute from scaled_dot frontend to LLVM co…
whitneywhtsang Jan 13, 2025
0dfb744
Fix `TypeError: XPUBackend.get_codegen_implementation() takes 1 posit…
whitneywhtsang Jan 13, 2025
865cfae
Merge commit '07023209bfc88c06a9f06b655da6d25e6208f9fa'
whitneywhtsang Jan 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
rev: v0.7.1
hooks:
- id: ruff
files: '^(python|benchmarks|third_party/intel|scripts)/.*'
files: '(^python|^third_party/proton|^third_party/amd|^benchmarks|^third_party/intel|^scripts)/.*'
args: ["--fix", "--exit-non-zero-on-fix"]
exclude: |
(?x)(
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -468,7 +468,8 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
ArrayRef<Value> values);

// Scale a mxfp4 value by a given scale.
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale);
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
bool fastMath);

} // namespace LLVM

Expand Down
18 changes: 6 additions & 12 deletions include/triton/Dialect/Triton/IR/Dialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,19 +55,13 @@ class DialectInferLayoutInterface

// Tries to compute the encoding for the result of a reshape operation that
// makes the reshape a "nop", i.e. the same GPU threads contain the same
// elements as before the reshape using legacy layouts. This is not always
// possible (in which case we fallback to using LinearLayouts)
// In the future we'll always use LinearLayouts
// elements as before the reshape. Note that this is not always possible (in
// which case you'd need to choose a different layout for the input to the
// reshape).
virtual LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

// Check if two layouts are structurally the same, even if their names are
// different
virtual LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const = 0;
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const = 0;

virtual LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
Expand Down
3 changes: 2 additions & 1 deletion include/triton/Dialect/Triton/IR/TritonOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,8 @@ def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
Optional<RankedTensorOf<[I8]>>:$lhs_scale,
Optional<RankedTensorOf<[I8]>>:$rhs_scale,
TT_ScaleDotElemTypeAttr:$lhs_type,
TT_ScaleDotElemTypeAttr:$rhs_type
TT_ScaleDotElemTypeAttr:$rhs_type,
BoolAttr:$fastMath
);

let results = (outs TT_FloatTensor:$d);
Expand Down
11 changes: 7 additions & 4 deletions include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,13 @@ def TTG_UpcastMXFPOp : TTG_Op<"upcast_mxfp", [Pure]> {
Compute the bf16 encoded in the given mxfp number as per
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
}];
let arguments = (ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_ScaleDotElemTypeAttr:$fp_type);
let arguments = (
ins
TT_Tensor:$src,
TT_Tensor:$scale,
TT_ScaleDotElemTypeAttr:$fp_type,
BoolAttr:$fastMath
);
let results = (outs TT_Tensor:$result);

let assemblyFormat = [{
Expand Down
7 changes: 4 additions & 3 deletions lib/Analysis/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,15 @@ bool ReduceOpHelper::isSupportedLayout() {
}

auto srcLayout = getSrcLayout();
if (isa<BlockedEncodingAttr, LinearEncodingAttr, SliceEncodingAttr>(
srcLayout)) {
if (isa<BlockedEncodingAttr>(srcLayout)) {
return true;
}

if (auto mmaLayout = dyn_cast<MmaEncodingTrait>(srcLayout)) {
return mmaLayout.supportReduction();
}
if (auto sliceLayout = dyn_cast<SliceEncodingAttr>(srcLayout)) {
return true;
}
return false;
}

Expand Down
10 changes: 6 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -904,13 +904,15 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
return results;
}

Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v,
Value scale) {
Value mxfpScaleBf16(RewriterBase &rewriter, Location loc, Value v, Value scale,
bool fastMath) {
Value vBf16 = bitcast(v, bf16_ty);
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
Value scaleBf16 = bitcast(shl(zext(i16_ty, scale), i16_val(7)), bf16_ty);
Value scaledBf16 = fmul(vBf16, scaleBf16);
if (fastMath)
return scaledBf16;
Value nanBf16 = bitcast(i16_val(0x7fff), bf16_ty);
Value scaleIsNan = icmp_eq(scale, i8_val(0xff));
// Account for NaN in the scale as per the mxfp specification.
return select(scaleIsNan, nanBf16, scaledBf16);
};
Expand Down
30 changes: 16 additions & 14 deletions lib/Dialect/Triton/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
#include "triton/Dialect/Triton/IR/Dialect.h"
#include "triton/Dialect/Triton/IR/Types.h"
#include "triton/Dialect/Triton/IR/Utility.h"
#include "triton/Tools/LinearLayout.h"
#include "llvm/Support/ErrorHandling.h"

namespace mlir {
Expand Down Expand Up @@ -702,21 +701,24 @@ LogicalResult ReshapeOp::verify() {
"encodings, or (b) neither does.");
}

if (!srcEnc || getAllowReorder()) {
return success();
if (srcEnc && !getAllowReorder()) {
Attribute inferredDstEnc;
if (cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpNoReorderEncoding(srcTy.getShape(), srcEnc,
dstTy.getShape(), inferredDstEnc,
getLoc())
.failed()) {
return emitError("This reshape is impossible without reordering, but "
"reordering is not allowed. Try choosing a different "
"encoding for the input tensor (or allow reordering).");
}
if (inferredDstEnc != dstEnc) {
return emitError("Expected result encoding ")
<< inferredDstEnc << " but was " << dstEnc;
}
}

// Check that we can infer the dst encoding from the src encoding
// and that the inferred dst encoding is the same as the given dst encoding
Attribute inferredDstEnc;
auto result =
cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->inferReshapeOpEncoding(srcTy.getShape(), srcEnc, dstTy.getShape(),
inferredDstEnc, getLoc());
assert(succeeded(result));
return cast<DialectInferLayoutInterface>(&srcEnc.getDialect())
->verifyLayoutsAreEqual(dstTy.getShape(), inferredDstEnc, dstEnc,
getLoc());
return success();
}

//-- FpToFpOp --
Expand Down
125 changes: 42 additions & 83 deletions lib/Dialect/TritonGPU/IR/Dialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1470,7 +1470,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,

SmallVector<unsigned> ret(rank, 1);
auto nonZero = [](auto val) { return val != 0; };
int nonZeroIdx = 0;
int nonZeroIdx = -1;
for (const auto &basis : bases) {
auto it = std::find_if(basis.begin(), basis.end(), nonZero);
// Bases can have one or zero non-zero elements
Expand All @@ -1482,6 +1482,7 @@ SmallVector<unsigned> basesPerDim(const LinearLayout::BasesT &namedBases,
} else if (!skipBroadcast) {
// If we've seen a non-zero basis, we double the size of the previous dim
// This is just needed to count the CTAsPerCGA
assert(nonZeroIdx != -1);
ret[nonZeroIdx] *= 2;
}
}
Expand Down Expand Up @@ -1626,14 +1627,12 @@ LinearEncodingAttr::toLinearLayout(ArrayRef<int64_t> shape) const {

SmallVector<unsigned>
LinearEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape, Type) const {
// When broadcasting the layout the shape changes, otherwise the shape is
// the same as the shape of the tensor
// We can either have BroadcastOp with SameOperandsAndResultEncoding, or keep
// the invariant that the shape of the LL is that of the tensor
// We choose the former for BC
auto ll = *toLinearLayout(shape);
return basesPerDim(ll, StringAttr::get(getContext(), "register"),
/*skipBroadcast=*/false);
// We can relax this assert by calling toLinearLayout rather than
// getLinearLayout
SmallVector<int32_t> shapeVec(shape.begin(), shape.end());
assert(shapeVec == llvm::to_vector(getLinearLayout().getOutDimSizes()));
auto ll = getLinearLayout();
return basesPerDim(ll, StringAttr::get(getContext(), "register"));
}

// Start of Selection
Expand Down Expand Up @@ -2706,8 +2705,8 @@ struct TritonGPUInferLayoutInterface
// contains elements [a,b,c,d] before the reshape, it contains those same
// elements after the reshape, they're just "renamed".
//
// Using legacy layouts, a dst encoding that satisfies this property may not
// exist. Here are some positive and negative examples.
// A dst encoding that satisfies this property does not exist for all inputs.
// Here are some positive and negative examples.
//
// - NOT OK: 4x4 order=[0,1] -> 16. Reshape merges elements so
// dim 1 is the fastest-changing in the dst, but the src has the opposite
Expand All @@ -2721,19 +2720,17 @@ struct TritonGPUInferLayoutInterface
// - OK: 32x4 sizePerThread=[4,4] -> 128. dst with sizePerThread=[16] will
// contain the same elements as before.
//
// With linear layouts, we can always find a dst encoding that satisfies
// this property. See inferReshapeOpEncoding.
//
// Users of this function require that it is symmetrical: if
// (srcShape,srcEnc,dstShape) => dstEnc, then (dstShape,dstEnc,srcShape) =>
// srcEnc.
LogicalResult inferReshapeOpLegacyEncoding(ArrayRef<int64_t> srcShape,
Attribute srcEnc,
ArrayRef<int64_t> dstShape,
Attribute &dstEnc) const {
LogicalResult
inferReshapeOpNoReorderEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto src = mlir::dyn_cast<BlockedEncodingAttr>(srcEnc);
if (!src) {
return failure();
return emitOptionalError(
loc, "Non-reordering reshape only supports BlockedEncoding");
}

// Nop reshape; we can always infer an encoding.
Expand Down Expand Up @@ -2766,7 +2763,9 @@ struct TritonGPUInferLayoutInterface
// to handle CTASplitNum.
if (!all_of(src.getCTAsPerCGA(), [](int32_t x) { return x == 1; }) ||
!all_of(src.getCTASplitNum(), [](int32_t x) { return x == 1; })) {
return failure();
return emitOptionalError(
loc, "Non-reordering reshape does not currently support multi-CTA "
"layouts other than the default layout.");
}

// Cowardly refuse to handle encodings where shape[dim] is not divisible by
Expand All @@ -2776,7 +2775,12 @@ struct TritonGPUInferLayoutInterface
for (int dim = 0; dim < srcShape.size(); dim++) {
if (srcShape[dim] >= subblock[dim] &&
srcShape[dim] % subblock[dim] != 0) {
return failure();
return emitOptionalError(loc,
"Can't do a non-reordering reshape because "
"the size of dimension ",
dim, " (", srcShape[dim], ")",
" is not divisible by ", name, "[", dim, "]",
" = ", subblock[dim]);
}
}
return success();
Expand All @@ -2801,7 +2805,11 @@ struct TritonGPUInferLayoutInterface
// physical order, with `a` being the most major.
for (const auto &[srcDims, dstDims] : decomp) {
if (!isConsecutive(to_vector(reverse(gather(srcInvOrder, srcDims))))) {
return failure();
return emitOptionalError(loc,
"Cannot do a non-reordering reshape given "
"this src encoding order. Dimensions [",
join(srcDims),
"] must be physically consecutive.");
}
}

Expand Down Expand Up @@ -2848,7 +2856,11 @@ struct TritonGPUInferLayoutInterface
// Check that more-minor dims all have 1 in shapeRemaining.
for (int j = i + 1; j < srcDims.size(); j++) {
if (shapeRemaining[j] != 1) {
return failure();
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Must use "
"up sizePerThread / threadsPerWarp / warpsPerCTA for "
"more-minor dimensions before more major-dims can use them.");
}
}

Expand All @@ -2863,7 +2875,13 @@ struct TritonGPUInferLayoutInterface
// only if we're the most-major dimension of the chunk and in all
// future chunks, only this most-major dim has a non-1 size.
if (shapeRemaining[i] == 0 && i != 0) {
return failure();
return emitOptionalError(
loc,
"Invalid src encoding for non-reordering reshape. Block "
"size in dimension ",
dim,
" is larger than the shape that dimension, but this is only "
"allowed for the most-major dimension of a reshape chunk");
}
}
return success();
Expand Down Expand Up @@ -2953,65 +2971,6 @@ struct TritonGPUInferLayoutInterface
return success();
}

LogicalResult verifyLayoutsAreEqual(ArrayRef<int64_t> shape,
Attribute expected, Attribute got,
Location loc) const override {
if (expected == got) {
return success();
}
// Check whether the encodings are structurally the same.
auto expectedLL = triton::gpu::toLinearLayout(shape, expected);
auto gotLL = triton::gpu::toLinearLayout(shape, got);
if (expectedLL != gotLL) {
return emitError(loc, "Expected result encoding ")
<< expected << " but was " << got;
}
return success();
}

LogicalResult
inferReshapeOpEncoding(ArrayRef<int64_t> srcShape, Attribute srcEnc,
ArrayRef<int64_t> dstShape, Attribute &dstEnc,
std::optional<Location> loc) const override {
auto result =
inferReshapeOpLegacyEncoding(srcShape, srcEnc, dstShape, dstEnc);
if (succeeded(result)) {
return result;
}

// If the legacy encoding failed use LinearLayouts.
// Once LinearLayouts are more widely used, we can remove
// inferReshapeOpLegacyEncoding and simply use LLs.
auto *ctx = getContext();
auto src = triton::gpu::toLinearLayout(srcShape, srcEnc);
if (!src) {
return emitOptionalError(loc,
"src encoding does not support linear layout");
}

if (product(srcShape) != product(dstShape)) {
return emitOptionalError(loc, "numel of dst shape does not match "
"numel of src shape");
}

auto newRank = dstShape.size();
SmallVector<std::pair<StringAttr, int32_t>> newOutDims;
for (auto [dim, size] :
llvm::zip(standardOutDimNames(ctx, newRank), dstShape)) {
newOutDims.emplace_back(dim, size);
}
auto srcOutDims = llvm::to_vector(src->getOutDimNames());
// reshapeOp assumes minor-to-major, so we need to transpose the out dims
// before the reshape
std::reverse(srcOutDims.begin(), srcOutDims.end());
std::reverse(newOutDims.begin(), newOutDims.end());
auto dst = src->transposeOuts(srcOutDims)
.reshapeOuts(newOutDims)
.transposeOuts(standardOutDimNames(ctx, newRank));
dstEnc = LinearEncodingAttr::get(ctx, dst);
return success();
}

LogicalResult
inferJoinOpEncoding(Attribute srcEnc, Attribute &dstEnc,
std::optional<Location> loc) const override {
Expand Down
10 changes: 0 additions & 10 deletions lib/Dialect/TritonGPU/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,16 +43,6 @@ struct CanonicalizeConvertFromReshape
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
if (!convert)
return failure();
// If the layouts are structurally the same, the convert is trivial
auto srcType = convert.getSrc().getType();
auto dstType = convert.getType();
auto srcLL = toLinearLayout(srcType.getShape(), srcType.getEncoding());
auto dstLL = toLinearLayout(dstType.getShape(), dstType.getEncoding());
if (srcLL && dstLL && *srcLL == *dstLL) {
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(
op, op.getType(), convert.getSrc(), op.getAllowReorder());
return mlir::success();
}
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
return failure();
if (!op.getAllowReorder() || op.getEfficientLayout())
Expand Down
Loading