Skip to content

Commit

Permalink
[FIRRTL] FoldRegMems: insert new ops into same block as memory (#7868)
Browse files Browse the repository at this point in the history
Before this PR, FoldRegMems would construct new ops in the "body of the parent
FModuleOp". We need to place these ops in the same block as the memory.

This PR fixes a bug where, when a memory under a layerblock was canonicalized
to a register, the register would be placed at the original location of the
memory (under the layerblock), but its readers would be placed outside the
layerblock, resulting in a dominance checking error.
  • Loading branch information
rwy7 authored Nov 23, 2024
1 parent c259d1b commit b342d31
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 25 deletions.
18 changes: 12 additions & 6 deletions lib/Dialect/FIRRTL/FIRRTLFolds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2870,7 +2870,7 @@ struct FoldRegMems : public mlir::RewritePattern {
if (hasDontTouch(mem) || info.depth != 1)
return failure();

auto memModule = mem->getParentOfType<FModuleOp>();
auto *block = mem->getBlock();

// Find the clock of the register-to-be, all write ports should share it.
Value clock;
Expand Down Expand Up @@ -2926,11 +2926,17 @@ struct FoldRegMems : public mlir::RewritePattern {
}

// Create a new register to store the data.
auto clockWire = rewriter.create<WireOp>(mem.getLoc(), clock.getType());
auto ty = mem.getDataType();
rewriter.setInsertionPointAfterValue(clock);
auto reg = rewriter.create<RegOp>(mem.getLoc(), ty, clock, mem.getName())
auto reg = rewriter
.create<RegOp>(mem.getLoc(), ty, clockWire.getResult(),
mem.getName())
.getResult();

rewriter.setInsertionPointToEnd(block);
rewriter.create<MatchingConnectOp>(mem.getLoc(), clockWire.getResult(),
clock);

// Helper to insert a given number of pipeline stages through registers.
auto pipeline = [&](Value value, Value clock, const Twine &name,
unsigned latency) {
Expand Down Expand Up @@ -2964,7 +2970,7 @@ struct FoldRegMems : public mlir::RewritePattern {
auto portPipeline = [&, port = port](StringRef field, unsigned stages) {
Value value = getPortFieldValue(port, field);
assert(value);
rewriter.setInsertionPointAfterValue(value);
rewriter.setInsertionPointAfterValue(reg);
return pipeline(value, portClock, name + "_" + field, stages);
};

Expand Down Expand Up @@ -2998,7 +3004,7 @@ struct FoldRegMems : public mlir::RewritePattern {

Value en = getPortFieldValue(port, "en");
Value wmode = getPortFieldValue(port, "wmode");
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
rewriter.setInsertionPointToEnd(block);

auto wen = rewriter.create<AndPrimOp>(port.getLoc(), en, wmode);
auto wenPipelined =
Expand All @@ -3010,7 +3016,7 @@ struct FoldRegMems : public mlir::RewritePattern {
}

// Regardless of `writeUnderWrite`, always implement PortOrder.
rewriter.setInsertionPointToEnd(memModule.getBodyBlock());
rewriter.setInsertionPointToEnd(block);
Value next = reg;
for (auto &[data, en, mask] : writes) {
Value masked;
Expand Down
135 changes: 116 additions & 19 deletions test/Dialect/FIRRTL/simplify-mems.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -408,10 +408,21 @@ firrtl.circuit "OneAddressNoMask" {
in %in_rwen: !firrtl.uint<1>,
out %result_read: !firrtl.uint<32>,
out %result_rw: !firrtl.uint<32>) {
%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>
%Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined
{
depth = 1 : i64,
name = "Memory",
portNames = ["read", "rw", "write"],
readLatency = 2 : i32,
writeLatency = 4 : i32
} :
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>

// Pipeline the inputs.
// TODO: It would be good to de-duplicate these either in the pass or in a canonicalizer.

// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>
// CHECK: %Memory_write_en_1 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
Expand All @@ -433,22 +444,6 @@ firrtl.circuit "OneAddressNoMask" {
// CHECK: %Memory_rw_wdata_2 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_2, %Memory_rw_wdata_1 : !firrtl.uint<32>

%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

%Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined
{
depth = 1 : i64,
name = "Memory",
portNames = ["read", "rw", "write"],
readLatency = 2 : i32,
writeLatency = 4 : i32
} :
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>

// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
Expand Down Expand Up @@ -497,3 +492,105 @@ firrtl.circuit "OneAddressNoMask" {
firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
}
}

// -----

// This test ensures that the FoldRegMems canonicalization correctly
// folds memories under layerblocks.
firrtl.circuit "Rewrite1ElementMemoryToRegisterUnderLayerblock" {
firrtl.layer @A bind {}

firrtl.module public @Rewrite1ElementMemoryToRegisterUnderLayerblock(
in %clock: !firrtl.clock,
in %addr: !firrtl.uint<1>,
in %in_data: !firrtl.uint<32>,
in %wmode_rw: !firrtl.uint<1>,
in %in_wen: !firrtl.uint<1>,
in %in_rwen: !firrtl.uint<1>) {

%c1_ui1 = firrtl.constant 1 : !firrtl.uint<1>

// CHECK firrtl.layerblock @A
firrtl.layerblock @A {
// CHECK: %result_read = firrtl.wire : !firrtl.uint<32>
// CHECK: %result_rw = firrtl.wire : !firrtl.uint<32>
%result_read = firrtl.wire : !firrtl.uint<32>
%result_rw = firrtl.wire : !firrtl.uint<32>

%Memory_read, %Memory_rw, %Memory_write = firrtl.mem Undefined
{
depth = 1 : i64,
name = "Memory",
portNames = ["read", "rw", "write"],
readLatency = 2 : i32,
writeLatency = 2 : i32
} :
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>,
!firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>

// CHECK: %Memory = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: %Memory_write_mask_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_mask_0, %c1_ui1 : !firrtl.uint<1>

// CHECK: %Memory_write_en_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_write_en_0, %in_wen : !firrtl.uint<1>

// CHECK: %Memory_write_data_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_write_data_0, %in_data : !firrtl.uint<32>

// CHECK: %Memory_rw_wmask_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wmask_0, %c1_ui1 : !firrtl.uint<1>

// CHECK: %Memory_rw_wdata_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory_rw_wdata_0, %in_data : !firrtl.uint<32>

// CHECK: firrtl.matchingconnect %result_read, %Memory : !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %result_rw, %Memory : !firrtl.uint<32>

// CHECK: %0 = firrtl.and %in_rwen, %wmode_rw : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %Memory_rw_wen_0 = firrtl.reg %clock : !firrtl.clock, !firrtl.uint<1>
// CHECK: firrtl.matchingconnect %Memory_rw_wen_0, %0 : !firrtl.uint<1>
// CHECK: %1 = firrtl.and %Memory_rw_wen_0, %Memory_rw_wmask_0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %2 = firrtl.mux(%1, %Memory_rw_wdata_0, %Memory) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: %3 = firrtl.and %Memory_write_en_0, %Memory_write_mask_0 : (!firrtl.uint<1>, !firrtl.uint<1>) -> !firrtl.uint<1>
// CHECK: %4 = firrtl.mux(%3, %Memory_write_data_0, %2) : (!firrtl.uint<1>, !firrtl.uint<32>, !firrtl.uint<32>) -> !firrtl.uint<32>
// CHECK: firrtl.matchingconnect %Memory, %4 : !firrtl.uint<32>

%read_addr = firrtl.subfield %Memory_read[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%read_en = firrtl.subfield %Memory_read[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_en, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
%read_clk = firrtl.subfield %Memory_read[clk] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %read_clk, %clock : !firrtl.clock, !firrtl.clock
%read_data = firrtl.subfield %Memory_read[data] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data flip: uint<32>>
firrtl.connect %result_read, %read_data : !firrtl.uint<32>, !firrtl.uint<32>

%rw_addr = firrtl.subfield %Memory_rw[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%rw_en = firrtl.subfield %Memory_rw[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_en, %in_rwen : !firrtl.uint<1>, !firrtl.uint<1>
%rw_clk = firrtl.subfield %Memory_rw[clk] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_clk, %clock : !firrtl.clock, !firrtl.clock
%rw_rdata = firrtl.subfield %Memory_rw[rdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %result_rw, %rw_rdata : !firrtl.uint<32>, !firrtl.uint<32>
%rw_wmode = firrtl.subfield %Memory_rw[wmode] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmode, %wmode_rw : !firrtl.uint<1>, !firrtl.uint<1>
%rw_wdata = firrtl.subfield %Memory_rw[wdata] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wdata, %in_data : !firrtl.uint<32>, !firrtl.uint<32>
%rw_wmask = firrtl.subfield %Memory_rw[wmask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, rdata flip: uint<32>, wmode: uint<1>, wdata: uint<32>, wmask: uint<1>>
firrtl.connect %rw_wmask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>

%write_addr = firrtl.subfield %Memory_write[addr] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_addr, %addr : !firrtl.uint<1>, !firrtl.uint<1>
%write_en = firrtl.subfield %Memory_write[en] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_en, %in_wen : !firrtl.uint<1>, !firrtl.uint<1>
%write_clk = firrtl.subfield %Memory_write[clk] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_clk, %clock : !firrtl.clock, !firrtl.clock
%write_data = firrtl.subfield %Memory_write[data] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_data, %in_data : !firrtl.uint<32>, !firrtl.uint<32>
%write_mask = firrtl.subfield %Memory_write[mask] : !firrtl.bundle<addr: uint<1>, en: uint<1>, clk: clock, data: uint<32>, mask: uint<1>>
firrtl.connect %write_mask, %c1_ui1 : !firrtl.uint<1>, !firrtl.uint<1>
}
}
}

0 comments on commit b342d31

Please sign in to comment.