Skip to content

Commit

Permalink
[tuner]: add property functions to lowering config python binding (ir…
Browse files Browse the repository at this point in the history
…ee-org#19376)

This PR introduces additional property functions to the LoweringConfig
Python binding. These new functions enable direct extraction of the
following attributes: `workgroup`, `reduction`, `subgroup_m_count`,
`subgroup_n_count`, and `mma_kind` directly from the lowering config
python binding.

This PR is relevant to the task in
nod-ai/shark-ai#453: use IREE bindings for
compilation info (incl., lowering_config and translation_info).

---------

Signed-off-by: Bangtian Liu <liubangtian@gmail.com>
  • Loading branch information
bangtianliu authored Dec 7, 2024
1 parent cb59389 commit 62903cc
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 5 deletions.
19 changes: 19 additions & 0 deletions compiler/bindings/c/iree/compiler/dialects/iree_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,25 @@ MLIR_CAPI_EXPORTED MlirAttribute ireeGPULoweringConfigAttrGet(
MLIR_CAPI_EXPORTED MlirAttribute
ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr);

struct ireeGPUTileSizes {
MlirAttribute workgroupAttr;
MlirAttribute reductionAttr;
};

MLIR_CAPI_EXPORTED ireeGPUTileSizes
ireeGPULoweringConfigAttrGetTileSizes(MlirAttribute attr);

struct ireeGPUSubgroupCountInfo {
MlirAttribute subgroupMCountAttr;
MlirAttribute subgroupNCountAttr;
};

MLIR_CAPI_EXPORTED ireeGPUSubgroupCountInfo
ireeGPULoweringConfigAttrGetSubgroupCount(MlirAttribute attr);

MLIR_CAPI_EXPORTED MlirAttribute
ireeGPULoweringConfigAttrGetMmaKind(MlirAttribute attr);

#ifdef __cplusplus
}
#endif
Expand Down
61 changes: 60 additions & 1 deletion compiler/bindings/python/IREECompilerDialectsModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,66 @@ PYBIND11_MODULE(_ireeCompilerDialects, m) {
"cls"_a, "value"_a, "ctx"_a = py::none(),
"Gets an #iree_gpu.lowering_config from parameters.")
.def_property_readonly("attributes",
ireeGPULoweringConfigAttrGetAttributes);
ireeGPULoweringConfigAttrGetAttributes)
.def_property_readonly(
"workgroup_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes = ireeGPULoweringConfigAttrGetTileSizes(self);
MlirAttribute workgroupAttr = tilesizes.workgroupAttr;
if (mlirAttributeIsNull(workgroupAttr)) {
return {};
}

size_t len = mlirArrayAttrGetNumElements(workgroupAttr);
std::vector<int64_t> workgroup(len);
for (size_t i = 0; i < len; ++i) {
MlirAttribute attr = mlirArrayAttrGetElement(workgroupAttr, i);
workgroup[i] = mlirIntegerAttrGetValueInt(attr);
}
return workgroup;
})
.def_property_readonly(
"reduction_tile_sizes",
[](MlirAttribute self) -> std::vector<int64_t> {
auto tilesizes = ireeGPULoweringConfigAttrGetTileSizes(self);
MlirAttribute reductionAttr = tilesizes.reductionAttr;
if (mlirAttributeIsNull(reductionAttr)) {
return {};
}

size_t len = mlirArrayAttrGetNumElements(reductionAttr);
std::vector<int64_t> reduction(len);
for (size_t i = 0; i < len; ++i) {
MlirAttribute attr = mlirArrayAttrGetElement(reductionAttr, i);
reduction[i] = mlirIntegerAttrGetValueInt(attr);
}
return reduction;
})
.def_property_readonly(
"subgroup_count_mn",
[](MlirAttribute self) -> py::tuple {
ireeGPUSubgroupCountInfo info =
ireeGPULoweringConfigAttrGetSubgroupCount(self);
MlirAttribute mCountAttr = info.subgroupMCountAttr;
MlirAttribute nCountAttr = info.subgroupNCountAttr;
std::optional<int64_t> mCount;
if (!mlirAttributeIsNull(mCountAttr)) {
mCount = mlirIntegerAttrGetValueInt(mCountAttr);
}

std::optional<int64_t> nCount;
if (!mlirAttributeIsNull(nCountAttr)) {
nCount = mlirIntegerAttrGetValueInt(nCountAttr);
}
return py::make_tuple(mCount, nCount);
})
.def_property_readonly(
"mma_kind", [](MlirAttribute self) -> std::optional<MlirAttribute> {
auto attr = ireeGPULoweringConfigAttrGetMmaKind(self);
if (!mlirAttributeIsNull(attr))
return attr;
return std::nullopt;
});

