diff --git a/test/Conversion/intel/simd-reduce.mlir b/test/Conversion/intel/simd-reduce.mlir new file mode 100644 index 0000000000..a4d328e0a0 --- /dev/null +++ b/test/Conversion/intel/simd-reduce.mlir @@ -0,0 +1,24 @@ +// RUN: triton-opt %s -split-input-file --intel-allocate-shared-memory --convert-triton-intel-gpu-to-llvm | FileCheck %s + +// Basic 16x16 SIMD reduction. + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: llvm.func spir_kernelcc @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: !llvm.struct +// CHECK: %[[VAL_17:.*]] = llvm.mlir.poison : vector<16xf32> +// COM: Check we insert all tensor elements in a vector: +// CHECK-COUNT-16: llvm.insertelement +// CHECK: %[[VAL_50:.*]] = llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "{\0A.decl temp_result v_type=G type=f num_elts=128 align=wordx32\0Aadd (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, 8)<16;8,1>\0Aadd (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> temp_result(0, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> temp_result(2, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> temp_result(4, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> temp_result(6, 4)<8;4,1>\0Aadd (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> temp_result(0, 2)<4;2,1>\0Aadd (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> temp_result(2, 2)<4;2,1>\0Aadd (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> temp_result(0, 1)<2;1,0>\0A}", "=rw,rw" %{{.*}} : (vector<16xf32>) -> f32 +// COM: Check we obtain a single result, i.e., the SIMD reduction minimizes register usage. +// CHECK: %[[VAL_51:.*]] = llvm.mlir.undef : !llvm.struct<(f32)> +// CHECK: %[[VAL_52:.*]] = llvm.insertvalue %[[VAL_50]], %[[VAL_51]][0] : !llvm.struct<(f32)> +// CHECK: llvm.return %[[VAL_52]] : !llvm.struct<(f32)> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #blocked1> { + %0 = triton_intel_gpu.simd_reduce add %arg0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + tt.return %0 : tensor<16xf32, #blocked1> + } +} diff --git a/test/TritonIntelGPU/optimize-reduction-simd.mlir b/test/TritonIntelGPU/optimize-reduction-simd.mlir new file mode 100644 index 0000000000..e7748dea71 --- /dev/null +++ b/test/TritonIntelGPU/optimize-reduction-simd.mlir @@ -0,0 +1,289 @@ +// RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s + +// This test serves as a counterpart to optimize-reduction.mlir for cases in +// which the SIMD reduction is supported. + +// Test reduction in a single warp (16x16->16). + +// CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_1:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_2:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 1], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_3:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_4:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 1], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_single( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_4]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x1x1x1xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x1x1xf32, #[[$ATTR_1]]> -> tensor<16x1x1x1xf32, #[[$ATTR_1]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x1x1xf32, #[[$ATTR_1]]> -> tensor<16x1x1xf32, #[[$ATTR_2]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x1x1xf32, #[[$ATTR_2]]>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_2]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_2]]}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: } + tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the non-reduction dimension (32x16->32). + +// CHECK: #[[$ATTR_5:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_6:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 1, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_7:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_8:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_9:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 1], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_single_twice( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_9]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x1x1x2xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x1x2xf32, #[[$ATTR_6]]> -> tensor<16x1x1x2xf32, #[[$ATTR_6]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x1x2xf32, #[[$ATTR_6]]> -> tensor<16x1x2xf32, #[[$ATTR_7]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x1x2xf32, #[[$ATTR_7]]>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_7]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_7]]}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: } + tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across the reduction dimension (16x32->16). + +// CHECK: #[[$ATTR_10:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_11:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 1], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_12:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 1], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_13:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 1], order = [0, 1]}> +// CHECK: #[[$ATTR_14:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [1, 2], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps_red( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_14]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x1x2x1xf32, #[[$ATTR_11]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x1xf32, #[[$ATTR_11]]> -> tensor<16x1x2x1xf32, #[[$ATTR_11]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x1xf32, #[[$ATTR_11]]> -> tensor<16x2x1xf32, #[[$ATTR_12]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_12]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_12]]}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: } + tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction in two warps across both dimensions (32x32->32). + +// CHECK: #[[$ATTR_15:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 2], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_16:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 2], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_17:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 2], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_18:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [0, 1]}> +// CHECK: #[[$ATTR_19:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1], A = [16, 8], B = [8, 16], C = [16, 16]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [2, 2], repCluster = [2, 1]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { + +// CHECK-LABEL: tt.func @test_two_warps( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_19]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce add %[[VAL_6]] axis = 0 : tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x1x2x2xf32, #[[$ATTR_16]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x2xf32, #[[$ATTR_16]]> -> tensor<16x1x2x2xf32, #[[$ATTR_16]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x2xf32, #[[$ATTR_16]]> -> tensor<16x2x2xf32, #[[$ATTR_17]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.addf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x2xf32, #[[$ATTR_17]]>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_17]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_17]]}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: } + tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.addf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension. + +// CHECK: #[[$ATTR_20:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 2, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_21:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_22:.+]] = #ttg.blocked<{sizePerThread = [1, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 4], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_23:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_24:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2], A = [16, 8], B = [8, 32], C = [16, 32]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [2, 2]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_24]]>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce max %[[VAL_6]] axis = 0 : tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x1x2x4xf32, #[[$ATTR_21]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x1x2x4xf32, #[[$ATTR_21]]> -> tensor<16x1x2x4xf32, #[[$ATTR_21]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x1x2x4xf32, #[[$ATTR_21]]> -> tensor<16x2x4xf32, #[[$ATTR_22]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.maxnumf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<16x2x4xf32, #[[$ATTR_22]]>) -> tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_22]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_22]]}>> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} + +// ----- + +// Test reduction across 2 warps in the reduction dimension and 4 in the non-reduction dimension with repCluster[0] = 4. + +// CHECK: #[[$ATTR_25:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 2, 4, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 2, 1, 4], order = [0, 1, 2, 3, 4, 5, 6]}> +// CHECK: #[[$ATTR_26:.+]] = #ttg.blocked<{sizePerThread = [1, 2, 1, 1], threadsPerWarp = [16, 1, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_27:.+]] = #ttg.blocked<{sizePerThread = [2, 1, 1, 1], threadsPerWarp = [8, 2, 1, 1], warpsPerCTA = [1, 1, 2, 4], order = [0, 1, 2, 3]}> +// CHECK: #[[$ATTR_28:.+]] = #ttg.blocked<{sizePerThread = [2, 1, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 2, 4], order = [0, 1, 2]}> +// CHECK: #[[$ATTR_29:.+]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [1, 16], warpsPerCTA = [2, 4], order = [0, 1]}> +// CHECK: #[[$ATTR_30:.+]] = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2], A = [32, 8], B = [8, 32], C = [32, 32]}> +#mma = #triton_intel_gpu.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 1, threadsPerWarp = 16, warpsPerCTA = [4, 2], repCluster = [4, 2]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { +// CHECK-LABEL: tt.func @test( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<128x64xf32, #[[$ATTR_30]]>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$ATTR_30]]> -> tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]> +// CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ +// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: tt.reduce.return %[[VAL_5]] : f32 +// CHECK: }) : (tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]>) -> tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>> +// CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ +// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: tt.reduce.return %[[VAL_9]] : f32 +// CHECK: }) : (tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>>) -> tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> +// CHECK: %[[VAL_10:.*]] = triton_intel_gpu.simd_reduce max %[[VAL_6]] axis = 0 : tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> -> tensor<16x2x2x4xf32, #[[$ATTR_26]]> +// CHECK: %[[VAL_11:.*]] = ttg.convert_layout %[[VAL_10]] : tensor<16x2x2x4xf32, #[[$ATTR_26]]> -> tensor<16x2x2x4xf32, #[[$ATTR_27]]> +// CHECK: %[[VAL_12:.*]] = tt.reshape %[[VAL_11]] allow_reorder efficient_layout : tensor<16x2x2x4xf32, #[[$ATTR_27]]> -> tensor<32x2x4xf32, #[[$ATTR_28]]> +// CHECK: %[[VAL_13:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ +// CHECK: ^bb0(%[[VAL_14:.*]]: f32, %[[VAL_15:.*]]: f32): +// CHECK: %[[VAL_16:.*]] = arith.maxnumf %[[VAL_14]], %[[VAL_15]] : f32 +// CHECK: tt.reduce.return %[[VAL_16]] : f32 +// CHECK: }) : (tensor<32x2x4xf32, #[[$ATTR_28]]>) -> tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_28]]}>> +// CHECK: %[[VAL_17:.*]] = tt.reshape %[[VAL_13]] allow_reorder efficient_layout : tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_28]]}>> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_29]]}>> +// CHECK: %[[VAL_18:.*]] = ttg.convert_layout %[[VAL_17]] : tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_29]]}>> -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> +// CHECK: tt.return %[[VAL_18]] : tensor<128xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_30]]}>> +// CHECK: } + tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ + ^bb0(%arg1: f32, %arg2: f32): + %1 = arith.maxnumf %arg1, %arg2 : f32 + tt.reduce.return %1 : f32 + }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> + } +} diff --git a/test/TritonIntelGPU/optimize-reduction.mlir b/test/TritonIntelGPU/optimize-reduction.mlir index a7bd409467..5847ba65a9 100644 --- a/test/TritonIntelGPU/optimize-reduction.mlir +++ b/test/TritonIntelGPU/optimize-reduction.mlir @@ -1,5 +1,8 @@ // RUN: triton-opt %s --split-input-file -tritonintelgpu-optimize-reduction-locality | FileCheck %s +// This test serves as a counterpart to optimize-reduction-simd.mlir for cases +// in which the SIMD reduction is not supported. + // Test reduction in a single warp (16x16->16). // CHECK: #[[$ATTR_0:.+]] = #ttg.blocked<{sizePerThread = [1, 8, 1, 2, 1, 1, 1], threadsPerWarp = [16, 1, 1, 1, 1, 1, 1], warpsPerCTA = [1, 1, 1, 1, 1, 1, 1], order = [0, 1, 2, 3, 4, 5, 6]}> @@ -12,41 +15,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_single( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf32, #[[$ATTR_4]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf32, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x16xf16, #[[$ATTR_4]]>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x16xf16, #[[$ATTR_4]]> -> tensor<16x8x1x2x1x1x1xf16, #[[$ATTR_0]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x1x1x1xf32, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x1x1x1xf16, #[[$ATTR_0]]>) -> tensor<16x8x2x1x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x1x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x1xf32, #[[$ATTR_1]]> -> tensor<16x16x1x1xf32, #[[$ATTR_2]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x1x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>>) -> tensor<16x8x2x1x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_0]]}>}>> -> tensor<16x8x2x1x1xf16, #[[$ATTR_1]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x1xf16, #[[$ATTR_1]]> -> tensor<16x16x1x1xf16, #[[$ATTR_2]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x1x1xf32, #[[$ATTR_2]]>) -> tensor<16x1x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x1x1xf16, #[[$ATTR_2]]>) -> tensor<16x1x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x1x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x1x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>>) -> tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_2]]}>}>> -> tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_3]]}>> -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_4]]}>> // CHECK: } - tt.func @test_single(%arg0: tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_single(%arg0: tensor<16x16xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<16x16xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<16x16xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -64,41 +67,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_single_twice( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf32, #[[$ATTR_9]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf32, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x16xf16, #[[$ATTR_9]]>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x16xf16, #[[$ATTR_9]]> -> tensor<16x8x1x2x1x1x2xf16, #[[$ATTR_5]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x1x1x2xf32, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x1x1x2xf16, #[[$ATTR_5]]>) -> tensor<16x8x2x1x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x1x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x2xf32, #[[$ATTR_6]]> -> tensor<16x16x1x2xf32, #[[$ATTR_7]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x1x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>>) -> tensor<16x8x2x1x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x1x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_5]]}>}>> -> tensor<16x8x2x1x2xf16, #[[$ATTR_6]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x1x2xf16, #[[$ATTR_6]]> -> tensor<16x16x1x2xf16, #[[$ATTR_7]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x1x2xf32, #[[$ATTR_7]]>) -> tensor<16x1x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x1x2xf16, #[[$ATTR_7]]>) -> tensor<16x1x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x1x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x1x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>>) -> tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_7]]}>}>> -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_8]]}>> -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_9]]}>> // CHECK: } - tt.func @test_single_twice(%arg0: tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_single_twice(%arg0: tensor<32x16xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32x16xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<32x16xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -116,41 +119,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps_red( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf32, #[[$ATTR_14]]>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf32, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<16x32xf16, #[[$ATTR_14]]>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<16x32xf16, #[[$ATTR_14]]> -> tensor<16x8x1x2x2x1x1xf16, #[[$ATTR_10]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x2x1x1xf32, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x2x1x1xf16, #[[$ATTR_10]]>) -> tensor<16x8x2x2x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x1xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x1xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x1xf32, #[[$ATTR_11]]> -> tensor<16x16x2x1xf32, #[[$ATTR_12]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x1xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>>) -> tensor<16x8x2x2x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x1xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_10]]}>}>> -> tensor<16x8x2x2x1xf16, #[[$ATTR_11]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x1xf16, #[[$ATTR_11]]> -> tensor<16x16x2x1xf16, #[[$ATTR_12]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x1xf32, #[[$ATTR_12]]>) -> tensor<16x2x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x1xf16, #[[$ATTR_12]]>) -> tensor<16x2x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x1xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>>) -> tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -> tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<16xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x1xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>>) -> tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x1xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_12]]}>}>> -> tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<16xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_13]]}>> -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<16xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_14]]}>> // CHECK: } - tt.func @test_two_warps_red(%arg0: tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps_red(%arg0: tensor<16x32xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<16x32xf32, #mma>) -> tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<16xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<16x32xf16, #mma>) -> tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<16xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -168,41 +171,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 2 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test_two_warps( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf32, #[[$ATTR_19]]>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf32, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<32x32xf16, #[[$ATTR_19]]>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<32x32xf16, #[[$ATTR_19]]> -> tensor<16x8x1x2x2x1x2xf16, #[[$ATTR_15]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x1x2x2x1x2xf32, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.addf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x1x2x2x1x2xf16, #[[$ATTR_15]]>) -> tensor<16x8x2x2x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x2xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x2xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x2xf32, #[[$ATTR_16]]> -> tensor<16x16x2x2xf32, #[[$ATTR_17]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.addf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x2xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>>) -> tensor<16x8x2x2x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x2xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_15]]}>}>> -> tensor<16x8x2x2x2xf16, #[[$ATTR_16]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x2xf16, #[[$ATTR_16]]> -> tensor<16x16x2x2xf16, #[[$ATTR_17]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x2xf32, #[[$ATTR_17]]>) -> tensor<16x2x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.addf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x2xf16, #[[$ATTR_17]]>) -> tensor<16x2x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x2xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>>) -> tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -> tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<32xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.addf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x2xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>>) -> tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x2xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_17]]}>}>> -> tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<32xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_18]]}>> -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<32xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_19]]}>> // CHECK: } - tt.func @test_two_warps(%arg0: tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test_two_warps(%arg0: tensor<32x32xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.addf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<32x32xf32, #mma>) -> tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<32xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.addf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<32x32xf16, #mma>) -> tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<32xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -219,41 +222,41 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32} { // CHECK-LABEL: tt.func @test( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf32, #[[$ATTR_24]]>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { -// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf32, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<64x64xf16, #[[$ATTR_24]]>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> { +// CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<64x64xf16, #[[$ATTR_24]]> -> tensor<16x8x2x2x2x1x4xf16, #[[$ATTR_20]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ -// CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 -// CHECK: tt.reduce.return %[[VAL_5]] : f32 -// CHECK: }) : (tensor<16x8x2x2x2x1x4xf32, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> +// CHECK: ^bb0(%[[VAL_3:.*]]: f16, %[[VAL_4:.*]]: f16): +// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f16 +// CHECK: tt.reduce.return %[[VAL_5]] : f16 +// CHECK: }) : (tensor<16x8x2x2x2x1x4xf16, #[[$ATTR_20]]>) -> tensor<16x8x2x2x1x4xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ -// CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 -// CHECK: tt.reduce.return %[[VAL_9]] : f32 -// CHECK: }) : (tensor<16x8x2x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> -// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x4xf32, #[[$ATTR_21]]> -> tensor<16x16x2x4xf32, #[[$ATTR_22]]> +// CHECK: ^bb0(%[[VAL_7:.*]]: f16, %[[VAL_8:.*]]: f16): +// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f16 +// CHECK: tt.reduce.return %[[VAL_9]] : f16 +// CHECK: }) : (tensor<16x8x2x2x1x4xf16, #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>>) -> tensor<16x8x2x2x4xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> +// CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x2x2x4xf16, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_20]]}>}>> -> tensor<16x8x2x2x4xf16, #[[$ATTR_21]]> +// CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x2x2x4xf16, #[[$ATTR_21]]> -> tensor<16x16x2x4xf16, #[[$ATTR_22]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ -// CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 -// CHECK: tt.reduce.return %[[VAL_15]] : f32 -// CHECK: }) : (tensor<16x16x2x4xf32, #[[$ATTR_22]]>) -> tensor<16x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>> +// CHECK: ^bb0(%[[VAL_13:.*]]: f16, %[[VAL_14:.*]]: f16): +// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f16 +// CHECK: tt.reduce.return %[[VAL_15]] : f16 +// CHECK: }) : (tensor<16x16x2x4xf16, #[[$ATTR_22]]>) -> tensor<16x2x4xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ -// CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 -// CHECK: tt.reduce.return %[[VAL_19]] : f32 -// CHECK: }) : (tensor<16x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>>) -> tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -> tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<64xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> -// CHECK: tt.return %[[VAL_21]] : tensor<64xf32, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: ^bb0(%[[VAL_17:.*]]: f16, %[[VAL_18:.*]]: f16): +// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f16 +// CHECK: tt.reduce.return %[[VAL_19]] : f16 +// CHECK: }) : (tensor<16x2x4xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>>) -> tensor<16x4xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> +// CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<16x4xf16, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_22]]}>}>> -> tensor<64xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> +// CHECK: %[[VAL_21:.*]] = ttg.convert_layout %[[VAL_20]] : tensor<64xf16, #ttg.slice<{dim = 0, parent = #[[$ATTR_23]]}>> -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> +// CHECK: tt.return %[[VAL_21]] : tensor<64xf16, #ttg.slice<{dim = 1, parent = #[[$ATTR_24]]}>> // CHECK: } - tt.func @test(%arg0: tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> { + tt.func @test(%arg0: tensor<64x64xf16, #mma>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ - ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 - tt.reduce.return %1 : f32 - }) : (tensor<64x64xf32, #mma>) -> tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> - tt.return %0 : tensor<64xf32, #ttg.slice<{dim = 1, parent = #mma}>> + ^bb0(%arg1: f16, %arg2: f16): + %1 = arith.maxnumf %arg1, %arg2 : f16 + tt.reduce.return %1 : f16 + }) : (tensor<64x64xf16, #mma>) -> tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> + tt.return %0 : tensor<64xf16, #ttg.slice<{dim = 1, parent = #mma}>> } } @@ -274,24 +277,24 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr // CHECK: %[[VAL_1:.*]] = tt.reshape %[[VAL_0]] allow_reorder efficient_layout : tensor<128x64xf32, #[[$ATTR_29]]> -> tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]> // CHECK: %[[VAL_2:.*]] = "tt.reduce"(%[[VAL_1]]) <{axis = 2 : i32}> ({ // CHECK: ^bb0(%[[VAL_3:.*]]: f32, %[[VAL_4:.*]]: f32): -// CHECK: %[[VAL_5:.*]] = arith.maxnumf %[[VAL_3]], %[[VAL_4]] : f32 +// CHECK: %[[VAL_5:.*]] = arith.mulf %[[VAL_3]], %[[VAL_4]] : f32 // CHECK: tt.reduce.return %[[VAL_5]] : f32 // CHECK: }) : (tensor<16x8x2x4x2x1x4xf32, #[[$ATTR_25]]>) -> tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>> // CHECK: %[[VAL_6:.*]] = "tt.reduce"(%[[VAL_2]]) <{axis = 4 : i32}> ({ // CHECK: ^bb0(%[[VAL_7:.*]]: f32, %[[VAL_8:.*]]: f32): -// CHECK: %[[VAL_9:.*]] = arith.maxnumf %[[VAL_7]], %[[VAL_8]] : f32 +// CHECK: %[[VAL_9:.*]] = arith.mulf %[[VAL_7]], %[[VAL_8]] : f32 // CHECK: tt.reduce.return %[[VAL_9]] : f32 // CHECK: }) : (tensor<16x8x4x2x1x4xf32, #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>>) -> tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> // CHECK: %[[VAL_10:.*]] = ttg.convert_layout %[[VAL_6]] : tensor<16x8x4x2x4xf32, #ttg.slice<{dim = 4, parent = #ttg.slice<{dim = 2, parent = #[[$ATTR_25]]}>}>> -> tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> // CHECK: %[[VAL_11:.*]] = tt.reshape %[[VAL_10]] allow_reorder efficient_layout : tensor<16x8x4x2x4xf32, #[[$ATTR_26]]> -> tensor<16x32x2x4xf32, #[[$ATTR_27]]> // CHECK: %[[VAL_12:.*]] = "tt.reduce"(%[[VAL_11]]) <{axis = 0 : i32}> ({ // CHECK: ^bb0(%[[VAL_13:.*]]: f32, %[[VAL_14:.*]]: f32): -// CHECK: %[[VAL_15:.*]] = arith.maxnumf %[[VAL_13]], %[[VAL_14]] : f32 +// CHECK: %[[VAL_15:.*]] = arith.mulf %[[VAL_13]], %[[VAL_14]] : f32 // CHECK: tt.reduce.return %[[VAL_15]] : f32 // CHECK: }) : (tensor<16x32x2x4xf32, #[[$ATTR_27]]>) -> tensor<32x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>> // CHECK: %[[VAL_16:.*]] = "tt.reduce"(%[[VAL_12]]) <{axis = 1 : i32}> ({ // CHECK: ^bb0(%[[VAL_17:.*]]: f32, %[[VAL_18:.*]]: f32): -// CHECK: %[[VAL_19:.*]] = arith.maxnumf %[[VAL_17]], %[[VAL_18]] : f32 +// CHECK: %[[VAL_19:.*]] = arith.mulf %[[VAL_17]], %[[VAL_18]] : f32 // CHECK: tt.reduce.return %[[VAL_19]] : f32 // CHECK: }) : (tensor<32x2x4xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>>) -> tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> // CHECK: %[[VAL_20:.*]] = tt.reshape %[[VAL_16]] allow_reorder efficient_layout : tensor<32x4xf32, #ttg.slice<{dim = 1, parent = #ttg.slice<{dim = 0, parent = #[[$ATTR_27]]}>}>> -> tensor<128xf32, #ttg.slice<{dim = 0, parent = #[[$ATTR_28]]}>> @@ -301,7 +304,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.func @test(%arg0: tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> { %0 = "tt.reduce"(%arg0) <{axis = 1 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): - %1 = arith.maxnumf %arg1, %arg2 : f32 + %1 = arith.mulf %arg1, %arg2 : f32 tt.reduce.return %1 : f32 }) : (tensor<128x64xf32, #mma>) -> tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> tt.return %0 : tensor<128xf32, #ttg.slice<{dim = 1, parent = #mma}>> diff --git a/test/TritonIntelGPU/tritonintelgpu.mlir b/test/TritonIntelGPU/tritonintelgpu.mlir index 3d486780b2..155bf8e19b 100644 --- a/test/TritonIntelGPU/tritonintelgpu.mlir +++ b/test/TritonIntelGPU/tritonintelgpu.mlir @@ -58,3 +58,17 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.return %res : tensor<16x16xf16> } } + +// ----- + +#blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 16 : i32} { + tt.func @triton_intel_gpu.simd_reduce(%arg0: tensor<16x16xf32, #blocked>) -> tensor<16xf32, #blocked1> { + // CHECK-LABEL: @triton_intel_gpu.simd_reduce + // CHECK: triton_intel_gpu.simd_reduce add %{{.*}} axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + %0 = triton_intel_gpu.simd_reduce add %arg0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + tt.return %0 : tensor<16xf32, #blocked1> + } +} diff --git a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td index 85d33ed5c5..811c33767f 100644 --- a/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td +++ b/third_party/intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUOps.td @@ -12,6 +12,7 @@ include "triton/Dialect/Triton/IR/TritonTypes.td" include "triton/Dialect/Triton/IR/TritonAttrDefs.td" include "triton/Dialect/TritonGPU/IR/TritonGPUTypes.td" +include "intel/include/Dialect/TritonGEN/IR/TritonGENAttrDefs.td" include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUAttrDefs.td" include "intel/include/Dialect/TritonIntelGPU/IR/TritonIntelGPUDialect.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -202,4 +203,66 @@ def TTIG_SubGroupTransposeOp let hasVerifier = 1; } +def TTIG_SIMDReduceOp : TTIG_Op<"simd_reduce", [Pure, SameOperandsAndResultElementType]> { + let summary = "SIMD reduction."; + let description = [{ + The `triton_intel_gpu.simd_reduce` operation performs a SIMD reduction. + Contrary to `tt.reduce`, when performing a warp reduction, the result is + non-uniform. + + The reduction axis must be in such a way that only a warp reduction is + performed, i.e., `sizePerThread[axis]`, `warpsPerCTA[axis]` and + `CTAsPerCGA[axis]` must be 1; and `shape[axis]` and `threadsPerWarp[axis]` + must be equal to the sub-group size. + + The output type must be compatible with the performed reduction. However, + ensuring this is up to the user. As a rule of thumb, the total number of + elements in the output tensor must be sub-group size smaller than in the + original one. Users should bear in mind a tensor like: + + ``` + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + ``` + + would be reduced to: + + ``` + t0 t1 t2 t3 t4 t5 t6 t7 t8 t9 t10 t11 t12 t13 t14 t15 + ``` + + Example: + ```mlir + #blocked = #ttg.blocked<{sizePerThread = [1, 16], threadsPerWarp = [16, 1], warpsPerCTA = [1, 1], order = [0, 1]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [16], warpsPerCTA = [1], order = [0]}> + triton_intel_gpu.simd_reduce add %0 axis = 0 : tensor<16x16xf32, #blocked> -> tensor<16xf32, #blocked1> + // # 3D reduction: + #blocked = #ttg.blocked<{sizePerThread = [1, 16, 1], threadsPerWarp = [16, 1, 1], warpsPerCTA = [1, 1, 2], order = [0, 1, 2]}> + #blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [16, 1], warpsPerCTA = [1, 2], order = [0, 1]}> + triton_intel_gpu.simd_reduce add %0 axis = 0 : tensor<16x16x2xf32, #blocked> -> tensor<16x2xf32, #blocked1> + ``` + }]; + let arguments = (ins TT_Tensor:$src, + TritonGEN_ReduceKindAttr: $op, + I32Attr:$axis); + let results = (outs TT_Tensor:$res); + let assemblyFormat = [{ + $op $src `axis` `=` $axis attr-dict `:` type($src) `->` type($res) + }]; +} + #endif diff --git a/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h b/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h new file mode 100644 index 0000000000..1f84514e50 --- /dev/null +++ b/third_party/intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h @@ -0,0 +1,330 @@ +#ifndef TRITON_CONVERSION_TRITON_GPU_TO_LLVM_VISA_ASM_FORMAT_H_ +#define TRITON_CONVERSION_TRITON_GPU_TO_LLVM_VISA_ASM_FORMAT_H_ + +#include "mlir/IR/Value.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include +#include + +namespace mlir { +class ConversionPatternRewriter; +class Location; + +namespace triton { +using llvm::StringRef; + +struct VISAInstr; +struct VISAInstrCommon; +struct VISAInstrExecution; + +// VISABuilder helps to manage a VISA asm program. +// +// A helper for building an ASM program, the objective of VISABuilder is to give +// a thin encapsulation and make the ASM code for MLIR LLVM Dialect clearer. +// Currently, several factors are introduced to reduce the need for mixing +// string and C++ if-else code. +// +// Usage: +// To build: +// ``` +// @$3 asm("@%3 add.s32 %0, %1, %2;" : "=r"(i) : "r"(j), "r"(k), "b"(p)); +// ``` +// ```cpp +// VISABuilder builder; +// auto& add = builder.create<>(); +// add.predicate(pVal).o("lo").o("u32"); // add any suffix +// // predicate here binds %0 to pVal, pVal is a mlir::Value +// +// auto* iOpr = builder.newOperand(iVal, "r"); // %1 bind to iVal +// auto* jOpr = builder.newOperand(jVal, "r"); // %2 bind to jVal +// auto* kOpr = builder.newOperand(kVal, "r"); // %3 bind to kVal +// add(iOpr, jOpr, kOpr).predicate(predVal); // set operands and predicate +// +// To get the asm code: +// builder.dump() +// ``` +// To get all the `mlir::Value` used in the VISA code, +// ```cpp +// builder.getAllMlirArgs() // get {pVal, iVal, jVal, kVal} +// ``` +// To get the string containing all the constraints separated by a comma: +// ```cpp +// builder.getConstraints() // get "=r,r,k" +// ``` +// VISABuilder can build a VISA asm with multiple instructions, sample code: +// ```cpp +// VISABuilder builder; +// auto& mov = builder.create("mov"); +// auto& cp = builder.create("cp"); +// mov(...); +// cp(...); +// ``` +// This will get a VISA code with two instructions. +// +// Similar to a C function, a declared VISAInstr instance can be launched +// multiple times with different operands, e.g. +// ```cpp +// auto& mov = builder.create("mov"); +// mov(... some operands ...); +// mov(... some different operands ...); +// ``` +// Finally, we will get a VISA code with two mov instructions. +// +// There are several derived instruction type for typical instructions, for +// example, the PtxIOInstr for ld and st instructions. +struct VISABuilder { + struct Operand { + std::string constraint; + Value value; + int idx{-1}; + llvm::SmallVector list; + std::function repr; + + // for list + Operand() = default; + Operand(const Operation &) = delete; + Operand(Value value, StringRef constraint) + : constraint(constraint), value(value) {} + + bool isList() const { return !value && constraint.empty(); } + + Operand *listAppend(Operand *arg) { + list.push_back(arg); + return this; + } + + Operand *listGet(size_t nth) const { + assert(nth < list.size()); + return list[nth]; + } + + std::string dump() const; + }; + + template + INSTR *create(Args &&...args) { + instrs.emplace_back(std::make_unique(this, args...)); + return static_cast(instrs.back().get()); + } + + // Create a list of operands. + Operand *newListOperand() { return newOperand(); } + + Operand *newListOperand(ArrayRef> items) { + auto *list = newOperand(); + for (auto &item : items) { + list->listAppend(newOperand(item.first, item.second)); + } + return list; + } + + Operand *newListOperand(unsigned count, mlir::Value val, + const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(val, constraint)); + } + return list; + } + + Operand *newListOperand(unsigned count, const std::string &constraint) { + auto *list = newOperand(); + for (unsigned i = 0; i < count; ++i) { + list->listAppend(newOperand(constraint)); + } + return list; + } + + // Create a new operand. It will not add to operand list. + // @value: the MLIR value bind to this operand. + // @constraint: ASM operand constraint, .e.g. "=r" + // @formatter: extra format to represent this operand in ASM code, default is + // "%{0}".format(operand.idx). + Operand *newOperand(mlir::Value value, StringRef constraint, + std::function formatter = nullptr); + + // Create a new operand which is written to, that is, the constraint starts + // with "=", e.g. "=r". + // If the operand will be used in predicated execution, + // users may want to initialize it before use. + // Otherwise if the register is only used in the true branch or the false + // branch but not both, the register is undefined and ptxas can perform + // aggressive optimizations that may lead to incorrect results. + Operand *newOperand(StringRef constraint, bool init = false); + + // Create a new operand that is tied to a previous operand. In this case the + // asm would be permitted to write to an input register. Instead of providing + // constraint code for this operand, the constraint code of the tied operand + // is used. + Operand *newOperand(unsigned operandIndex); + + // Create a constant integer operand. + Operand *newConstantOperand(int64_t v); + // Create a constant operand with explicit code specified. + Operand *newConstantOperand(const std::string &v); + + Operand *newAddrOperand(mlir::Value addr, StringRef constraint, int off = 0); + + llvm::SmallVector getAllArgs() const; + + llvm::SmallVector getAllMLIRArgs() const; + + std::string getConstraints() const; + + std::string dump() const; + + mlir::Value launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect = true, bool isAlignStack = false, + ArrayRef attrs = {}) const; + +private: + Operand *newOperand() { + argArchive.emplace_back(std::make_unique()); + return argArchive.back().get(); + } + + void initOperand(Operand *opr); + + // Make the operands in argArchive follow the provided \param order. + void reorderArgArchive(ArrayRef order) { + assert(order.size() == argArchive.size()); + // The order in argArchive is unnecessary when onlyAttachMLIRArgs=false, but + // it does necessary when onlyAttachMLIRArgs is true for the $0, $1... are + // determined by VISA code snippet passed from external. + sort(argArchive.begin(), argArchive.end(), + [&](std::unique_ptr &a, std::unique_ptr &b) { + auto ida = std::find(order.begin(), order.end(), a.get()); + auto idb = std::find(order.begin(), order.end(), b.get()); + assert(ida != order.end()); + assert(idb != order.end()); + return ida < idb; + }); + } + + friend struct VISAInstr; + friend struct VISAInstrCommon; + +protected: + llvm::SmallVector, 6> argArchive; + llvm::SmallVector, 2> instrs; + llvm::SmallVector, 4> executions; + int oprCounter{}; +}; + +// VISA instruction common interface. +// Put the generic logic for all the instructions here. +struct VISAInstrCommon { + explicit VISAInstrCommon(VISABuilder *builder) : builder(builder) {} + + using Operand = VISABuilder::Operand; + + // Set operands of this instruction. + VISAInstrExecution &operator()(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + +protected: + // "Call" the instruction with operands. + // \param oprs The operands of this instruction. + // \param onlyAttachMLIRArgs Indicate that it simply attach the MLIR Arguments + // to the inline Asm without generating the operand ids(such as $0, $1) in + // VISA code. + VISAInstrExecution &call(llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs = false); + + VISABuilder *builder{}; + llvm::SmallVector instrParts; + + friend struct VISAInstrExecution; +}; + +template struct VISAInstrBase : public VISAInstrCommon { + using Operand = VISABuilder::Operand; + + explicit VISAInstrBase(VISABuilder *builder, const std::string &name) + : VISAInstrCommon(builder) { + o(name); + } + + // Append a suffix to the instruction. + // e.g. VISAInstr("add").o("s32") get a add.s32. + // A predicate is used to tell whether to apply the suffix, so that no if-else + // code needed. e.g. `VISAInstr("add").o("s32", isS32).o("u32", !isS32);` will + // get a `add.s32` if isS32 is true. + ConcreteT &o(const std::string &suffix, bool predicate = true) { + if (predicate) + instrParts.push_back(suffix); + return *static_cast(this); + } +}; + +struct VISAInstr : public VISAInstrBase { + using VISAInstrBase::VISAInstrBase; + + // Append a ".global" to the instruction. + VISAInstr &global(); + + // Append a ".shared" to the instruction. + VISAInstr &shared(); + + // Append a ".v[0-9]+" to the instruction + VISAInstr &v(int vecWidth, bool predicate = true); + + // Append a".b[0-9]+" to the instruction + VISAInstr &b(int width); +}; + +// Record the operands and context for "launching" a PtxInstr. +struct VISAInstrExecution { + using Operand = VISABuilder::Operand; + + llvm::SmallVector argsInOrder; + + VISAInstrExecution() = default; + explicit VISAInstrExecution(VISAInstrCommon *instr, + llvm::ArrayRef oprs, + bool onlyAttachMLIRArgs) + : argsInOrder(oprs.begin(), oprs.end()), instr(instr), + onlyAttachMLIRArgs(onlyAttachMLIRArgs) {} + + // Prefix a predicate to the instruction. + VISAInstrExecution &predicate(mlir::Value value, StringRef constraint = "b") { + pred = instr->builder->newOperand(value, constraint); + return *this; + } + + // Prefix a !predicate to the instruction. + VISAInstrExecution &predicateNot(mlir::Value value, StringRef constraint) { + pred = instr->builder->newOperand(value, constraint); + pred->repr = [](int idx) { return "@!$" + std::to_string(idx); }; + return *this; + } + + std::string dump() const; + + SmallVector getArgList() const; + + VISAInstrCommon *instr{}; + Operand *pred{}; + bool onlyAttachMLIRArgs{}; +}; + +/// ====== Some instruction wrappers ====== +// We add the wrappers to make the usage more intuitive by avoiding mixing the +// VISA code with some trivial C++ code. + +struct VISACpAsyncLoadInstr : VISAInstrBase { + explicit VISACpAsyncLoadInstr(VISABuilder *builder, + triton::CacheModifier modifier) + : VISAInstrBase(builder, "cp.async") { + o(triton::stringifyCacheModifier(modifier).str()); + o("shared"); + o("global"); + } +}; + +} // namespace triton +} // namespace mlir + +#endif diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt index 4e86cbd2f2..89659234a0 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/CMakeLists.txt @@ -26,6 +26,7 @@ add_triton_library(TritonIntelGPUToLLVM TypeConverter.cpp Utility.cpp ViewOpToLLVM.cpp + VISAASMFormat.cpp DEPENDS TritonIntelGPUConversionPassIncGen diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp index 4b7efe346f..ca24c6e80f 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/ReduceOpToLLVM.cpp @@ -1,9 +1,13 @@ +#include "intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h" + #include "PatternTritonGPUOpToLLVM.h" #include "ReduceScanCommon.h" #include "Utility.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" #include +#include "llvm/Support/FormatVariadic.h" + using namespace mlir; using namespace mlir::triton; @@ -490,10 +494,122 @@ struct ReduceOpConversion rewriter.replaceOp(op, results); } }; + +class SIMDReduceOpConversion final + : public ConvertTritonGPUOpToLLVMPattern< + mlir::triton::gpu::intel::SIMDReduceOp> { +public: + using OpTy = mlir::triton::gpu::intel::SIMDReduceOp; + + SIMDReduceOpConversion(LLVMTypeConverter &typeConverter, + const TargetInfoBase &targetInfo, + PatternBenefit benefit) + : ConvertTritonGPUOpToLLVMPattern(typeConverter, benefit), + targetInfo(targetInfo) {} + + static SmallVector> splitInBatches(ArrayRef srcValues, + size_t batchSize) { + SmallVector> batches; + for (; !srcValues.empty(); srcValues = srcValues.drop_front(batchSize)) + batches.push_back(srcValues.take_front(batchSize)); + return batches; + } + + LogicalResult + matchAndRewrite(OpTy simdReduce, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const final { + Location loc = simdReduce->getLoc(); + + SmallVector srcValues = + unpackLLElements(loc, adaptor.getSrc(), rewriter); + + // TODO: Implement relevant types. + if (!srcValues.front().getType().isF32()) + return failure(); + + // TODO: Implement mapping for all reduction kinds. + if (simdReduce.getOp() != TritonGEN::ReduceKind::ADD && + simdReduce.getOp() != TritonGEN::ReduceKind::MAX) + return failure(); + + auto mod = simdReduce->getParentOfType(); + constexpr unsigned supportedSIMDSize = 16; + unsigned subGroupSize = + triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + // Only SIMD 16 is supported. + if (subGroupSize != supportedSIMDSize) + return failure(); + assert(srcValues.size() % subGroupSize == 0 && + "Expecting splittable input"); + SmallVector results(srcValues.size() / subGroupSize); + constexpr StringLiteral asmTemplate = R"({ +.decl temp_result v_type=G type=f num_elts=128 align=wordx32 +{0} (M1_NM, 16) temp_result(0, 0)<1> $1(0, 0)<16;8,1> $1(0, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> $1(2, 0)<16;8,1> $1(2, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(2, 0)<1> $1(4, 0)<16;8,1> $1(4, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(3, 0)<1> $1(6, 0)<16;8,1> $1(6, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(4, 0)<1> $1(8, 0)<16;8,1> $1(8, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(5, 0)<1> $1(10, 0)<16;8,1> $1(10, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(6, 0)<1> $1(12, 0)<16;8,1> $1(12, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(7, 0)<1> $1(14, 0)<16;8,1> $1(14, 8)<16;8,1> +{0} (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<8;4,1> temp_result(0, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<8;4,1> temp_result(2, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(2, 0)<1> temp_result(4, 0)<8;4,1> temp_result(4, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(3, 0)<1> temp_result(6, 0)<8;4,1> temp_result(6, 4)<8;4,1> +{0} (M1_NM, 16) temp_result(0, 0)<1> temp_result(0, 0)<4;2,1> temp_result(0, 2)<4;2,1> +{0} (M1_NM, 16) temp_result(1, 0)<1> temp_result(2, 0)<4;2,1> temp_result(2, 2)<4;2,1> +{0} (M1_NM, 16) $0(0, 0)<1> temp_result(0, 0)<2;1,0> temp_result(0, 1)<2;1,0> +})"; + std::string batchedHorizontalReduce = + llvm::formatv(asmTemplate.data(), getASMOperation(simdReduce.getOp())) + .str(); + constexpr unsigned vecWidth = 16; + VectorType reduceTy = vec_ty(srcValues.front().getType(), vecWidth); + llvm::transform( + splitInBatches(srcValues, subGroupSize), std::begin(results), + [&](ArrayRef inputs) { + auto inputRange = llvm::enumerate(inputs); + Value batchedReduceVal = std::accumulate( + std::begin(inputRange), std::end(inputRange), + rewriter.create(loc, reduceTy).getRes(), + [reduceTy, loc, &rewriter](Value acc, auto entry) -> Value { + auto [index, src] = entry; + return insert_element(reduceTy, acc, src, i32_val(index)); + }); + VISABuilder vISABuilder; + VISAInstr &bReduceOp = *vISABuilder.create<>(batchedHorizontalReduce); + VISABuilder::Operand *res = vISABuilder.newOperand("=rw"); + VISABuilder::Operand *in = + vISABuilder.newOperand(batchedReduceVal, "rw"); + bReduceOp({res, in}, /*onlyAttachMLIRArgs=*/true); + Type resultTy = reduceTy.getElementType(); + return vISABuilder.launch(rewriter, loc, resultTy, true); + }); + Value packedRes = packLLElements(loc, getTypeConverter(), results, rewriter, + simdReduce.getRes().getType()); + rewriter.replaceOp(simdReduce, packedRes); + return success(); + } + +private: + static StringRef getASMOperation(TritonGEN::ReduceKind op) { + switch (op) { + case TritonGEN::ReduceKind::ADD: + return "add"; + case TritonGEN::ReduceKind::MAX: + return "max"; + default: + llvm_unreachable("Unhandled kind"); + } + } + + const TargetInfoBase &targetInfo; +}; } // namespace void mlir::triton::intel::populateReduceOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, const TargetInfoBase &targetInfo, PatternBenefit benefit) { - patterns.add(typeConverter, targetInfo, benefit); + patterns.add(typeConverter, + targetInfo, benefit); } diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp new file mode 100644 index 0000000000..52237cd6de --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/VISAASMFormat.cpp @@ -0,0 +1,237 @@ +#include "intel/include/TritonIntelGPUToLLVM/VISAASMFormat.h" + +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Transforms/DialectConversion.h" +#include "triton/Conversion/TritonGPUToLLVM/AsmFormat.h" +#include "llvm/Support/raw_ostream.h" +// TODO(Superjomn): unify to llvm::raw_string_ostream +#include + +namespace mlir { +namespace triton { + +VISAInstr::Operand * +VISABuilder::newOperand(mlir::Value value, StringRef constraint, + std::function formatter) { + argArchive.emplace_back(std::make_unique(value, constraint)); + auto *opr = argArchive.back().get(); + opr->repr = formatter; + opr->idx = oprCounter++; + return opr; +} + +void VISABuilder::initOperand(Operand *opr) { + auto numBits = 0; + // Derive numBits from the constraint. + if (opr->constraint[1] == 'c' || opr->constraint[1] == 'h') + numBits = 16; + else if (opr->constraint[1] == 'r') + numBits = 32; + else if (opr->constraint[1] == 'l') + numBits = 64; + else + llvm_unreachable(("Unknown constraint: " + opr->constraint).c_str()); + // If numBits is less than 16, we use 16 as default because VISA does not + // support 8-bit mov. + numBits = numBits < 16 ? 16 : numBits; + auto *zero = newConstantOperand(0); + auto &init = create<>("mov")->o("u" + std::to_string(numBits)); + init(opr, zero); +} + +VISABuilder::Operand *VISABuilder::newOperand(StringRef constraint, bool init) { + // Constraint should be something like "=rw" + assert(constraint[0] == '='); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = constraint; + if (init) { + initOperand(opr); + } + return opr; +} + +VISABuilder::Operand *VISABuilder::newOperand(unsigned operandIndex) { + assert(operandIndex < oprCounter && "operand index out of range"); + auto *opr = newOperand(); + opr->idx = oprCounter++; + opr->constraint = std::to_string(operandIndex); + return opr; +} + +VISABuilder::Operand *VISABuilder::newConstantOperand(const std::string &v) { + argArchive.emplace_back(std::make_unique()); + argArchive.back()->repr = [v](int idx) { return v; }; + return argArchive.back().get(); +} + +VISABuilder::Operand *VISABuilder::newConstantOperand(int64_t v) { + std::stringstream ss; + ss << "0x" << std::hex << v; + return newConstantOperand(ss.str()); +} + +std::string VISABuilder::getConstraints() const { + auto args = getAllArgs(); + llvm::SmallVector argReprs; + for (auto arg : args) + argReprs.push_back(arg->constraint); + return strJoin(argReprs, ","); +} + +llvm::SmallVector VISABuilder::getAllMLIRArgs() const { + llvm::SmallVector res; + for (auto &arg : argArchive) { + if (!arg->isList() && arg->value) + res.push_back(arg->value); + } + return res; +} + +SmallVector VISABuilder::getAllArgs() const { + llvm::SmallVector res; + for (auto &x : argArchive) + if (!x->isList()) + res.push_back(x.get()); + return res; +} + +mlir::Value VISABuilder::launch(OpBuilder &rewriter, Location loc, Type resTy, + bool hasSideEffect, bool isAlignStack, + ArrayRef attrs) const { + auto *ctx = rewriter.getContext(); + auto inlineAsm = rewriter.create( + loc, resTy, getAllMLIRArgs(), // operands + dump(), // asm_string + getConstraints(), // constraints + hasSideEffect, // has_side_effects + isAlignStack, // is_align_stack + LLVM::AsmDialectAttr::get(ctx, + LLVM::AsmDialect::AD_ATT), // asm_dialect + ArrayAttr::get(ctx, attrs) // operand_attrs + ); + + return inlineAsm.getRes(); +} + +std::string VISAInstr::Operand::dump() const { + if (repr) + return repr(idx); + if (!isList()) + return "$" + std::to_string(idx); + + llvm::SmallVector oprs; + for (auto *opr : list) + oprs.push_back(opr->dump()); + return "{ " + strJoin(oprs, ", ") + " }"; +} + +VISAInstr::Operand *VISABuilder::newAddrOperand(mlir::Value addr, + StringRef constraint, int off) { + auto *opr = newOperand(addr, constraint); + opr->repr = [off](int idx) -> std::string { + std::stringstream ss; + ss << "[ $" << idx << " + " << off << " ]"; + return ss.str(); + }; + + return opr; +} + +std::string VISABuilder::dump() const { + llvm::SmallVector lines; + for (auto &exec : executions) { + lines.push_back(exec->dump()); + } + + return strJoin(lines, "\n\t"); +} + +VISAInstrExecution &VISAInstrCommon::call(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + if (onlyAttachMLIRArgs) { + // Nearly impossible to make the $0,$1 in two VISA code snippets to point to + // the same MLIR values in onlyAttachMLIRArgs mode. + assert(builder->executions.empty() && + "builder can only hold a single execution when onlyAttachMIIRArgs " + "is true."); + builder->reorderArgArchive(oprs); + } + + builder->executions.emplace_back( + std::make_unique(this, oprs, onlyAttachMLIRArgs)); + + return *builder->executions.back(); +} + +VISAInstrExecution &VISAInstrCommon::operator()(ArrayRef oprs, + bool onlyAttachMLIRArgs) { + return call(oprs, onlyAttachMLIRArgs); +} + +std::string VISAInstrExecution::dump() const { + std::string osStr; + llvm::raw_string_ostream os(osStr); + + if (pred) { + if (!pred->repr) + os << "@" << pred->dump() << " "; + else + os << pred->repr(pred->idx) << " "; + } + + std::string instrRepr = strJoin(instr->instrParts, "."); + if (onlyAttachMLIRArgs) { + os << instrRepr; + os.flush(); + return osStr; + } + + llvm::SmallVector argReprs; + for (auto *arg : argsInOrder) { + argReprs.push_back(arg->dump()); + } + + std::string argsRepr = strJoin(argReprs, ", "); + + os << instrRepr << " " << argsRepr << ";"; + os.flush(); + return osStr; +} + +SmallVector +VISAInstrExecution::getArgList() const { + SmallVector args; + for (auto *arg : argsInOrder) { + if (arg->isList()) + args.insert(args.end(), arg->list.begin(), arg->list.end()); + else + args.push_back(arg); + } + return args; +} + +VISAInstr &VISAInstr::global() { + o("global"); + return *this; +} + +VISAInstr &VISAInstr::shared() { + o("shared"); + return *this; +} + +VISAInstr &VISAInstr::v(int vecWidth, bool predicate) { + if (vecWidth > 1) { + o("v" + std::to_string(vecWidth), predicate); + } + return *this; +} + +VISAInstr &VISAInstr::b(int width) { + o("b" + std::to_string(width)); + return *this; +} + +} // namespace triton +} // namespace mlir diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp index 610d82c453..abc3c0a20b 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp @@ -15,6 +15,8 @@ #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include "llvm/ADT/TypeSwitch.h" + #define DEBUG_TYPE "tritonintelgpu-optimize-reduction-locality" namespace mlir::triton::gpu::intel { @@ -183,6 +185,7 @@ namespace { /// - Dimensions 1 and 3 refer to the original dimension 0 /// - Dimensions 0, and 2 refer to the original dimension 1 /// - Order is preserved + /// /// And on with step 3, after reducing on dimensions 0 and 1 (2 - 1 as 0 is /// squashed), we'd get: /// ``` @@ -200,6 +203,81 @@ namespace { /// | t3 | /// ``` /// And untranspose with a layout conversion to the original layout. + /// + /// In case the element type is `f32` and the reduction operation is a simple + /// add or max operation, we can use an optimized SIMD transposed reduction in + /// registers so no SLM transpose is needed. This would replace the two steps + /// above, but leading to a result with the same encoding. + /// + /// In order to do so, the SIMD reduction result is: + /// - Shape: [executionSize, + /// repeatCount * repCluster[0] / executionSize, + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [1, repeatCount * repCluster[0] / executionSize, 1, 1], + /// threadsPerWarp = [executionSize, 1, 1, 1], + /// warpsPerCTA = [1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// threadsPerWarp[0] + /// <------------------> + /// ^ t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... ^ + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | + /// sizePerThread[1] | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | warpsPerCTA[2] + /// | t0 t1 t2 t3 ... tn tn1 tn2 tn3 ... | + /// ``` + /// + /// Note how the SIMD transposition leads to a different order compared to the + /// original approach above (and an additional dimension). Before going on to + /// the next step, we need to get to the same layout as above via a layout + /// conversion: + /// - Shape: [executionSize, + /// repeatCount * repCluster[0] / executionSize, + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [repeatCount * repCluster[0] / executionSize, 1, 1, 1], + /// threadsPerWarp = [executionSize^2 / (repeatCount * repCluster[0]), + /// repeatCount * repCluster[0] / executionSize, + /// 1, 1] + /// warpsPerCTA = [1, 1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2, 3]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// sizePerThread[0] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0,1] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// + /// Followed by a reshape to get rid of one of the dimensions: + /// + /// - Shape: [repeatCount * repCluster[0], + /// warpsPerCTA[1], + /// warpsPerCTA[0]] + /// - Encoding: `#ttg.blocked<{ + /// sizePerThread = [repeatCount * repCluster[0] / executionSize, 1, 1], + /// threadsPerWarp = [executionSize, 1, 1] + /// warpsPerCTA = [1, warpsPerCTA[1], warpsPerCTA[0]], + /// order = [0, 1, 2]}>`. + /// ``` + /// warpsPerCTA[3] + /// <------------------------------------> + /// sizePerThread[0] + /// <------------------> + /// ^ t0 t0 t0 t0 ... t0 tn1 tn1 tn1 ... tn1 ^ + /// | t1 t1 t1 t1 ... t1 tn2 tn2 tn2 ... tn2 | + /// threadsPerWarp[0] | t2 t2 t2 t2 ... t2 tn3 tn3 tn3 ... tn3 | warpsPerCTA[2] + /// | t3 t3 t3 t3 ... t3 tn4 tn4 tn4 ... tn4 | + /// ``` + /// After this, the cross the warp reduction is performed and the same steps + /// to go back to the expected tensor type are performed as above. // clang-format on struct DpasOperandPattern final : OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -209,6 +287,7 @@ struct DpasOperandPattern final : OpRewritePattern { static constexpr int preferredReductionAxis = 1; // Intermediate reductions + static constexpr int simdReductionAxis = 0; static constexpr int finalElementwiseReductionAxis = 0; static constexpr int finalWarpsReductionAxis = 1; static constexpr int innerElementwiseReductionAxis = 2; @@ -263,20 +342,39 @@ struct DpasOperandPattern final : OpRewritePattern { LLVM_DEBUG(llvm::dbgs() << "Performed initial elementwise reductions: " << operand << "\n"); - operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding); + // Some combinations of operation + type can be represented as special SIMD + // reductions. + if (supportsSIMDReduction(op)) { + operand = performSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Converted layout for final reduction: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "SIMD reduction performed: " << operand << "\n"); - operand = reshapeForFinalReduction(op, rewriter, operand, encoding); + operand = convertLayoutPostSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Reshaped for final reduction: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() << "Converted layout for final reduction: " + << operand << "\n"); - operand = performFinalElementwiseReduction(op, rewriter, operand); + operand = reshapePostSIMDReduction(op, rewriter, operand, encoding); - LLVM_DEBUG(llvm::dbgs() - << "Final elementwise reduction performed: " << operand << "\n"); + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + } else { + operand = convertLayoutForFinalReduction(op, rewriter, operand, encoding); + + LLVM_DEBUG(llvm::dbgs() << "Converted layout for final reduction: " + << operand << "\n"); + + operand = reshapeForFinalReduction(op, rewriter, operand, encoding); + + LLVM_DEBUG(llvm::dbgs() + << "Reshaped for final reduction: " << operand << "\n"); + + operand = performFinalElementwiseReduction(op, rewriter, operand); + + LLVM_DEBUG(llvm::dbgs() << "Final elementwise reduction performed: " + << operand << "\n"); + } operand = performFinalAcrossWarpsReduction(op, rewriter, operand); @@ -377,6 +475,157 @@ struct DpasOperandPattern final : OpRewritePattern { outerElementwiseReductionAxis); } + static bool supportsSIMDReduction(ReduceOp op) { + constexpr unsigned simd16 = 16; + auto mod = op->getParentOfType(); + if (triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod) != simd16) + return false; + + // TODO: Enable more element types when enabled in the lowering. + if (!cast(op.getSrcs().front().getType()) + .getElementType() + .isF32()) + return false; + + Region &combineOp = op.getCombineOp(); + if (combineOp.getBlocks().size() > 1) + return false; + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + if (!reduceOp || reduceOp->getNumOperands() != 2 || + reduceOp->getNumResults() != 1) + return false; + if (reduceOp->getOperand(0) != block.getArgument(0) || + reduceOp->getOperand(1) != block.getArgument(1)) + return false; + + // TODO: Enable more operations when more are enabled in the lowering. + return TypeSwitch(reduceOp) + .Case([](auto) { return true; }) + .Default(false); + } + + static TritonGEN::ReduceKind getSIMDReductionKind(ReduceOp op) { + Region &combineOp = op.getCombineOp(); + assert(combineOp.getBlocks().size() <= 1 && "Unexpected number of blocks"); + Block &block = *combineOp.begin(); + Operation *yield = block.getTerminator(); + Operation *reduceOp = yield->getOperand(0).getDefiningOp(); + assert(reduceOp && reduceOp->getNumOperands() == 2 && + reduceOp->getNumResults() == 1 && + "Expecting sub-group reduction-like operation"); + assert(reduceOp->getOperand(0) == block.getArgument(0) && + reduceOp->getOperand(1) == block.getArgument(1) && + "Expecting sub-group reduction-like operation"); + + // TODO: Enable more operations when more are enabled in the lowering. + return TypeSwitch(reduceOp) + .Case([](arith::AddFOp) { return TritonGEN::ReduceKind::ADD; }) + .Case([](arith::MaxNumFOp) { return TritonGEN::ReduceKind::MAX; }); + } + + Value performSIMDReduction(ReduceOp op, PatternRewriter &rewriter, Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 4; + std::array shape{ + dpasEncoding.getExecutionSize(), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; + std::array sizePerThread{ + 1, + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); + + TritonGEN::ReduceKind redOp = getSIMDReductionKind(op); + return rewriter.create(op.getLoc(), + static_cast(type), + val, redOp, simdReductionAxis); + } + + Value convertLayoutPostSIMDReduction(ReduceOp op, PatternRewriter &rewriter, + Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 4; + std::array sizePerThread{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1, 1}; + std::array threadsPerWarp{ + dpasEncoding.getExecutionSize() * dpasEncoding.getExecutionSize() / + (dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0]), + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{1, 1, + dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2, 3}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setEncoding(encoding); + + return rewriter.create( + op.getLoc(), static_cast(type), val); + } + + Value reshapePostSIMDReduction(ReduceOp op, PatternRewriter &rewriter, + Value val, + DpasEncodingAttr dpasEncoding) const { + auto oldType = cast(val.getType()); + + constexpr size_t rank = 3; + std::array shape{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0], + dpasEncoding.getWarpsPerCTA()[1], dpasEncoding.getWarpsPerCTA()[0]}; + std::array sizePerThread{ + dpasEncoding.getRepeatCount() * dpasEncoding.getRepCluster()[0] / + dpasEncoding.getExecutionSize(), + 1, 1}; + std::array threadsPerWarp{dpasEncoding.getExecutionSize(), + 1, 1}; + std::array warpsPerCTA{1, dpasEncoding.getWarpsPerCTA()[1], + dpasEncoding.getWarpsPerCTA()[0]}; + constexpr std::array order{0, 1, 2}; + CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank); + + auto encoding = rewriter.getAttr( + sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout); + + RankedTensorType::Builder type(oldType); + type.setShape(shape); + type.setEncoding(encoding); + + return rewriter.create(op.getLoc(), + static_cast(type), val, + /*allow_reorder=*/true, + /*efficient_layout=*/true); + } + Value convertLayoutForFinalReduction(ReduceOp op, PatternRewriter &rewriter, Value val, DpasEncodingAttr dpasEncoding) const { @@ -411,7 +660,6 @@ struct DpasOperandPattern final : OpRewritePattern { Value val, DpasEncodingAttr dpasEncoding) const { auto oldType = cast(val.getType()); - ArrayRef oldShape = oldType.getShape(); constexpr size_t rank = 4; std::array shape{