Skip to content

Commit

Permalink
[CIR] Add cir.global_addr attribute
Browse files Browse the repository at this point in the history
This patch adds a new attribute `#cir.global_addr`` to the CIR dialect. This
attribute is quite similar to `#cir.global_view`, except that the new attribute
represents the address of the global variable as an integer instead of a
pointer. And the new attribute does not have the "indecies" stuff.

CIRGen would not generate this attribute. I'm adding this new attribute because
it could be useful during ABI lowering. For example, when doing ABI lowering for
a member function pointer constant, ItaniumABI needs to lower the constant into
a `#cir.const_struct` with two fields, first of which is an integer that
represents the address of a function. This is where we need this attribute.
  • Loading branch information
Lancern committed Dec 19, 2024
1 parent 2ec1a24 commit 61d5469
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 48 deletions.
34 changes: 34 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIRAttrs.td
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,40 @@ def GlobalViewAttr : CIR_Attr<"GlobalView", "global_view", [TypedAttrInterface]>
}];
}

//===----------------------------------------------------------------------===//
// GlobalAddrAttr
//===----------------------------------------------------------------------===//

def GlobalAddrAttr
: CIR_Attr<"GlobalAddr", "global_addr", [TypedAttrInterface]> {
let summary = "Get access to a constant integral address of a global";
let description = [{
Get constant address of a global `symbol` as an integer value. The type of
the `#cir.global_addr` attribute must be an integer type.

Example:

```
cir.global external @str = @"hello": !cir.ptr<i8>
cir.global external @str_addr = #cir.global_addr<@str> : !s64i
```
}];

let parameters = (ins AttributeSelfTypeParameter<"", "cir::IntType">:$type,
"mlir::FlatSymbolRefAttr":$symbol);

let builders = [
AttrBuilderWithInferredContext<(ins "mlir::Type":$type,
"mlir::FlatSymbolRefAttr":$symbol), [{
return $_get(type.getContext(), mlir::cast<cir::IntType>(type), symbol);
}]>
];

let assemblyFormat = [{
`<` $symbol `>`
}];
}

//===----------------------------------------------------------------------===//
// TypeInfoAttr
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -420,6 +420,7 @@ static LogicalResult checkConstantTypes(mlir::Operation *op, mlir::Type opType,
}

if (mlir::isa<cir::GlobalViewAttr>(attrType) ||
mlir::isa<cir::GlobalAddrAttr>(attrType) ||
mlir::isa<cir::TypeInfoAttr>(attrType) ||
mlir::isa<cir::ConstArrayAttr>(attrType) ||
mlir::isa<cir::ConstVectorAttr>(attrType) ||
Expand Down
135 changes: 87 additions & 48 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,36 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::ConstVectorAttr constVec,
mlirValues));
}

static void lookupGlobalSymbolInfo(mlir::ModuleOp module,
mlir::FlatSymbolRefAttr symbolRef,
mlir::Type *sourceType,
unsigned *sourceAddrSpace,
llvm::StringRef *symName,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter &converter) {
auto *sourceSymbol = mlir::SymbolTable::lookupSymbolIn(module, symbolRef);
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
*sourceType = llvmSymbol.getType();
*symName = llvmSymbol.getSymName();
*sourceAddrSpace = llvmSymbol.getAddrSpace();
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
*sourceType = converter.convertType(cirSymbol.getSymType());
*symName = cirSymbol.getSymName();
*sourceAddrSpace =
getGlobalOpTargetAddrSpace(rewriter, &converter, cirSymbol);
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
*sourceType = llvmFun.getFunctionType();
*symName = llvmFun.getSymName();
*sourceAddrSpace = 0;
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
*sourceType = converter.convertType(fun.getFunctionType());
*symName = fun.getSymName();
*sourceAddrSpace = 0;
} else {
llvm_unreachable("Unexpected GlobalOp type");
}
}

// GlobalViewAttr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
Expand All @@ -575,28 +605,8 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
llvm::StringRef symName;
auto *sourceSymbol =
mlir::SymbolTable::lookupSymbolIn(module, globalAttr.getSymbol());
if (auto llvmSymbol = dyn_cast<mlir::LLVM::GlobalOp>(sourceSymbol)) {
sourceType = llvmSymbol.getType();
symName = llvmSymbol.getSymName();
sourceAddrSpace = llvmSymbol.getAddrSpace();
} else if (auto cirSymbol = dyn_cast<cir::GlobalOp>(sourceSymbol)) {
sourceType = converter->convertType(cirSymbol.getSymType());
symName = cirSymbol.getSymName();
sourceAddrSpace =
getGlobalOpTargetAddrSpace(rewriter, converter, cirSymbol);
} else if (auto llvmFun = dyn_cast<mlir::LLVM::LLVMFuncOp>(sourceSymbol)) {
sourceType = llvmFun.getFunctionType();
symName = llvmFun.getSymName();
sourceAddrSpace = 0;
} else if (auto fun = dyn_cast<cir::FuncOp>(sourceSymbol)) {
sourceType = converter->convertType(fun.getFunctionType());
symName = fun.getSymName();
sourceAddrSpace = 0;
} else {
llvm_unreachable("Unexpected GlobalOp type");
}
lookupGlobalSymbolInfo(module, globalAttr.getSymbol(), &sourceType,
&sourceAddrSpace, &symName, rewriter, *converter);

