Skip to content

Commit

Permalink
[MLIR] Fixes arith.sub folder crash on dynamically shaped tensors (#1…
Browse files Browse the repository at this point in the history
…18908)

We can't create a constant for a value with dynamic shape.

Fixes #118772
  • Loading branch information
joker-eph authored Dec 6, 2024
1 parent 92376c3 commit 1801fb4
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 2 deletions.
8 changes: 6 additions & 2 deletions mlir/lib/Dialect/Arith/IR/ArithOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,12 @@ void arith::AddUIExtendedOp::getCanonicalizationPatterns(

OpFoldResult arith::SubIOp::fold(FoldAdaptor adaptor) {
// subi(x,x) -> 0
if (getOperand(0) == getOperand(1))
return Builder(getContext()).getZeroAttr(getType());
if (getOperand(0) == getOperand(1)) {
auto shapedType = dyn_cast<ShapedType>(getType());
// We can't generate a constant with a dynamic shaped tensor.
if (!shapedType || shapedType.hasStaticShape())
return Builder(getContext()).getZeroAttr(getType());
}
// subi(x,0) -> x
if (matchPattern(adaptor.getRhs(), m_Zero()))
return getLhs();
Expand Down
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Arith/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,27 @@ func.func @tripleAddAddOvf2(%arg0: index) -> index {
return %add2 : index
}


// CHECK-LABEL: @foldSubXX_tensor
// CHECK: %[[c0:.+]] = arith.constant dense<0> : tensor<10xi32>
// CHECK: %[[sub:.+]] = arith.subi
// CHECK: return %[[c0]], %[[sub]]
func.func @foldSubXX_tensor(%static : tensor<10xi32>, %dyn : tensor<?x?xi32>) -> (tensor<10xi32>, tensor<?x?xi32>) {
%static_sub = arith.subi %static, %static : tensor<10xi32>
%dyn_sub = arith.subi %dyn, %dyn : tensor<?x?xi32>
return %static_sub, %dyn_sub : tensor<10xi32>, tensor<?x?xi32>
}

// CHECK-LABEL: @foldSubXX_vector
// CHECK-DAG: %[[c0:.+]] = arith.constant dense<0> : vector<8xi32>
// CHECK-DAG: %[[c0_scalable:.+]] = arith.constant dense<0> : vector<[4]xi32>
// CHECK: return %[[c0]], %[[c0_scalable]]
func.func @foldSubXX_vector(%static : vector<8xi32>, %dyn : vector<[4]xi32>) -> (vector<8xi32>, vector<[4]xi32>) {
%static_sub = arith.subi %static, %static : vector<8xi32>
%dyn_sub = arith.subi %dyn, %dyn : vector<[4]xi32>
return %static_sub, %dyn_sub : vector<8xi32>, vector<[4]xi32>
}

// CHECK-LABEL: @tripleAddSub0
// CHECK: %[[cres:.+]] = arith.constant 59 : index
// CHECK: %[[add:.+]] = arith.subi %[[cres]], %arg0 : index
Expand Down

0 comments on commit 1801fb4

Please sign in to comment.