-
Notifications
You must be signed in to change notification settings - Fork 12.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrites for I2 to I8 signed and unsigned extension #121298
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-vector Author: None (ziereis) ChangesAdds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion. I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x. I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well. Patch is 20.48 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121298.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 181c394edc1d20..323da627de7bc2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -1084,8 +1084,8 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = srcType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
+ // Only {s}i4/i2 -> (size_of({{s}i/f}) >= 8) are supported for now.
+ if ((srcElemBitwidth != 4 && srcElemBitwidth != 2) || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
@@ -1233,6 +1233,117 @@ static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i2 -> i8 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8SignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // Element 0 (bits 0-1)
+ constexpr int8_t shiftConst6 = 6;
+ auto shiftAttr6 = DenseElementsAttr::get(i8VecType, shiftConst6);
+ auto shiftValues6 = rewriter.create<arith::ConstantOp>(loc, shiftAttr6);
+ Value shl0 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues6);
+ Value elem0 = rewriter.create<arith::ShRSIOp>(loc, shl0, shiftValues6);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shiftConst4 = 4;
+ auto shiftAttr4 = DenseElementsAttr::get(i8VecType, shiftConst4);
+ auto shiftValues4 = rewriter.create<arith::ConstantOp>(loc, shiftAttr4);
+ Value shl1 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues4);
+ Value elem1 = rewriter.create<arith::ShRSIOp>(loc, shl1, shiftValues6);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shiftConst2 = 2;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shiftConst2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shl2 = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::ShRSIOp>(loc, shl2, shiftValues6);
+
+ // Element 3 (bits 6-7)
+ Value elem3 = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues6);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
+/// Rewrite the i2 -> i8 unsigned extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI2ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(2) &&
+ "Expected i2 type");
+
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
+ SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
+ constexpr int64_t i2Toi8BitwidthFactor = 4;
+ i8VecShape.back() = i8VecShape.back() / i2Toi8BitwidthFactor;
+ auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
+ Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
+
+ // 2. Extract each i2 element using shifts and masks
+ constexpr uint8_t mask = 3; // Mask for 2 bits: [0000 0011]
+ auto maskAttr = DenseElementsAttr::get(i8VecType, mask);
+ auto maskValues = rewriter.create<arith::ConstantOp>(loc, maskAttr);
+
+ // Element 0 (bits 0-1)
+ Value elem0 = rewriter.create<arith::AndIOp>(loc, i8Vector, maskValues);
+
+ // Element 1 (bits 2-3)
+ constexpr int8_t shift1 = 2;
+ auto shiftAttr1 = DenseElementsAttr::get(i8VecType, shift1);
+ auto shiftValues1 = rewriter.create<arith::ConstantOp>(loc, shiftAttr1);
+ Value shifted1 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues1);
+ Value elem1 = rewriter.create<arith::AndIOp>(loc, shifted1, maskValues);
+
+ // Element 2 (bits 4-5)
+ constexpr int8_t shift2 = 4;
+ auto shiftAttr2 = DenseElementsAttr::get(i8VecType, shift2);
+ auto shiftValues2 = rewriter.create<arith::ConstantOp>(loc, shiftAttr2);
+ Value shifted2 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues2);
+ Value elem2 = rewriter.create<arith::AndIOp>(loc, shifted2, maskValues);
+
+ // Element 3 (bits 6-7)
+ constexpr int8_t shift3 = 6;
+ auto shiftAttr3 = DenseElementsAttr::get(i8VecType, shift3);
+ auto shiftValues3 = rewriter.create<arith::ConstantOp>(loc, shiftAttr3);
+ Value shifted3 = rewriter.create<arith::ShRUIOp>(loc, i8Vector, shiftValues3);
+ Value elem3 = rewriter.create<arith::AndIOp>(loc, shifted3, maskValues);
+
+ // interleave all 4 elements by first interleaving even elements and then odd
+ // elem0 = [0,0,0,0]
+ // elem1 = [1,1,1,1]
+ // elem2 = [2,2,2,2]
+ // elem3 = [3,3,3,3]
+ // 02 = [0,2,0,2]
+ // 13 = [1,3,1,3]
+ // 0213 = [0,1,2,3]
+ Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, elem0, elem2);
+ Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, elem1, elem3);
+ return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
+}
+
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
@@ -1438,11 +1549,21 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
// Perform the rewrite.
Value subByteExt;
if (isSigned) {
- subByteExt =
- rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ subByteExt =
+ rewriteI2ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ else {
+ subByteExt =
+ rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
} else {
- subByteExt =
- rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ if (srcVecType.getElementType().getIntOrFloatBitWidth() == 2) {
+ subByteExt =
+ rewriteI2ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ } else {
+ subByteExt =
+ rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
+ }
}
// Finalize the rewrite.
@@ -1489,6 +1610,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
truncOp)))
return failure();
+ // not supported currently.
+ if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
+ return failure();
+
// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 210025e30d7db5..0b469066f290c2 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -220,8 +220,8 @@ func.func @aligned_extsi_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extsi_2d(
-func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extsi_i4_2d(
+func.func @aligned_extsi_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[IN:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi4> to vector<8x16xi8>
@@ -234,6 +234,72 @@ func.func @aligned_extsi_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i8(
+func.func @aligned_extsi_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_to_i32(
+func.func @aligned_extsi_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extsi %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extsi_i2_2d(
+func.func @aligned_extsi_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[SHL_6:.*]] = arith.shli %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.shrsi %[[SHL_6]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_4:.*]] = arith.shli %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.shrsi %[[SHL_4]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[SHL_2:.*]] = arith.shli %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.shrsi %[[SHL_2]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.shrsi %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<8x16xi8>
+// CHECK: %[[RESULT:.*]] = arith.extsi %[[INTERLEAVE]] : vector<8x32xi8> to vector<8x32xi32>
+ %0 = arith.extsi %a : vector<8x32xi2> to vector<8x32xi32>
+ return %0 : vector<8x32xi32>
+}
+
// CHECK-LABEL: func.func @aligned_trunci_i8_to_i4(
func.func @aligned_trunci_i8_to_i4(%a: vector<8xi8>) -> vector<8xi4> {
@@ -292,6 +358,13 @@ func.func @aligned_trunci_nd(%a: vector<3x8x32xi32>) -> vector<3x8x32xi4> {
return %0 : vector<3x8x32xi4>
}
+func.func @aligned_trunci_i8_to_i2_no_match(%a: vector<8xi8>) -> vector<8xi2> {
+ // CHECK-NOT: arith.bitcast
+ // CHECK: arith.trunci %[[IN:.*]] : vector<8xi8> to vector<8xi2>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi2>
+ return %0 : vector<8xi2>
+}
+
// CHECK-LABEL: func.func @aligned_extui_i4_to_i8(
func.func @aligned_extui_i4_to_i8(%a: vector<8xi4>) -> vector<8xi8> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi4>) -> vector<8xi8> {
@@ -319,8 +392,8 @@ func.func @aligned_extui_i4_to_i32(%a: vector<8xi4>) -> vector<8xi32> {
return %0 : vector<8xi32>
}
-// CHECK-LABEL: func.func @aligned_extui_2d(
-func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
+// CHECK-LABEL: func.func @aligned_extui_i4_2d(
+func.func @aligned_extui_i4_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK-SAME: %[[VAL_0:.*]]: vector<8x32xi4>) -> vector<8x32xi32> {
// CHECK: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<8x16xi8>
// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<15> : vector<8x16xi8>
@@ -333,6 +406,74 @@ func.func @aligned_extui_2d(%a: vector<8x32xi4>) -> vector<8x32xi32> {
return %0 : vector<8x32xi32>
}
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i8(
+func.func @aligned_extui_i2_to_i8(%a: vector<8xi2>) -> vector<8xi8> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi8> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[RESULT:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi8>
+ return %0 : vector<8xi8>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_to_i32(
+func.func @aligned_extui_i2_to_i32(%a: vector<8xi2>) -> vector<8xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi2>) -> vector<8xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<2xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<2xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<2xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<2xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8xi2> to vector<2xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<2xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<2xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<2xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<2xi8>
+// CHECK: %[[INTERLEAVE:.*]] = vector.interleave %[[INTERLEAVE02]], %[[INTERLEAVE13]] : vector<4xi8>
+// CHECK: %[[RESULT:.*]] = arith.extui %[[INTERLEAVE]] : vector<8xi8> to vector<8xi32>
+ %0 = arith.extui %a : vector<8xi2> to vector<8xi32>
+ return %0 : vector<8xi32>
+}
+
+// CHECK-LABEL: func.func @aligned_extui_i2_2d(
+func.func @aligned_extui_i2_2d(%a: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK-SAME: %[[IN:.*]]: vector<8x32xi2>) -> vector<8x32xi32> {
+// CHECK: %[[CST_6:.*]] = arith.constant dense<6> : vector<8x8xi8>
+// CHECK: %[[CST_4:.*]] = arith.constant dense<4> : vector<8x8xi8>
+// CHECK: %[[CST_2:.*]] = arith.constant dense<2> : vector<8x8xi8>
+// CHECK: %[[LOWBITS_MASK:.*]] = arith.constant dense<3> : vector<8x8xi8>
+// CHECK: %[[BITCAST:.*]] = vector.bitcast %[[IN]] : vector<8x32xi2> to vector<8x8xi8>
+// CHECK: %[[ELEM0:.*]] = arith.andi %[[BITCAST]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_2:.*]] = arith.shrui %[[BITCAST]], %[[CST_2]] : vector<8x8xi8>
+// CHECK: %[[ELEM1:.*]] = arith.andi %[[SHR_2]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_4:.*]] = arith.shrui %[[BITCAST]], %[[CST_4]] : vector<8x8xi8>
+// CHECK: %[[ELEM2:.*]] = arith.andi %[[SHR_4]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[SHR_6:.*]] = arith.shrui %[[BITCAST]], %[[CST_6]] : vector<8x8xi8>
+// CHECK: %[[ELEM3:.*]] = arith.andi %[[SHR_6]], %[[LOWBITS_MASK]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE02:.*]] = vector.interleave %[[ELEM0]], %[[ELEM2]] : vector<8x8xi8>
+// CHECK: %[[INTERLEAVE13:.*]] = vector.interleave %[[ELEM1]], %[[ELEM3]] : vector<8x8xi...
[truncated]
|
Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion.
I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x.
I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well.