Skip to content

Commit

Permalink
Use block loads for post-dpas vector computation
Browse files Browse the repository at this point in the history
  • Loading branch information
alexbaden committed Dec 13, 2024
1 parent ffc2601 commit 37b841e
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 15 deletions.
31 changes: 31 additions & 0 deletions test/TritonIntelGPU/blockptr_load.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,37 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32}

// -----

// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects<other = none, argMem = none, inaccessibleMem = none>, no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return}
#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}>
#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}>
module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} {
tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr<f16>, %arg1: !tt.ptr<f16>, %arg2: !tt.ptr<f32>, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) {
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas>
%c0_i32 = arith.constant 0 : i32
%c1_i64 = arith.constant 1 : i64
%ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x32xf16, #dot0>>
%ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<32x64xf16, #dot1>>
// CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
// CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32>
%A = tt.load %ptrA {boundaryCheck = array<i32: 1>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x32xf16, #dot0>>
%B = tt.load %ptrB {boundaryCheck = array<i32: 0>, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<32x64xf16, #dot1>>
%D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas>
%ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array<i32: 1, 0>} : <tensor<64x64xf32, #dpas>>
// CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> ()
%X = tt.load %ptrX {boundaryCheck = array<i32: 0, 1>, triton_intel_gpu.block_io = "row_major"} : !tt.ptr<tensor<64x64xf32, #dpas>>
// CHECK-COUNT-32: llvm.fadd {{.*}}, {{.*}}
%0 = arith.addf %D, %X : tensor<64x64xf32, #dpas>
tt.return
}
}

// -----

#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}>
#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}>
#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}>
Expand Down
161 changes: 146 additions & 15 deletions third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -499,7 +499,9 @@ struct LoadOpConversion
auto tensorType = cast<RankedTensorType>(resultType);

// Only lower loadOp with dpas layout encoding.
if (!hasDotDpasEncoding(tensorType))
auto encoding = tensorType.getEncoding();
const bool hasDpasLayout = isa<DpasEncodingAttr>(encoding);
if (!hasDpasLayout && !hasDotDpasEncoding(tensorType))
return failure();

Attribute blockIOAttr =
Expand All @@ -514,8 +516,11 @@ struct LoadOpConversion
"Only row_major or column_major is supported");
const bool memoryRowMajor = (memoryLayoutInfo == "row_major");

DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value();
auto dotOrder = dotLayout.getThreadOrder();
auto dpasLayout = hasDpasLayout
? cast<DpasEncodingAttr>(encoding)
: cast<DpasEncodingAttr>(
getDotEncoding(tensorType).value().getParent());
auto dotOrder = dpasLayout.getThreadOrder();
size_t rank = dotOrder.size();
const bool valueRowMajor =
(dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0);
Expand All @@ -524,10 +529,19 @@ struct LoadOpConversion
"Only row_major or column_major is allowed");
const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor;

auto dpasLayout = cast<DpasEncodingAttr>(dotLayout.getParent());
auto getOpIdx = [&]() -> unsigned {
if (hasDpasLayout) {
return 2;
} else {
auto dotLayout = getDotEncoding(tensorType).value();
return dotLayout.getOpIdx();
}
};

const unsigned opIdx = dotLayout.getOpIdx();
const unsigned opIdx = getOpIdx();
Type eltTy = tensorType.getElementType();
unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth();

const ArrayRef<int64_t> tensorShape = tensorType.getShape();
unsigned numElems = getTotalElemsPerThread(resultType);
SmallVector<int64_t> numReps =
Expand All @@ -543,6 +557,123 @@ struct LoadOpConversion
SmallVector<Value> multiDimWarpId =
delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder);

if (hasDpasLayout) {
// A block load with the DPAS layout but without the DotDpasLayout is
// expected to follow the ordering of the DPAS output. For a 2D block
// load, the rows are distributed across work items/SIMD lanes and the
// column vectors are available for each work item to process. This layout
// aligns to the DPAS layout as the DPAS operation output layout
// distributes rows across work items.
if (isTransposeRequired) {
// TODO: this would likely require a shuffle to match the expected
// ordering coming out of the DPAS layout and requires more
// investigation
return failure();
}

MLIRContext *ctx = rewriter.getContext();

Value elemSizeInBytes = i32_val(elemSizeInBits / 8);

SmallVector<unsigned> elemsPerInstr = dpasLayout.getDPASInstShapeC();
int64_t elemsPerLane = product<unsigned>(elemsPerInstr) / threadsPerWarp;
Type load2DGenXType =
LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits),
elemsPerLane); // make it opaque type.

auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX,
offsetBaseY] =
getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter);
baseWidth = trunc(i32_ty, baseWidth);
baseHeight = trunc(i32_ty, baseHeight);

auto pitch = trunc(i32_ty, rowStride);

SmallVector<unsigned> repClusterShape = dpasLayout.getShapeC();
unsigned outerDimWarpNum =
std::min<unsigned>(warpsPerCTA[rank - 2],
mlir::ceil<unsigned>(tensorShape[rank - 2],
repClusterShape[rank - 2]));
unsigned innerDimWarpNum =
std::min<unsigned>(warpsPerCTA[rank - 1],
mlir::ceil<unsigned>(tensorShape[rank - 1],
repClusterShape[rank - 1]));
Value outerDimWarpId =
urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum));
Value innerDimWarpId =
urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum));
int64_t numRepOuter = numReps[1];
int64_t numRepInner = numReps[2];

