Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Verif] LowerContractsPass #7870

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
8 changes: 8 additions & 0 deletions include/circt/Dialect/Verif/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -42,4 +42,12 @@ def LowerFormalToHWPass : Pass<"lower-formal-to-hw", "mlir::ModuleOp"> {
}];
}

def LowerContractsPass : Pass<"lower-contracts", "mlir::ModuleOp"> {
let summary = "Lower contracts into formal tests";
let description = [{
Converts `hw.module` ops containing a `verif.contract` into a
`verif.formal` op.
Comment on lines +48 to +49
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be nice to have a slightly more detailed explanation (and maybe an example) here (it's essentially already in the rationale, but it might be nice to repeat the most important points here.

}];
}

#endif // CIRCT_DIALECT_VERIF_PASSES_TD
1 change: 1 addition & 0 deletions lib/Dialect/Verif/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ add_circt_dialect_library(CIRCTVerifTransforms
VerifyClockedAssertLike.cpp
PrepareForFormal.cpp
LowerFormalToHW.cpp
LowerContracts.cpp

DEPENDS
CIRCTVerifTransformsIncGen
Expand Down
109 changes: 109 additions & 0 deletions lib/Dialect/Verif/Transforms/LowerContracts.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
//===- LowerContracts.cpp - Formal Preparations --*- C++ -*----------------===//
leonardt marked this conversation as resolved.
Show resolved Hide resolved
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//===----------------------------------------------------------------------===//
leonardt marked this conversation as resolved.
Show resolved Hide resolved
//
// Lower contracts into verif.formal tests.
//
//===----------------------------------------------------------------------===//
#include "circt/Dialect/Verif/VerifOps.h"
leonardt marked this conversation as resolved.
Show resolved Hide resolved
#include "circt/Dialect/Verif/VerifPasses.h"

#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace circt;

namespace circt {
namespace verif {
#define GEN_PASS_DEF_LOWERCONTRACTSPASS
#include "circt/Dialect/Verif/Passes.h.inc"
} // namespace verif
} // namespace circt

using namespace mlir;
using namespace verif;
using namespace hw;

namespace {
struct LowerContractsPass
: verif::impl::LowerContractsPassBase<LowerContractsPass> {
void runOnOperation() override;
};

template <typename FROM, typename TO>
void replaceContractOp(PatternRewriter &rewriter, Block *body) {
for (auto op : llvm::make_early_inc_range(body->getOps<FROM>())) {
leonardt marked this conversation as resolved.
Show resolved Hide resolved
auto label = op.getLabel();
StringAttr labelAttr;
if (label) {
labelAttr = rewriter.getStringAttr(label.value());
}
leonardt marked this conversation as resolved.
Show resolved Hide resolved
rewriter.replaceOpWithNewOp<TO>(op, op.getProperty(), op.getEnable(),
labelAttr);
}
}

struct HWOpRewritePattern : public OpRewritePattern<HWModuleOp> {
using OpRewritePattern<HWModuleOp>::OpRewritePattern;

LogicalResult matchAndRewrite(HWModuleOp op,
PatternRewriter &rewriter) const override {
auto formalOp = rewriter.create<verif::FormalOp>(
op.getLoc(), op.getNameAttr(), rewriter.getDictionaryAttr({}));
leonardt marked this conversation as resolved.
Show resolved Hide resolved

// Clone module body into fomal op body
leonardt marked this conversation as resolved.
Show resolved Hide resolved
rewriter.cloneRegionBefore(op.getRegion(), formalOp.getBody(),
formalOp.getBody().end());

auto *bodyBlock = &formalOp.getBody().front();

// Erase hw.output
rewriter.eraseOp(bodyBlock->getTerminator());

// Convert block args to symbolic values
rewriter.setInsertionPointToStart(bodyBlock);
for (auto arg : llvm::make_early_inc_range(bodyBlock->getArguments())) {
auto sym =
rewriter.create<verif::SymbolicValueOp>(arg.getLoc(), arg.getType());
rewriter.replaceAllUsesWith(arg, sym);
}
bodyBlock->eraseArguments(0, bodyBlock->getNumArguments());

// Inline contract ops
for (auto contractOp :
llvm::make_early_inc_range(bodyBlock->getOps<verif::ContractOp>())) {

// Convert ensure to assert, require to assume
rewriter.setInsertionPointToEnd(&contractOp.getBody().front());
Block *contractBlock = &contractOp.getBody().front();
leonardt marked this conversation as resolved.
Show resolved Hide resolved
replaceContractOp<EnsureOp, AssertOp>(rewriter, contractBlock);
replaceContractOp<RequireOp, AssumeOp>(rewriter, contractBlock);
leonardt marked this conversation as resolved.
Show resolved Hide resolved

// Inline body
rewriter.inlineBlockBefore(&contractOp.getBody().front(),
leonardt marked this conversation as resolved.
Show resolved Hide resolved
&formalOp.getBody().front(),
formalOp.getBody().front().end());
leonardt marked this conversation as resolved.
Show resolved Hide resolved

// Replace results with inputs and erase
for (auto [input, result] :
llvm::zip(contractOp.getResults(), contractOp.getInputs())) {
rewriter.replaceAllUsesWith(input, result);
}
rewriter.eraseOp(contractOp);
}

rewriter.eraseOp(op);
leonardt marked this conversation as resolved.
Show resolved Hide resolved
return success();
}
};
} // namespace

void LowerContractsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
patterns.add<HWOpRewritePattern>(patterns.getContext());

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are only matching on modules, you can also just have a for loop iterating over the modules, which would be a bit more efficient.

}
2 changes: 1 addition & 1 deletion llvm
leonardt marked this conversation as resolved.
Show resolved Hide resolved
Submodule llvm updated 8433 files
118 changes: 118 additions & 0 deletions test/Dialect/Verif/lower-contracts.mlir
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file contains tests that are great as integration tests, but for regression tests we want to check the corner cases and keep the tests small (they don't have to do something that makes sense from a design implementation/logics perspective). For example, there's no test of a module with more than one contract inside, with zero contracts inside. Deleting the module also only works because you have no other module instantiating it.

Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
// RUN: circt-opt --lower-contracts %s | FileCheck %s


// CHECK: verif.formal @Mul9 {
// CHECK: %c9_i42 = hw.constant 9 : i42
// CHECK: %c3_i42 = hw.constant 3 : i42
// CHECK: %0 = verif.symbolic_value : i42
// CHECK: %1 = comb.shl %0, %c3_i42 : i42
// CHECK: %2 = comb.add %0, %1 : i42
// CHECK: %3 = comb.mul %0, %c9_i42 : i42
// CHECK: %4 = comb.icmp eq %2, %3 : i42
// CHECK: verif.assert %4 : i1
// CHECK: }

hw.module @Mul9(in %a: i42, out z: i42) {
%c3_i42 = hw.constant 3 : i42
%c9_i42 = hw.constant 9 : i42
%0 = comb.shl %a, %c3_i42 : i42 // 8*a
%1 = comb.add %a, %0 : i42 // a + 8*a
%2 = verif.contract %1 : i42 {
%3 = comb.mul %a, %c9_i42 : i42 // 9*a
%4 = comb.icmp eq %2, %3 : i42 // 9*a == a + 8*a
verif.ensure %4 : i1
}
hw.output %2 : i42
}

// CHECK: verif.formal @CarrySaveCompress3to2 {
// CHECK: %c1_i42 = hw.constant 1 : i42
// CHECK: %0 = verif.symbolic_value : i42
// CHECK: %1 = verif.symbolic_value : i42
// CHECK: %2 = verif.symbolic_value : i42
// CHECK: %3 = comb.xor %0, %1, %2 : i42
// CHECK: %4 = comb.and %0, %1 : i42
// CHECK: %5 = comb.or %0, %1 : i42
// CHECK: %6 = comb.and %5, %2 : i42
// CHECK: %7 = comb.or %4, %6 : i42
// CHECK: %8 = comb.shl %7, %c1_i42 : i42
// CHECK: %9 = comb.add %0, %1, %2 : i42
// CHECK: %10 = comb.add %3, %8 : i42
// CHECK: %11 = comb.icmp eq %9, %10 : i42
// CHECK: verif.assert %11 : i1
// CHECK: }

hw.module @CarrySaveCompress3to2(
in %a0: i42, in %a1: i42, in %a2: i42,
out z0: i42, out z1: i42
) {
%c1_i42 = hw.constant 1 : i42
%0 = comb.xor %a0, %a1, %a2 : i42 // sum bits of FA (a0^a1^a2)
%1 = comb.and %a0, %a1 : i42
%2 = comb.or %a0, %a1 : i42
%3 = comb.and %2, %a2 : i42
%4 = comb.or %1, %3 : i42 // carry bits of FA (a0&a1 | a2&(a0|a1))
%5 = comb.shl %4, %c1_i42 : i42 // %5 = carry << 1
%z0, %z1 = verif.contract %0, %5 : i42, i42 {
%inputSum = comb.add %a0, %a1, %a2 : i42
%outputSum = comb.add %z0, %z1 : i42
%6 = comb.icmp eq %inputSum, %outputSum : i42
verif.ensure %6 : i1
}
hw.output %z0, %z1 : i42, i42
}

// CHECK: verif.formal @ShiftLeft {
// CHECK: %c1_i8 = hw.constant 1 : i8
// CHECK: %c2_i8 = hw.constant 2 : i8
// CHECK: %c4_i8 = hw.constant 4 : i8
// CHECK: %c8_i8 = hw.constant 8 : i8
// CHECK: %0 = verif.symbolic_value : i8
// CHECK: %1 = verif.symbolic_value : i8
// CHECK: %2 = comb.extract %1 from 2 : (i8) -> i1
// CHECK: %3 = comb.extract %1 from 1 : (i8) -> i1
// CHECK: %4 = comb.extract %1 from 0 : (i8) -> i1
// CHECK: %5 = comb.shl %0, %c4_i8 : i8
// CHECK: %6 = comb.mux %2, %5, %0 : i8
// CHECK: %7 = comb.shl %6, %c2_i8 : i8
// CHECK: %8 = comb.mux %3, %7, %6 : i8
// CHECK: %9 = comb.shl %8, %c1_i8 : i8
// CHECK: %10 = comb.mux %4, %9, %8 : i8
// CHECK: %11 = comb.icmp ult %1, %c8_i8 : i8
// CHECK: %12 = comb.shl %0, %1 : i8
// CHECK: %13 = comb.icmp eq %10, %12 : i8
// CHECK: verif.assert %13 : i1
// CHECK: verif.assume %11 : i1
// CHECK: }

hw.module @ShiftLeft(in %a: i8, in %b: i8, out z: i8) {
%c4_i8 = hw.constant 4 : i8
%c2_i8 = hw.constant 2 : i8
%c1_i8 = hw.constant 1 : i8
%b2 = comb.extract %b from 2 : (i8) -> i1
%b1 = comb.extract %b from 1 : (i8) -> i1
%b0 = comb.extract %b from 0 : (i8) -> i1
%0 = comb.shl %a, %c4_i8 : i8
%1 = comb.mux %b2, %0, %a : i8
%2 = comb.shl %1, %c2_i8 : i8
%3 = comb.mux %b1, %2, %1 : i8
%4 = comb.shl %3, %c1_i8 : i8
%5 = comb.mux %b0, %4, %3 : i8

// Contract to check that the multiplexers and constant shifts above indeed
// produce the correct shift by 0 to 7 places, assuming the shift amount is
// less than 8 (we can't shift a number out).
%z = verif.contract %5 : i8 {
// Shift amount must be less than 8.
%c8_i8 = hw.constant 8 : i8
%blt8 = comb.icmp ult %b, %c8_i8 : i8
verif.require %blt8 : i1

// In that case the mux tree computes the correct left-shift.
%ashl = comb.shl %a, %b : i8
%eq = comb.icmp eq %z, %ashl : i8
verif.ensure %eq : i1
}

hw.output %z : i8
}
Loading