Skip to content

Commit

Permalink
[XPU][TritonGEN] Replace subgroup reduce and scan (#2893)
Browse files Browse the repository at this point in the history
Replace usage of the TritonGEN dialect subgroup reduce and scan
operations with equivalent operations from the SPIR-V dialect.

This closes #2892.

---------

Signed-off-by: Lukas Sommer <lukas.sommer@codeplay.com>
  • Loading branch information
sommerlukas authored Dec 3, 2024
1 parent 78c13a5 commit 6588f0d
Show file tree
Hide file tree
Showing 14 changed files with 175 additions and 592 deletions.
40 changes: 20 additions & 20 deletions test/Conversion/intel/tritongpu_to_gen.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1500,63 +1500,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "triton_
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @reduce_all(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) {

// CHECK: @_Z32sub_group_non_uniform_reduce_addf
// CHECK: @_Z27__spirv_GroupNonUniformFAddiif
%0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.addf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_addi
// CHECK: @_Z27__spirv_GroupNonUniformIAddiij
%1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.addi %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_mulf
// CHECK: @_Z27__spirv_GroupNonUniformFMuliif
%2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.mulf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_muli
// CHECK: @_Z27__spirv_GroupNonUniformIMuliij
%3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.muli %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_maxf
// CHECK: @_Z27__spirv_GroupNonUniformFMaxiif
%4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.maxnumf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_minf
// CHECK: @_Z27__spirv_GroupNonUniformFMiniif
%5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.minnumf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_andi
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij
%6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.andi %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z31sub_group_non_uniform_reduce_ori
// CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij
%7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.ori %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z32sub_group_non_uniform_reduce_xori
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij
%8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.xori %arg4, %arg5 : i32
Expand All @@ -1575,63 +1575,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} {
tt.func public @reduce_cluster(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) {

// CHECK: @_Z30sub_group_clustered_reduce_addfj
// CHECK: @_Z27__spirv_GroupNonUniformFAddiif
%0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.addf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_addij
// CHECK: @_Z27__spirv_GroupNonUniformIAddiij
%1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.addi %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_mulfj
// CHECK: @_Z27__spirv_GroupNonUniformFMuliif
%2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.mulf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_mulij
// CHECK: @_Z27__spirv_GroupNonUniformIMuliij
%3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.muli %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_maxfj
// CHECK: @_Z27__spirv_GroupNonUniformFMaxiif
%4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.maxnumf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_minfj
// CHECK: @_Z27__spirv_GroupNonUniformFMiniif
%5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({
^bb0(%arg4: f32, %arg5: f32):
%48 = arith.minnumf %arg4, %arg5 : f32
tt.reduce.return %48 : f32
}) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_andij
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij
%6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.andi %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z29sub_group_clustered_reduce_orij
// CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij
%7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.ori %arg4, %arg5 : i32
tt.reduce.return %48 : i32
}) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice>

// CHECK: @_Z30sub_group_clustered_reduce_xorij
// CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij
%8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({
^bb0(%arg4: i32, %arg5: i32):
%48 = arith.xori %arg4, %arg5 : i32
Expand All @@ -1645,9 +1645,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr
// -----

// CHECK-LABEL: sum_reduction
// CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi(%{{.*}}) {{.*}} : (i32) -> i32
// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32) -> i32
// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
// CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_addij(%{{.*}}, %{{.*}}) {{.*}}convergent{{.*}}no_unwind{{.*}}will_return{{.*}} : (i32, i32) -> i32
// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiijj(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32, i32) -> i32

// CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> ()
#blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,12 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup
#warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}>
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} {

// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_maxf(f32) -> f32
// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addf(f32) -> f32

// CHECK-LABEL: llvm.func spir_kernelcc @reduce_sum(
// CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array<i32: 128, 1, 1>}
tt.func public @reduce_sum(%arg0: tensor<8x16xf32>) -> f32 {
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addf([[VAL_2]]) {{.*}} : (f32) -> f32
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFAddiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32
%0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32>
%1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
Expand All @@ -176,7 +173,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr
tt.func public @reduce_max(%arg0: tensor<8x16xf32>) -> f32 {
// CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32
// CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32>
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_maxf([[VAL_2]]) {{.*}} : (f32) -> f32
// CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32
%0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32>
%1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({
^bb0(%arg1: f32, %arg2: f32):
Expand Down
40 changes: 0 additions & 40 deletions test/TritonGEN/gpu-to-tritongen.mlir

This file was deleted.

48 changes: 0 additions & 48 deletions test/TritonGEN/tritongen-invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -16,54 +16,6 @@ llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) {

// -----

llvm.func @triton_gen.sub_group_reduce() {
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting valid target env attribute}}
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32
llvm.return
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_reduce() {
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = triton_gen.sub_group_reduce add %0 {size = 0} : i32
llvm.return
}
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_reduce() {
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = triton_gen.sub_group_reduce add %0 {size = 32} : i32
llvm.return
}
}

// -----

module attributes {
spirv.target_env = #spirv.target_env<#spirv.vce<v1.4, [Kernel, Addresses, GroupNonUniformShuffle, Int64], []>, #spirv.resource_limits<subgroup_size = 16>>
} {
llvm.func @triton_gen.sub_group_reduce() {
// expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}}
%0 = llvm.mlir.constant(0 : i32) : i32
%1 = triton_gen.sub_group_reduce add %0 {size = 6} : i32
llvm.return
}
}

// -----

llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) {
// expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}}
%0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32>
Expand Down
Loading

0 comments on commit 6588f0d

Please sign in to comment.