diff --git a/bin/triton-tensor-layout.cpp b/bin/triton-tensor-layout.cpp index 49a07681ed..7c635dafaa 100644 --- a/bin/triton-tensor-layout.cpp +++ b/bin/triton-tensor-layout.cpp @@ -80,9 +80,17 @@ static cl::opt TensorStr( //===--------------------------------------------------------------------===// LogicalResult layoutPrint(RankedTensorType tensorType, raw_ostream &os) { + StringRef dialectName = tensorType.getEncoding().getDialect().getNamespace(); + // Dispatch to the corresponding dialect helper function to print the layout. - os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); - return success(); + if (dialectName == "ttg") { + os << triton::gpu::getLayoutStr(tensorType, UseHWPointOfView); + return success(); + } + + llvm::errs() << "Unsupported tensor layout attribute: " + << tensorType.getEncoding() << "\n"; + return failure(); } LogicalResult printLayoutFromFile(MLIRContext *context, StringRef filename, diff --git a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp index 1568341deb..d5afb6e2b1 100644 --- a/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp +++ b/lib/Conversion/TritonGPUToLLVM/DecomposeUnsupportedConversions.cpp @@ -24,8 +24,8 @@ void decomposeSplatOpToSharedLayoutConversion(ModuleOp module) { int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(module); module.walk([&](triton::SplatOp splatOp) -> void { auto dstType = cast(splatOp.getType()); - auto shared = dyn_cast_or_null( - dstType.getEncoding()); + auto shared = + dyn_cast(dstType.getEncoding()); if (shared) { OpBuilder builder(splatOp); SmallVector sizePerThread(dstType.getRank(), 1);