auto loc = parentOp->getLoc();
mlir::Value addrOp = rewriter.create<mlir::LLVM::AddressOfOp>(
Expand Down Expand Up @@ -637,36 +647,53 @@ lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalViewAttr globalAttr,
addrOp);
}

// GlobalViewAddr visitor.
static mlir::Value
lowerCirAttrAsValue(mlir::Operation *parentOp, cir::GlobalAddrAttr globalAttr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
auto module = parentOp->getParentOfType<mlir::ModuleOp>();
mlir::Type sourceType;
unsigned sourceAddrSpace = 0;
llvm::StringRef symName;
lookupGlobalSymbolInfo(module, globalAttr.getSymbol(), &sourceType,
&sourceAddrSpace, &symName, rewriter, *converter);

auto loc = parentOp->getLoc();
auto addrTy =
mlir::LLVM::LLVMPointerType::get(rewriter.getContext(), sourceAddrSpace);
mlir::Value addrOp =
rewriter.create<mlir::LLVM::AddressOfOp>(loc, addrTy, symName);

auto llvmDstTy = converter->convertType(globalAttr.getType());
return rewriter.create<mlir::LLVM::PtrToIntOp>(parentOp->getLoc(), llvmDstTy,
addrOp);
}

/// Switches on the type of attribute and calls the appropriate conversion.
mlir::Value lowerCirAttrAsValue(mlir::Operation *parentOp, mlir::Attribute attr,
mlir::ConversionPatternRewriter &rewriter,
const mlir::TypeConverter *converter) {
if (const auto intAttr = mlir::dyn_cast<cir::IntAttr>(attr))
return lowerCirAttrAsValue(parentOp, intAttr, rewriter, converter);
if (const auto fltAttr = mlir::dyn_cast<cir::FPAttr>(attr))
return lowerCirAttrAsValue(parentOp, fltAttr, rewriter, converter);
if (const auto ptrAttr = mlir::dyn_cast<cir::ConstPtrAttr>(attr))
return lowerCirAttrAsValue(parentOp, ptrAttr, rewriter, converter);
if (const auto constStruct = mlir::dyn_cast<cir::ConstStructAttr>(attr))
return lowerCirAttrAsValue(parentOp, constStruct, rewriter, converter);
if (const auto constArr = mlir::dyn_cast<cir::ConstArrayAttr>(attr))
return lowerCirAttrAsValue(parentOp, constArr, rewriter, converter);
if (const auto constVec = mlir::dyn_cast<cir::ConstVectorAttr>(attr))
return lowerCirAttrAsValue(parentOp, constVec, rewriter, converter);
if (const auto boolAttr = mlir::dyn_cast<cir::BoolAttr>(attr))
return lowerCirAttrAsValue(parentOp, boolAttr, rewriter, converter);
if (const auto zeroAttr = mlir::dyn_cast<cir::ZeroAttr>(attr))
return lowerCirAttrAsValue(parentOp, zeroAttr, rewriter, converter);
if (const auto undefAttr = mlir::dyn_cast<cir::UndefAttr>(attr))
return lowerCirAttrAsValue(parentOp, undefAttr, rewriter, converter);
if (const auto poisonAttr = mlir::dyn_cast<cir::PoisonAttr>(attr))
return lowerCirAttrAsValue(parentOp, poisonAttr, rewriter, converter);
if (const auto globalAttr = mlir::dyn_cast<cir::GlobalViewAttr>(attr))
return lowerCirAttrAsValue(parentOp, globalAttr, rewriter, converter);
if (const auto vtableAttr = mlir::dyn_cast<cir::VTableAttr>(attr))
return lowerCirAttrAsValue(parentOp, vtableAttr, rewriter, converter);
if (const auto typeinfoAttr = mlir::dyn_cast<cir::TypeInfoAttr>(attr))
return lowerCirAttrAsValue(parentOp, typeinfoAttr, rewriter, converter);
#define LOWER_CIR_ATTR(type) \
if (const auto castedAttr = mlir::dyn_cast<type>(attr)) \
return lowerCirAttrAsValue(parentOp, castedAttr, rewriter, converter);

LOWER_CIR_ATTR(cir::BoolAttr)
LOWER_CIR_ATTR(cir::ConstArrayAttr)
LOWER_CIR_ATTR(cir::ConstPtrAttr)
LOWER_CIR_ATTR(cir::ConstStructAttr)
LOWER_CIR_ATTR(cir::ConstVectorAttr)
LOWER_CIR_ATTR(cir::FPAttr)
LOWER_CIR_ATTR(cir::GlobalAddrAttr)
LOWER_CIR_ATTR(cir::GlobalViewAttr)
LOWER_CIR_ATTR(cir::IntAttr)
LOWER_CIR_ATTR(cir::PoisonAttr)
LOWER_CIR_ATTR(cir::TypeInfoAttr)
LOWER_CIR_ATTR(cir::UndefAttr)
LOWER_CIR_ATTR(cir::VTableAttr)
LOWER_CIR_ATTR(cir::ZeroAttr)

#undef LOWER_CIR_ATTR

llvm_unreachable("unhandled attribute type");
}
Expand Down Expand Up @@ -1663,6 +1690,13 @@ mlir::LogicalResult CIRToLLVMConstantOpLowering::matchAndRewrite(
attr = rewriter.getIntegerAttr(typeConverter->convertType(op.getType()),
value);
} else if (mlir::isa<cir::IntType>(op.getType())) {
// Lower GlobalAddrAttr to llvm.mlir.addressof + llvm.mlir.ptrtoint
if (auto ga = mlir::dyn_cast<cir::GlobalAddrAttr>(op.getValue())) {
auto newOp = lowerCirAttrAsValue(op, ga, rewriter, getTypeConverter());
rewriter.replaceOp(op, newOp);
return mlir::success();
}

attr = rewriter.getIntegerAttr(
typeConverter->convertType(op.getType()),
mlir::cast<cir::IntAttr>(op.getValue()).getValue());
Expand Down Expand Up @@ -2348,6 +2382,11 @@ mlir::LogicalResult CIRToLLVMGlobalOpLowering::matchAndRewrite(
rewriter.create<mlir::LLVM::ReturnOp>(
loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter));
return mlir::success();
} else if (auto attr = mlir::dyn_cast<cir::GlobalAddrAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
rewriter.create<mlir::LLVM::ReturnOp>(
loc, lowerCirAttrAsValue(op, attr, rewriter, typeConverter));
return mlir::success();
} else if (const auto vtableAttr =
mlir::dyn_cast<cir::VTableAttr>(init.value())) {
setupRegionInitializedLLVMGlobalOp(op, rewriter);
Expand Down
20 changes: 20 additions & 0 deletions clang/test/CIR/Lowering/globals.cir
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,22 @@ module {
cir.global external @alpha = #cir.const_array<[#cir.int<97> : !s8i, #cir.int<98> : !s8i, #cir.int<99> : !s8i, #cir.int<0> : !s8i]> : !cir.array<!s8i x 4>
cir.global "private" constant internal @".str" = #cir.const_array<"example\00" : !cir.array<!s8i x 8>> : !cir.array<!s8i x 8> {alignment = 1 : i64}
cir.global external @s = #cir.global_view<@".str"> : !cir.ptr<!s8i>
cir.global external @s_addr = #cir.global_addr<@".str"> : !u64i
// MLIR: llvm.mlir.global internal constant @".str"("example\00")
// MLIR-SAME: {addr_space = 0 : i32, alignment = 1 : i64}
// MLIR: llvm.mlir.global external @s() {addr_space = 0 : i32} : !llvm.ptr {
// MLIR: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR: %1 = llvm.bitcast %0 : !llvm.ptr to !llvm.ptr
// MLIR: llvm.return %1 : !llvm.ptr
// MLIR: }
// MLIR: llvm.mlir.global external @s_addr() {addr_space = 0 : i32} : i64 {
// MLIR: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR: %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
// MLIR: llvm.return %1 : i64
// MLIR: }
// LLVM: @.str = internal constant [8 x i8] c"example\00"
// LLVM: @s = global ptr @.str
// LLVM: @s_addr = global i64 ptrtoint (ptr @.str to i64)
cir.global external @aPtr = #cir.global_view<@a> : !cir.ptr<!s32i>
// MLIR: llvm.mlir.global external @aPtr() {addr_space = 0 : i32} : !llvm.ptr {
// MLIR: %0 = llvm.mlir.addressof @a : !llvm.ptr
Expand Down Expand Up @@ -198,4 +205,17 @@ module {
}
// MLIR: %0 = llvm.mlir.addressof @zero_array

cir.func @const_global_addr() -> !u64i {
%0 = cir.const #cir.global_addr<@".str"> : !u64i
cir.return %0 : !u64i
}
// MLIR-LABEL: @const_global_addr
// MLIR-NEXT: %0 = llvm.mlir.addressof @".str" : !llvm.ptr
// MLIR-NEXT: %1 = llvm.ptrtoint %0 : !llvm.ptr to i64
// MLIR-NEXT: llvm.return %1 : i64
// MLIR-NEXT: }
// LLVM-LABEL: @const_global_addr
// LLVM-NEXT: ret i64 ptrtoint (ptr @.str to i64)
// LLVM-NEXT: }

}

0 comments on commit 61d5469

Please sign in to comment.