Skip to content

Commit

Permalink
[NFC] Move convertFp32ToFp16 to utility
Browse files Browse the repository at this point in the history
Signed-off-by: Whitney Tsang <whitney.tsang@intel.com>
  • Loading branch information
whitneywhtsang committed Jan 12, 2025
1 parent 16a54d6 commit b109c05
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 30 deletions.
34 changes: 4 additions & 30 deletions third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -881,33 +881,6 @@ struct FpToFpOpConversion
return rewriter.create<LLVM::FPExtOp>(loc, f32_ty, v);
}

static LLVM::RoundingMode
convertTritonRoundingModeToLLVM(const RoundingMode rounding) {
LLVM::RoundingMode roundingMode;
switch (rounding) {
case RoundingMode::RTNE:
return LLVM::RoundingMode::NearestTiesToEven;
case RoundingMode::RTZ:
return LLVM::RoundingMode::TowardZero;
default:
llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 "
"conversion: "
<< stringifyRoundingMode(rounding) << "\n";
llvm_unreachable("");
}
}

static Value convertFp32ToFp16(Location loc,
ConversionPatternRewriter &rewriter,
const Value &v, const RoundingMode rounding) {
MLIRContext *ctx = rewriter.getContext();
return rewriter.create<LLVM::ConstrainedFPTruncIntr>(
loc, f16_ty, v,
LLVM::RoundingModeAttr::get(ctx,
convertTritonRoundingModeToLLVM(rounding)),
arith::getLLVMDefaultFPExceptionBehavior(*ctx));
}

std::pair<ConverterT, size_t>
getConversionFunc(Type srcTy, Type dstTy,
std::optional<RoundingMode> roundingMode) const {
Expand Down Expand Up @@ -1005,8 +978,8 @@ struct FpToFpOpConversion
"rounding mode must be specified for fp32->fp16 conversion");
SmallVector<Value> outVals;
for (Value v : operands[0]) {
outVals.push_back(
convertFp32ToFp16(loc, rewriter, v, roundingMode.value()));
outVals.push_back(LLVM::intel::convertFp32ToFp16(loc, rewriter, v,
roundingMode.value()));
}
return outVals;
}
Expand Down Expand Up @@ -1035,7 +1008,8 @@ struct FpToFpOpConversion
}
if (useFP16IntermediateSrc)
for (Value &v : inVals)
v = convertFp32ToFp16(loc, rewriter, v, roundingMode.value());
v = LLVM::intel::convertFp32ToFp16(loc, rewriter, v,
roundingMode.value());
inVals.resize(numElements, undef(typeConverter->convertType(srcType)));
SmallVector<Value> outVals = cvtFunc(loc, rewriter, inVals);
assert(outVals.size() == inVals.size());
Expand Down
27 changes: 27 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
//===----------------------------------------------------------------------===//

#include "Utility.h"
#include "mlir/Conversion/ArithCommon/AttrToLLVMConverter.h"
#include "mlir/Dialect/GPU/IR/GPUDialect.h"

using namespace mlir;
Expand Down Expand Up @@ -143,4 +144,30 @@ LLVM::LLVMFuncOp getSpirvPrintfDeclaration(RewriterBase &rewriter) {
return printFunc;
}

static LLVM::RoundingMode
convertTritonRoundingModeToLLVM(const triton::RoundingMode rounding) {
LLVM::RoundingMode roundingMode;
switch (rounding) {
case triton::RoundingMode::RTNE:
return LLVM::RoundingMode::NearestTiesToEven;
case triton::RoundingMode::RTZ:
return LLVM::RoundingMode::TowardZero;
default:
llvm::errs() << "WARNING: unsupported rounding mode for f32->f16 "
"conversion: "
<< stringifyRoundingMode(rounding) << "\n";
llvm_unreachable("");
}
}

Value convertFp32ToFp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v, const triton::RoundingMode rounding) {
MLIRContext *ctx = rewriter.getContext();
return rewriter.create<LLVM::ConstrainedFPTruncIntr>(
loc, f16_ty, v,
LLVM::RoundingModeAttr::get(ctx,
convertTritonRoundingModeToLLVM(rounding)),
arith::getLLVMDefaultFPExceptionBehavior(*ctx));
}

} // namespace mlir::LLVM::intel
3 changes: 3 additions & 0 deletions third_party/intel/lib/TritonIntelGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ static Value getModuleWarpSize(RewriterBase &rewriter, Location loc) {
return i32_val(triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod));
}

Value convertFp32ToFp16(Location loc, ConversionPatternRewriter &rewriter,
const Value &v, triton::RoundingMode rounding);

} // namespace mlir::LLVM::intel

using mlir::triton::gpu::intel::DpasEncodingAttr;
Expand Down

0 comments on commit b109c05

Please sign in to comment.