//===-------------------------------------------------------------------===//
// Binding to utility function getExecutableVariantOps
Expand Down
40 changes: 38 additions & 2 deletions compiler/bindings/python/test/ir/dialects_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@
from iree.compiler.dialects import flow, hal, stream, vm, util, iree_codegen, iree_gpu


def get_index_attr(val: int) -> ir.IntegerAttr:
return ir.IntegerAttr.get(ir.IndexType.get(), val)


def get_index_array_attr(vals: list[int]) -> ir.ArrayAttr:
return ir.ArrayAttr.get([get_index_attr(val) for val in vals])


def run(fn):
with ir.Context(), ir.Location.unknown():
module = ir.Module.create()
Expand Down Expand Up @@ -210,16 +218,44 @@ def mma_intrinsic_attr():

@run
def lowering_config_attr():
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])})
attributes = ir.DictAttr.get(
{
"reduction": get_index_array_attr([]),
}
)
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
assert lowering_config is not None

assert lowering_config.attributes == attributes
assert lowering_config.workgroup_tile_sizes == []
assert lowering_config.reduction_tile_sizes == []
assert lowering_config.subgroup_count_mn == (None, None)
assert lowering_config.mma_kind == None

mma_intrinsic = iree_gpu.MMAIntrinsic.MFMA_F32_16x16x16_F16
mma_attr = iree_gpu.MMAAttr.get(mma_intrinsic)
attributes = ir.DictAttr.get(
{
"reduction": get_index_array_attr([1]),
"workgroup": get_index_array_attr([2, 3]),
"subgroup_m_count": get_index_attr(1),
"subgroup_n_count": get_index_attr(2),
"mma_kind": mma_attr,
}
)
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
assert lowering_config.workgroup_tile_sizes == [2, 3]
assert lowering_config.reduction_tile_sizes == [1]
assert lowering_config.subgroup_count_mn == (1, 2)
assert lowering_config.mma_kind == mma_attr
assert (
str(lowering_config.mma_kind) == "#iree_gpu.mma_layout<MFMA_F32_16x16x16_F16>"
)


@run
def compilation_info():
attributes = ir.DictAttr.get({"reduction": ir.ArrayAttr.get([])})
attributes = ir.DictAttr.get({"reduction": get_index_array_attr([])})
lowering_config = iree_gpu.LoweringConfigAttr.get(attributes)
pipeline_attr = iree_codegen.DispatchLoweringPassPipelineAttr.get(
iree_codegen.DispatchLoweringPassPipeline.None_
Expand Down
65 changes: 65 additions & 0 deletions compiler/src/iree/compiler/API/Internal/IREEGPUDialectCAPI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include <cassert>
#include <cstdint>
#include <type_traits>
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
#include "iree/compiler/dialects/iree_gpu.h"
Expand Down Expand Up @@ -213,3 +214,67 @@ MlirAttribute ireeGPULoweringConfigAttrGetAttributes(MlirAttribute attr) {
unwrap(attr))
.getAttributes());
}

ireeGPUTileSizes ireeGPULoweringConfigAttrGetTileSizes(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
ireeGPUTileSizes tilesizes = {};
mlir::DictionaryAttr dict =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr))
.getAttributes();

llvm::StringRef workgroupName =
mlir::iree_compiler::IREE::GPU::getTilingLevelName(
mlir::iree_compiler::IREE::GPU::TilingLevel::Workgroup);

if (auto workgroupArray = dict.getAs<mlir::ArrayAttr>(workgroupName)) {
tilesizes.workgroupAttr = wrap(workgroupArray);
}

llvm::StringRef reductionName =
mlir::iree_compiler::IREE::GPU::getTilingLevelName(
mlir::iree_compiler::IREE::GPU::TilingLevel::Reduction);
if (auto reductionArray = dict.getAs<mlir::ArrayAttr>(reductionName)) {
tilesizes.reductionAttr = wrap(reductionArray);
}
return tilesizes;
}

ireeGPUSubgroupCountInfo
ireeGPULoweringConfigAttrGetSubgroupCount(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
auto loweringConfigAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));
std::optional<int64_t> subgroupMCount =
mlir::iree_compiler::IREE::GPU::getSubgroupMCount(loweringConfigAttr);
std::optional<int64_t> subgroupNCount =
mlir::iree_compiler::IREE::GPU::getSubgroupNCount(loweringConfigAttr);

ireeGPUSubgroupCountInfo info = {};

if (subgroupMCount) {
info.subgroupMCountAttr = wrap(mlir::IntegerAttr::get(
mlir::IndexType::get(loweringConfigAttr.getContext()),
*subgroupMCount));
}

if (subgroupNCount) {
info.subgroupNCountAttr = wrap(mlir::IntegerAttr::get(
mlir::IndexType::get(loweringConfigAttr.getContext()),
*subgroupNCount));
}
return info;
}

