Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RFC][CIR] Lower cir.bool to i1 #1158

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clang/lib/CIR/CodeGen/CIRGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -820,7 +820,7 @@ void CIRGenModule::replaceGlobal(cir::GlobalOp Old, cir::GlobalOp New) {
mlir::Type ptrTy = builder.getPointerTo(OldTy);
mlir::Value cast =
builder.createBitcast(GGO->getLoc(), UseOpResultValue, ptrTy);
UseOpResultValue.replaceAllUsesExcept(cast, {cast.getDefiningOp()});
UseOpResultValue.replaceAllUsesExcept(cast, cast.getDefiningOp());
}
}
}
Expand Down
267 changes: 166 additions & 101 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp

Large diffs are not rendered by default.

35 changes: 26 additions & 9 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,18 @@
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/Interfaces/DataLayoutInterfaces.h"
#include "mlir/Transforms/DialectConversion.h"

namespace cir {
namespace direct {

/// Convert a CIR attribute to an LLVM attribute. May use the datalayout for
/// lowering attributes to-be-stored in memory.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter);
const mlir::TypeConverter *converter,
mlir::DataLayout const &dataLayout);

mlir::LLVM::Linkage convertLinkage(cir::GlobalLinkageKind linkage);

Expand Down Expand Up @@ -302,12 +307,15 @@ class CIRToLLVMAllocaOpLowering

class CIRToLLVMLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMLoadOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {}

mlir::LogicalResult
matchAndRewrite(cir::LoadOp op, OpAdaptor,
Expand All @@ -317,12 +325,15 @@ class CIRToLLVMLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
class CIRToLLVMStoreOpLowering
: public mlir::OpConversionPattern<cir::StoreOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMStoreOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {}
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {}

mlir::LogicalResult
matchAndRewrite(cir::StoreOp op, OpAdaptor,
Expand All @@ -332,12 +343,15 @@ class CIRToLLVMStoreOpLowering
class CIRToLLVMConstantOpLowering
: public mlir::OpConversionPattern<cir::ConstantOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMConstantOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {
setHasBoundedRewriteRecursion();
}

Expand Down Expand Up @@ -538,12 +552,15 @@ class CIRToLLVMSwitchFlatOpLowering
class CIRToLLVMGlobalOpLowering
: public mlir::OpConversionPattern<cir::GlobalOp> {
cir::LowerModule *lowerMod;
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMGlobalOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
cir::LowerModule *lowerModule)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule) {
cir::LowerModule *lowerModule,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), lowerMod(lowerModule),
dataLayout(dataLayout) {
setHasBoundedRewriteRecursion();
}

Expand Down
5 changes: 1 addition & 4 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRLoopToSCF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,11 +337,8 @@ class CIRConditionOpLowering
auto *parentOp = op->getParentOp();
return llvm::TypeSwitch<mlir::Operation *, mlir::LogicalResult>(parentOp)
.Case<mlir::scf::WhileOp>([&](auto) {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::scf::ConditionOp>(
op, i1Condition, parentOp->getOperands());
op, adaptor.getCondition(), parentOp->getOperands());
return mlir::success();
})
.Default([](auto) { return mlir::failure(); });
Expand Down
123 changes: 80 additions & 43 deletions clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include "mlir/IR/Operation.h"
#include "mlir/IR/Region.h"
#include "mlir/IR/TypeRange.h"
#include "mlir/IR/Value.h"
#include "mlir/IR/ValueRange.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
Expand Down Expand Up @@ -105,15 +106,64 @@ class CIRCallOpLowering : public mlir::OpConversionPattern<cir::CallOp> {
}
};

/// Given a type convertor and a data layout, convert the given type to a type
/// that is suitable for memory operations. For example, this can be used to
/// lower cir.bool accesses to i8.
static mlir::Type convertTypeForMemory(const mlir::TypeConverter &converter,
mlir::Type type) {
// TODO(cir): Handle other types similarly to clang's codegen
// convertTypeForMemory
if (isa<cir::BoolType>(type)) {
// TODO: Use datalayout to get the size of bool
return mlir::IntegerType::get(type.getContext(), 8);
}

return converter.convertType(type);
}

