-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Verif] LowerContractsPass #7870
base: main
Are you sure you want to change the base?
Changes from 9 commits
37f1d02
96cc931
2efbf69
b55cb86
6cfe656
a019081
1b97286
3f7baf8
adf8f3d
2a214e6
00d0963
53bc0c3
56d6f01
079d477
64f55cf
c0689d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
//===- LowerContracts.cpp - Formal Preparations --*- C++ -*----------------===// | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
//===----------------------------------------------------------------------===// | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
// | ||
// Lower contracts into verif.formal tests. | ||
// | ||
//===----------------------------------------------------------------------===// | ||
#include "circt/Dialect/Verif/VerifOps.h" | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
#include "circt/Dialect/Verif/VerifPasses.h" | ||
|
||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" | ||
|
||
using namespace circt; | ||
|
||
namespace circt { | ||
namespace verif { | ||
#define GEN_PASS_DEF_LOWERCONTRACTSPASS | ||
#include "circt/Dialect/Verif/Passes.h.inc" | ||
} // namespace verif | ||
} // namespace circt | ||
|
||
using namespace mlir; | ||
using namespace verif; | ||
using namespace hw; | ||
|
||
namespace { | ||
struct LowerContractsPass | ||
: verif::impl::LowerContractsPassBase<LowerContractsPass> { | ||
void runOnOperation() override; | ||
}; | ||
|
||
template <typename FROM, typename TO> | ||
void replaceContractOp(PatternRewriter &rewriter, Block *body) { | ||
for (auto op : llvm::make_early_inc_range(body->getOps<FROM>())) { | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
auto label = op.getLabel(); | ||
StringAttr labelAttr; | ||
if (label) { | ||
labelAttr = rewriter.getStringAttr(label.value()); | ||
} | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rewriter.replaceOpWithNewOp<TO>(op, op.getProperty(), op.getEnable(), | ||
labelAttr); | ||
} | ||
} | ||
|
||
struct HWOpRewritePattern : public OpRewritePattern<HWModuleOp> { | ||
using OpRewritePattern<HWModuleOp>::OpRewritePattern; | ||
|
||
LogicalResult matchAndRewrite(HWModuleOp op, | ||
PatternRewriter &rewriter) const override { | ||
auto formalOp = rewriter.create<verif::FormalOp>( | ||
op.getLoc(), op.getNameAttr(), rewriter.getDictionaryAttr({})); | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Clone module body into fomal op body | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
rewriter.cloneRegionBefore(op.getRegion(), formalOp.getBody(), | ||
formalOp.getBody().end()); | ||
|
||
auto *bodyBlock = &formalOp.getBody().front(); | ||
|
||
// Erase hw.output | ||
rewriter.eraseOp(bodyBlock->getTerminator()); | ||
|
||
// Convert block args to symbolic values | ||
rewriter.setInsertionPointToStart(bodyBlock); | ||
for (auto arg : llvm::make_early_inc_range(bodyBlock->getArguments())) { | ||
auto sym = | ||
rewriter.create<verif::SymbolicValueOp>(arg.getLoc(), arg.getType()); | ||
rewriter.replaceAllUsesWith(arg, sym); | ||
} | ||
bodyBlock->eraseArguments(0, bodyBlock->getNumArguments()); | ||
|
||
// Inline contract ops | ||
for (auto contractOp : | ||
llvm::make_early_inc_range(bodyBlock->getOps<verif::ContractOp>())) { | ||
|
||
// Convert ensure to assert, require to assume | ||
rewriter.setInsertionPointToEnd(&contractOp.getBody().front()); | ||
Block *contractBlock = &contractOp.getBody().front(); | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
replaceContractOp<EnsureOp, AssertOp>(rewriter, contractBlock); | ||
replaceContractOp<RequireOp, AssumeOp>(rewriter, contractBlock); | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Inline body | ||
rewriter.inlineBlockBefore(&contractOp.getBody().front(), | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
&formalOp.getBody().front(), | ||
formalOp.getBody().front().end()); | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
// Replace results with inputs and erase | ||
for (auto [input, result] : | ||
llvm::zip(contractOp.getResults(), contractOp.getInputs())) { | ||
rewriter.replaceAllUsesWith(input, result); | ||
} | ||
rewriter.eraseOp(contractOp); | ||
} | ||
|
||
rewriter.eraseOp(op); | ||
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return success(); | ||
} | ||
}; | ||
} // namespace | ||
|
||
void LowerContractsPass::runOnOperation() { | ||
RewritePatternSet patterns(&getContext()); | ||
patterns.add<HWOpRewritePattern>(patterns.getContext()); | ||
|
||
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) | ||
signalPassFailure(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since you are only matching on modules, you can also just have a |
||
} |
leonardt marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This file contains tests that are great as integration tests, but for regression tests we want to check the corner cases and keep the tests small (they don't have to do something that makes sense from a design implementation/logics perspective). For example, there's no test of a module with more than one contract inside, with zero contracts inside. Deleting the module also only works because you have no other module instantiating it. |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
// RUN: circt-opt --lower-contracts %s | FileCheck %s | ||
|
||
|
||
// CHECK: verif.formal @Mul9 { | ||
// CHECK: %c9_i42 = hw.constant 9 : i42 | ||
// CHECK: %c3_i42 = hw.constant 3 : i42 | ||
// CHECK: %0 = verif.symbolic_value : i42 | ||
// CHECK: %1 = comb.shl %0, %c3_i42 : i42 | ||
// CHECK: %2 = comb.add %0, %1 : i42 | ||
// CHECK: %3 = comb.mul %0, %c9_i42 : i42 | ||
// CHECK: %4 = comb.icmp eq %2, %3 : i42 | ||
// CHECK: verif.assert %4 : i1 | ||
// CHECK: } | ||
|
||
hw.module @Mul9(in %a: i42, out z: i42) { | ||
%c3_i42 = hw.constant 3 : i42 | ||
%c9_i42 = hw.constant 9 : i42 | ||
%0 = comb.shl %a, %c3_i42 : i42 // 8*a | ||
%1 = comb.add %a, %0 : i42 // a + 8*a | ||
%2 = verif.contract %1 : i42 { | ||
%3 = comb.mul %a, %c9_i42 : i42 // 9*a | ||
%4 = comb.icmp eq %2, %3 : i42 // 9*a == a + 8*a | ||
verif.ensure %4 : i1 | ||
} | ||
hw.output %2 : i42 | ||
} | ||
|
||
// CHECK: verif.formal @CarrySaveCompress3to2 { | ||
// CHECK: %c1_i42 = hw.constant 1 : i42 | ||
// CHECK: %0 = verif.symbolic_value : i42 | ||
// CHECK: %1 = verif.symbolic_value : i42 | ||
// CHECK: %2 = verif.symbolic_value : i42 | ||
// CHECK: %3 = comb.xor %0, %1, %2 : i42 | ||
// CHECK: %4 = comb.and %0, %1 : i42 | ||
// CHECK: %5 = comb.or %0, %1 : i42 | ||
// CHECK: %6 = comb.and %5, %2 : i42 | ||
// CHECK: %7 = comb.or %4, %6 : i42 | ||
// CHECK: %8 = comb.shl %7, %c1_i42 : i42 | ||
// CHECK: %9 = comb.add %0, %1, %2 : i42 | ||
// CHECK: %10 = comb.add %3, %8 : i42 | ||
// CHECK: %11 = comb.icmp eq %9, %10 : i42 | ||
// CHECK: verif.assert %11 : i1 | ||
// CHECK: } | ||
|
||
hw.module @CarrySaveCompress3to2( | ||
in %a0: i42, in %a1: i42, in %a2: i42, | ||
out z0: i42, out z1: i42 | ||
) { | ||
%c1_i42 = hw.constant 1 : i42 | ||
%0 = comb.xor %a0, %a1, %a2 : i42 // sum bits of FA (a0^a1^a2) | ||
%1 = comb.and %a0, %a1 : i42 | ||
%2 = comb.or %a0, %a1 : i42 | ||
%3 = comb.and %2, %a2 : i42 | ||
%4 = comb.or %1, %3 : i42 // carry bits of FA (a0&a1 | a2&(a0|a1)) | ||
%5 = comb.shl %4, %c1_i42 : i42 // %5 = carry << 1 | ||
%z0, %z1 = verif.contract %0, %5 : i42, i42 { | ||
%inputSum = comb.add %a0, %a1, %a2 : i42 | ||
%outputSum = comb.add %z0, %z1 : i42 | ||
%6 = comb.icmp eq %inputSum, %outputSum : i42 | ||
verif.ensure %6 : i1 | ||
} | ||
hw.output %z0, %z1 : i42, i42 | ||
} | ||
|
||
// CHECK: verif.formal @ShiftLeft { | ||
// CHECK: %c1_i8 = hw.constant 1 : i8 | ||
// CHECK: %c2_i8 = hw.constant 2 : i8 | ||
// CHECK: %c4_i8 = hw.constant 4 : i8 | ||
// CHECK: %c8_i8 = hw.constant 8 : i8 | ||
// CHECK: %0 = verif.symbolic_value : i8 | ||
// CHECK: %1 = verif.symbolic_value : i8 | ||
// CHECK: %2 = comb.extract %1 from 2 : (i8) -> i1 | ||
// CHECK: %3 = comb.extract %1 from 1 : (i8) -> i1 | ||
// CHECK: %4 = comb.extract %1 from 0 : (i8) -> i1 | ||
// CHECK: %5 = comb.shl %0, %c4_i8 : i8 | ||
// CHECK: %6 = comb.mux %2, %5, %0 : i8 | ||
// CHECK: %7 = comb.shl %6, %c2_i8 : i8 | ||
// CHECK: %8 = comb.mux %3, %7, %6 : i8 | ||
// CHECK: %9 = comb.shl %8, %c1_i8 : i8 | ||
// CHECK: %10 = comb.mux %4, %9, %8 : i8 | ||
// CHECK: %11 = comb.icmp ult %1, %c8_i8 : i8 | ||
// CHECK: %12 = comb.shl %0, %1 : i8 | ||
// CHECK: %13 = comb.icmp eq %10, %12 : i8 | ||
// CHECK: verif.assert %13 : i1 | ||
// CHECK: verif.assume %11 : i1 | ||
// CHECK: } | ||
|
||
hw.module @ShiftLeft(in %a: i8, in %b: i8, out z: i8) { | ||
%c4_i8 = hw.constant 4 : i8 | ||
%c2_i8 = hw.constant 2 : i8 | ||
%c1_i8 = hw.constant 1 : i8 | ||
%b2 = comb.extract %b from 2 : (i8) -> i1 | ||
%b1 = comb.extract %b from 1 : (i8) -> i1 | ||
%b0 = comb.extract %b from 0 : (i8) -> i1 | ||
%0 = comb.shl %a, %c4_i8 : i8 | ||
%1 = comb.mux %b2, %0, %a : i8 | ||
%2 = comb.shl %1, %c2_i8 : i8 | ||
%3 = comb.mux %b1, %2, %1 : i8 | ||
%4 = comb.shl %3, %c1_i8 : i8 | ||
%5 = comb.mux %b0, %4, %3 : i8 | ||
|
||
// Contract to check that the multiplexers and constant shifts above indeed | ||
// produce the correct shift by 0 to 7 places, assuming the shift amount is | ||
// less than 8 (we can't shift a number out). | ||
%z = verif.contract %5 : i8 { | ||
// Shift amount must be less than 8. | ||
%c8_i8 = hw.constant 8 : i8 | ||
%blt8 = comb.icmp ult %b, %c8_i8 : i8 | ||
verif.require %blt8 : i1 | ||
|
||
// In that case the mux tree computes the correct left-shift. | ||
%ashl = comb.shl %a, %b : i8 | ||
%eq = comb.icmp eq %z, %ashl : i8 | ||
verif.ensure %eq : i1 | ||
} | ||
|
||
hw.output %z : i8 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be nice to have a slightly more detailed explanation (and maybe an example) here (it's essentially already in the rationale, but it might be nice to repeat the most important points here.