From 114add9ef2897d25a48d57d96e1e360711c93a67 Mon Sep 17 00:00:00 2001 From: victor-eds Date: Tue, 3 Dec 2024 13:57:55 +0000 Subject: [PATCH] [XPU][OptRed] Define `triton_intel_gpu.simd_reduce` and use in optimized transposed reduction Define SIMD transpose-reduce operation performing a SIMD reduction while transposing the implicit SIMD matrix. See description definition for further context. Using this operation in the transpose reduction pass allows us to perform the optimization while not using SLM. Signed-off-by: victor-eds Co-authored-by: chengjunlu Signed-off-by: Victor Perez --- test/Conversion/intel/simd-reduce.mlir | 24 ++ .../optimize-reduction-simd.mlir | 289 +++++++++++++++ test/TritonIntelGPU/optimize-reduction.mlir | 303 ++++++++-------- test/TritonIntelGPU/tritonintelgpu.mlir | 14 + .../TritonIntelGPU/IR/TritonIntelGPUOps.td | 63 ++++ .../TritonIntelGPUToLLVM/VISAASMFormat.h | 330 ++++++++++++++++++ .../lib/TritonIntelGPUToLLVM/CMakeLists.txt | 1 + .../TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp | 118 ++++++- .../TritonIntelGPUToLLVM/VISAASMFormat.cpp | 237 +++++++++++++ .../OptimizeReductionLocality.cpp | 268 +++++++++++++- 10 files changed, 1486 insertions(+), 161 deletions(-) create mode 100644 test/Conversion/intel/simd-reduce.mlir create mode 100644 test/TritonIntelGPU/optimize-reduction-simd.mlir create mode 100644 third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h create mode 100644 third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp diff --git a/test/Conversion/intel/simd-reduce.mlir b/test/Conversion/intel/simd-reduce.mlir new file mode 100644 index 0000000000..a4d328e0a0 --- /dev/null +++ b/test/Conversion/intel/simd-reduce.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s + +// Basic 16x16 SIMD reduction. + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: llvm.func spir_kernelcc @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct +// CHECK: %[[VAL_17:.*]] = llvm.mlir.poison : vector<16xf32> +// COM: Check we insert all tensor elements in a vector: +// CHECK-COUNT-16: llvm.insertelement +// CHECK: %[[VAL_50:.*]] = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "{\0A.decl temp_result v_type=G type=f num_elts=128 align=wordx32\0Aadd (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> temp_result(0, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> temp_result(2, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> temp_result(4, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> temp_result(6, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> temp_result(0, 2)<4;2,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> temp_result(2, 2)<4;2,1>\0Aadd (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> temp_result(0, 1)<2;1,0>\0A}", "=rw,rw" %{{.*}} : (vector<16xf32>) -> f32 +// COM: Check we obtain a single result, i.e., the SIMD reduction minimizes register usage. +// CHECK: %[[VAL_51:.*]] = llvm.mlir.undef : !llvm.struct<(f32)> +// CHECK: %[[VAL_52:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_51]][0] : !llvm.struct<(f32)> +// CHECK: llvm.return %[[VAL_52]] : !llvm.struct<(f32)> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #blocked1> { + %0 = triton_intel_gpu.simd_reduce add %arg0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + tt.return %0 : tensor<16xf32, #blocked1> + } +} diff --git a/test/TritonIntelGPU/optimize-reduction-simd.mlir b/test/TritonIntelGPU/optimize-reduction-simd.mlir new file mode 100644 index 0000000000..e7748dea71 --- /dev/null +++ b/test/TritonIntelGPU/optimize-reduction-simd.mlir @@ -0,0 +1,289 @@ +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s + +// This test serves as a counterpart to optimize-reduction.mlir for cases in +// which the SIMD reduction is supported. + +// Test reduction in a single warp (16x16->16). + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_4:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_4]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x1x1x1xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x1x1xf32, #[[$ATTR_1]]> -> tensor<16x1x1x1xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x1x1xf32, #[[$ATTR_1]]> -> tensor<16x1x1xf32, #[[$ATTR_2]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x1x1xf32, #[[$ATTR_2]]>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_2]]}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the non-reduction dimension (32x16->32). + +// CHECK: #[[$ATTR_5:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_6:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_7:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_8:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_9:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_single_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_9]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x1x1x2xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x1x2xf32, #[[$ATTR_6]]> -> tensor<16x1x1x2xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x1x2xf32, #[[$ATTR_6]]> -> tensor<16x1x2xf32, #[[$ATTR_7]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x1x2xf32, #[[$ATTR_7]]>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: } + tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the reduction dimension (16x32->16). + +// CHECK: #[[$ATTR_10:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_11:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_12:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 1], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_13:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps_red( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_14]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x1x2x1xf32, #[[$ATTR_11]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x1xf32, #[[$ATTR_11]]> -> tensor<16x1x2x1xf32, #[[$ATTR_11]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x1xf32, #[[$ATTR_11]]> -> tensor<16x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_12]]}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: } + tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across both dimensions (32x32->32). + +// CHECK: #[[$ATTR_15:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_16:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_17:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 2], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_18:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_19:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_19]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x1x2x2xf32, #[[$ATTR_16]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x2xf32, #[[$ATTR_16]]> -> tensor<16x1x2x2xf32, #[[$ATTR_16]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x2xf32, #[[$ATTR_16]]> -> tensor<16x2x2xf32, #[[$ATTR_17]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x2xf32, #[[$ATTR_17]]>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_17]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_17]]}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: } + tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. + +// CHECK: #[[$ATTR_20:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 2, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_21:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_22:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 4], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_23:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_24:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_24]]>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce max %[[VAL_6]] axis = 0 : tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x1x2x4xf32, #[[$ATTR_21]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x4xf32, #[[$ATTR_21]]> -> tensor<16x1x2x4xf32, #[[$ATTR_21]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x4xf32, #[[$ATTR_21]]> -> tensor<16x2x4xf32, #[[$ATTR_22]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.maxnumf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x4xf32, #[[$ATTR_22]]>) -> tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_22]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_22]]}>> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with repCluster[0] = 4. + +// CHECK: #[[$ATTR_25:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 2, 4, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_26:.+]] = #ttg.blocked<{sizePerThread = [1, 2, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_27:.+]] = #ttg.blocked<{sizePerThread = [2, 1, 1, 1], threadsPerWarp = [8, 2, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_28:.+]] = #ttg.blocked<{sizePerThread = [2, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 4], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_29:.+]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_30:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2], A = [32, 8], B = [8, 32], C = [32, 32]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$ATTR_30]]>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$ATTR_30]]> -> tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]>) -> tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>>) -> tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce max %[[VAL_6]] axis = 0 : tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> -> tensor<16x2x2x4xf32, #[[$ATTR_26]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x2x2x4xf32, #[[$ATTR_26]]> -> tensor<16x2x2x4xf32, #[[$ATTR_27]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x2x2x4xf32, #[[$ATTR_27]]> -> tensor<32x2x4xf32, #[[$ATTR_28]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.maxnumf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<32x2x4xf32, #[[$ATTR_28]]>) -> tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_28]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_28]]}>> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_29]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_29]]}>> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index a7bd409467..5847ba65a9 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -1,5 +1,8 @@ // RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s +// This test serves as a counterpart to optimize-reduction-simd.mlir for cases +// in which the SIMD reduction is not supported. + // Test reduction in a single warp (16x16->16). // CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> @@ -12,41 +15,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_single( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_4]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf16, #[[$ATTR_4]]>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf16, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf16, #[[$ATTR_0]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x1x1x1xf16, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> -> tensor<16x16x1x1xf32, #[[$ATTR_2]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x1x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x8x2x1x1xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x1xf16, #[[$ATTR_1]]> -> tensor<16x16x1x1xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x1x1xf32, #[[$ATTR_2]]>) -> tensor<16x1x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x1x1xf16, #[[$ATTR_2]]>) -> tensor<16x1x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x1x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x1x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>>) -> tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -> tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> // CHECK: } - tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_single(%arg0: tensor<16x16xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<16x16xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -64,41 +67,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_single_twice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_9]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf16, #[[$ATTR_9]]>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf16, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf16, #[[$ATTR_5]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x1x1x2xf16, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> -> tensor<16x16x1x2xf32, #[[$ATTR_7]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x1x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x8x2x1x2xf16, #[[$ATTR_6]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x2xf16, #[[$ATTR_6]]> -> tensor<16x16x1x2xf16, #[[$ATTR_7]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x1x2xf32, #[[$ATTR_7]]>) -> tensor<16x1x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x1x2xf16, #[[$ATTR_7]]>) -> tensor<16x1x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x1x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x1x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>>) -> tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> // CHECK: } - tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_single_twice(%arg0: tensor<32x16xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<32x16xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -116,41 +119,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps_red( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_14]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf16, #[[$ATTR_14]]>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf16, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf16, #[[$ATTR_10]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x2x1x1xf16, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> -> tensor<16x16x2x1xf32, #[[$ATTR_12]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x8x2x2x1xf16, #[[$ATTR_11]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x1xf16, #[[$ATTR_11]]> -> tensor<16x16x2x1xf16, #[[$ATTR_12]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x2x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x1xf16, #[[$ATTR_12]]>) -> tensor<16x2x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>>) -> tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -> tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } - tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps_red(%arg0: tensor<16x32xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<16x32xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -168,41 +171,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_19]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf16, #[[$ATTR_19]]>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf16, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf16, #[[$ATTR_15]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x2x1x2xf16, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> -> tensor<16x16x2x2xf32, #[[$ATTR_17]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x8x2x2x2xf16, #[[$ATTR_16]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x2xf16, #[[$ATTR_16]]> -> tensor<16x16x2x2xf16, #[[$ATTR_17]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x2xf32, #[[$ATTR_17]]>) -> tensor<16x2x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x2xf16, #[[$ATTR_17]]>) -> tensor<16x2x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>>) -> tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> // CHECK: } - tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps(%arg0: tensor<32x32xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<32x32xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -219,41 +222,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_24]]>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf16, #[[$ATTR_24]]>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf16, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf16, #[[$ATTR_20]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x2x2x2x1x4xf16, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> -> tensor<16x16x2x4xf32, #[[$ATTR_22]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x4xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x4xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x8x2x2x4xf16, #[[$ATTR_21]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x4xf16, #[[$ATTR_21]]> -> tensor<16x16x2x4xf16, #[[$ATTR_22]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x4xf32, #[[$ATTR_22]]>) -> tensor<16x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x4xf16, #[[$ATTR_22]]>) -> tensor<16x2x4xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>>) -> tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x4xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>>) -> tensor<16x4xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x4xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -> tensor<64xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<64xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> // CHECK: } - tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test(%arg0: tensor<64x64xf16, #mma>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.maxnumf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<64x64xf16, #mma>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -274,24 +277,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$ATTR_29]]> -> tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 // CHECK: }) : (tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]>) -> tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>>) -> tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> // CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> -> tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> // CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> -> tensor<16x32x2x4xf32, #[[$ATTR_27]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ // CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32 // CHECK: tt.reduce.return %[[VAL_15]] : f32 // CHECK: }) : (tensor<16x32x2x4xf32, #[[$ATTR_27]]>) -> tensor<32x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_17]], %[[VAL_18]] : f32 // CHECK: tt.reduce.return %[[VAL_19]] : f32 // CHECK: }) : (tensor<32x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>>) -> tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> // CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_28]]}>> @@ -301,7 +304,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 + %1 = arith.mulf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> diff --git a/test/TritonIntelGPU/tritonintelgpu.mlir b/test/TritonIntelGPU/tritonintelgpu.mlir index 3d486780b2..155bf8e19b 100644 --- a/test/TritonIntelGPU/tritonintelgpu.mlir +++ b/test/TritonIntelGPU/tritonintelgpu.mlir @@ -58,3 +58,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.return %res : tensor<16x16xf16> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { + tt.func @triton_intel_gpu.simd_reduce(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #blocked1> { + // CHECK-LABEL: @triton_intel_gpu.simd_reduce + // CHECK: triton_intel_gpu.simd_reduce add %{{.*}} axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + %0 = triton_intel_gpu.simd_reduce add %arg0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + tt.return %0 : tensor<16xf32, #blocked1> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td index 85d33ed5c5..811c33767f 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td @@ -12,6 +12,7 @@ include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td" include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td" include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUDialect.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -202,4 +203,66 @@ def TTIG_SubGroupTransposeOp let hasVerifier = 1; } +def TTIG_SIMDReduceOp : TTIG_Op<"simd_reduce", [Pure, SameOperandsAndResultElementType]> { + let summary = "SIMD reduction."; + let description = [{ + The `triton_intel_gpu.simd_reduce` operation performs a SIMD reduction. + Contrary to `tt.reduce`, when performing a warp reduction, the result is + non-uniform. + + The reduction axis must be in such a way that only a warp reduction is + performed, i.e., `sizePerThread[axis]`, `warpsPerCTA[axis]` and + `CTAsPerCGA[axis]` must be 1; and `shape[axis]` and `threadsPerWarp[axis]` + must be equal to the sub-group size. + + The output type must be compatible with the performed reduction. However, + ensuring this is up to the user. As a rule of thumb, the total number of + elements in the output tensor must be sub-group size smaller than in the + original one. Users should bear in mind a tensor like: + + ``` + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + ``` + + would be reduced to: + + ``` + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + ``` + + Example: + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + triton_intel_gpu.simd_reduce add %0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + // # 3D reduction: + #blocked = #ttg.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [0, 1, 2]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [0, 1]}> + triton_intel_gpu.simd_reduce add %0 axis = 0 : tensor<16x16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + ``` + }]; + let arguments = (ins TT_Tensor:$src, + TritonGEN_ReduceKindAttr: $op, + I32Attr:$axis); + let results = (outs TT_Tensor:$res); + let assemblyFormat = [{ + $op $src `axis` `=` $axis attr-dict `:` type($src) `->` type($res) + }]; +} + #endif diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h b/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h new file mode 100644 index 0000000000..1f84514e50 --- /dev/null +++ b/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h @@ -0,0 +1,330 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_VISA_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_VISA_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +struct VISAInstr; +struct VISAInstrCommon; +struct VISAInstrExecution; + +// VISABuilder helps to manage a VISA asm program. +// +// A helper for building an ASM program, the objective of VISABuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect clearer. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// +// Usage: +// To build: +// ``` +// @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p)); +// ``` +// ```cpp +// VISABuilder builder; +// auto& add = builder.create<>(); +// add.predicate(pVal).o("lo").o("u32"); // add any suffix +// // predicate here binds %0 to pVal, pVal is a mlir::Value +// +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal +// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate +// +// To get the asm code: +// builder.dump() +// ``` +// To get all the `mlir::Value` used in the VISA code, +// ```cpp +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// ``` +// To get the string containing all the constraints separated by a comma: +// ```cpp +// builder.getConstraints() // get "=r,r,k" +// ``` +// VISABuilder can build a VISA asm with multiple instructions, sample code: +// ```cpp +// VISABuilder builder; +// auto& mov = builder.create("mov"); +// auto& cp = builder.create("cp"); +// mov(...); +// cp(...); +// ``` +// This will get a VISA code with two instructions. +// +// Similar to a C function, a declared VISAInstr instance can be launched +// multiple times with different operands, e.g. +// ```cpp +// auto& mov = builder.create("mov"); +// mov(... some operands ...); +// mov(... some different operands ...); +// ``` +// Finally, we will get a VISA code with two mov instructions. +// +// There are several derived instruction type for typical instructions, for +// example, the PtxIOInstr for ld and st instructions. +struct VISABuilder { + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(const Operation &) = delete; + Operand(Value value, StringRef constraint) + : constraint(constraint), value(value) {} + + bool isList() const { return !value && constraint.empty(); } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + + std::string dump() const; + }; + + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formatter: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formatter = nullptr); + + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + // If the operand will be used in predicated execution, + // users may want to initialize it before use. + // Otherwise if the register is only used in the true branch or the false + // branch but not both, the register is undefined and ptxas can perform + // aggressive optimizations that may lead to incorrect results. + Operand *newOperand(StringRef constraint, bool init = false); + + // Create a new operand that is tied to a previous operand. In this case the + // asm would be permitted to write to an input register. Instead of providing + // constraint code for this operand, the constraint code of the tied operand + // is used. + Operand *newOperand(unsigned operandIndex); + + // Create a constant integer operand. + Operand *newConstantOperand(int64_t v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstraints() const; + + std::string dump() const; + + mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, + ArrayRef attrs = {}) const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + void initOperand(Operand *opr); + + // Make the operands in argArchive follow the provided \param order. + void reorderArgArchive(ArrayRef order) { + assert(order.size() == argArchive.size()); + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but + // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are + // determined by VISA code snippet passed from external. + sort(argArchive.begin(), argArchive.end(), + [&](std::unique_ptr &a, std::unique_ptr &b) { + auto ida = std::find(order.begin(), order.end(), a.get()); + auto idb = std::find(order.begin(), order.end(), b.get()); + assert(ida != order.end()); + assert(idb != order.end()); + return ida < idb; + }); + } + + friend struct VISAInstr; + friend struct VISAInstrCommon; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> instrs; + llvm::SmallVector, 4> executions; + int oprCounter{}; +}; + +// VISA instruction common interface. +// Put the generic logic for all the instructions here. +struct VISAInstrCommon { + explicit VISAInstrCommon(VISABuilder *builder) : builder(builder) {} + + using Operand = VISABuilder::Operand; + + // Set operands of this instruction. + VISAInstrExecution &operator()(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + +protected: + // "Call" the instruction with operands. + // \param oprs The operands of this instruction. + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments + // to the inline Asm without generating the operand ids(such as $0, $1) in + // VISA code. + VISAInstrExecution &call(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + + VISABuilder *builder{}; + llvm::SmallVector instrParts; + + friend struct VISAInstrExecution; +}; + +template struct VISAInstrBase : public VISAInstrCommon { + using Operand = VISABuilder::Operand; + + explicit VISAInstrBase(VISABuilder *builder, const std::string &name) + : VISAInstrCommon(builder) { + o(name); + } + + // Append a suffix to the instruction. + // e.g. VISAInstr("add").o("s32") get a add.s32. + // A predicate is used to tell whether to apply the suffix, so that no if-else + // code needed. e.g. `VISAInstr("add").o("s32", isS32).o("u32", !isS32);` will + // get a `add.s32` if isS32 is true. + ConcreteT &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *static_cast(this); + } +}; + +struct VISAInstr : public VISAInstrBase { + using VISAInstrBase::VISAInstrBase; + + // Append a ".global" to the instruction. + VISAInstr &global(); + + // Append a ".shared" to the instruction. + VISAInstr &shared(); + + // Append a ".v[0-9]+" to the instruction + VISAInstr &v(int vecWidth, bool predicate = true); + + // Append a".b[0-9]+" to the instruction + VISAInstr &b(int width); +}; + +// Record the operands and context for "launching" a PtxInstr. +struct VISAInstrExecution { + using Operand = VISABuilder::Operand; + + llvm::SmallVector argsInOrder; + + VISAInstrExecution() = default; + explicit VISAInstrExecution(VISAInstrCommon *instr, + llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} + + // Prefix a predicate to the instruction. + VISAInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { + pred = instr->builder->newOperand(value, constraint); + return *this; + } + + // Prefix a !predicate to the instruction. + VISAInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { + pred = instr->builder->newOperand(value, constraint); + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; + return *this; + } + + std::string dump() const; + + SmallVector getArgList() const; + + VISAInstrCommon *instr{}; + Operand *pred{}; + bool onlyAttachMLIRArgs{}; +}; + +/// ====== Some instruction wrappers ====== +// We add the wrappers to make the usage more intuitive by avoiding mixing the +// VISA code with some trivial C++ code. + +struct VISACpAsyncLoadInstr : VISAInstrBase { + explicit VISACpAsyncLoadInstr(VISABuilder *builder, + triton::CacheModifier modifier) + : VISAInstrBase(builder, "cp.async") { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + } +}; + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 4e86cbd2f2..89659234a0 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -26,6 +26,7 @@ add_triton_library(TritonIntelGPUToLLVM TypeConverter.cpp Utility.cpp ViewOpToLLVM.cpp + VISAASMFormat.cpp DEPENDS TritonIntelGPUConversionPassIncGen diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp index 4b7efe346f..ca24c6e80f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,9 +1,13 @@ +#include "intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h" + #include "PatternTritonGPUOpToLLVM.h" #include "ReduceScanCommon.h" #include "Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include +#include "llvm/Support/FormatVariadic.h" + using namespace mlir; using namespace mlir::triton; @@ -490,10 +494,122 @@ struct ReduceOpConversion rewriter.replaceOp(op, results); } }; + +class SIMDReduceOpConversion final + : public ConvertTritonGPUOpToLLVMPattern< + mlir::triton::gpu::intel::SIMDReduceOp> { +public: + using OpTy = mlir::triton::gpu::intel::SIMDReduceOp; + + SIMDReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + static SmallVector> splitInBatches(ArrayRef srcValues, + size_t batchSize) { + SmallVector> batches; + for (; !srcValues.empty(); srcValues = srcValues.drop_front(batchSize)) + batches.push_back(srcValues.take_front(batchSize)); + return batches; + } + + LogicalResult + matchAndRewrite(OpTy simdReduce, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Location loc = simdReduce->getLoc(); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // TODO: Implement relevant types. + if (!srcValues.front().getType().isF32()) + return failure(); + + // TODO: Implement mapping for all reduction kinds. + if (simdReduce.getOp() != TritonGEN::ReduceKind::ADD && + simdReduce.getOp() != TritonGEN::ReduceKind::MAX) + return failure(); + + auto mod = simdReduce->getParentOfType(); + constexpr unsigned supportedSIMDSize = 16; + unsigned subGroupSize = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // Only SIMD 16 is supported. + if (subGroupSize != supportedSIMDSize) + return failure(); + assert(srcValues.size() % subGroupSize == 0 && + "Expecting splittable input"); + SmallVector results(srcValues.size() / subGroupSize); + constexpr StringLiteral asmTemplate = R"({ +.decl temp_result v_type=G type=f num_elts=128 align=wordx32 +{0} (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> temp_result(0, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> temp_result(2, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> temp_result(4, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> temp_result(6, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> temp_result(0, 2)<4;2,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> temp_result(2, 2)<4;2,1> +{0} (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> temp_result(0, 1)<2;1,0> +})"; + std::string batchedHorizontalReduce = + llvm::formatv(asmTemplate.data(), getASMOperation(simdReduce.getOp())) + .str(); + constexpr unsigned vecWidth = 16; + VectorType reduceTy = vec_ty(srcValues.front().getType(), vecWidth); + llvm::transform( + splitInBatches(srcValues, subGroupSize), std::begin(results), + [&](ArrayRef inputs) { + auto inputRange = llvm::enumerate(inputs); + Value batchedReduceVal = std::accumulate( + std::begin(inputRange), std::end(inputRange), + rewriter.create(loc, reduceTy).getRes(), + [reduceTy, loc, &rewriter](Value acc, auto entry) -> Value { + auto [index, src] = entry; + return insert_element(reduceTy, acc, src, i32_val(index)); + }); + VISABuilder vISABuilder; + VISAInstr &bReduceOp = *vISABuilder.create<>(batchedHorizontalReduce); + VISABuilder::Operand *res = vISABuilder.newOperand("=rw"); + VISABuilder::Operand *in = + vISABuilder.newOperand(batchedReduceVal, "rw"); + bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true); + Type resultTy = reduceTy.getElementType(); + return vISABuilder.launch(rewriter, loc, resultTy, true); + }); + Value packedRes = packLLElements(loc, getTypeConverter(), results, rewriter, + simdReduce.getRes().getType()); + rewriter.replaceOp(simdReduce, packedRes); + return success(); + } + +private: + static StringRef getASMOperation(TritonGEN::ReduceKind op) { + switch (op) { + case TritonGEN::ReduceKind::ADD: + return "add"; + case TritonGEN::ReduceKind::MAX: + return "max"; + default: + llvm_unreachable("Unhandled kind"); + } + } + + const TargetInfoBase &targetInfo; +}; } // namespace void mlir::triton::intel::populateReduceOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, + targetInfo, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp new file mode 100644 index 0000000000..52237cd6de --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp @@ -0,0 +1,237 @@ +#include "intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" +#include "llvm/Support/raw_ostream.h" +// TODO(Superjomn): unify to llvm::raw_string_ostream +#include + +namespace mlir { +namespace triton { + +VISAInstr::Operand * +VISABuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formatter) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formatter; + opr->idx = oprCounter++; + return opr; +} + +void VISABuilder::initOperand(Operand *opr) { + auto numBits = 0; + // Derive numBits from the constraint. + if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h') + numBits = 16; + else if (opr->constraint[1] == 'r') + numBits = 32; + else if (opr->constraint[1] == 'l') + numBits = 64; + else + llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str()); + // If numBits is less than 16, we use 16 as default because VISA does not + // support 8-bit mov. + numBits = numBits < 16 ? 16 : numBits; + auto *zero = newConstantOperand(0); + auto &init = create<>("mov")->o("u" + std::to_string(numBits)); + init(opr, zero); +} + +VISABuilder::Operand *VISABuilder::newOperand(StringRef constraint, bool init) { + // Constraint should be something like "=rw" + assert(constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + if (init) { + initOperand(opr); + } + return opr; +} + +VISABuilder::Operand *VISABuilder::newOperand(unsigned operandIndex) { + assert(operandIndex < oprCounter && "operand index out of range"); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = std::to_string(operandIndex); + return opr; +} + +VISABuilder::Operand *VISABuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +VISABuilder::Operand *VISABuilder::newConstantOperand(int64_t v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string VISABuilder::getConstraints() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector VISABuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector VISABuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +mlir::Value VISABuilder::launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = rewriter.create( + loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + +std::string VISAInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return "$" + std::to_string(idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +VISAInstr::Operand *VISABuilder::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + std::stringstream ss; + ss << "[ $" << idx << " + " << off << " ]"; + return ss.str(); + }; + + return opr; +} + +std::string VISABuilder::dump() const { + llvm::SmallVector lines; + for (auto &exec : executions) { + lines.push_back(exec->dump()); + } + + return strJoin(lines, "\n\t"); +} + +VISAInstrExecution &VISAInstrCommon::call(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + if (onlyAttachMLIRArgs) { + // Nearly impossible to make the $0,$1 in two VISA code snippets to point to + // the same MLIR values in onlyAttachMLIRArgs mode. + assert(builder->executions.empty() && + "builder can only hold a single execution when onlyAttachMIIRArgs " + "is true."); + builder->reorderArgArchive(oprs); + } + + builder->executions.emplace_back( + std::make_unique(this, oprs, onlyAttachMLIRArgs)); + + return *builder->executions.back(); +} + +VISAInstrExecution &VISAInstrCommon::operator()(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + return call(oprs, onlyAttachMLIRArgs); +} + +std::string VISAInstrExecution::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + + if (pred) { + if (!pred->repr) + os << "@" << pred->dump() << " "; + else + os << pred->repr(pred->idx) << " "; + } + + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) { + os << instrRepr; + os.flush(); + return osStr; + } + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + os << instrRepr << " " << argsRepr << ";"; + os.flush(); + return osStr; +} + +SmallVector +VISAInstrExecution::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +VISAInstr &VISAInstr::global() { + o("global"); + return *this; +} + +VISAInstr &VISAInstr::shared() { + o("shared"); + return *this; +} + +VISAInstr &VISAInstr::v(int vecWidth, bool predicate) { + if (vecWidth > 1) { + o("v" + std::to_string(vecWidth), predicate); + } + return *this; +} + +VISAInstr &VISAInstr::b(int width) { + o("b" + std::to_string(width)); + return *this; +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 610d82c453..abc3c0a20b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -15,6 +15,8 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + #define DEBUG_TYPE "tritonintelgpu-optimize-reduction-locality" namespace mlir::triton::gpu::intel { @@ -183,6 +185,7 @@ namespace { /// - Dimensions 1 and 3 refer to the original dimension 0 /// - Dimensions 0, and 2 refer to the original dimension 1 /// - Order is preserved + /// /// And on with step 3, after reducing on dimensions 0 and 1 (2 - 1 as 0 is /// squashed), we'd get: /// ``` @@ -200,6 +203,81 @@ namespace { /// | t3 | /// ``` /// And untranspose with a layout conversion to the original layout. + /// + /// In case the element type is `f32` and the reduction operation is a simple + /// add or max operation, we can use an optimized SIMD transposed reduction in + /// registers so no SLM transpose is needed. This would replace the two steps + /// above, but leading to a result with the same encoding. + /// + /// In order to do so, the SIMD reduction result is: + /// - Shape: [executionSize, + /// repeatCount * repCluster[0] / executionSize, + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [1, repeatCount * repCluster[0] / executionSize, 1, 1], + /// threadsPerWarp = [executionSize, 1, 1, 1], + /// warpsPerCTA = [1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// threadsPerWarp[0] + /// <------------------> + /// ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... ^ + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | + /// sizePerThread[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | warpsPerCTA[2] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | + /// ``` + /// + /// Note how the SIMD transposition leads to a different order compared to the + /// original approach above (and an additional dimension). Before going on to + /// the next step, we need to get to the same layout as above via a layout + /// conversion: + /// - Shape: [executionSize, + /// repeatCount * repCluster[0] / executionSize, + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [repeatCount * repCluster[0] / executionSize, 1, 1, 1], + /// threadsPerWarp = [executionSize^2 / (repeatCount * repCluster[0]), + /// repeatCount * repCluster[0] / executionSize, + /// 1, 1] + /// warpsPerCTA = [1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// sizePerThread[0] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0,1] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// + /// Followed by a reshape to get rid of one of the dimensions: + /// + /// - Shape: [repeatCount * repCluster[0], + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [repeatCount * repCluster[0] / executionSize, 1, 1], + /// threadsPerWarp = [executionSize, 1, 1] + /// warpsPerCTA = [1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// sizePerThread[0] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// After this, the cross the warp reduction is performed and the same steps + /// to go back to the expected tensor type are performed as above. // clang-format on struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -209,6 +287,7 @@ struct DpasOperandPattern final : OpRewritePattern { static constexpr int preferredReductionAxis = 1; // Intermediate reductions + static constexpr int simdReductionAxis = 0; static constexpr int finalElementwiseReductionAxis = 0; static constexpr int finalWarpsReductionAxis = 1; static constexpr int innerElementwiseReductionAxis = 2; @@ -263,20 +342,39 @@ struct DpasOperandPattern final : OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "Performed initial elementwise reductions: " << operand << "\n"); - operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding); + // Some combinations of operation + type can be represented as special SIMD + // reductions. + if (supportsSIMDReduction(op)) { + operand = performSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Converted layout for final reduction: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "SIMD reduction performed: " << operand << "\n"); - operand = reshapeForFinalReduction(op, rewriter, operand, encoding); + operand = convertLayoutPostSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Reshaped for final reduction: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Converted layout for final reduction: " + << operand << "\n"); - operand = performFinalElementwiseReduction(op, rewriter, operand); + operand = reshapePostSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Final elementwise reduction performed: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + } else { + operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding); + + LLVM_DEBUG(llvm::dbgs() << "Converted layout for final reduction: " + << operand << "\n"); + + operand = reshapeForFinalReduction(op, rewriter, operand, encoding); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + + operand = performFinalElementwiseReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() << "Final elementwise reduction performed: " + << operand << "\n"); + } operand = performFinalAcrossWarpsReduction(op, rewriter, operand); @@ -377,6 +475,157 @@ struct DpasOperandPattern final : OpRewritePattern { outerElementwiseReductionAxis); } + static bool supportsSIMDReduction(ReduceOp op) { + constexpr unsigned simd16 = 16; + auto mod = op->getParentOfType(); + if (triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod) != simd16) + return false; + + // TODO: Enable more element types when enabled in the lowering. + if (!cast(op.getSrcs().front().getType()) + .getElementType() + .isF32()) + return false; + + Region &combineOp = op.getCombineOp(); + if (combineOp.getBlocks().size() > 1) + return false; + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return false; + if (reduceOp->getOperand(0) != block.getArgument(0) || + reduceOp->getOperand(1) != block.getArgument(1)) + return false; + + // TODO: Enable more operations when more are enabled in the lowering. + return TypeSwitch(reduceOp) + .Case([](auto) { return true; }) + .Default(false); + } + + static TritonGEN::ReduceKind getSIMDReductionKind(ReduceOp op) { + Region &combineOp = op.getCombineOp(); + assert(combineOp.getBlocks().size() <= 1 && "Unexpected number of blocks"); + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + assert(reduceOp && reduceOp->getNumOperands() == 2 && + reduceOp->getNumResults() == 1 && + "Expecting sub-group reduction-like operation"); + assert(reduceOp->getOperand(0) == block.getArgument(0) && + reduceOp->getOperand(1) == block.getArgument(1) && + "Expecting sub-group reduction-like operation"); + + // TODO: Enable more operations when more are enabled in the lowering. + return TypeSwitch(reduceOp) + .Case([](arith::AddFOp) { return TritonGEN::ReduceKind::ADD; }) + .Case([](arith::MaxNumFOp) { return TritonGEN::ReduceKind::MAX; }); + } + + Value performSIMDReduction(ReduceOp op, PatternRewriter &rewriter, Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 4; + std::array shape{ + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; + std::array sizePerThread{ + 1, + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); + + TritonGEN::ReduceKind redOp = getSIMDReductionKind(op); + return rewriter.create(op.getLoc(), + static_cast(type), + val, redOp, simdReductionAxis); + } + + Value convertLayoutPostSIMDReduction(ReduceOp op, PatternRewriter &rewriter, + Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 4; + std::array sizePerThread{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1, 1}; + std::array threadsPerWarp{ + dpasEncoding.getExecutionSize() * dpasEncoding.getExecutionSize() / + (dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0]), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setEncoding(encoding); + + return rewriter.create( + op.getLoc(), static_cast(type), val); + } + + Value reshapePostSIMDReduction(ReduceOp op, PatternRewriter &rewriter, + Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 3; + std::array shape{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0], + dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; + std::array sizePerThread{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{1, dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); + + return rewriter.create(op.getLoc(), + static_cast(type), val, + /*allow_reorder=*/true, + /*efficient_layout=*/true); + } + Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter, Value val, DpasEncodingAttr dpasEncoding) const { @@ -411,7 +660,6 @@ struct DpasOperandPattern final : OpRewritePattern { Value val, DpasEncodingAttr dpasEncoding) const { auto oldType = cast(val.getType()); - ArrayRef oldShape = oldType.getShape(); constexpr size_t rank = 4; std::array shape{