diff --git a/lib/Dialect/Comb/CombFolds.cpp b/lib/Dialect/Comb/CombFolds.cpp index 92aa494dc789..691ae579c9cb 100644 --- a/lib/Dialect/Comb/CombFolds.cpp +++ b/lib/Dialect/Comb/CombFolds.cpp @@ -1178,107 +1178,6 @@ OpFoldResult OrOp::fold(FoldAdaptor adaptor) { return constFoldAssociativeOp(inputs, hw::PEO::Or); } -/// Simplify concat ops in an or op when a constant operand is present in either -/// concat. -/// -/// This will invert an or(concat, concat) into concat(or, or, ...), which can -/// often be further simplified due to the smaller or ops being easier to fold. -/// -/// For example: -/// -/// or(..., concat(x, 0), concat(0, y)) -/// ==> or(..., concat(x, 0, y)), when x and y don't overlap. -/// -/// or(..., concat(x: i2, cst1: i4), concat(cst2: i5, y: i1)) -/// ==> or(..., concat(or(x: i2, extract(cst2, 4..3)), -/// or(extract(cst1, 3..1), extract(cst2, 2..0)), -/// or(extract(cst1, 0..0), y: i1)) -static bool canonicalizeOrOfConcatsWithCstOperands(OrOp op, size_t concatIdx1, - size_t concatIdx2, - PatternRewriter &rewriter) { - assert(concatIdx1 < concatIdx2 && "concatIdx1 must be < concatIdx2"); - - auto inputs = op.getInputs(); - auto concat1 = inputs[concatIdx1].getDefiningOp(); - auto concat2 = inputs[concatIdx2].getDefiningOp(); - - assert(concat1 && concat2 && "expected indexes to point to ConcatOps"); - - // We can simplify as long as a constant is present in either concat. - bool hasConstantOp1 = - llvm::any_of(concat1->getOperands(), [&](Value operand) -> bool { - return operand.getDefiningOp(); - }); - if (!hasConstantOp1) { - bool hasConstantOp2 = - llvm::any_of(concat2->getOperands(), [&](Value operand) -> bool { - return operand.getDefiningOp(); - }); - if (!hasConstantOp2) - return false; - } - - SmallVector newConcatOperands; - - // Simultaneously iterate over the operands of both concat ops, from MSB to - // LSB, pushing out or's of overlapping ranges of the operands. When operands - // span different bit ranges, we extract only the maximum overlap. - auto operands1 = concat1->getOperands(); - auto operands2 = concat2->getOperands(); - // Number of bits already consumed from operands 1 and 2, respectively. - unsigned consumedWidth1 = 0; - unsigned consumedWidth2 = 0; - for (auto it1 = operands1.begin(), end1 = operands1.end(), - it2 = operands2.begin(), end2 = operands2.end(); - it1 != end1 && it2 != end2;) { - auto operand1 = *it1; - auto operand2 = *it2; - - unsigned remainingWidth1 = - hw::getBitWidth(operand1.getType()) - consumedWidth1; - unsigned remainingWidth2 = - hw::getBitWidth(operand2.getType()) - consumedWidth2; - unsigned widthToConsume = std::min(remainingWidth1, remainingWidth2); - auto narrowedType = rewriter.getIntegerType(widthToConsume); - - auto extract1 = rewriter.createOrFold( - op.getLoc(), narrowedType, operand1, remainingWidth1 - widthToConsume); - auto extract2 = rewriter.createOrFold( - op.getLoc(), narrowedType, operand2, remainingWidth2 - widthToConsume); - - newConcatOperands.push_back( - rewriter.createOrFold(op.getLoc(), extract1, extract2, false)); - - consumedWidth1 += widthToConsume; - consumedWidth2 += widthToConsume; - - if (widthToConsume == remainingWidth1) { - ++it1; - consumedWidth1 = 0; - } - if (widthToConsume == remainingWidth2) { - ++it2; - consumedWidth2 = 0; - } - } - - ConcatOp newOp = rewriter.create(op.getLoc(), newConcatOperands); - - // Copy the old operands except for concatIdx1 and concatIdx2, and append the - // new ConcatOp to the end. - SmallVector newOrOperands; - newOrOperands.append(inputs.begin(), inputs.begin() + concatIdx1); - newOrOperands.append(inputs.begin() + concatIdx1 + 1, - inputs.begin() + concatIdx2); - newOrOperands.append(inputs.begin() + concatIdx2 + 1, - inputs.begin() + inputs.size()); - newOrOperands.push_back(newOp); - - replaceOpWithNewOpAndCopyName(rewriter, op, op.getType(), - newOrOperands); - return true; -} - LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { auto inputs = op.getInputs(); auto size = inputs.size(); @@ -1328,16 +1227,6 @@ LogicalResult OrOp::canonicalize(OrOp op, PatternRewriter &rewriter) { } } - // or(..., concat(x, cst1), concat(cst2, y) - // ==> or(..., concat(x, cst3, y)), when x and y don't overlap. - for (size_t i = 0; i < size - 1; ++i) { - if (auto concat = inputs[i].getDefiningOp()) - for (size_t j = i + 1; j < size; ++j) - if (auto concat = inputs[j].getDefiningOp()) - if (canonicalizeOrOfConcatsWithCstOperands(op, i, j, rewriter)) - return success(); - } - // extracts only of or(...) -> or(extract()...) if (narrowOperationWidth(op, true, rewriter)) return success(); diff --git a/test/Dialect/Comb/canonicalization.mlir b/test/Dialect/Comb/canonicalization.mlir index 8c1a3f4345dc..e7cda4bd24c4 100644 --- a/test/Dialect/Comb/canonicalization.mlir +++ b/test/Dialect/Comb/canonicalization.mlir @@ -181,87 +181,6 @@ hw.module @dedupLong(in %arg0 : i7, in %arg1 : i7, in %arg2: i7, out resAnd: i7, hw.output %0, %1 : i7, i7 } -// CHECK-LABEL: hw.module @orExclusiveConcats -hw.module @orExclusiveConcats(in %arg0 : i6, in %arg1 : i2, out o: i9) { - // CHECK-NEXT: %false = hw.constant false - // CHECK-NEXT: %0 = comb.concat %arg1, %false, %arg0 : i2, i1, i6 - // CHECK-NEXT: hw.output %0 : i9 - %c0 = hw.constant 0 : i3 - %0 = comb.concat %c0, %arg0 : i3, i6 - %c1 = hw.constant 0 : i7 - %1 = comb.concat %arg1, %c1 : i2, i7 - %2 = comb.or %0, %1 : i9 - hw.output %2 : i9 -} - -// When two concats are or'd together and have mutually-exclusive fields, they -// can be merged together into a single concat. -// concat0: 0aaa aaa0 0000 0bb0 -// concat1: 0000 0000 ccdd d000 -// merged: 0aaa aaa0 ccdd dbb0 -// CHECK-LABEL: hw.module @orExclusiveConcats2 -hw.module @orExclusiveConcats2(in %arg0 : i6, in %arg1 : i2, in %arg2: i2, in %arg3: i3, out o: i16) { - // CHECK-NEXT: %false = hw.constant false - // CHECK-NEXT: %0 = comb.concat %false, %arg0, %false, %arg2, %arg3, %arg1, %false : i1, i6, i1, i2, i3, i2, i1 - // CHECK-NEXT: hw.output %0 : i16 - %c0 = hw.constant 0 : i1 - %c1 = hw.constant 0 : i6 - %c2 = hw.constant 0 : i1 - %0 = comb.concat %c0, %arg0, %c1, %arg1, %c2: i1, i6, i6, i2, i1 - %c3 = hw.constant 0 : i8 - %c4 = hw.constant 0 : i3 - %1 = comb.concat %c3, %arg2, %arg3, %c4 : i8, i2, i3, i3 - %2 = comb.or %0, %1 : i16 - hw.output %2 : i16 -} - -// When two concats are or'd together and have mutually-exclusive fields, they -// can be merged together into a single concat. -// concat0: aaaa 1111 -// concat1: 1111 10bb -// merged: 1111 1111 -// CHECK-LABEL: hw.module @orExclusiveConcats3 -hw.module @orExclusiveConcats3(in %arg0 : i4, in %arg1 : i2, out o: i8) { - // CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant -1 : i8 - // CHECK-NEXT: hw.output [[RES]] : i8 - %c0 = hw.constant -1 : i4 - %0 = comb.concat %arg0, %c0: i4, i4 - %c1 = hw.constant -1 : i5 - %c2 = hw.constant 0 : i1 - %1 = comb.concat %c1, %c2, %arg1 : i5, i1, i2 - %2 = comb.or %0, %1 : i8 - hw.output %2 : i8 -} - -// CHECK-LABEL: hw.module @orMultipleExclusiveConcats -hw.module @orMultipleExclusiveConcats(in %arg0 : i2, in %arg1 : i2, in %arg2: i2, out o: i6) { - // CHECK-NEXT: %0 = comb.concat %arg0, %arg1, %arg2 : i2, i2, i2 - // CHECK-NEXT: hw.output %0 : i6 - %c2 = hw.constant 0 : i2 - %c4 = hw.constant 0 : i4 - %0 = comb.concat %arg0, %c4: i2, i4 - %1 = comb.concat %c2, %arg1, %c2: i2, i2, i2 - %2 = comb.concat %c4, %arg2: i4, i2 - %out = comb.or %0, %1, %2 : i6 - hw.output %out : i6 -} - -// CHECK-LABEL: hw.module @orConcatsWithMux -hw.module @orConcatsWithMux(in %bit: i1, in %cond: i1, out o: i6) { - // CHECK-NEXT: [[RES:%[a-z0-9_-]+]] = hw.constant 0 : i4 - // CHECK-NEXT: %0 = comb.concat [[RES]], %cond, %bit : i4, i1, i1 - // CHECK-NEXT: hw.output %0 : i6 - %c0 = hw.constant 0 : i5 - %0 = comb.concat %c0, %bit: i5, i1 - %c1 = hw.constant 0 : i4 - %c2 = hw.constant 2 : i2 - %c3 = hw.constant 0 : i2 - %1 = comb.mux %cond, %c2, %c3 : i2 - %2 = comb.concat %c1, %1 : i4, i2 - %3 = comb.or %0, %2 : i6 - hw.output %3 : i6 -} - // CHECK-LABEL: @extractNested hw.module @extractNested(in %0: i5, out o1 : i1) { // Multiple layers of nested extract is a weak evidence that the cannonicalization