std::array<unsigned, 2> replicaStride = {
outerDimWarpNum * repClusterShape[rank - 2],
innerDimWarpNum * repClusterShape[rank - 1]};
std::array<unsigned, 2> warpStride = {repClusterShape[rank - 2],
repClusterShape[rank - 1]};

Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0]));
Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1]));
Value warpId0Offset = add(dimWarpId0, offsetBaseY);
Value warpId1Offset = add(dimWarpId1, offsetBaseX);

ArrayRef<unsigned> repCluster = dpasLayout.getRepCluster();
unsigned valOffset = 0;

SmallVector<Value> unpackedLoadedVals;

for (int m = 0; m < numRepOuter; ++m) {
for (int n = 0; n < numRepInner; ++n) {
for (int repM = 0; repM < repCluster[0]; ++repM) {

Value offsetY =
add(warpId0Offset,
i32_val(m * replicaStride[0] + repM * elemsPerInstr[0]));
for (int repN = 0; repN < repCluster[1]; ++repN) {
Value offsetX =
add(warpId1Offset,
i32_val(n * replicaStride[1] + repN * elemsPerInstr[1]));

auto load2dOp = rewriter.create<TritonGEN::Matrix2DBlockLoadOp>(
loc, load2DGenXType,
/*ptr*/ base,
/*base_width*/ mul(baseWidth, elemSizeInBytes),
/*base_height*/ baseHeight,
/*base_pitch*/ mul(pitch, elemSizeInBytes),
/*x*/ trunc(i32_ty, offsetX),
/*y*/ trunc(i32_ty, offsetY),
/*elem_size_in_bits*/ elemSizeInBits,
/*tile_width*/ elemsPerInstr[1],
/*tile_height*/ elemsPerInstr[0],
/*v_blocks*/ 1,
/*transpose*/ false,
/*vnni_transform*/ false);
if (failed(load2dOp.verify())) {
// Explicitly invoke verifier because `triton_gen` ops are
// immediately lowered further to a builtin call.
return failure();
}

Value ret = bitcast(
load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane));

for (size_t i = 0; i < elemsPerLane; i++) {
Value loaded = extract_element(eltTy, ret, i32_val(i));
unpackedLoadedVals.push_back(loaded);
}
}
}
}
}

TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
Type llvmResultStructTy = typeConverter->convertType(op.getType());
Value resultStruct = packLLElements(
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
rewriter.replaceOp(op, {resultStruct});

return success();
}

bool isOperandA = (opIdx == 0);
SmallVector<unsigned> dpasInstShape = isOperandA
? dpasLayout.getDPASInstShapeA()
Expand Down Expand Up @@ -573,11 +704,11 @@ struct LoadOpConversion
// input operands to DPAS.
// TODO: add support for int4 and int2.
unsigned opsPerChannel = dpasLayout.getOpsPerChannel();
unsigned elemBits = eltTy.getIntOrFloatBitWidth();
if ((opsPerChannel == 4 && elemBits == 8) ||
(opsPerChannel == 2 && elemBits == 16) ||
(opsPerChannel == 1 && elemBits == 32)) {
loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty;
if ((opsPerChannel == 4 && elemSizeInBits == 8) ||
(opsPerChannel == 2 && elemSizeInBits == 16) ||
(opsPerChannel == 1 && elemSizeInBits == 32)) {
loadResultElemType =
(isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty;
packedElemsPerLanePerDPASInst =
isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1)
: elemsPerLanePerDPASInst / opsPerChannel;
Expand Down Expand Up @@ -651,7 +782,7 @@ struct LoadOpConversion

// PVC 2D load supports 64 bytes per row at most. Load multiple dot operands
// by enlarging the vBlocks.
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8;
unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8;
numOperandsPer2DloadN =
std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp);
vBlocks = numOperandsPer2DloadN;
Expand Down Expand Up @@ -695,12 +826,12 @@ struct LoadOpConversion
baseWidth = trunc(i32_ty, baseWidth);
baseHeight = trunc(i32_ty, baseHeight);

unsigned originalElemBits = elemBits;
const unsigned originalElemBits = elemSizeInBits;
if (isTransposeRequired) {
// adjust the block io parameter to align HW's limitations on
// transposing load.
tileWidth = tileWidth / (32 / originalElemBits);
elemBits = 32;
elemSizeInBits = 32;
}
Value elemSizeInBytes = i32_val(originalElemBits / 8);

Expand Down Expand Up @@ -744,14 +875,14 @@ struct LoadOpConversion
/*base_pitch*/ mul(pitch, elemSizeInBytes),
/*x*/ trunc(i32_ty, offsetX),
/*y*/ trunc(i32_ty, offsetY),
/*elem_size_in_bits*/ elemBits,
/*elem_size_in_bits*/ elemSizeInBits,
/*tile_width*/ tileWidth,
/*tile_height*/ tileHeight,
/*v_blocks*/ vBlocks,
/*transpose*/ isTransposeRequired,
/*vnni_transform*/
(usePackedType && !isOperandA && !isTransposeRequired &&
eltTy.getIntOrFloatBitWidth() != 32));
originalElemBits != 32));
if (failed(load2dOp.verify())) {
// Explicitly invoke verifier because `triton_gen` ops are
// immediately lowered further to a builtin call.
Expand Down

0 comments on commit 37b841e

Please sign in to comment.