From 37b841e02de6b7d0e421003ee07408fa721bf9ea Mon Sep 17 00:00:00 2001 From: Alex Baden Date: Wed, 11 Dec 2024 21:22:09 +0000 Subject: [PATCH] Use block loads for post-dpas vector computation --- test/TritonIntelGPU/blockptr_load.mlir | 31 ++++ .../LoadStoreOpToLLVM.cpp | 161 ++++++++++++++++-- 2 files changed, 177 insertions(+), 15 deletions(-) diff --git a/test/TritonIntelGPU/blockptr_load.mlir b/test/TritonIntelGPU/blockptr_load.mlir index 2189722047..6b8ef0615d 100644 --- a/test/TritonIntelGPU/blockptr_load.mlir +++ b/test/TritonIntelGPU/blockptr_load.mlir @@ -28,6 +28,37 @@ module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} // ----- +// CHECK-DAG: llvm.func spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f(vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} +// CHECK-DAG: llvm.func spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +// CHECK-DAG: llvm.func spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj(!llvm.ptr<1> {llvm.nonnull, llvm.readonly}, i32, i32, i32, vector<2xi32>, !llvm.ptr {llvm.nonnull, llvm.writeonly}) attributes {no_unwind, will_return} +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [1, 0]}> +#dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=2}> +#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> +module attributes {"ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { + tt.func public @matmul_no_scf_with_add_kernel(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg3: i64, %arg4: i64, %arg5: i64, %arg6: i64, %arg8: i64) { + %C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #dpas> + %c0_i32 = arith.constant 0 : i32 + %c1_i64 = arith.constant 1 : i64 + %ptrA = tt.make_tensor_ptr %arg0, [%arg3, %arg5], [%arg6, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %ptrB = tt.make_tensor_ptr %arg1, [%arg5, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK-COUNT-2: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_16b_8r16x2cPU3AS1viiiDv2_iPt({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-COUNT-2: llvm.call spir_funccc @_Z52intel_sub_group_2d_block_read_transform_16b_32r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + // CHECK-COUNT-8: llvm.call spir_funccc @_Z38intel_sub_group_f16_f16_matrix_mad_k16Dv8_sDv8_iDv8_f({{.*}}) {{.*}} : (vector<8xi16>, vector<8xi32>, vector<8xf32>) -> vector<8xf32> + %A = tt.load %ptrA {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %B = tt.load %ptrB {boundaryCheck = array, padding = 1 : i32, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + %D = tt.dot %A, %B, %C, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x64xf16, #dot1> -> tensor<64x64xf32, #dpas> + %ptrX = tt.make_tensor_ptr %arg2, [%arg3, %arg4], [%arg8, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + // CHECK-COUNT-4: llvm.call spir_funccc @_Z41intel_sub_group_2d_block_read_32b_8r16x1cPU3AS1viiiDv2_iPj({{.*}}) {{.*}} : (!llvm.ptr<1>, i32, i32, i32, vector<2xi32>, !llvm.ptr) -> () + %X = tt.load %ptrX {boundaryCheck = array, triton_intel_gpu.block_io = "row_major"} : !tt.ptr> + // CHECK-COUNT-32: llvm.fadd {{.*}}, {{.*}} + %0 = arith.addf %D, %X : tensor<64x64xf32, #dpas> + tt.return + } +} + +// ----- + #dpas = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [8, 4], repCluster = [1, 1], A = [8, 8], B = [8, 16], C = [8, 16]}> #dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> #dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=1}> diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp index 622abed4fe..78e9ffe315 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp @@ -499,7 +499,9 @@ struct LoadOpConversion auto tensorType = cast(resultType); // Only lower loadOp with dpas layout encoding. - if (!hasDotDpasEncoding(tensorType)) + auto encoding = tensorType.getEncoding(); + const bool hasDpasLayout = isa(encoding); + if (!hasDpasLayout && !hasDotDpasEncoding(tensorType)) return failure(); Attribute blockIOAttr = @@ -514,8 +516,11 @@ struct LoadOpConversion "Only row_major or column_major is supported"); const bool memoryRowMajor = (memoryLayoutInfo == "row_major"); - DotOperandEncodingAttr dotLayout = getDotEncoding(tensorType).value(); - auto dotOrder = dotLayout.getThreadOrder(); + auto dpasLayout = hasDpasLayout + ? cast(encoding) + : cast( + getDotEncoding(tensorType).value().getParent()); + auto dotOrder = dpasLayout.getThreadOrder(); size_t rank = dotOrder.size(); const bool valueRowMajor = (dotOrder[rank - 2] == 1 && dotOrder[rank - 1] == 0); @@ -524,10 +529,19 @@ struct LoadOpConversion "Only row_major or column_major is allowed"); const bool isTransposeRequired = valueRowMajor ^ memoryRowMajor; - auto dpasLayout = cast(dotLayout.getParent()); + auto getOpIdx = [&]() -> unsigned { + if (hasDpasLayout) { + return 2; + } else { + auto dotLayout = getDotEncoding(tensorType).value(); + return dotLayout.getOpIdx(); + } + }; - const unsigned opIdx = dotLayout.getOpIdx(); + const unsigned opIdx = getOpIdx(); Type eltTy = tensorType.getElementType(); + unsigned elemSizeInBits = eltTy.getIntOrFloatBitWidth(); + const ArrayRef tensorShape = tensorType.getShape(); unsigned numElems = getTotalElemsPerThread(resultType); SmallVector numReps = @@ -543,6 +557,123 @@ struct LoadOpConversion SmallVector multiDimWarpId = delinearize(rewriter, loc, warpId, warpsPerCTA, dpasOrder); + if (hasDpasLayout) { + // A block load with the DPAS layout but without the DotDpasLayout is + // expected to follow the ordering of the DPAS output. For a 2D block + // load, the rows are distributed across work items/SIMD lanes and the + // column vectors are available for each work item to process. This layout + // aligns to the DPAS layout as the DPAS operation output layout + // distributes rows across work items. + if (isTransposeRequired) { + // TODO: this would likely require a shuffle to match the expected + // ordering coming out of the DPAS layout and requires more + // investigation + return failure(); + } + + MLIRContext *ctx = rewriter.getContext(); + + Value elemSizeInBytes = i32_val(elemSizeInBits / 8); + + SmallVector elemsPerInstr = dpasLayout.getDPASInstShapeC(); + int64_t elemsPerLane = product(elemsPerInstr) / threadsPerWarp; + Type load2DGenXType = + LLVM::getFixedVectorType(IntegerType::get(ctx, elemSizeInBits), + elemsPerLane); // make it opaque type. + + auto [base, baseWidth, baseHeight, rowStride, colStride, offsetBaseX, + offsetBaseY] = + getValuesFromBlockPointerStruct(adaptor.getPtr(), rewriter); + baseWidth = trunc(i32_ty, baseWidth); + baseHeight = trunc(i32_ty, baseHeight); + + auto pitch = trunc(i32_ty, rowStride); + + SmallVector repClusterShape = dpasLayout.getShapeC(); + unsigned outerDimWarpNum = + std::min(warpsPerCTA[rank - 2], + mlir::ceil(tensorShape[rank - 2], + repClusterShape[rank - 2])); + unsigned innerDimWarpNum = + std::min(warpsPerCTA[rank - 1], + mlir::ceil(tensorShape[rank - 1], + repClusterShape[rank - 1])); + Value outerDimWarpId = + urem(multiDimWarpId[rank - 2], i32_val(outerDimWarpNum)); + Value innerDimWarpId = + urem(multiDimWarpId[rank - 1], i32_val(innerDimWarpNum)); + int64_t numRepOuter = numReps[1]; + int64_t numRepInner = numReps[2]; + + std::array replicaStride = { + outerDimWarpNum * repClusterShape[rank - 2], + innerDimWarpNum * repClusterShape[rank - 1]}; + std::array warpStride = {repClusterShape[rank - 2], + repClusterShape[rank - 1]}; + + Value dimWarpId0 = mul(outerDimWarpId, i32_val(warpStride[0])); + Value dimWarpId1 = mul(innerDimWarpId, i32_val(warpStride[1])); + Value warpId0Offset = add(dimWarpId0, offsetBaseY); + Value warpId1Offset = add(dimWarpId1, offsetBaseX); + + ArrayRef repCluster = dpasLayout.getRepCluster(); + unsigned valOffset = 0; + + SmallVector unpackedLoadedVals; + + for (int m = 0; m < numRepOuter; ++m) { + for (int n = 0; n < numRepInner; ++n) { + for (int repM = 0; repM < repCluster[0]; ++repM) { + + Value offsetY = + add(warpId0Offset, + i32_val(m * replicaStride[0] + repM * elemsPerInstr[0])); + for (int repN = 0; repN < repCluster[1]; ++repN) { + Value offsetX = + add(warpId1Offset, + i32_val(n * replicaStride[1] + repN * elemsPerInstr[1])); + + auto load2dOp = rewriter.create( + loc, load2DGenXType, + /*ptr*/ base, + /*base_width*/ mul(baseWidth, elemSizeInBytes), + /*base_height*/ baseHeight, + /*base_pitch*/ mul(pitch, elemSizeInBytes), + /*x*/ trunc(i32_ty, offsetX), + /*y*/ trunc(i32_ty, offsetY), + /*elem_size_in_bits*/ elemSizeInBits, + /*tile_width*/ elemsPerInstr[1], + /*tile_height*/ elemsPerInstr[0], + /*v_blocks*/ 1, + /*transpose*/ false, + /*vnni_transform*/ false); + if (failed(load2dOp.verify())) { + // Explicitly invoke verifier because `triton_gen` ops are + // immediately lowered further to a builtin call. + return failure(); + } + + Value ret = bitcast( + load2dOp, LLVM::getFixedVectorType(eltTy, elemsPerLane)); + + for (size_t i = 0; i < elemsPerLane; i++) { + Value loaded = extract_element(eltTy, ret, i32_val(i)); + unpackedLoadedVals.push_back(loaded); + } + } + } + } + } + + TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter(); + Type llvmResultStructTy = typeConverter->convertType(op.getType()); + Value resultStruct = packLLElements( + loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy); + rewriter.replaceOp(op, {resultStruct}); + + return success(); + } + bool isOperandA = (opIdx == 0); SmallVector dpasInstShape = isOperandA ? dpasLayout.getDPASInstShapeA() @@ -573,11 +704,11 @@ struct LoadOpConversion // input operands to DPAS. // TODO: add support for int4 and int2. unsigned opsPerChannel = dpasLayout.getOpsPerChannel(); - unsigned elemBits = eltTy.getIntOrFloatBitWidth(); - if ((opsPerChannel == 4 && elemBits == 8) || - (opsPerChannel == 2 && elemBits == 16) || - (opsPerChannel == 1 && elemBits == 32)) { - loadResultElemType = (isOperandA && elemBits != 32) ? i16_ty : i32_ty; + if ((opsPerChannel == 4 && elemSizeInBits == 8) || + (opsPerChannel == 2 && elemSizeInBits == 16) || + (opsPerChannel == 1 && elemSizeInBits == 32)) { + loadResultElemType = + (isOperandA && elemSizeInBits != 32) ? i16_ty : i32_ty; packedElemsPerLanePerDPASInst = isOperandA ? elemsPerLanePerDPASInst / (opsPerChannel == 4 ? 2 : 1) : elemsPerLanePerDPASInst / opsPerChannel; @@ -651,7 +782,7 @@ struct LoadOpConversion // PVC 2D load supports 64 bytes per row at most. Load multiple dot operands // by enlarging the vBlocks. - unsigned totalBytesPerRowPerDPASOp = tileWidth * elemBits / 8; + unsigned totalBytesPerRowPerDPASOp = tileWidth * elemSizeInBits / 8; numOperandsPer2DloadN = std::min(numOperandsPer2DloadN, 64 / totalBytesPerRowPerDPASOp); vBlocks = numOperandsPer2DloadN; @@ -695,12 +826,12 @@ struct LoadOpConversion baseWidth = trunc(i32_ty, baseWidth); baseHeight = trunc(i32_ty, baseHeight); - unsigned originalElemBits = elemBits; + const unsigned originalElemBits = elemSizeInBits; if (isTransposeRequired) { // adjust the block io parameter to align HW's limitations on // transposing load. tileWidth = tileWidth / (32 / originalElemBits); - elemBits = 32; + elemSizeInBits = 32; } Value elemSizeInBytes = i32_val(originalElemBits / 8); @@ -744,14 +875,14 @@ struct LoadOpConversion /*base_pitch*/ mul(pitch, elemSizeInBytes), /*x*/ trunc(i32_ty, offsetX), /*y*/ trunc(i32_ty, offsetY), - /*elem_size_in_bits*/ elemBits, + /*elem_size_in_bits*/ elemSizeInBits, /*tile_width*/ tileWidth, /*tile_height*/ tileHeight, /*v_blocks*/ vBlocks, /*transpose*/ isTransposeRequired, /*vnni_transform*/ (usePackedType && !isOperandA && !isTransposeRequired && - eltTy.getIntOrFloatBitWidth() != 32)); + originalElemBits != 32)); if (failed(load2dOp.verify())) { // Explicitly invoke verifier because `triton_gen` ops are // immediately lowered further to a builtin call.