diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp index 31cea35ef8..e6825e8a2a 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp @@ -32,7 +32,9 @@ namespace { SmallVector getWarpsPerTile(tt::DotOp dotOp, ttg::intel::DpasEncodingAttr::DPASCapability dpasCap, - const ArrayRef shape, unsigned numWarps) { + const ArrayRef shape, unsigned numWarps, + const SmallVector &order) { + auto filter = [&dotOp](Operation *op) { return op->getParentRegion() == dotOp->getParentRegion(); }; @@ -64,7 +66,7 @@ getWarpsPerTile(tt::DotOp dotOp, uint32_t colRowRatio = ceil(dpasCap.executionSize, dpasCap.repeatCount); - int rowDim = rank - 2, colDim = rank - 1; + int rowDim = order[rank - 2], colDim = order[rank - 1]; do { if (ret[rowDim] * ret[colDim] >= numWarps) break; @@ -78,7 +80,6 @@ getWarpsPerTile(tt::DotOp dotOp, ret[colDim] *= 2; } } while (true); - return ret; } @@ -117,8 +118,22 @@ class BlockedToDPAS : public OpRewritePattern { Type elemType = oldAType.getElementType(); unsigned opsPerChan = ttg::intel::DpasEncodingAttr::getOpsPerChannel(elemType); + + SmallVector order = {0, 1}; + Operation *aOp = a.getDefiningOp(); + if (aOp && isa(aOp)) { + auto valueToConvert = aOp->getOperand(0); + aOp = valueToConvert.getDefiningOp(); + } + if (aOp && isa(aOp)) { + assert(aOp->getNumResults() == 1); + Attribute layout = + cast(aOp->getResult(0).getType()).getEncoding(); + order = triton::gpu::getOrder(layout); + } + SmallVector warpsPerTile = - getWarpsPerTile(dotOp, dpasCap, retShape, numWarps); + getWarpsPerTile(dotOp, dpasCap, retShape, numWarps, order); size_t rank = retShape.size(); SmallVector repCluster(rank, 1);