MlirAttribute ireeGPULoweringConfigAttrGetMmaKind(MlirAttribute attr) {
assert(ireeAttributeIsAGPULoweringConfigAttr(attr));
auto loweringConfigAttr =
llvm::cast<mlir::iree_compiler::IREE::GPU::LoweringConfigAttr>(
unwrap(attr));

mlir::iree_compiler::IREE::GPU::MmaInterfaceAttr mma_attr =
mlir::iree_compiler::IREE::GPU::getMmaKind(loweringConfigAttr);

return wrap(mma_attr);
}
6 changes: 6 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.c
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,9 @@ extern void ireeCompilerSourceWrapBuffer();
extern void ireeGPULoweringConfigAttrGet();
extern void ireeGPULoweringConfigAttrGetAttributes();
extern void ireeGPULoweringConfigAttrGetTypeID();
extern void ireeGPULoweringConfigAttrGetTileSizes();
extern void ireeGPULoweringConfigAttrGetSubgroupCount();
extern void ireeGPULoweringConfigAttrGetMmaKind();
extern void ireeGPUMMAAttrGet();
extern void ireeGPUMMAAttrGetInfo();
extern void ireeGPUMMAAttrGetTypeID();
Expand Down Expand Up @@ -945,6 +948,9 @@ uintptr_t __iree_compiler_hidden_force_extern() {
x += (uintptr_t)&ireeGPULoweringConfigAttrGet;
x += (uintptr_t)&ireeGPULoweringConfigAttrGetAttributes;
x += (uintptr_t)&ireeGPULoweringConfigAttrGetTypeID;
x += (uintptr_t)&ireeGPULoweringConfigAttrGetTileSizes;
x += (uintptr_t)&ireeGPULoweringConfigAttrGetSubgroupCount;
x += (uintptr_t)&ireeGPULoweringConfigAttrGetMmaKind;
x += (uintptr_t)&ireeGPUMMAAttrGet;
x += (uintptr_t)&ireeGPUMMAAttrGetInfo;
x += (uintptr_t)&ireeGPUMMAAttrGetTypeID;
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.def
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ EXPORTS
ireeGPULoweringConfigAttrGet
ireeGPULoweringConfigAttrGetAttributes
ireeGPULoweringConfigAttrGetTypeID
ireeGPULoweringConfigAttrGetTileSizes
ireeGPULoweringConfigAttrGetSubgroupCount
ireeGPULoweringConfigAttrGetMmaKind
ireeGPUMMAAttrGet
ireeGPUMMAAttrGetInfo
ireeGPUMMAAttrGetTypeID
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.ld
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,9 @@ VER_0 {
ireeGPULoweringConfigAttrGet;
ireeGPULoweringConfigAttrGetAttributes;
ireeGPULoweringConfigAttrGetTypeID;
ireeGPULoweringConfigAttrGetTileSizes;
ireeGPULoweringConfigAttrGetSubgroupCount;
ireeGPULoweringConfigAttrGetMmaKind;
ireeGPUMMAAttrGet;
ireeGPUMMAAttrGetInfo;
ireeGPUMMAAttrGetTypeID;
Expand Down
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/API/api_exports.macos.lst
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ _ireeCompilerSourceWrapBuffer
_ireeGPULoweringConfigAttrGet
_ireeGPULoweringConfigAttrGetAttributes
_ireeGPULoweringConfigAttrGetTypeID
_ireeGPULoweringConfigAttrGetTileSizes
_ireeGPULoweringConfigAttrGetSubgroupCount
_ireeGPULoweringConfigAttrGetMmaKind
_ireeGPUMMAAttrGet
_ireeGPUMMAAttrGetInfo
_ireeGPUMMAAttrGetTypeID
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ constexpr StringLiteral kThreadLevelName = "thread";
constexpr StringLiteral kSubgroupLevelName = "subgroup";
constexpr StringLiteral kLaneLevelName = "lane";

static StringRef getTilingLevelName(GPU::TilingLevel level) {
StringRef getTilingLevelName(GPU::TilingLevel level) {
switch (level) {
case GPU::TilingLevel::Workgroup:
return kWorkgroupLevelName;
Expand All @@ -1120,7 +1120,7 @@ static StringRef getTilingLevelName(GPU::TilingLevel level) {
return kLaneLevelName;
}
assert(false && "Unknown tiling level");
return StringAttr();
return StringRef();
}

static SmallVector<int64_t> getIntegerVector(ArrayAttr array) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
MMASingleSubgroupLayout getSingleSubgroupLayout(MmaInterfaceAttr mmaKind,
MMAFragment fragment);

/// Returns the name of the tilling `level`, as used in the `lowering_config`
/// attribute.
StringRef getTilingLevelName(GPU::TilingLevel level);

} // namespace mlir::iree_compiler::IREE::GPU

// clang-format off
Expand Down

0 comments on commit 62903cc

Please sign in to comment.