diff --git a/test/Conversion/intel/tritongpu_to_gen.mlir b/test/Conversion/intel/tritongpu_to_gen.mlir index 1a529484d9..8f53bc2b41 100644 --- a/test/Conversion/intel/tritongpu_to_gen.mlir +++ b/test/Conversion/intel/tritongpu_to_gen.mlir @@ -1500,63 +1500,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "triton_ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_all(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) { - // CHECK: @_Z32sub_group_non_uniform_reduce_addf + // CHECK: @_Z27__spirv_GroupNonUniformFAddiif %0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.addf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_addi + // CHECK: @_Z27__spirv_GroupNonUniformIAddiij %1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.addi %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_mulf + // CHECK: @_Z27__spirv_GroupNonUniformFMuliif %2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.mulf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_muli + // CHECK: @_Z27__spirv_GroupNonUniformIMuliij %3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.muli %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_maxf + // CHECK: @_Z27__spirv_GroupNonUniformFMaxiif %4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.maxnumf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_minf + // CHECK: @_Z27__spirv_GroupNonUniformFMiniif %5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.minnumf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_andi + // CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij %6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.andi %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z31sub_group_non_uniform_reduce_ori + // CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij %7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.ori %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z32sub_group_non_uniform_reduce_xori + // CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij %8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.xori %arg4, %arg5 : i32 @@ -1575,63 +1575,63 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.threads-per-warp" = 32 : i32} { tt.func public @reduce_cluster(%arg: tensor<256x1xi32, #blocked>, %arg_0: tensor<256x1xf32, #blocked>) { - // CHECK: @_Z30sub_group_clustered_reduce_addfj + // CHECK: @_Z27__spirv_GroupNonUniformFAddiif %0 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.addf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_addij + // CHECK: @_Z27__spirv_GroupNonUniformIAddiij %1 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.addi %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_mulfj + // CHECK: @_Z27__spirv_GroupNonUniformFMuliif %2 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.mulf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_mulij + // CHECK: @_Z27__spirv_GroupNonUniformIMuliij %3 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.muli %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_maxfj + // CHECK: @_Z27__spirv_GroupNonUniformFMaxiif %4 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.maxnumf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_minfj + // CHECK: @_Z27__spirv_GroupNonUniformFMiniif %5 = "tt.reduce"(%arg_0) <{axis = 0 : i32}> ({ ^bb0(%arg4: f32, %arg5: f32): %48 = arith.minnumf %arg4, %arg5 : f32 tt.reduce.return %48 : f32 }) : (tensor<256x1xf32, #blocked>) -> tensor<1xf32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_andij + // CHECK: @_Z33__spirv_GroupNonUniformBitwiseAndiij %6 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.andi %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z29sub_group_clustered_reduce_orij + // CHECK: @_Z32__spirv_GroupNonUniformBitwiseOriij %7 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.ori %arg4, %arg5 : i32 tt.reduce.return %48 : i32 }) : (tensor<256x1xi32, #blocked>) -> tensor<1xi32, #slice> - // CHECK: @_Z30sub_group_clustered_reduce_xorij + // CHECK: @_Z33__spirv_GroupNonUniformBitwiseXoriij %8 = "tt.reduce"(%arg) <{axis = 0 : i32}> ({ ^bb0(%arg4: i32, %arg5: i32): %48 = arith.xori %arg4, %arg5 : i32 @@ -1645,9 +1645,9 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32, "ttg.thr // ----- // CHECK-LABEL: sum_reduction -// CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi(%{{.*}}) {{.*}} : (i32) -> i32 +// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32) -> i32 // CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> () -// CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_addij(%{{.*}}, %{{.*}}) {{.*}}convergent{{.*}}no_unwind{{.*}}will_return{{.*}} : (i32, i32) -> i32 +// CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiijj(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) {{.*}} : (i32, i32, i32, i32) -> i32 // CHECK: llvm.call spir_funccc @_Z7barrierj({{.*}}) {{.*}} : (i32) -> () #blocked = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [1, 32], warpsPerCTA = [1, 4], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> diff --git a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir index 9155698be5..b530fe47ba 100644 --- a/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir +++ b/test/Conversion/intel/tritongpu_to_llvm_intel_advanced_path.mlir @@ -153,15 +153,12 @@ module attributes {"triton_intel_gpu.support_sg_2d_block", "triton_intel_gpu.sup #warp = #triton_intel_gpu.warp<{sizePerThread = [16, 64], threadsPerWarp = [1, 1], order = [1, 0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 16 : i32, triton_intel_gpu.min_sg_size = 16 : i32, triton_intel_gpu.support_dpas, triton_intel_gpu.support_sg_2d_block} { - // CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_maxf(f32) -> f32 - // CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addf(f32) -> f32 - // CHECK-LABEL: llvm.func spir_kernelcc @reduce_sum( // CHECK-SAME: [[VAL_0:%.*]]: vector<8xf32>) -> f32 attributes {intel_reqd_sub_group_size = 16 : i32, triton_gen.max_work_group_size = array} tt.func public @reduce_sum(%arg0: tensor<8x16xf32>) -> f32 { // CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32> - // CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addf([[VAL_2]]) {{.*}} : (f32) -> f32 + // CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFAddiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32 %0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32> %1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): @@ -176,7 +173,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.thr tt.func public @reduce_max(%arg0: tensor<8x16xf32>) -> f32 { // CHECK: [[VAL_1:%.*]] = llvm.mlir.constant(0 : i32) : i32 // CHECK: [[VAL_2:%.*]] = llvm.extractelement [[VAL_0]][[[VAL_1]] : i32] : vector<8xf32> - // CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_maxf([[VAL_2]]) {{.*}} : (f32) -> f32 + // CHECK: [[VAL_3:%.*]] = llvm.call spir_funccc @_Z27__spirv_GroupNonUniformFMaxiif(%{{.*}}, %{{.*}}, [[VAL_2]]) {{.*}} : (i32, i32, f32) -> f32 %0 = triton_intel_gpu.extract %arg0[0] : tensor<8x16xf32> -> tensor<16xf32> %1 = "tt.reduce"(%0) <{axis = 0 : i32}> ({ ^bb0(%arg1: f32, %arg2: f32): diff --git a/test/TritonGEN/gpu-to-tritongen.mlir b/test/TritonGEN/gpu-to-tritongen.mlir deleted file mode 100644 index 7945274763..0000000000 --- a/test/TritonGEN/gpu-to-tritongen.mlir +++ /dev/null @@ -1,40 +0,0 @@ -// RUN: triton-opt -convert-gpu-to-tritongen %s | FileCheck %s - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - -gpu.module @kernels { - llvm.func @triton_gen.sub_group_reduce() { - %0 = llvm.mlir.constant(0.0 : f32) : f32 - %1 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: triton_gen.sub_group_reduce add %0 {size = 16} : f32 - %2 = gpu.subgroup_reduce add %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce mul %0 {size = 16} : f32 - %3 = gpu.subgroup_reduce mul %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce min %1 {size = 16} : i32 - %4 = gpu.subgroup_reduce minui %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce min %1 {size = 16} : i32 - %5 = gpu.subgroup_reduce minsi %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce min %0 {size = 16} : f32 - %6 = gpu.subgroup_reduce minimumf %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce min %0 {size = 16} : f32 - %7 = gpu.subgroup_reduce minnumf %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce max %1 {size = 16} : i32 - %8 = gpu.subgroup_reduce maxui %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce max %1 {size = 16} : i32 - %9 = gpu.subgroup_reduce maxsi %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce max %0 {size = 16} : f32 - %10 = gpu.subgroup_reduce maximumf %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce max %0 {size = 16} : f32 - %11 = gpu.subgroup_reduce maxnumf %0 : (f32) -> (f32) - // CHECK: triton_gen.sub_group_reduce and %1 {size = 16} : i32 - %12 = gpu.subgroup_reduce and %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce or %1 {size = 16} : i32 - %13 = gpu.subgroup_reduce or %1 : (i32) -> (i32) - // CHECK: triton_gen.sub_group_reduce xor %1 {size = 16} : i32 - %14 = gpu.subgroup_reduce xor %1 : (i32) -> (i32) - llvm.return - } -} -} diff --git a/test/TritonGEN/tritongen-invalid.mlir b/test/TritonGEN/tritongen-invalid.mlir index bebd25643e..8dd751f008 100644 --- a/test/TritonGEN/tritongen-invalid.mlir +++ b/test/TritonGEN/tritongen-invalid.mlir @@ -16,54 +16,6 @@ llvm.func @triton_gen.illegal_cache_controls_attr(%arg0: !llvm.ptr) { // ----- -llvm.func @triton_gen.sub_group_reduce() { - // expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting valid target env attribute}} - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32 - llvm.return -} - -// ----- - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - // expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}} - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 0} : i32 - llvm.return - } -} - -// ----- - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - // expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}} - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 32} : i32 - llvm.return - } -} - -// ----- - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - // expected-error @+2 {{'triton_gen.sub_group_reduce' op expecting size to be a power of 2 between 1 and subgroup size}} - %0 = llvm.mlir.constant(0 : i32) : i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 6} : i32 - llvm.return - } -} - -// ----- - llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // expected-error @+1 {{'triton_gen.dpas' op expecting repeat count to be 1, 2, 4, or 8}} %0 = triton_gen.dpas %c, %a, %b {pa=i8, pb=i8, rc=16} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> diff --git a/test/TritonGEN/tritongen-to-llvm.mlir b/test/TritonGEN/tritongen-to-llvm.mlir index 53cfdbf249..bcdcf626b2 100644 --- a/test/TritonGEN/tritongen-to-llvm.mlir +++ b/test/TritonGEN/tritongen-to-llvm.mlir @@ -1,138 +1,5 @@ // RUN: triton-opt -convert-tritongen-to-llvm -split-input-file %s | FileCheck %s -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_addij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_mulij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_maxij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_minij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_andij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z29sub_group_clustered_reduce_orij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_xorij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_addij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_mulij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %2 = triton_gen.sub_group_reduce mul %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_minij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %3 = triton_gen.sub_group_reduce min %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_maxij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %4 = triton_gen.sub_group_reduce max %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_andij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %5 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z29sub_group_clustered_reduce_orij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %6 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 - // CHECK: [[SIZE:%.*]] = llvm.mlir.constant(16 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_xorij([[VAL]], [[SIZE]]) {{.*}} : (i32, i32) -> i32 - %7 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 - llvm.return - } -} - -// ----- - -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addi(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_muli(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_mini(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_maxi(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_andi(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z31sub_group_non_uniform_reduce_ori(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_xori(i32) -> i32 - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi([[VAL]]) {{.*}} : (i32) -> i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_muli([[VAL]]) {{.*}} : (i32) -> i32 - %2 = triton_gen.sub_group_reduce mul %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_mini([[VAL]]) {{.*}} : (i32) -> i32 - %3 = triton_gen.sub_group_reduce min %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_maxi([[VAL]]) {{.*}} : (i32) -> i32 - %4 = triton_gen.sub_group_reduce max %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_andi([[VAL]]) {{.*}} : (i32) -> i32 - %5 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z31sub_group_non_uniform_reduce_ori([[VAL]]) {{.*}} : (i32) -> i32 - %6 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_xori([[VAL]]) {{.*}} : (i32) -> i32 - %7 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 - llvm.return - } -} - -// ----- - -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_addi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_muli(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_maxi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_mini(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_andi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z39sub_group_non_uniform_scan_exclusive_ori(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_xori(i32) -> i32 attributes {convergent, no_unwind, will_return} - -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_addi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_muli(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_maxi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_mini(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_andi(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z39sub_group_non_uniform_scan_inclusive_ori(i32) -> i32 attributes {convergent, no_unwind, will_return} -// CHECK-DAG: llvm.func spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_xori(i32) -> i32 attributes {convergent, no_unwind, will_return} - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_scan() { - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: [[VAL:%.*]] = llvm.mlir.constant(0 : i32) : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_addi([[VAL]]) {{.*}} : (i32) -> i32 - %1 = triton_gen.sub_group_scan add %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_muli([[VAL]]) {{.*}} : (i32) -> i32 - %2 = triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_mini([[VAL]]) {{.*}} : (i32) -> i32 - %3 = triton_gen.sub_group_scan min %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_maxi([[VAL]]) {{.*}} : (i32) -> i32 - %4 = triton_gen.sub_group_scan max %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_andi([[VAL]]) {{.*}} : (i32) -> i32 - %5 = triton_gen.sub_group_scan and %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z39sub_group_non_uniform_scan_exclusive_ori([[VAL]]) {{.*}} : (i32) -> i32 - %6 = triton_gen.sub_group_scan or %0 {kind = exclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_exclusive_xori([[VAL]]) {{.*}} : (i32) -> i32 - %7 = triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32 - - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_addi([[VAL]]) {{.*}} : (i32) -> i32 - %8 = triton_gen.sub_group_scan add %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_muli([[VAL]]) {{.*}} : (i32) -> i32 - %9 = triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_mini([[VAL]]) {{.*}} : (i32) -> i32 - %10 = triton_gen.sub_group_scan min %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_maxi([[VAL]]) {{.*}} : (i32) -> i32 - %11 = triton_gen.sub_group_scan max %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_andi([[VAL]]) {{.*}} : (i32) -> i32 - %12 = triton_gen.sub_group_scan and %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z39sub_group_non_uniform_scan_inclusive_ori([[VAL]]) {{.*}} : (i32) -> i32 - %13 = triton_gen.sub_group_scan or %0 {kind = inclusive} : i32 - // CHECK: llvm.call spir_funccc @_Z40sub_group_non_uniform_scan_inclusive_xori([[VAL]]) {{.*}} : (i32) -> i32 - %14 = triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32 - - llvm.return - } -} - -// ----- - // CHECK: llvm.func spir_funccc @_Z36intel_sub_group_i8_i8_matrix_mad_k32Dv8_sDv8_iS0_(vector<8xi16>, vector<8xi32>, vector<8xi32>) -> vector<8xi32> attributes {convergent, memory_effects = #llvm.memory_effects, no_unwind, will_return} llvm.func @triton_gen.dpas.i8(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { diff --git a/test/TritonGEN/tritongen.mlir b/test/TritonGEN/tritongen.mlir index ed5118cd20..6f4448a0f6 100644 --- a/test/TritonGEN/tritongen.mlir +++ b/test/TritonGEN/tritongen.mlir @@ -1,73 +1,5 @@ // RUN: triton-opt %s -split-input-file -verify-diagnostics | FileCheck %s -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_reduce() { - // CHECK-LABEL: triton_gen.sub_group_reduce - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: triton_gen.sub_group_reduce add %0 {size = 16} : i32 - %1 = triton_gen.sub_group_reduce add %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce mul %0 {size = 16} : i32 - %2 = triton_gen.sub_group_reduce mul %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce min %0 {size = 16} : i32 - %3 = triton_gen.sub_group_reduce min %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce max %0 {size = 16} : i32 - %4 = triton_gen.sub_group_reduce max %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce and %0 {size = 16} : i32 - %5 = triton_gen.sub_group_reduce and %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce or %0 {size = 16} : i32 - %6 = triton_gen.sub_group_reduce or %0 {size = 16} : i32 - // CHECK: triton_gen.sub_group_reduce xor %0 {size = 16} : i32 - %7 = triton_gen.sub_group_reduce xor %0 {size = 16} : i32 - llvm.return - } -} - -// ----- - -module attributes { - spirv.target_env = #spirv.target_env<#spirv.vce, #spirv.resource_limits> -} { - llvm.func @triton_gen.sub_group_scan() { - // CHECK-LABEL: triton_gen.sub_group_scan - %0 = llvm.mlir.constant(0 : i32) : i32 - // CHECK: triton_gen.sub_group_scan add %0 {kind = exclusive} : i32 - %1 = triton_gen.sub_group_scan add %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32 - %2 = triton_gen.sub_group_scan mul %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan min %0 {kind = exclusive} : i32 - %3 = triton_gen.sub_group_scan min %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan max %0 {kind = exclusive} : i32 - %4 = triton_gen.sub_group_scan max %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan and %0 {kind = exclusive} : i32 - %5 = triton_gen.sub_group_scan and %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan or %0 {kind = exclusive} : i32 - %6 = triton_gen.sub_group_scan or %0 {kind = exclusive} : i32 - // CHECK: triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32 - %7 = triton_gen.sub_group_scan xor %0 {kind = exclusive} : i32 - - // CHECK: triton_gen.sub_group_scan add %0 {kind = inclusive} : i32 - %8 = triton_gen.sub_group_scan add %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32 - %9 = triton_gen.sub_group_scan mul %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan min %0 {kind = inclusive} : i32 - %10 = triton_gen.sub_group_scan min %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan max %0 {kind = inclusive} : i32 - %11 = triton_gen.sub_group_scan max %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan and %0 {kind = inclusive} : i32 - %12 = triton_gen.sub_group_scan and %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan or %0 {kind = inclusive} : i32 - %13 = triton_gen.sub_group_scan or %0 {kind = inclusive} : i32 - // CHECK: triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32 - %14 = triton_gen.sub_group_scan xor %0 {kind = inclusive} : i32 - - llvm.return - } -} - -// ----- - llvm.func @triton_gen.dpas(%c : vector<8xi32>, %a : vector<8xi16>, %b : vector<8xi32>) { // CHECK: llvm.func @triton_gen.dpas(%arg0: vector<8xi32>, %arg1: vector<8xi16>, %arg2: vector<8xi32>) { // CHECK-NEXT: %0 = triton_gen.dpas %arg0, %arg1, %arg2 {pa = i8, pb = i8, rc = 8} : (vector<8xi32>, vector<8xi16>, vector<8xi32>) -> vector<8xi32> diff --git a/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir b/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir index 0574782972..b45aeecf95 100644 --- a/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir +++ b/test/TritonIntelGPU/tritongpu_reduce_op_lowering.mlir @@ -2,28 +2,25 @@ // COM: Tests reduction when threads_per_warp < num_warps. -// CHECK-DAG: llvm.func spir_funccc @_Z32sub_group_non_uniform_reduce_addi(i32) -> i32 -// CHECK-DAG: llvm.func spir_funccc @_Z30sub_group_clustered_reduce_addij(i32, i32) -> i32 attributes {convergent, no_unwind, will_return} - #blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [64], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}> module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 64 : i32, "ttg.threads-per-warp" = 32 : i32} { // CHECK-LABEL: reduce_problem_size_64_threads_per_warp_32 tt.func @reduce_problem_size_64_threads_per_warp_32(%f : tensor<2048xi32, #blocked>) { // 1st round intra-warp reduce - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi(%{{.*}}) + // CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%{{.*}}) {{.*}} : (i32, i32, i32) -> i32 // CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<3> // 2nd round inter-warp reduce with problem size 64 with threads_per_warp 32 // CHECK: llvm.call spir_funccc @_Z7barrierj(%{{.*}}) {{.*}} : (i32) -> () // CHECK: [[PARTIAL_REDUCE_0:%.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> i32 - // CHECK: llvm.call spir_funccc @_Z32sub_group_non_uniform_reduce_addi(%{{.*}}) + // CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiij(%{{.*}}) {{.*}} : (i32, i32, i32) -> i32 // CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<3> // 3rd round inter-warp reduce with problem size 2 with threads_per_warp 32 // CHECK: llvm.call spir_funccc @_Z7barrierj(%{{.*}}) {{.*}} : (i32) -> () // CHECK: [[PARTIAL_REDUCE_1:%.*]] = llvm.load %{{.*}} : !llvm.ptr<3> -> i32 - // CHECK: llvm.call spir_funccc @_Z30sub_group_clustered_reduce_addij(%{{.*}}) + // CHECK: llvm.call spir_funccc @_Z27__spirv_GroupNonUniformIAddiijj(%{{.*}}) {{.*}} : (i32, i32, i32, i32) -> i32 // CHECK: llvm.store %{{.*}}, %{{.*}} : i32, !llvm.ptr<3> // get final result diff --git a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td index e69ab231f2..d637a87366 100644 --- a/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td +++ b/third_party/intel/include/Dialect/TritonGEN/IR/TritonGENOps.td @@ -28,56 +28,6 @@ include "mlir/IR/OpAsmInterface.td" class TritonGEN_Op traits = []> : Op; -//===----------------------------------------------------------------------===// -// Synchronization -//===----------------------------------------------------------------------===// - -def TritonGEN_SubGroupReduceOp : TritonGEN_Op<"sub_group_reduce", [ - AllTypesMatch<["res", "value"]>]>, - Results<(outs SignlessIntegerOrFloatLike:$res)>, - Arguments<(ins SignlessIntegerOrFloatLike:$value, - TritonGEN_ReduceKindAttr:$kind, - I32Attr:$size)> { - let summary = "Subgroup reduce"; - - let description = [{ - The `triton_gen.sub_group_reduce` operation is invoked by all work items in - a subgroup, each of them providing a $value. The $size argument is used to - form groups of $size consecutive work items called clusters. Each cluster - performs the reduction operation identified by $kind. The result of the - cluster reduction is propagated to the work items belonging to that cluster. - }]; - - let assemblyFormat = [{ - $kind $value ` ` `{` `size` `=` $size `}` attr-dict `:` type($value) - }]; - - let hasVerifier = 1; -} - -def TritonGEN_SubGroupScanOp : TritonGEN_Op<"sub_group_scan", [ - AllTypesMatch<["res", "value"]>]>, - Results<(outs SignlessIntegerOrFloatLike:$res)>, - Arguments<(ins SignlessIntegerOrFloatLike:$value, - TritonGEN_ReduceKindAttr:$reduce_kind, - TritonGEN_ScanKindAttr:$scan_kind)> { - let summary = "Subgroup scan"; - - let description = [{ - The `triton_gen.sub_group_scan` operation is invoked by all work items in - a subgroup, each of them providing a $value. Each work item performs the - reduction operation identified by $reduce_kind. The $scan_kind attribute - indicates whether to perform an inclusive or exclusive scan. The result - of the scan operation is returned for each work item. - Note: The scan order is defined by increasing sub-group local ID within - the sub-group. - }]; - - let assemblyFormat = [{ - $reduce_kind $value ` ` `{` `kind` `=` $scan_kind `}` attr-dict `:` type($value) - }]; -} - //===----------------------------------------------------------------------===// // Matrix operations //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp index 17ae98733b..eb6808b845 100644 --- a/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp +++ b/third_party/intel/lib/Dialect/TritonGEN/IR/TritonGENOps.cpp @@ -48,23 +48,6 @@ template static LogicalResult verifyMatrixInput(Op op) { return success(); } -//===----------------------------------------------------------------------===// -// gen.sub_group_reduce -//===----------------------------------------------------------------------===// - -LogicalResult TritonGEN::SubGroupReduceOp::verify() { - spirv::TargetEnvAttr attr = spirv::lookupTargetEnv(*this); - if (!attr) - return this->emitOpError("expecting valid target env attribute"); - - if (getSize() < 1 || getSize() > TritonGEN::getSubgroupSize(*this) || - !llvm::isPowerOf2_32(getSize())) - return this->emitOpError( - "expecting size to be a power of 2 between 1 and subgroup size"); - - return success(); -} - //===----------------------------------------------------------------------===// // gen.matrix.dpas //===----------------------------------------------------------------------===// diff --git a/third_party/intel/lib/GPUToTritonGEN/GPUToTritonGENPass.cpp b/third_party/intel/lib/GPUToTritonGEN/GPUToTritonGENPass.cpp index 256e8146bc..25ac3585fe 100644 --- a/third_party/intel/lib/GPUToTritonGEN/GPUToTritonGENPass.cpp +++ b/third_party/intel/lib/GPUToTritonGEN/GPUToTritonGENPass.cpp @@ -57,50 +57,6 @@ namespace { /// Import the GPU Ops to TritonGEN Patterns. #include "GPUToTritonGEN.cpp.inc" -struct GPUSubgroupReduceOpLowering - : public ConvertOpToLLVMPattern { - using ConvertOpToLLVMPattern< - mlir::gpu::SubgroupReduceOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(mlir::gpu::SubgroupReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp( - op, op.getResult().getType(), op.getValue(), - convertReduceKind(op.getOp()), TritonGEN::getSubgroupSize(op)); - return success(); - } - -private: - static TritonGEN::ReduceKind - convertReduceKind(mlir::gpu::AllReduceOperation op) { - switch (op) { - case mlir::gpu::AllReduceOperation::ADD: - return TritonGEN::ReduceKind::ADD; - case mlir::gpu::AllReduceOperation::MUL: - return TritonGEN::ReduceKind::MUL; - case mlir::gpu::AllReduceOperation::MINUI: - case mlir::gpu::AllReduceOperation::MINSI: - case mlir::gpu::AllReduceOperation::MINIMUMF: - case mlir::gpu::AllReduceOperation::MINNUMF: - return TritonGEN::ReduceKind::MIN; - case mlir::gpu::AllReduceOperation::MAXUI: - case mlir::gpu::AllReduceOperation::MAXSI: - case mlir::gpu::AllReduceOperation::MAXIMUMF: - case mlir::gpu::AllReduceOperation::MAXNUMF: - return TritonGEN::ReduceKind::MAX; - case mlir::gpu::AllReduceOperation::AND: - return TritonGEN::ReduceKind::AND; - case mlir::gpu::AllReduceOperation::OR: - return TritonGEN::ReduceKind::OR; - case mlir::gpu::AllReduceOperation::XOR: - return TritonGEN::ReduceKind::XOR; - default: - llvm_unreachable("unsupported reduction mode"); - } - } -}; - // A pass that replaces all occurrences of GPU device operations with their // corresponding TritonGEN equivalent. // @@ -196,7 +152,6 @@ static void populateOpPatterns(LLVMTypeConverter &converter, void mlir::triton::populateGPUToTritonGENConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { populateWithGenerated(patterns); - patterns.add(converter); patterns.add( converter, /*allocaAddrSpace=*/TritonGEN::TritonGENMemorySpace::kFunction, diff --git a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp index c407968cdc..e7a8e2d5fc 100644 --- a/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp +++ b/third_party/intel/lib/TritonGENToLLVM/TritonGENToLLVMPass.cpp @@ -477,133 +477,6 @@ createGenISA2DBlockPrefetch(TritonGEN::Matrix2DBlockPrefetchOp op, namespace { -//===----------------------------------------------------------------------===// -// Synchronization Ops Lowerings -//===----------------------------------------------------------------------===// - -struct TritonSubGroupBase { -protected: - template ::value>> - static Value extend(OpType op, Value val, Type type, - ConversionPatternRewriter &rewriter) { - Location loc = op.getLoc(); - unsigned bitWidth = type.getIntOrFloatBitWidth(); - - if constexpr (llvm::is_one_of::value) { - if (type.isInteger() && bitWidth < 8) - val = zext(i8_ty, val); - } - - return val; - } - - template ::value>> - static Value truncate(OpType op, Value val, Type type, - ConversionPatternRewriter &rewriter) { - Location loc = op.getLoc(); - unsigned bitWidth = type.getIntOrFloatBitWidth(); - - if constexpr (llvm::is_one_of::value) { - if (type.isInteger() && bitWidth < 8) - val = trunc(type, val); - return val; - } - - return val; - } -}; - -struct TritonSubGroupReduceLowering - : public ConvertOpToLLVMPattern, - public TritonSubGroupBase { - using ConvertOpToLLVMPattern< - TritonGEN::SubGroupReduceOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TritonGEN::SubGroupReduceOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Location loc = op.getLoc(); - Value val = op.getValue(); - Type origTy = val.getType(); - val = TritonSubGroupBase::extend(op, val, origTy, rewriter); - Type valTy = val.getType(); - SmallVector argTypes{valTy}; - SmallVector argIsUnsigned{false}; - SmallVector args{val}; - bool useCluster = (getSubgroupSize(op) != op.getSize()); - - std::string fnName = "sub_group_"; - fnName += useCluster ? "clustered_" : "non_uniform_"; - fnName += "reduce_" + stringifyReduceKind(op.getKind()).str(); - LLVMFuncAttributeOptions funcAttrs{}; - if (useCluster) { - argTypes.push_back(i32_ty); - argIsUnsigned.push_back(true); - auto size = rewriter.create( - loc, i32_ty, static_cast(op.getSize())); - args.push_back(size); - MLIRContext *ctx = rewriter.getContext(); - funcAttrs = convergentNoUnwindWillReturnAttrs; - } - fnName = intel::mangle(fnName, argTypes, argIsUnsigned); - - Value result = createDeviceFunctionCall(rewriter, fnName, valTy, argTypes, - args, {}, funcAttrs) - .getResult(); - result = TritonSubGroupBase::truncate(op, result, origTy, rewriter); - rewriter.replaceOp(op, result); - return success(); - } -}; - -struct TritonSubGroupScanLowering - : public ConvertOpToLLVMPattern, - public TritonSubGroupBase { - using ConvertOpToLLVMPattern< - TritonGEN::SubGroupScanOp>::ConvertOpToLLVMPattern; - - LogicalResult - matchAndRewrite(TritonGEN::SubGroupScanOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - Value val = op.getValue(); - Type origTy = val.getType(); - val = TritonSubGroupBase::extend(op, op.getValue(), origTy, rewriter); - Type valTy = val.getType(); - SmallVector argTypes{valTy}; - SmallVector args{val}; - - std::string fnName = "sub_group_non_uniform_scan_"; - switch (op.getScanKind()) { - case TritonGEN::ScanKind::EXCLUSIVE: - fnName += "exclusive_"; - break; - case TritonGEN::ScanKind::INCLUSIVE: - fnName += "inclusive_"; - break; - default: - llvm_unreachable("unhandled scan kind"); - }; - - fnName += stringifyReduceKind(op.getReduceKind()).str(); - fnName = intel::mangle(fnName, valTy); - - Value result = - createDeviceFunctionCall(rewriter, fnName, valTy, argTypes, args, {}, - convergentNoUnwindWillReturnAttrs) - .getResult(); - result = TritonSubGroupBase::truncate(op, result, origTy, rewriter); - rewriter.replaceOp(op, result); - - return success(); - } -}; - //===----------------------------------------------------------------------===// // Matrix operations //===----------------------------------------------------------------------===// @@ -1027,8 +900,7 @@ struct TritonGENToLLVMDialectInterface : public ConvertToLLVMPatternInterface { void mlir::triton::populateTritonGENToLLVMConversionPatterns( LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns - .add(converter); diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/SPIRVSubgroupOps.h b/third_party/intel/lib/TritonIntelGPUToLLVM/SPIRVSubgroupOps.h new file mode 100644 index 0000000000..372f777aaf --- /dev/null +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/SPIRVSubgroupOps.h @@ -0,0 +1,80 @@ +//===- SPIRVSubgroupOps.h - Mapping for SPIR-V Reduction --------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines mapping from operations in the 'arith' dialect to the +// corresponding SPIR-V Subgroup Reduction Operation. +// +//===----------------------------------------------------------------------===// + +#ifndef TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H +#define TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" + +using namespace mlir; + +namespace mlir::triton::intel { + +template struct SPIRVArithmeticGroupOp {}; + +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformFAddOp; +}; +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformIAddOp; +}; +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformFMulOp; +}; +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformIMulOp; +}; +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformFMaxOp; +}; +template <> struct SPIRVArithmeticGroupOp { + using type = spirv::GroupNonUniformFMinOp; +}; + +template +using SPIRVArithmeticGroupOpTy = typename SPIRVArithmeticGroupOp::type; + +template struct SPIRVBitwiseGroupOp {}; + +template <> struct SPIRVBitwiseGroupOp { + using type = spirv::GroupNonUniformBitwiseAndOp; +}; +template <> struct SPIRVBitwiseGroupOp { + using type = spirv::GroupNonUniformBitwiseOrOp; +}; +template <> struct SPIRVBitwiseGroupOp { + using type = spirv::GroupNonUniformBitwiseXorOp; +}; + +template +using SPIRVBitwiseGroupOpTy = typename SPIRVBitwiseGroupOp::type; + +template struct SPIRVLogicalGroupOp {}; + +template <> struct SPIRVLogicalGroupOp { + using type = spirv::GroupNonUniformLogicalAndOp; +}; +template <> struct SPIRVLogicalGroupOp { + using type = spirv::GroupNonUniformLogicalOrOp; +}; +template <> struct SPIRVLogicalGroupOp { + using type = spirv::GroupNonUniformLogicalXorOp; +}; + +template +using SPIRVLogicalGroupOpTy = typename SPIRVLogicalGroupOp::type; + +} // namespace mlir::triton::intel + +#endif // TRITONINTELGPUTOLLVM_SPIRVSUBGROUPOPS_H diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp index 1eb4dada9d..75e10be3a7 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TargetInfo.cpp @@ -7,7 +7,9 @@ //===----------------------------------------------------------------------===// #include "TargetInfo.h" +#include "SPIRVSubgroupOps.h" #include "Utility.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "llvm/ADT/TypeSwitch.h" using namespace mlir; @@ -107,6 +109,52 @@ Value TargetInfo::programId(RewriterBase &rewriter, Location loc, return rewriter.create(loc, i32_ty, blockId); } +namespace { + +template +Value createSPIRVGroupOp(RewriterBase &rewriter, Location loc, Type resultTy, + Value acc, unsigned numLanesToReduce, + unsigned warpSize) { + auto spvGroupOp = spirv::GroupOperation::Reduce; + Value clusterSize; + if (numLanesToReduce != warpSize) { + spvGroupOp = spirv::GroupOperation::ClusteredReduce; + clusterSize = rewriter.create( + loc, rewriter.getI32Type(), + rewriter.getI32IntegerAttr(numLanesToReduce)); + } + + Value result = rewriter.create(loc, resultTy, spirv::Scope::Subgroup, + spvGroupOp, acc, clusterSize); + return result; +} + +Value warpReduceHelper(RewriterBase &rewriter, Location loc, Value acc, + Operation *reduceOp, unsigned numLanesToReduce, + unsigned warpSize) { + auto resultType = reduceOp->getResult(0).getType(); + Value warpReduce = + TypeSwitch(reduceOp) + .Case([&](auto groupOp) { + return createSPIRVGroupOp< + SPIRVArithmeticGroupOpTy>( + rewriter, loc, resultType, acc, numLanesToReduce, warpSize); + }) + .Case([&](auto groupOp) { + if (resultType.isInteger(1)) { + return createSPIRVGroupOp< + SPIRVLogicalGroupOpTy>( + rewriter, loc, resultType, acc, numLanesToReduce, warpSize); + } + return createSPIRVGroupOp>( + rewriter, loc, resultType, acc, numLanesToReduce, warpSize); + }); + return warpReduce; +} + +} // namespace + bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, SmallVector &acc, triton::ReduceOp op, unsigned numLaneToReduce, @@ -134,28 +182,19 @@ bool TargetInfo::warpReduce(RewriterBase &rewriter, Location loc, reduceOp->getOperand(1) != block.getArgument(1)) return false; - auto reduceKind = - llvm::TypeSwitch>( - reduceOp) - .Case( - [&](auto) { return TritonGEN::ReduceKind::ADD; }) - .Case( - [&](auto) { return TritonGEN::ReduceKind::MUL; }) - .Case( - [&](auto) { return TritonGEN::ReduceKind::MAX; }) - .Case( - [&](auto) { return TritonGEN::ReduceKind::MIN; }) - .Case([&](auto) { return TritonGEN::ReduceKind::AND; }) - .Case([&](auto) { return TritonGEN::ReduceKind::OR; }) - .Case([&](auto) { return TritonGEN::ReduceKind::XOR; }) - .Default([](auto) { return std::nullopt; }); - if (reduceKind == std::nullopt) + auto supportedOp = isa(reduceOp); + + if (!supportedOp) return false; + auto mod = op->getParentOfType(); + unsigned warpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); + for (unsigned i = 0; i < acc.size(); ++i) { - acc[i] = rewriter.create( - loc, reduceOp->getResult(0).getType(), acc[i], *reduceKind, - numLaneToReduce); + acc[i] = warpReduceHelper(rewriter, loc, acc[i], reduceOp, numLaneToReduce, + warpSize); } return true; diff --git a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp index eabc89531f..f6da706e3c 100644 --- a/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp +++ b/third_party/intel/lib/TritonIntelGPUToLLVM/TritonOpsToLLVM.cpp @@ -1,11 +1,13 @@ #include "Dialect/TritonIntelGPU/IR/Utils.h" #include "PatternTritonGPUOpToLLVM.h" +#include "SPIRVSubgroupOps.h" #include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h" #include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h" #include "mlir/Conversion/LLVMCommon/TypeConverter.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h" #include "mlir/IR/BuiltinTypes.h" #include "triton/Analysis/Utility.h" #include "triton/Dialect/Triton/IR/Dialect.h" @@ -561,7 +563,7 @@ class ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern { using ConvertTritonGPUOpToLLVMPattern< ReduceOp>::ConvertTritonGPUOpToLLVMPattern; LogicalResult - matchAndRewrite(ReduceOp op, OpAdaptor adaptor, + matchAndRewrite(ReduceOp op, ReduceOpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto mod = op->getParentOfType(); int subgroupSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); @@ -582,18 +584,15 @@ class ReduceOpConversion : public ConvertTritonGPUOpToLLVMPattern { Operation *combine = &*combineOp.front().getOperations().begin(); // FIXME: support all possible reduction modes - using AllReduceOperation = mlir::gpu::AllReduceOperation; - AllReduceOperation redKind; - if (isa(combine)) - redKind = AllReduceOperation::ADD; - else if (isa(combine)) - redKind = AllReduceOperation::MAXNUMF; - else - llvm_unreachable("Unhandled reduction kind"); + TypeSwitch(combine).Case( + [&](auto reduce) { + rewriter.replaceOpWithNewOp< + intel::SPIRVArithmeticGroupOpTy>( + op, typeConverter->convertType(op.getType(0)), + spirv::Scope::Subgroup, spirv::GroupOperation::Reduce, + adaptor.getSrcs()[0], Value()); + }); - Value result = rewriter.create( - loc, adaptor.getSrcs()[0], redKind, true); - rewriter.replaceOp(op, result); return success(); } };