diff --git a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp index d283a498ec94..bdf7f801fb72 100644 --- a/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/CPU/CPUMaterializeEncodingPass.cpp @@ -57,6 +57,12 @@ chooseMatmulTileParamsAArch64(EncodingUser user, TypeRange elementTypes, Type out = elementTypes[2]; if (out.isF32() || out.isF16() || out.isBF16()) { + if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) { + if (hasFeature(target, "+bf16")) { + // Aim to use BFMMLA. + return MatmulTileParams{8, 4, 8}; + } + } // Note: 16-bit floating point types currently use the same tile size as // f32. This makes sense when either (1) the accumulator is f32, or (2) // the arithmetic will have to expand f16 to f32 in registers. We may @@ -94,6 +100,11 @@ chooseMatmulTileParamsX86_64(EncodingUser user, TypeRange elementTypes, Type out = elementTypes[2]; if (out.isF32() || out.isF16() || out.isBF16()) { + if (lhs.isBF16() && rhs.isBF16() && (out.isBF16() || out.isF32())) { + if (hasFeature(target, "+avx512bf16")) { + return MatmulTileParams{16, 2, 16}; + } + } // Note: 16-bit floating point types currently use the same tile size as // f32. This makes sense when either (1) the accumulator is f32, or (2) // the arithmetic will have to expand f16 to f32 in registers. We may diff --git a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h index 9720e64c374c..5f5c11c8dcf3 100644 --- a/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h +++ b/runtime/src/iree/builtins/ukernel/arch/x86_64/common_x86_64_entry_point.h @@ -15,6 +15,7 @@ #define IREE_UK_BUILD_X86_64_AVX2_FMA #define IREE_UK_BUILD_X86_64_AVX512_BASE #define IREE_UK_BUILD_X86_64_AVX512_VNNI +#define IREE_UK_BUILD_X86_64_AVX512_BF16 #else // IREE_DEVICE_STANDALONE // Compiling with the system toolchain. Include the configured header. #include "iree/builtins/ukernel/arch/x86_64/config_x86_64.h"