/// Emits the value from memory as expected by its users. Should be called when
/// the memory represetnation of a CIR type is not equal to its scalar
/// representation.
static mlir::Value emitFromMemory(mlir::ConversionPatternRewriter &rewriter,
cir::LoadOp op, mlir::Value value) {

// TODO(cir): Handle other types similarly to clang's codegen EmitFromMemory
if (isa<cir::BoolType>(op.getResult().getType())) {
// Create trunc of value from i8 to i1
// TODO: Use datalayout to get the size of bool
assert(value.getType().isInteger(8));
return createIntCast(rewriter, value, rewriter.getI1Type());
}

return value;
}

/// Emits a value to memory with the expected scalar type. Should be called when
/// the memory represetnation of a CIR type is not equal to its scalar
/// representation.
static mlir::Value emitToMemory(mlir::ConversionPatternRewriter &rewriter,
cir::StoreOp op, mlir::Value value) {

// TODO(cir): Handle other types similarly to clang's codegen EmitToMemory
if (isa<cir::BoolType>(op.getValue().getType())) {
// Create zext of value from i1 to i8
// TODO: Use datalayout to get the size of bool
return createIntCast(rewriter, value, rewriter.getI8Type());
}

return value;
}

class CIRAllocaOpLowering : public mlir::OpConversionPattern<cir::AllocaOp> {
public:
using OpConversionPattern<cir::AllocaOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::AllocaOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = adaptor.getAllocaType();
auto mlirType = getTypeConverter()->convertType(type);

mlir::Type mlirType =
convertTypeForMemory(*getTypeConverter(), adaptor.getAllocaType());

// FIXME: Some types can not be converted yet (e.g. struct)
if (!mlirType)
Expand Down Expand Up @@ -174,12 +224,20 @@ class CIRLoadOpLowering : public mlir::OpConversionPattern<cir::LoadOp> {
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;
mlir::memref::LoadOp newLoad;
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), base, indices);
// rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, base, indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::LoadOp>(op, adaptor.getAddr());
newLoad =
rewriter.create<mlir::memref::LoadOp>(op.getLoc(), adaptor.getAddr());

// Convert adapted result to its original type if needed.
mlir::Value result = emitFromMemory(rewriter, op, newLoad.getResult());
rewriter.replaceOp(op, result);
return mlir::LogicalResult::success();
}
};
Expand All @@ -194,13 +252,16 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
mlir::Value base;
SmallVector<mlir::Value> indices;
SmallVector<mlir::Operation *> eraseList;

// Convert adapted value to its memory type if needed.
mlir::Value value = emitToMemory(rewriter, op, adaptor.getValue());
if (findBaseAndIndices(adaptor.getAddr(), base, indices, eraseList,
rewriter)) {
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
base, indices);
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value, base,
indices);
eraseIfSafe(op.getAddr(), base, eraseList, rewriter);
} else
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, adaptor.getValue(),
rewriter.replaceOpWithNewOp<mlir::memref::StoreOp>(op, value,
adaptor.getAddr());
return mlir::LogicalResult::success();
}
Expand Down Expand Up @@ -744,29 +805,20 @@ class CIRCmpOpLowering : public mlir::OpConversionPattern<cir::CmpOp> {
mlir::ConversionPatternRewriter &rewriter) const override {
auto type = op.getLhs().getType();

mlir::Value mlirResult;

if (auto ty = mlir::dyn_cast<cir::IntType>(type)) {
auto kind = convertCmpKindToCmpIPredicate(op.getKind(), ty.isSigned());
mlirResult = rewriter.create<mlir::arith::CmpIOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::arith::CmpIOp>(
op, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = mlir::dyn_cast<cir::CIRFPTypeInterface>(type)) {
auto kind = convertCmpKindToCmpFPredicate(op.getKind());
mlirResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, adaptor.getLhs(), adaptor.getRhs());
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(
op, kind, adaptor.getLhs(), adaptor.getRhs());
} else if (auto ty = mlir::dyn_cast<cir::PointerType>(type)) {
llvm_unreachable("pointer comparison not supported yet");
} else {
return op.emitError() << "unsupported type for CmpOp: " << type;
}

// MLIR comparison ops return i1, but cir::CmpOp returns the same type as
// the LHS value. Since this return value can be used later, we need to
// restore the type with the extension below.
auto mlirResultTy = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, mlirResultTy,
mlirResult);

return mlir::LogicalResult::success();
}
};
Expand Down Expand Up @@ -826,12 +878,8 @@ struct CIRBrCondOpLowering : public mlir::OpConversionPattern<cir::BrCondOp> {
mlir::LogicalResult
matchAndRewrite(cir::BrCondOp brOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {

auto condition = adaptor.getCond();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
brOp.getLoc(), rewriter.getI1Type(), condition);
rewriter.replaceOpWithNewOp<mlir::cf::CondBranchOp>(
brOp, i1Condition.getResult(), brOp.getDestTrue(),
brOp, adaptor.getCond(), brOp.getDestTrue(),
adaptor.getDestOperandsTrue(), brOp.getDestFalse(),
adaptor.getDestOperandsFalse());

Expand All @@ -847,16 +895,13 @@ class CIRTernaryOpLowering : public mlir::OpConversionPattern<cir::TernaryOp> {
matchAndRewrite(cir::TernaryOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
rewriter.setInsertionPoint(op);
auto condition = adaptor.getCond();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
op.getLoc(), rewriter.getI1Type(), condition);
SmallVector<mlir::Type> resultTypes;
if (mlir::failed(getTypeConverter()->convertTypes(op->getResultTypes(),
resultTypes)))
return mlir::failure();

auto ifOp = rewriter.create<mlir::scf::IfOp>(op.getLoc(), resultTypes,
i1Condition.getResult(), true);
adaptor.getCond(), true);
auto *thenBlock = &ifOp.getThenRegion().front();
auto *elseBlock = &ifOp.getElseRegion().front();
rewriter.inlineBlockBefore(&op.getTrueRegion().front(), thenBlock,
Expand Down Expand Up @@ -893,11 +938,8 @@ class CIRIfOpLowering : public mlir::OpConversionPattern<cir::IfOp> {
mlir::LogicalResult
matchAndRewrite(cir::IfOp ifop, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const override {
auto condition = adaptor.getCondition();
auto i1Condition = rewriter.create<mlir::arith::TruncIOp>(
ifop->getLoc(), rewriter.getI1Type(), condition);
auto newIfOp = rewriter.create<mlir::scf::IfOp>(
ifop->getLoc(), ifop->getResultTypes(), i1Condition);
ifop->getLoc(), ifop->getResultTypes(), adaptor.getCondition());
auto *thenBlock = rewriter.createBlock(&newIfOp.getThenRegion());
rewriter.inlineBlockBefore(&ifop.getThenRegion().front(), thenBlock,
thenBlock->end());
Expand All @@ -924,7 +966,7 @@ class CIRGlobalOpLowering : public mlir::OpConversionPattern<cir::GlobalOp> {
mlir::OpBuilder b(moduleOp.getContext());

const auto CIRSymType = op.getSymType();
auto convertedType = getTypeConverter()->convertType(CIRSymType);
auto convertedType = convertTypeForMemory(*getTypeConverter(), CIRSymType);
if (!convertedType)
return mlir::failure();
auto memrefType = dyn_cast<mlir::MemRefType>(convertedType);
Expand Down Expand Up @@ -1170,19 +1212,14 @@ class CIRCastOpLowering : public mlir::OpConversionPattern<cir::CastOp> {
return mlir::success();
}
case CIR::float_to_bool: {
auto dstTy = mlir::cast<cir::BoolType>(op.getType());
auto newDstType = convertTy(dstTy);
auto kind = mlir::arith::CmpFPredicate::UNE;

// Check if float is not equal to zero.
auto zeroFloat = rewriter.create<mlir::arith::ConstantOp>(
op.getLoc(), src.getType(), mlir::FloatAttr::get(src.getType(), 0.0));

// Extend comparison result to either bool (C++) or int (C).
mlir::Value cmpResult = rewriter.create<mlir::arith::CmpFOp>(
op.getLoc(), kind, src, zeroFloat);
rewriter.replaceOpWithNewOp<mlir::arith::ExtUIOp>(op, newDstType,
cmpResult);
rewriter.replaceOpWithNewOp<mlir::arith::CmpFOp>(op, kind, src,
zeroFloat);
return mlir::success();
}
case CIR::bool_to_int: {
Expand Down Expand Up @@ -1330,7 +1367,7 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
static mlir::TypeConverter prepareTypeConverter() {
mlir::TypeConverter converter;
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
auto ty = converter.convertType(type.getPointee());
auto ty = convertTypeForMemory(converter, type.getPointee());
// FIXME: The pointee type might not be converted (e.g. struct)
if (!ty)
return nullptr;
Expand All @@ -1350,7 +1387,7 @@ static mlir::TypeConverter prepareTypeConverter() {
mlir::IntegerType::SignednessSemantics::Signless);
});
converter.addConversion([&](cir::BoolType type) -> mlir::Type {
return mlir::IntegerType::get(type.getContext(), 8);
return mlir::IntegerType::get(type.getContext(), 1);
});
converter.addConversion([&](cir::SingleType type) -> mlir::Type {
return mlir::FloatType::getF32(type.getContext());
Expand Down
8 changes: 3 additions & 5 deletions clang/test/CIR/CodeGen/atomic-xchg-field.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,14 @@ void structAtomicExchange(unsigned referenceCount, wPtr item) {
// LLVM: %[[RES:.*]] = cmpxchg weak ptr %9, i32 %[[EXP]], i32 %[[DES]] seq_cst seq_cst
// LLVM: %[[OLD:.*]] = extractvalue { i32, i1 } %[[RES]], 0
// LLVM: %[[CMP:.*]] = extractvalue { i32, i1 } %[[RES]], 1
// LLVM: %[[Z:.*]] = zext i1 %[[CMP]] to i8
// LLVM: %[[X:.*]] = xor i8 %[[Z]], 1
// LLVM: %[[FAIL:.*]] = trunc i8 %[[X]] to i1

// LLVM: br i1 %[[FAIL:.*]], label %[[STORE_OLD:.*]], label %[[CONTINUE:.*]]
// LLVM: %[[FAIL:.*]] = xor i1 %[[CMP]], true
// LLVM: br i1 %[[FAIL]], label %[[STORE_OLD:.*]], label %[[CONTINUE:.*]]
// LLVM: [[STORE_OLD]]:
// LLVM: store i32 %[[OLD]], ptr
// LLVM: br label %[[CONTINUE]]

// LLVM: [[CONTINUE]]:
// LLVM: %[[Z:.*]] = zext i1 %[[CMP]] to i8
// LLVM: store i8 %[[Z]], ptr {{.*}}, align 1
// LLVM: ret void

Expand Down
10 changes: 4 additions & 6 deletions clang/test/CIR/CodeGen/bf16-ops.c
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,12 @@ void foo(void) {
// NATIVE-NEXT: %{{.+}} = cir.cast(integral, %[[#C]] : !s32i), !u32i

// NONATIVE-LLVM: %[[#A:]] = fcmp une bfloat %{{.+}}, 0xR0000
// NONATIVE-LLVM-NEXT: %[[#B:]] = zext i1 %[[#A]] to i8
// NONATIVE-LLVM-NEXT: %[[#C:]] = xor i8 %[[#B]], 1
// NONATIVE-LLVM-NEXT: %{{.+}} = zext i8 %[[#C]] to i32
// NONATIVE-LLVM-NEXT: %[[#C:]] = xor i1 %[[#A]], true
// NONATIVE-LLVM-NEXT: %{{.+}} = zext i1 %[[#C]] to i32

// NATIVE-LLVM: %[[#A:]] = fcmp une bfloat %{{.+}}, 0xR0000
// NATIVE-LLVM-NEXT: %[[#B:]] = zext i1 %[[#A]] to i8
// NATIVE-LLVM-NEXT: %[[#C:]] = xor i8 %[[#B]], 1
// NATIVE-LLVM-NEXT: %{{.+}} = zext i8 %[[#C]] to i32
// NATIVE-LLVM-NEXT: %[[#C:]] = xor i1 %[[#A]], true
// NATIVE-LLVM-NEXT: %{{.+}} = zext i1 %[[#C]] to i32

h1 = -h1;
// NONATIVE: %[[#A:]] = cir.cast(floating, %{{.+}} : !cir.bf16), !cir.float
Expand Down
Loading
Loading