From 626b0ca6d63bd22b632085f5cecfef29a42eef52 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 18 Jul 2024 07:36:35 -0700 Subject: [PATCH 1/5] Adding `hal.command_buffer.update_buffer`. --- .../HALToVM/ConvertCommandBufferOps.cpp | 40 +++++++++++++ .../HALToVM/test/command_buffer_ops.mlir | 28 +++++++++ .../compiler/Dialect/HAL/IR/HALOpFolders.cpp | 60 +++++++++++++++---- .../iree/compiler/Dialect/HAL/IR/HALOps.cpp | 26 ++++++++ .../iree/compiler/Dialect/HAL/IR/HALOps.td | 44 +++++++++++++- .../HAL/IR/test/command_buffer_folding.mlir | 29 +++++++++ .../HAL/IR/test/command_buffer_ops.mlir | 24 ++++++++ .../compiler/Dialect/HAL/hal.imports.mlir | 10 ++++ runtime/src/iree/modules/hal/exports.inl | 1 + runtime/src/iree/modules/hal/module.c | 24 ++++++++ 10 files changed, 275 insertions(+), 11 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 4de18677c832..607161458069 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -127,6 +127,43 @@ class CommandBufferFillBufferOpConversion mutable IREE::VM::ImportOp importOp; }; +class CommandBufferUpdateBufferOpConversion + : public OpConversionPattern { +public: + CommandBufferUpdateBufferOpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferUpdateBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + SmallVector callOperands = { + adaptor.getCommandBuffer(), + adaptor.getSourceBuffer(), + castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(), + rewriter), + adaptor.getTargetBuffer(), + castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(), + rewriter), + castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter), + }; + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + class CommandBufferCollectiveOpConversion : public OpConversionPattern { public: @@ -329,6 +366,9 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context, "hal.command_buffer.execution_barrier"); patterns.insert( context, importSymbols, typeConverter, "hal.command_buffer.fill_buffer"); + patterns.insert( + context, importSymbols, typeConverter, + "hal.command_buffer.update_buffer"); patterns.insert>( context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer"); patterns.insert( diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 711e86022487..9a3e26d728a3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -98,6 +98,34 @@ util.func public @command_buffer_fill_buffer_i32( // ----- +// CHECK-LABEL: @command_buffer_update_buffer +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !vm.ref, %[[DST_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32) +util.func public @command_buffer_update_buffer( + %cmd: !hal.command_buffer, + %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index, + %device_buffer: !hal.buffer, %dst_offset: index, + %length: index + ) { + // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]] + // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]] + // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]] + // CHECK: vm.call @hal.command_buffer.update_buffer + // CHECK-SAME: (%[[CMD]], + // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]], + // CHECK-SAME: %[[DEVICE_BUFFER]], %[[DST_OFFSET_I64]], + // CHECK-SAME: %[[LENGTH_I64]]) + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] + target(%device_buffer : !hal.buffer)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_copy_buffer util.func public @command_buffer_copy_buffer( %arg0: !hal.command_buffer, diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp index 784ade32807b..d082847334ef 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOpFolders.cpp @@ -132,10 +132,10 @@ struct FoldBufferViewCreateSubspan bool needsUpdate = false; auto newSourceBuffer = op.getSourceBuffer(); auto newSourceOffset = llvm::cast(op.getSourceOffset()); - if (auto subspanOp = dyn_cast_or_null( + if (auto subspanOp = dyn_cast_or_null( op.getSourceBuffer().getDefiningOp())) { newSourceBuffer = subspanOp.getSourceBuffer(); - newSourceOffset = rewriter.createOrFold( + newSourceOffset = rewriter.createOrFold( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getSourceOffset()); needsUpdate = true; @@ -220,10 +220,10 @@ struct FoldCommandBufferFillBufferSubspans bool needsUpdate = false; auto newTargetBuffer = op.getTargetBuffer(); auto newTargetOffset = llvm::cast(op.getTargetOffset()); - if (auto subspanOp = dyn_cast_or_null( + if (auto subspanOp = dyn_cast_or_null( op.getTargetBuffer().getDefiningOp())) { newTargetBuffer = subspanOp.getSourceBuffer(); - newTargetOffset = rewriter.createOrFold( + newTargetOffset = rewriter.createOrFold( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getTargetOffset()); needsUpdate = true; @@ -248,6 +248,46 @@ void CommandBufferFillBufferOp::getCanonicalizationPatterns( namespace { +/// Folds hal.buffer.subspans into buffer update offsets. +struct FoldCommandBufferUpdateBufferSubspans + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CommandBufferUpdateBufferOp op, + PatternRewriter &rewriter) const override { + auto ip = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(op); + bool needsUpdate = false; + auto newTargetBuffer = op.getTargetBuffer(); + auto newTargetOffset = llvm::cast(op.getTargetOffset()); + if (auto subspanOp = dyn_cast_or_null( + op.getTargetBuffer().getDefiningOp())) { + newTargetBuffer = subspanOp.getSourceBuffer(); + newTargetOffset = rewriter.createOrFold( + subspanOp.getLoc(), subspanOp.getSourceOffset(), + op.getTargetOffset()); + needsUpdate = true; + } + rewriter.restoreInsertionPoint(ip); + if (!needsUpdate) + return failure(); + rewriter.modifyOpInPlace(op, [&]() { + op.getTargetBufferMutable().assign(newTargetBuffer); + op.getTargetOffsetMutable().assign(newTargetOffset); + }); + return success(); + } +}; + +} // namespace + +void CommandBufferUpdateBufferOp::getCanonicalizationPatterns( + RewritePatternSet &results, MLIRContext *context) { + results.insert(context); +} + +namespace { + /// Folds hal.buffer.subspans into buffer copy offsets. struct FoldCommandBufferCopyBufferSubspans : public OpRewritePattern { @@ -260,20 +300,20 @@ struct FoldCommandBufferCopyBufferSubspans bool needsUpdate = false; auto newSourceBuffer = op.getSourceBuffer(); auto newSourceOffset = llvm::cast(op.getSourceOffset()); - if (auto subspanOp = dyn_cast_or_null( + if (auto subspanOp = dyn_cast_or_null( op.getSourceBuffer().getDefiningOp())) { newSourceBuffer = subspanOp.getSourceBuffer(); - newSourceOffset = rewriter.createOrFold( + newSourceOffset = rewriter.createOrFold( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getSourceOffset()); needsUpdate = true; } auto newTargetBuffer = op.getTargetBuffer(); auto newTargetOffset = llvm::cast(op.getTargetOffset()); - if (auto subspanOp = dyn_cast_or_null( + if (auto subspanOp = dyn_cast_or_null( op.getTargetBuffer().getDefiningOp())) { newTargetBuffer = subspanOp.getSourceBuffer(); - newTargetOffset = rewriter.createOrFold( + newTargetOffset = rewriter.createOrFold( subspanOp.getLoc(), subspanOp.getSourceOffset(), op.getTargetOffset()); needsUpdate = true; @@ -317,10 +357,10 @@ struct FoldCommandBufferPushDescriptorSetBufferSubspan auto *definingOp = bindingBuffers[i].getDefiningOp(); if (!definingOp) continue; - if (auto subspanOp = dyn_cast(definingOp)) { + if (auto subspanOp = dyn_cast(definingOp)) { needsUpdate = true; bindingBuffers[i] = subspanOp.getSourceBuffer(); - bindingOffsets[i] = rewriter.createOrFold( + bindingOffsets[i] = rewriter.createOrFold( subspanOp.getLoc(), subspanOp.getSourceOffset(), bindingOffsets[i]); } } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp index cc4de30bcf3d..b8c0bab8c8d0 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.cpp @@ -990,6 +990,32 @@ void CommandBufferCreateOp::getAsmResultNames( setNameFn(getResult(), "cmd"); } +//===----------------------------------------------------------------------===// +// hal.command_buffer.update_buffer +//===----------------------------------------------------------------------===// + +IREE::Util::SubrangeOperand +CommandBufferUpdateBufferOp::getSubrangeOperand(unsigned operandIndex) { + if (operandIndex == 1) { + return IREE::Util::SubrangeOperand{getSourceBuffer(), getSourceSize(), + getSourceOffset(), getLength()}; + } else { + assert(false && "only source is a subrange"); + return {}; + } +} + +void CommandBufferUpdateBufferOp::setSubrangeOperand( + unsigned operandIndex, IREE::Util::SubrangeOperand operand) { + if (operandIndex == 1) { + getSourceBufferMutable().assign(operand.resource); + getSourceSizeMutable().assign(operand.resourceSize); + getSourceOffsetMutable().assign(operand.offset); + } else { + assert(false && "only source is a subrange"); + } +} + //===----------------------------------------------------------------------===// // hal.command_buffer.push_descriptor_set //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 1889f597708f..74d0b560fa75 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1306,7 +1306,49 @@ def HAL_CommandBufferFillBufferOp : HAL_Op<"command_buffer.fill_buffer"> { let hasCanonicalizer = 1; } -// TODO(benvanik): update buffer op. +def HAL_CommandBufferUpdateBufferOp : HAL_Op<"command_buffer.update_buffer", [ + // TODO(benvanik): figure out the right way to model host effects - this is + // a host read but a device write; if we make it just MemRead then it gets + // DCEd because it has no result. For now we report both to keep analysis + // appeased even if incorrect. + MemoryEffects<[MemRead, MemWrite]>, + Util_SizeAwareOp, + DeclareOpInterfaceMethods, +]> { + let summary = [{command buffer buffer update recording operation}]; + let description = [{ + Copies a range of a host buffer into a device buffer. The host buffer + contents will be captured at the time of the call and embedded in the + command buffer. + }]; + + let arguments = (ins + HAL_CommandBuffer:$command_buffer, + Util_BufferType:$source_buffer, + Util_Size:$source_size, + Util_Size:$source_offset, + AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer, + HAL_DeviceSize:$target_offset, + HAL_DeviceSize:$length + ); + + let assemblyFormat = [{ + `<` $command_buffer `:` type($command_buffer) `>` + `source` `(` $source_buffer `:` type($source_buffer) `{` $source_size `}` `)` + `` `[` $source_offset `]` + `target` `(` $target_buffer `:` type($target_buffer) `)` + `` `[` $target_offset `]` + `length` `(` $length `)` + attr-dict-with-keyword + }]; + + let extraClassDeclaration = [{ + Value getOperandSize(unsigned idx) { return idx == 1 ? getSourceSize() : Value{}; } + Value getResultSize(unsigned idx) { return {}; } + }]; + + let hasCanonicalizer = 1; +} def HAL_CommandBufferCopyBufferOp : HAL_Op<"command_buffer.copy_buffer"> { let summary = [{command buffer buffer copy recording operation}]; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir index 5ced86abfb4e..a61ea4f3b687 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir @@ -42,6 +42,35 @@ util.func public @fold_buffer_subspan_into_fill_buffer( // ----- +// CHECK-LABEL: @fold_buffer_subspans_into_update_buffer +// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[SOURCE_BUFFER:.+]]: !util.buffer, %[[SOURCE_BUFFER_SIZE:.+]]: index, +// CHECK-SAME: %[[TARGET_BUFFER:.+]]: !hal.buffer +util.func public @fold_buffer_subspans_into_update_buffer( + %cmd: !hal.command_buffer, + %source_buffer: !util.buffer, %source_buffer_size: index, + %target_buffer: !hal.buffer + ) { + %c0 = arith.constant 0 : index + %c4096 = arith.constant 4096 : index + %c8192 = arith.constant 8192 : index + %c100000 = arith.constant 100000 : index + %c262144 = arith.constant 262144 : index + %source_subspan = util.buffer.subspan %source_buffer[%c4096] : !util.buffer{%source_buffer_size} -> !util.buffer{%c262144} + %target_subspan = hal.buffer.subspan<%target_buffer : !hal.buffer>[%c8192, %c262144] : !hal.buffer + // CHECK: hal.command_buffer.update_buffer + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + // CHECK-SAME: source(%[[SOURCE_BUFFER]] : !util.buffer{%[[SOURCE_BUFFER_SIZE]]})[%c4096] + source(%source_subspan : !util.buffer{%c262144})[%c0] + // CHECK-SAME: target(%[[TARGET_BUFFER]] : !hal.buffer)[%c108192] + target(%target_subspan : !hal.buffer)[%c100000] + // CHECK-SAME: length(%c8192) + length(%c8192) + util.return +} + +// ----- + // CHECK-LABEL: @fold_buffer_subspan_into_copy_buffer // CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, // CHECK-SAME: %[[BASE_BUFFER:.+]]: !hal.buffer diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 766b39a3d4ed..dc16d454c859 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir @@ -73,6 +73,30 @@ util.func public @command_buffer_fill_buffer( // ----- +// CHECK-LABEL: @command_buffer_update_buffer +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !util.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: index, %[[SRC_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[DEVICE_BUFFER:[a-z0-9]+]]: !hal.buffer, %[[DST_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index) +util.func public @command_buffer_update_buffer( + %cmd: !hal.command_buffer, + %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index, + %device_buffer: !hal.buffer, %dst_offset: index, + %length: index + ) { + // CHECK: hal.command_buffer.update_buffer<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source(%[[HOST_BUFFER]] : !util.buffer{%[[HOST_BUFFER_SIZE]]})[%[[SRC_OFFSET]]] + // CHECK-SAME: target(%[[DEVICE_BUFFER]] : !hal.buffer)[%[[DST_OFFSET]]] + // CHECK-SAME: length(%[[LENGTH]]) + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] + target(%device_buffer : !hal.buffer)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_copy_buffer // CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, // CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 1d21fb38ce6b..67319429db2f 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -239,6 +239,16 @@ vm.import private @command_buffer.fill_buffer( %pattern_length: i32 ) +// Updates a device buffer with the captured contents of a host buffer. +vm.import private @command_buffer.update_buffer( + %command_buffer : !vm.ref, + %source_buffer : !vm.buffer, + %source_offset : i64, + %target_buffer : !vm.ref, + %target_offset : i64, + %length : i64 +) + // Copies a range of one buffer to another. vm.import private @command_buffer.copy_buffer( %command_buffer : !vm.ref, diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index 13f9d09f6bc7..b808785a0d2b 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -58,6 +58,7 @@ EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buff EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v) EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v) EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiirIID, v) +EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrII, v) EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 1b9f8dfb3616..4b84671205b2 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -774,6 +774,30 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, // &pattern, pattern_length); } +IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, // + iree_hal_module_state_t, // + rrIrII, v) { + iree_hal_command_buffer_t* command_buffer = NULL; + IREE_RETURN_IF_ERROR( + iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); + iree_vm_buffer_t* source_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer)); + iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2); + iree_hal_buffer_t* target_buffer = NULL; + IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer)); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4); + iree_device_size_t length = iree_hal_cast_device_size(args->i5); + + iree_const_byte_span_t source_span = iree_const_byte_span_empty(); + IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro( + source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span)); + + iree_hal_buffer_ref_t target_ref = + iree_hal_make_buffer_ref(target_buffer, target_offset, length); + return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data, + /*source_offset=*/0, target_ref); +} + IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, // iree_hal_module_state_t, // rrIrII, v) { From cbff7317bf3c6034267bbe98d927cb13d725cb8b Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 18 Jul 2024 09:13:01 -0700 Subject: [PATCH 2/5] Adding queue affinity arg to `hal.command_buffer.create`. The C API has taken this for awhile as a hint to implementations as to which queues a command buffer may be executed on. It's legal for this to always be "any" but we may want to scope things more tightly in the future to e.g. replicate command buffers on multiple logical devices represented as queues or manage NUMA hinting. --- .../Codegen/SPIRV/test/link_executables.mlir | 9 +++++--- .../Codegen/VMVX/test/link_executables.mlir | 6 +++-- .../HALToVM/ConvertCommandBufferOps.cpp | 1 + .../HALToVM/test/command_buffer_ops.mlir | 14 +++++++----- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 3 ++- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 2 ++ .../HAL/IR/test/command_buffer_folding.mlir | 7 +++--- .../HAL/IR/test/command_buffer_ops.mlir | 10 +++++---- .../Transforms/DumpExecutableBenchmarks.cpp | 3 ++- .../HAL/Transforms/test/convert_to_hal.mlir | 2 +- .../Transforms/test/fixup_legacy_sync.mlir | 22 +++++++++---------- .../compiler/Dialect/HAL/hal.imports.mlir | 1 + runtime/src/iree/modules/hal/exports.inl | 2 +- runtime/src/iree/modules/hal/module.c | 10 +++++---- runtime/src/iree/vm/shims.c | 1 + runtime/src/iree/vm/shims.h | 9 ++++++++ 16 files changed, 65 insertions(+), 37 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir index cda0bbbc66d3..7d2977e517b6 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir @@ -81,7 +81,8 @@ func.func @basic_linking() -> () attributes { } { %c0 = arith.constant 0 : index %device = hal.devices.get %c0 : !hal.device - %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes { + %affinity = arith.constant -1 : i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer attributes { testing.op.a = @dispatch_0, testing.op.b = @dispatch_0::@spirv, testing.op.c = @dispatch_0::@spirv::@dispatch_0 @@ -101,7 +102,8 @@ func.func @basic_linking() -> () attributes { util.initializer { %c0 = arith.constant 0 : index %device = hal.devices.get %c0 : !hal.device - %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer + %affinity = arith.constant -1 : i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer %c1 = arith.constant 1 : index %dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable %dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable @@ -291,7 +293,8 @@ hal.executable private @dispatch_3 { func.func @two_target_environments() -> () { %c0 = arith.constant 0 : index %device = hal.devices.get %c0 : !hal.device - %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer + %affinity = arith.constant -1 : i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer %c1 = arith.constant 1 : index %dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable %dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir index d6d6d151dc72..af83f1ba0733 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir @@ -74,7 +74,8 @@ func.func @basic_linking() -> () attributes { } { %c0 = arith.constant 0 : index %device = hal.devices.get %c0 : !hal.device - %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer attributes { + %affinity = arith.constant -1 : i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer attributes { testing.op.a = @dispatch_0, testing.op.b = @dispatch_0::@vmvx, testing.op.c = @dispatch_0::@vmvx::@dispatch_0 @@ -94,7 +95,8 @@ func.func @basic_linking() -> () attributes { util.initializer { %c0 = arith.constant 0 : index %device = hal.devices.get %c0 : !hal.device - %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer + %affinity = arith.constant -1 : i64 + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer %c1 = arith.constant 1 : index %dispatch_0_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_0) : !hal.executable %dispatch_1_exe = hal.executable.lookup device(%device : !hal.device) executable(@dispatch_1) : !hal.executable diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 607161458069..8299939e0729 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -51,6 +51,7 @@ class CommandBufferCreateOpConversion if (!categoriesValue.has_value()) return failure(); callOperands.append(categoriesValue.value()); + callOperands.push_back(adaptor.getQueueAffinity()); if (adaptor.getBindingCapacity()) { callOperands.push_back(castToImportType(adaptor.getBindingCapacity(), rewriter.getI32Type(), rewriter)); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 9a3e26d728a3..7402f62324c6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -1,18 +1,20 @@ // RUN: iree-opt --split-input-file --iree-vm-conversion --canonicalize --iree-vm-target-index-bits=32 %s | FileCheck %s // CHECK-LABEL: @command_buffer_create -util.func public @command_buffer_create(%arg0: !hal.device) { - // CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %zero) : (!vm.ref, i32, i32, i32) -> !vm.ref - %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") : !hal.command_buffer +// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64) +util.func public @command_buffer_create(%device: !hal.device, %affinity: i64) { + // CHECK: = vm.call @hal.command_buffer.create(%[[DEVICE]], %c1, %c3, %[[AFFINITY]], %zero) : (!vm.ref, i32, i32, i64, i32) -> !vm.ref + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } // ----- // CHECK-LABEL: @command_buffer_create_bindings -util.func public @command_buffer_create_bindings(%arg0: !hal.device, %arg1: index) { - // CHECK: %ref = vm.call @hal.command_buffer.create(%arg0, %c1, %c3, %arg1) : (!vm.ref, i32, i32, i32) -> !vm.ref - %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot") categories("Transfer|Dispatch") bindings(%arg1) : !hal.command_buffer +// CHECK-SAME: (%[[DEVICE:.+]]: !vm.ref, %[[AFFINITY:.+]]: i64, %[[CAPACITY:.+]]: i32) +util.func public @command_buffer_create_bindings(%device: !hal.device, %affinity: i64, %capacity: index) { + // CHECK: = vm.call @hal.command_buffer.create(%[[DEVICE]], %c1, %c3, %[[AFFINITY]], %[[CAPACITY]]) : (!vm.ref, i32, i32, i64, i32) -> !vm.ref + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("OneShot") categories("Transfer|Dispatch") affinity(%affinity) bindings(%capacity) : !hal.command_buffer util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 771e7d4786a0..274b348f753b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -939,7 +939,8 @@ struct CmdExecuteOpPattern rewriter .create( loc, rewriter.getType(), device, - modes, commandCategories, /*binding_capacity=*/Value{}) + modes, commandCategories, queueAffinity, + /*binding_capacity=*/Value{}) .getResult(); mapping->mapCommandBuffer(executeOp, commandBuffer); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 74d0b560fa75..68a4c5fbdf11 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1163,6 +1163,7 @@ def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [ HAL_Device:$device, HAL_CommandBufferModeBitfieldAttr:$modes, HAL_CommandCategoryBitfieldAttr:$command_categories, + HAL_DeviceQueueAffinity:$queue_affinity, Optional:$binding_capacity ); let results = (outs @@ -1173,6 +1174,7 @@ def HAL_CommandBufferCreateOp : HAL_Op<"command_buffer.create", [ `device` `(` $device `:` type($device) `)` `mode` `(` $modes `)` `categories` `(` $command_categories `)` + `affinity` `(` $queue_affinity `)` (`bindings` `(` $binding_capacity^ `)`)? `:` type($result) attr-dict-with-keyword diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir index a61ea4f3b687..3adbce807fb7 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_folding.mlir @@ -1,11 +1,12 @@ // RUN: iree-opt --split-input-file --canonicalize %s | iree-opt --split-input-file | FileCheck %s // CHECK-LABEL: @skip_command_buffer_device -// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) -util.func public @skip_command_buffer_device(%device: !hal.device) -> !hal.executable { +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64) +util.func public @skip_command_buffer_device(%device: !hal.device, %affinity: i64) -> !hal.executable { %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) - categories("Transfer|Dispatch") : !hal.command_buffer + categories("Transfer|Dispatch") + affinity(%affinity) : !hal.command_buffer // CHECK-NOT: hal.command_buffer.device // CHECK: = hal.executable.lookup device(%[[DEVICE]] : !hal.device) diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index dc16d454c859..c3348d969750 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir @@ -1,15 +1,17 @@ // RUN: iree-opt --split-input-file %s | FileCheck %s // CHECK-LABEL: @command_buffer_create -// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device) -util.func public @command_buffer_create(%device: !hal.device) { +// CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[AFFINITY:.+]]: i64) +util.func public @command_buffer_create(%device: !hal.device, %affinity: i64) { // CHECK: %cmd = hal.command_buffer.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: mode(OneShot) - // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer + // CHECK-SAME: categories("Transfer|Dispatch") + // CHECK-SAME: affinity(%[[AFFINITY]]) : !hal.command_buffer %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) - categories("Transfer|Dispatch") : !hal.command_buffer + categories("Transfer|Dispatch") + affinity(%affinity) : !hal.command_buffer util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 588f560b469a..4ec14bfd86f4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -256,6 +256,7 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, // TODO(multi-device): support multiple devices in benchmark generation. // For now we should just use the affinityAttr to resolve the device. Value device = IREE::HAL::DeviceType::resolveAny(loc, funcBuilder); + Value queueAffinity = funcBuilder.create(loc, -1, 64); // Create and begin command buffer. // TODO(benvanik): reuse the command buffer (initialize once and store). @@ -267,6 +268,7 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, .create( loc, funcBuilder.getType(), device, commandBufferModes, IREE::HAL::CommandCategoryBitfield::Dispatch, + queueAffinity, /*binding_capacity=*/Value{}) .getResult(); @@ -379,7 +381,6 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, IREE::HAL::FenceFlagBitfield::None); // Queue execution. - auto queueAffinity = funcBuilder.create(loc, -1, 64); funcBuilder.create( loc, device, queueAffinity, waitFence, signalFence, ValueRange{commandBuffer}); diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir index 3bf0a7becebd..d504c5e4c87c 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/convert_to_hal.mlir @@ -109,7 +109,7 @@ util.func public @simpleDispatch(%arg0: !hal.buffer_view, %arg1: !hal.buffer_vie // CHECK: %[[CMD:.+]] = hal.command_buffer.create // CHECK-SAME: device(%[[DEVICE]] : !hal.device) // CHECK-SAME: mode("OneShot|AllowInlineExecution") - // CHECK-SAME: categories("Transfer|Dispatch") : !hal.command_buffer + // CHECK-SAME: categories("Transfer|Dispatch") %timepoint = stream.cmd.execute with(%arg0_resource as %arg0_capture: !stream.resource{%c16}, %arg1_resource as %arg1_capture: !stream.resource{%c16}, diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir index f47bcd2c112d..29de091b2df4 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/fixup_legacy_sync.mlir @@ -5,9 +5,9 @@ module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { // CHECK-LABEL: @command_buffer_reusable -util.func public @command_buffer_reusable(%arg0: !hal.device) { - // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("None") - %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode("None") categories("Transfer|Dispatch") : !hal.command_buffer +util.func public @command_buffer_reusable(%device: !hal.device, %affinity: i64) { + // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("None") + %cmd = hal.command_buffer.create device(%device : !hal.device) mode("None") categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } } // module @@ -18,9 +18,9 @@ util.func public @command_buffer_reusable(%arg0: !hal.device) { module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sync}>]} { // CHECK-LABEL: @command_buffer_oneshot -util.func public @command_buffer_oneshot(%arg0: !hal.device) { - // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode("OneShot|AllowInlineExecution") - %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer +util.func public @command_buffer_oneshot(%device: !hal.device, %affinity: i64) { + // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode("OneShot|AllowInlineExecution") + %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } } // module @@ -34,9 +34,9 @@ module attributes {hal.device.targets = [ #hal.device.target<"vulkan", {}> ]} { // CHECK-LABEL: @legacy_mode_not_required -util.func public @legacy_mode_not_required(%arg0: !hal.device) { - // CHECK: hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) - %cmd = hal.command_buffer.create device(%arg0 : !hal.device) mode(OneShot) categories("Transfer|Dispatch") : !hal.command_buffer +util.func public @legacy_mode_not_required(%device: !hal.device, %affinity: i64) { + // CHECK: hal.command_buffer.create device(%{{.+}} : !hal.device) mode(OneShot) + %cmd = hal.command_buffer.create device(%device : !hal.device) mode(OneShot) categories("Transfer|Dispatch") affinity(%affinity) : !hal.command_buffer util.return } } // module @@ -51,7 +51,7 @@ module attributes {hal.device.targets = [ ]} { // CHECK-LABEL: @mixed_legacy_mode_required util.func public @mixed_legacy_mode_required(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { - %affinity = arith.constant 0 : i64 + %affinity = arith.constant 1 : i64 // CHECK: hal.fence.await // CHECK: hal.device.queue.execute // CHECK: hal.fence.await @@ -71,7 +71,7 @@ module attributes {hal.device.targets = [#hal.device.target<"vulkan", {legacy_sy // CHECK-LABEL: @blocking_execute // CHECK-SAME: (%[[DEVICE:.+]]: !hal.device, %[[WAIT:.+]]: !hal.fence, %[[CMD:.+]]: !hal.command_buffer, %[[SIGNAL:.+]]: !hal.fence) util.func public @blocking_execute(%device: !hal.device, %wait: !hal.fence, %cmd: !hal.command_buffer, %signal: !hal.fence) { - %affinity = arith.constant 0 : i64 + %affinity = arith.constant 1 : i64 // CHECK-DAG: %[[NULL:.+]] = util.null : !hal.fence // CHECK-DAG: hal.fence.await until([%[[WAIT]]]) // CHECK-NEXT: hal.device.queue.execute<%[[DEVICE]] : !hal.device> diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 67319429db2f..de1ed185306a 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -199,6 +199,7 @@ vm.import private @command_buffer.create( %device : !vm.ref, %modes : i32, %command_categories : i32, + %queue_affinity : i64, %binding_capacity : i32 ) -> !vm.ref diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index b808785a0d2b..bd5adee0b625 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -49,7 +49,7 @@ EXPORT_FN("channel.split", iree_hal_module_channel_split, riii, r) EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_begin_debug_group, rr, v) EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriirIIrIII, v) EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rrIrII, v) -EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riii, r) +EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r) EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v) EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rrirI, v) EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 4b84671205b2..52b3fe386991 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -670,19 +670,21 @@ IREE_VM_ABI_EXPORT(iree_hal_module_channel_rank_and_count, // IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_create, // iree_hal_module_state_t, // - riii, r) { + riiIi, r) { iree_hal_device_t* device = NULL; IREE_RETURN_IF_ERROR(iree_hal_device_check_deref(args->r0, &device)); iree_hal_command_buffer_mode_t modes = (iree_hal_command_buffer_mode_t)args->i1; iree_hal_command_category_t command_categories = (iree_hal_command_category_t)args->i2; - iree_host_size_t binding_capacity = (iree_host_size_t)args->i3; + iree_hal_queue_affinity_t queue_affinity = + (iree_hal_queue_affinity_t)args->i3; + iree_host_size_t binding_capacity = (iree_host_size_t)args->i4; iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR(iree_hal_command_buffer_create( - device, modes, command_categories, IREE_HAL_QUEUE_AFFINITY_ANY, - binding_capacity, &command_buffer)); + device, modes, command_categories, queue_affinity, binding_capacity, + &command_buffer)); iree_status_t status = iree_hal_command_buffer_begin(command_buffer); if (iree_status_is_ok(status)) { diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 9c8b68eb213d..1d5744da868b 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -42,6 +42,7 @@ IREE_VM_ABI_DEFINE_SHIM(rif, v); IREE_VM_ABI_DEFINE_SHIM(riii, r); IREE_VM_ABI_DEFINE_SHIM(riiI, r); IREE_VM_ABI_DEFINE_SHIM(riii, v); +IREE_VM_ABI_DEFINE_SHIM(riiIi, r); IREE_VM_ABI_DEFINE_SHIM(rIiiI, r); IREE_VM_ABI_DEFINE_SHIM(riIiirII, r); IREE_VM_ABI_DEFINE_SHIM(rriirIIrIII, v); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index ab2f012f461c..195daf8e66a0 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -313,6 +313,14 @@ IREE_VM_ABI_FIXED_STRUCT(riiii, { int32_t i4; }); +IREE_VM_ABI_FIXED_STRUCT(riiIi, { + iree_vm_ref_t r0; + int32_t i1; + int32_t i2; + int64_t i3; + int32_t i4; +}); + IREE_VM_ABI_FIXED_STRUCT(riiI, { iree_vm_ref_t r0; int32_t i1; @@ -659,6 +667,7 @@ IREE_VM_ABI_DECLARE_SHIM(rif, v); IREE_VM_ABI_DECLARE_SHIM(riii, r); IREE_VM_ABI_DECLARE_SHIM(riiI, r); IREE_VM_ABI_DECLARE_SHIM(riii, v); +IREE_VM_ABI_DECLARE_SHIM(riiIi, r); IREE_VM_ABI_DECLARE_SHIM(rIiiI, r); IREE_VM_ABI_DECLARE_SHIM(riIiirII, r); IREE_VM_ABI_DECLARE_SHIM(rriirIIrIII, v); From 77bf3d0d0627b4f6d1043bad06610d19d4b31179 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 18 Jul 2024 12:29:38 -0700 Subject: [PATCH 3/5] Updating HAL VM ABI to pass binding table slots. --- .../HALToVM/ConvertCommandBufferOps.cpp | 211 ++++++++++++------ .../HALToVM/test/command_buffer_ops.mlir | 174 ++++++++++++--- .../iree/compiler/Dialect/HAL/IR/HALOps.td | 12 +- .../HAL/IR/test/command_buffer_ops.mlir | 74 +++++- .../compiler/Dialect/HAL/hal.imports.mlir | 14 +- .../Dialect/VM/Conversion/ImportUtils.cpp | 47 ++-- .../Dialect/VM/Conversion/ImportUtils.h | 30 ++- runtime/src/iree/modules/hal/exports.inl | 10 +- runtime/src/iree/modules/hal/module.c | 84 +++---- runtime/src/iree/vm/shims.c | 8 +- runtime/src/iree/vm/shims.h | 44 +++- 11 files changed, 493 insertions(+), 215 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 8299939e0729..71be493be24e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -16,6 +16,31 @@ namespace mlir::iree_compiler { namespace { +// Returns a slot value and a buffer ref value. +// |bufferOrSlot| is intended to be a `AnyTypeOf<[Index, HAL_BufferType]>` in +// the op definition. +static std::tuple +splitBufferSlot(Location loc, Value bufferOrSlot, OpBuilder &builder) { + if (!bufferOrSlot) { + return std::make_tuple( + builder.create(loc), + builder.create( + loc, + IREE::VM::RefType::get(builder.getType()))); + } else if (isa(bufferOrSlot.getType())) { + // Direct buffer binding; pass 0 for table slot. + return std::make_tuple(builder.create(loc), + bufferOrSlot); + } else { + // Indirect binding table reference; pass null for the buffer. + return std::make_tuple( + castToImportType(bufferOrSlot, builder.getI32Type(), builder), + builder.create( + loc, + IREE::VM::RefType::get(builder.getType()))); + } +} + // TODO(benvanik): import op handling of optional values. // It'd be nice if the std::optional:$binding_capacity could be emitted // as 0 when not present; today it'll be omitted entirely (as it's not in the @@ -89,12 +114,15 @@ class CommandBufferFillBufferOpConversion ConversionPatternRewriter &rewriter) const override { auto importType = importOp.getFunctionType(); + auto [targetBufferSlot, targetBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter); SmallVector callOperands = { adaptor.getCommandBuffer(), - adaptor.getTargetBuffer(), + targetBuffer, castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(), rewriter), castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter), + targetBufferSlot, }; // Record the original pattern length then extend it to a 32 bit integer. @@ -144,12 +172,57 @@ class CommandBufferUpdateBufferOpConversion matchAndRewrite(IREE::HAL::CommandBufferUpdateBufferOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto importType = importOp.getFunctionType(); + auto [targetBufferSlot, targetBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter); SmallVector callOperands = { adaptor.getCommandBuffer(), adaptor.getSourceBuffer(), castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(), rewriter), - adaptor.getTargetBuffer(), + targetBuffer, + castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(), + rewriter), + castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter), + targetBufferSlot}; + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + +class CommandBufferCopyBufferOpConversion + : public OpConversionPattern { +public: + CommandBufferCopyBufferOpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferCopyBufferOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + auto [sourceBufferSlot, sourceBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getSourceBuffer(), rewriter); + auto [targetBufferSlot, targetBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getTargetBuffer(), rewriter); + SmallVector callOperands = { + adaptor.getCommandBuffer(), + sourceBufferSlot, + targetBufferSlot, + sourceBuffer, + castToImportType(adaptor.getSourceOffset(), rewriter.getI64Type(), + rewriter), + targetBuffer, castToImportType(adaptor.getTargetOffset(), rewriter.getI64Type(), rewriter), castToImportType(adaptor.getLength(), rewriter.getI64Type(), rewriter), @@ -182,15 +255,6 @@ class CommandBufferCollectiveOpConversion ConversionPatternRewriter &rewriter) const override { auto importType = importOp.getFunctionType(); - Value nullBuffer; - auto getNullBuffer = [&]() { - if (!nullBuffer) { - nullBuffer = rewriter.create( - op.getLoc(), - IREE::VM::RefType::get(rewriter.getType())); - } - return nullBuffer; - }; Value zeroI64; auto getZeroI64 = [&]() { if (!zeroI64) { @@ -203,10 +267,12 @@ class CommandBufferCollectiveOpConversion // %channel : !vm.ref, // %op : i32, // %param : i32, + // %send_buffer_slot : i32, + // %recv_buffer_slot : i32, // %send_buffer : !vm.ref, + // %recv_buffer : !vm.ref, // %send_offset : i64, // %send_length : i64, - // %recv_buffer : !vm.ref, // %recv_offset : i64, // %recv_length : i64, // %element_count : i64 @@ -222,25 +288,22 @@ class CommandBufferCollectiveOpConversion rewriter.create(op.getLoc())); } - if (adaptor.getSendBuffer()) { - callOperands.push_back(adaptor.getSendBuffer()); - callOperands.push_back(adaptor.getSendOffset()); - callOperands.push_back(adaptor.getSendLength()); - } else { - callOperands.push_back(getNullBuffer()); - callOperands.push_back(getZeroI64()); - callOperands.push_back(getZeroI64()); - } - - if (adaptor.getRecvBuffer()) { - callOperands.push_back(adaptor.getRecvBuffer()); - callOperands.push_back(adaptor.getRecvOffset()); - callOperands.push_back(adaptor.getRecvLength()); - } else { - callOperands.push_back(getNullBuffer()); - callOperands.push_back(getZeroI64()); - callOperands.push_back(getZeroI64()); - } + auto [sendBufferSlot, sendBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getSendBuffer(), rewriter); + auto [recvBufferSlot, recvBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getRecvBuffer(), rewriter); + callOperands.push_back(sendBufferSlot); + callOperands.push_back(recvBufferSlot); + callOperands.push_back(sendBuffer); + callOperands.push_back(recvBuffer); + callOperands.push_back(adaptor.getSendOffset() ? adaptor.getSendOffset() + : getZeroI64()); + callOperands.push_back(adaptor.getSendLength() ? adaptor.getSendLength() + : getZeroI64()); + callOperands.push_back(adaptor.getRecvOffset() ? adaptor.getRecvOffset() + : getZeroI64()); + callOperands.push_back(adaptor.getRecvLength() ? adaptor.getRecvLength() + : getZeroI64()); callOperands.push_back(castToImportType(adaptor.getElementCount(), rewriter.getI64Type(), rewriter)); @@ -275,29 +338,6 @@ class CommandBufferPushDescriptorSetOpConversion ConversionPatternRewriter &rewriter) const override { auto importType = importOp.getFunctionType(); - // Memoize zeros/nulls ala IndexSet. - // Since there are usually hundreds to thousands of these push ops and each - // one can have 5-10 of these this saves us a tremendous amount of time - // creating/verifying/pattern matching/folding/CSE'ing. - // We could extend IndexSet into a ConstantSet that could use these custom - // VM ops instead of just arith.constant in order to make this more - // reusable. - Value zero; - auto getI32Zero = [&]() { - if (!zero) { - zero = rewriter.create(op.getLoc()); - } - return zero; - }; - Value null; - auto getNull = [&]() { - if (!null) { - null = rewriter.create( - op.getLoc(), - IREE::VM::RefType::get(rewriter.getType())); - } - return null; - }; auto i32Type = rewriter.getI32Type(); auto i64Type = rewriter.getI64Type(); @@ -316,16 +356,10 @@ class CommandBufferPushDescriptorSetOpConversion for (size_t i = 0; i < adaptor.getBindingOrdinals().size(); ++i) { callOperands.push_back( castToImportType(adaptor.getBindingOrdinals()[i], i32Type, rewriter)); - auto bindingBuffer = adaptor.getBindingBuffers()[i]; - if (llvm::isa(bindingBuffer.getType())) { - // Buffer binding; pass 0 for table slot. - callOperands.push_back(getI32Zero()); - callOperands.push_back(bindingBuffer); - } else { - // Binding table reference; pass null for the buffer. - callOperands.push_back(bindingBuffer); - callOperands.push_back(getNull()); - } + auto [bindingBufferSlot, bindingBuffer] = splitBufferSlot( + op.getLoc(), adaptor.getBindingBuffers()[i], rewriter); + callOperands.push_back(bindingBufferSlot); + callOperands.push_back(bindingBuffer); callOperands.push_back( castToImportType(adaptor.getBindingOffsets()[i], i64Type, rewriter)); callOperands.push_back( @@ -343,6 +377,46 @@ class CommandBufferPushDescriptorSetOpConversion mutable IREE::VM::ImportOp importOp; }; +class CommandBufferDispatchIndirectOpConversion + : public OpConversionPattern { +public: + CommandBufferDispatchIndirectOpConversion(MLIRContext *context, + SymbolTable &importSymbols, + TypeConverter &typeConverter, + StringRef importName) + : OpConversionPattern(typeConverter, context) { + importOp = importSymbols.lookup(importName); + assert(importOp); + } + + LogicalResult + matchAndRewrite(IREE::HAL::CommandBufferDispatchIndirectOp op, + OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto importType = importOp.getFunctionType(); + auto [workgroupsBufferSlot, workgroupsBuffer] = + splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter); + SmallVector callOperands = { + adaptor.getCommandBuffer(), + adaptor.getExecutable(), + castToImportType(adaptor.getEntryPoint(), rewriter.getI32Type(), + rewriter), + workgroupsBufferSlot, + workgroupsBuffer, + castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(), + rewriter), + }; + auto callOp = rewriter.replaceOpWithNewOp( + op, SymbolRefAttr::get(importOp), importType.getResults(), + callOperands); + copyImportAttrs(importOp, callOp); + return success(); + } + +private: + mutable IREE::VM::ImportOp importOp; +}; + } // namespace void populateHALCommandBufferToVMPatterns(MLIRContext *context, @@ -370,7 +444,7 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context, patterns.insert( context, importSymbols, typeConverter, "hal.command_buffer.update_buffer"); - patterns.insert>( + patterns.insert( context, importSymbols, typeConverter, "hal.command_buffer.copy_buffer"); patterns.insert( context, importSymbols, typeConverter, "hal.command_buffer.collective"); @@ -383,10 +457,9 @@ void populateHALCommandBufferToVMPatterns(MLIRContext *context, "hal.command_buffer.push_descriptor_set"); patterns.insert>( context, importSymbols, typeConverter, "hal.command_buffer.dispatch"); - patterns - .insert>( - context, importSymbols, typeConverter, - "hal.command_buffer.dispatch.indirect"); + patterns.insert( + context, importSymbols, typeConverter, + "hal.command_buffer.dispatch.indirect"); } } // namespace mlir::iree_compiler diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 7402f62324c6..61ca923bb164 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -52,9 +52,10 @@ util.func public @command_buffer_fill_buffer_i8( ) { %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 1 // CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i8.i32.u %arg2 : i32 -> i32 - // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref, !vm.ref, i64, i64, i32, i32) -> () + // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref, !vm.ref, i64, i64, i32, i32, i32) -> () hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer> target(%arg1 : !hal.buffer)[%c100, %c200] pattern(%arg2 : i8) @@ -71,9 +72,10 @@ util.func public @command_buffer_fill_buffer_i16( ) { %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 2 // CHECK-DAG: %[[EXTEND:.+]] = vm.ext.i16.i32.u %arg2 : i32 -> i32 - // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[EXTEND]], %[[PATTERN_LENGTH]]) : (!vm.ref, !vm.ref, i64, i64, i32, i32) -> () + // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %[[EXTEND]], %[[PATTERN_LENGTH]]) hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer> target(%arg1 : !hal.buffer)[%c100, %c200] pattern(%arg2 : i16) @@ -90,8 +92,9 @@ util.func public @command_buffer_fill_buffer_i32( ) { %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4 - // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %arg2, %[[PATTERN_LENGTH]]) : (!vm.ref, !vm.ref, i64, i64, i32, i32) -> () + // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %arg1, %c100, %c200, %[[UNUSED_SLOT]], %arg2, %[[PATTERN_LENGTH]]) hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer> target(%arg1 : !hal.buffer)[%c100, %c200] pattern(%arg2 : i32) @@ -100,6 +103,25 @@ util.func public @command_buffer_fill_buffer_i32( // ----- +// CHECK-LABEL: @command_buffer_fill_buffer_i32_indirect +util.func public @command_buffer_fill_buffer_i32_indirect( + %arg0: !hal.command_buffer, + %arg1: index, + %arg2: i32 +) { + %c100 = arith.constant 100 : index + %c200 = arith.constant 200 : index + // CHECK-DAG: %[[PATTERN_LENGTH:.+]] = vm.const.i32 4 + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK: vm.call @hal.command_buffer.fill_buffer(%arg0, %[[NULL_BUFFER]], %c100, %c200, %arg1, %arg2, %[[PATTERN_LENGTH]]) + hal.command_buffer.fill_buffer<%arg0 : !hal.command_buffer> + target(%arg1 : index)[%c100, %c200] + pattern(%arg2 : i32) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_update_buffer // CHECK-SAME: (%[[CMD:.+]]: !vm.ref, // CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32, @@ -111,6 +133,7 @@ util.func public @command_buffer_update_buffer( %device_buffer: !hal.buffer, %dst_offset: index, %length: index ) { + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]] // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]] // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]] @@ -118,7 +141,7 @@ util.func public @command_buffer_update_buffer( // CHECK-SAME: (%[[CMD]], // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]], // CHECK-SAME: %[[DEVICE_BUFFER]], %[[DST_OFFSET_I64]], - // CHECK-SAME: %[[LENGTH_I64]]) + // CHECK-SAME: %[[LENGTH_I64]], %[[UNUSED_SLOT]]) hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] target(%device_buffer : !hal.buffer)[%dst_offset] @@ -128,18 +151,69 @@ util.func public @command_buffer_update_buffer( // ----- +// CHECK-LABEL: @command_buffer_update_buffer_indirect +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[HOST_BUFFER:[a-z0-9]+]]: !vm.buffer, %[[HOST_BUFFER_SIZE:[a-z0-9]+]]: i32, %[[SRC_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[DEVICE_BUFFER_SLOT:[a-z0-9]+]]: i32, %[[DST_OFFSET:[a-z0-9]+]]: i32, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: i32) +util.func public @command_buffer_update_buffer_indirect( + %cmd: !hal.command_buffer, + %host_buffer: !util.buffer, %host_buffer_size: index, %src_offset: index, + %device_buffer: index, %dst_offset: index, + %length: index + ) { + // CHECK-DAG: %[[SRC_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[SRC_OFFSET]] + // CHECK-DAG: %[[DST_OFFSET_I64:.+]] = vm.ext.i32.i64.s %[[DST_OFFSET]] + // CHECK-DAG: %[[LENGTH_I64:.+]] = vm.ext.i32.i64.s %[[LENGTH]] + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK: vm.call @hal.command_buffer.update_buffer + // CHECK-SAME: (%[[CMD]], + // CHECK-SAME: %[[HOST_BUFFER]], %[[SRC_OFFSET_I64]], + // CHECK-SAME: %[[NULL_BUFFER]], %[[DST_OFFSET_I64]], + // CHECK-SAME: %[[LENGTH_I64]], %[[DEVICE_BUFFER_SLOT]]) + hal.command_buffer.update_buffer<%cmd : !hal.command_buffer> + source(%host_buffer : !util.buffer{%host_buffer_size})[%src_offset] + target(%device_buffer : index)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_copy_buffer +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, %[[BUFFER:.+]]: !vm.ref) util.func public @command_buffer_copy_buffer( - %arg0: !hal.command_buffer, - %arg1: !hal.buffer + %cmd: !hal.command_buffer, + %buffer: !hal.buffer ) { %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK: vm.call @hal.command_buffer.copy_buffer(%arg0, %arg1, %c100, %arg1, %c200, %c300) : (!vm.ref, !vm.ref, i64, !vm.ref, i64, i64) -> () - hal.command_buffer.copy_buffer<%arg0 : !hal.command_buffer> - source(%arg1 : !hal.buffer)[%c100] - target(%arg1 : !hal.buffer)[%c200] + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero + // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[UNUSED_SLOT]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[BUFFER]], %c200, %c300) + hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> + source(%buffer : !hal.buffer)[%c100] + target(%buffer : !hal.buffer)[%c200] + length(%c300) + util.return +} + +// ----- + +// CHECK-LABEL: @command_buffer_copy_buffer_indirect +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, %[[BUFFER_SLOT:.+]]: i32) +util.func public @command_buffer_copy_buffer_indirect( + %cmd: !hal.command_buffer, + %buffer_slot: index +) { + %c100 = arith.constant 100 : index + %c200 = arith.constant 200 : index + %c300 = arith.constant 300 : index + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK: vm.call @hal.command_buffer.copy_buffer(%[[CMD]], %[[BUFFER_SLOT]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[NULL_BUFFER]], %c200, %c300) + hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> + source(%buffer_slot : index)[%c100] + target(%buffer_slot : index)[%c200] length(%c300) util.return } @@ -159,16 +233,17 @@ util.func public @command_buffer_collective_all_reduce_sum( %send_buffer: !hal.buffer, %recv_buffer: !hal.buffer, %count: index) { // CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 590081 - // CHECK-DAG: %[[PARAM:.+]] = vm.const.i32.zero + // CHECK-DAG: %[[ZERO_I32:.+]] = vm.const.i32.zero %c10 = arith.constant 10 : index %c20 = arith.constant 20 : index %c128 = arith.constant 128 : index %c256 = arith.constant 256 : index // CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]] // CHECK: vm.call @hal.command_buffer.collective - // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]] - // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128, - // CHECK-SAME: %[[RECV_BUFFER]], %c20, %c256, + // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[ZERO_I32]] + // CHECK-SAME: %[[ZERO_I32]], %[[ZERO_I32]], + // CHECK-SAME: %[[SEND_BUFFER]], %[[RECV_BUFFER]], + // CHECK-SAME: %c10, %c128, %c20, %c256, // CHECK-SAME: %[[COUNT_I64]]) hal.command_buffer.collective<%cmd : !hal.command_buffer> channel(%channel : !hal.channel) @@ -193,15 +268,18 @@ util.func public @command_buffer_collective_send( %param: i32, %send_buffer: !hal.buffer, %count: index) { - // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref // CHECK-DAG: %[[OP_BITS:.+]] = vm.const.i32 262150 %c10 = arith.constant 10 : index %c128 = arith.constant 128 : index // CHECK-DAG: %[[COUNT_I64:.+]] = vm.ext.i32.i64.s %[[COUNT]] + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero + // CHECK-DAG: %[[ZERO_I64:.+]] = vm.const.i64.zero // CHECK: vm.call @hal.command_buffer.collective // CHECK-SAME: (%[[CMD]], %[[CHANNEL]], %[[OP_BITS]], %[[PARAM]], - // CHECK-SAME: %[[SEND_BUFFER]], %c10, %c128, - // CHECK-SAME: %[[NULL_BUFFER]], %zero, %zero, + // CHECK-SAME: %[[UNUSED_SLOT]], %[[UNUSED_SLOT]], + // CHECK-SAME: %[[SEND_BUFFER]], %[[NULL_BUFFER]], + // CHECK-SAME: %c10, %c128, %[[ZERO_I64]], %[[ZERO_I64]], // CHECK-SAME: %[[COUNT_I64]]) hal.command_buffer.collective<%cmd : !hal.command_buffer> channel(%channel : !hal.channel) @@ -215,10 +293,10 @@ util.func public @command_buffer_collective_send( // ----- // CHECK-LABEL: @command_buffer_push_descriptor_set -// CHECK-SAME: %[[CMD:.+]]: !vm.ref, -// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref, -// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref, -// CHECK-SAME: %[[SLOT:.+]]: i32 +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[LAYOUT:.+]]: !vm.ref, +// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref, +// CHECK-SAME: %[[SLOT:.+]]: i32) util.func public @command_buffer_push_descriptor_set( %cmd: !hal.command_buffer, %layout: !hal.pipeline_layout, @@ -250,18 +328,20 @@ util.func public @command_buffer_push_descriptor_set( // ----- // CHECK-LABEL: @command_buffer_dispatch +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref) util.func public @command_buffer_dispatch( - %arg0: !hal.command_buffer, - %arg1: !hal.executable + %cmd: !hal.command_buffer, + %executable: !hal.executable ) { // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123 %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK: vm.call @hal.command_buffer.dispatch(%arg0, %arg1, %[[ORDINAL]], %c100, %c200, %c300) : (!vm.ref, !vm.ref, i32, i32, i32, i32) -> () - hal.command_buffer.dispatch<%arg0 : !hal.command_buffer> - target(%arg1 : !hal.executable)[%ordinal] + // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> + target(%executable : !hal.executable)[%ordinal] workgroups([%c100, %c200, %c300]) util.return } @@ -269,17 +349,43 @@ util.func public @command_buffer_dispatch( // ----- // CHECK-LABEL: @command_buffer_dispatch_indirect +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref, +// CHECK-SAME: %[[BUFFER:.+]]: !vm.ref) util.func public @command_buffer_dispatch_indirect( - %arg0: !hal.command_buffer, - %arg1: !hal.executable, - %arg2: !hal.buffer + %cmd: !hal.command_buffer, + %executable: !hal.executable, + %buffer: !hal.buffer ) { - // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123 + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 + %ordinal = arith.constant 123 : index + %c100 = arith.constant 100 : index + // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100) + hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> + target(%executable : !hal.executable)[%ordinal] + workgroups(%buffer : !hal.buffer)[%c100] + util.return +} + +// ----- + +// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect +// CHECK-SAME: (%[[CMD:.+]]: !vm.ref, +// CHECK-SAME: %[[EXECUTABLE:.+]]: !vm.ref, +// CHECK-SAME: %[[BUFFER_SLOT:.+]]: i32) +util.func public @command_buffer_dispatch_indirect_indirect( + %cmd: !hal.command_buffer, + %executable: !hal.executable, + %buffer_slot: index +) { + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index - // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%arg0, %arg1, %[[ORDINAL]], %arg2, %c100) : (!vm.ref, !vm.ref, i32, !vm.ref, i64) -> () - hal.command_buffer.dispatch.indirect<%arg0 : !hal.command_buffer> - target(%arg1 : !hal.executable)[%ordinal] - workgroups(%arg2 : !hal.buffer)[%c100] + // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100) + hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> + target(%executable : !hal.executable)[%ordinal] + workgroups(%buffer_slot : index)[%c100] util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index 68a4c5fbdf11..d0abb7104ab3 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1291,7 +1291,7 @@ def HAL_CommandBufferFillBufferOp : HAL_Op<"command_buffer.fill_buffer"> { let arguments = (ins HAL_CommandBuffer:$command_buffer, - HAL_BufferType:$target_buffer, + AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer, HAL_DeviceSize:$target_offset, HAL_DeviceSize:$length, HAL_FillPatternType:$pattern @@ -1360,9 +1360,9 @@ def HAL_CommandBufferCopyBufferOp : HAL_Op<"command_buffer.copy_buffer"> { let arguments = (ins HAL_CommandBuffer:$command_buffer, - HAL_BufferType:$source_buffer, + AnyTypeOf<[Index, HAL_BufferType]>:$source_buffer, HAL_DeviceSize:$source_offset, - HAL_BufferType:$target_buffer, + AnyTypeOf<[Index, HAL_BufferType]>:$target_buffer, HAL_DeviceSize:$target_offset, HAL_DeviceSize:$length ); @@ -1396,10 +1396,10 @@ def HAL_CommandBufferCollectiveOp : HAL_Op<"command_buffer.collective", [ Optional:$param, // TODO(benvanik): change this to take descriptor set + binding instead. // This would let us use indirect bindings. - Optional:$send_buffer, + Optional>:$send_buffer, Optional:$send_offset, Optional:$send_length, - Optional:$recv_buffer, + Optional>:$recv_buffer, Optional:$recv_offset, Optional:$recv_length ); @@ -1529,7 +1529,7 @@ def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indire HAL_CommandBuffer:$command_buffer, HAL_Executable:$executable, HAL_Ordinal:$entry_point, - HAL_BufferType:$workgroups_buffer, + AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer, HAL_DeviceSize:$workgroups_offset ); diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index c3348d969750..5598e39a1a7e 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir @@ -101,9 +101,10 @@ util.func public @command_buffer_update_buffer( // CHECK-LABEL: @command_buffer_copy_buffer // CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, -// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, -// CHECK-SAME: %[[SRC_OFFSET:.+]]: index, %[[DST_OFFSET:.+]]: index, -// CHECK-SAME: %[[LENGTH:.+]]: index) +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index) util.func public @command_buffer_copy_buffer( %cmd: !hal.command_buffer, %buffer: !hal.buffer, @@ -124,11 +125,38 @@ util.func public @command_buffer_copy_buffer( // ----- +// CHECK-LABEL: @command_buffer_copy_buffer_indirect +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index, +// CHECK-SAME: %[[SRC_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[DST_OFFSET:[a-z0-9]+]]: index, +// CHECK-SAME: %[[LENGTH:[a-z0-9]+]]: index) +util.func public @command_buffer_copy_buffer_indirect( + %cmd: !hal.command_buffer, + %buffer_slot: index, + %src_offset: index, + %dst_offset: index, + %length: index + ) { + // CHECK: hal.command_buffer.copy_buffer<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: source(%[[BUFFER_SLOT]] : index)[%[[SRC_OFFSET]]] + // CHECK-SAME: target(%[[BUFFER_SLOT]] : index)[%[[DST_OFFSET]]] + // CHECK-SAME: length(%[[LENGTH]]) + hal.command_buffer.copy_buffer<%cmd : !hal.command_buffer> + source(%buffer_slot : index)[%src_offset] + target(%buffer_slot : index)[%dst_offset] + length(%length) + util.return +} + +// ----- + // CHECK-LABEL: @command_buffer_collective // CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, // CHECK-SAME: %[[CHANNEL:.+]]: !hal.channel, // CHECK-SAME: %[[PARAM:.+]]: i32, -// CHECK-SAME: %[[SEND_BUFFER:.+]]: !hal.buffer, %[[RECV_BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[SEND_BUFFER:[a-z0-9]+]]: !hal.buffer, +// CHECK-SAME: %[[RECV_BUFFER:[a-z0-9]+]]: !hal.buffer, // CHECK-SAME: %[[COUNT:.+]]: index) util.func public @command_buffer_collective( %cmd: !hal.command_buffer, @@ -186,10 +214,10 @@ util.func public @command_buffer_collective( // ----- // CHECK-LABEL: @command_buffer_push_descriptor_set -// CHECK-SAME: %[[CMD:.+]]: !hal.command_buffer, -// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout, -// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, -// CHECK-SAME: %[[SLOT:.+]]: index +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[LAYOUT:.+]]: !hal.pipeline_layout, +// CHECK-SAME: %[[BUFFER:.+]]: !hal.buffer, +// CHECK-SAME: %[[SLOT:.+]]: index) util.func public @command_buffer_push_descriptor_set( %cmd: !hal.command_buffer, %layout: !hal.pipeline_layout, @@ -273,3 +301,33 @@ util.func public @command_buffer_dispatch_indirect( workgroups(%buffer : !hal.buffer)[%offset] util.return } + +// ----- + +hal.executable @ex { + hal.executable.variant @backend target(<"backend", "format">) { + hal.executable.export @entry0 ordinal(0) layout(#hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> + ]>) + } +} + +// CHECK-LABEL: @command_buffer_dispatch_indirect_indirect +// CHECK-SAME: (%[[CMD:.+]]: !hal.command_buffer, +// CHECK-SAME: %[[EXECUTABLE:[a-z0-9]+]]: !hal.executable, %[[ORDINAL:[a-z0-9]+]]: index, +// CHECK-SAME: %[[BUFFER_SLOT:[a-z0-9]+]]: index, %[[OFFSET:[a-z0-9]+]]: index) +util.func public @command_buffer_dispatch_indirect_indirect( + %cmd: !hal.command_buffer, + %executable: !hal.executable, %ordinal: index, + %buffer_slot: index, %offset: index) { + // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> + // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] + // CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]] + hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> + target(%executable: !hal.executable)[%ordinal] + workgroups(%buffer_slot : index)[%offset] + util.return +} diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index de1ed185306a..0a923c93ee0d 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -231,28 +231,35 @@ vm.import private @command_buffer.execution_barrier( ) // Fills the target buffer with the given repeating value. +// NOTE: order slightly differs from op in order to get better arg alignment. vm.import private @command_buffer.fill_buffer( %command_buffer : !vm.ref, %target_buffer : !vm.ref, %target_offset : i64, %length : i64, + %target_buffer_slot : i32, %pattern : i32, %pattern_length: i32 ) // Updates a device buffer with the captured contents of a host buffer. +// NOTE: order slightly differs from op in order to get better arg alignment. vm.import private @command_buffer.update_buffer( %command_buffer : !vm.ref, %source_buffer : !vm.buffer, %source_offset : i64, %target_buffer : !vm.ref, %target_offset : i64, - %length : i64 + %length : i64, + %target_buffer_slot : i32 ) // Copies a range of one buffer to another. +// NOTE: order slightly differs from op in order to get better arg alignment. vm.import private @command_buffer.copy_buffer( %command_buffer : !vm.ref, + %source_buffer_slot : i32, + %target_buffer_slot : i32, %source_buffer : !vm.ref, %source_offset : i64, %target_buffer : !vm.ref, @@ -267,10 +274,12 @@ vm.import private @command_buffer.collective( %channel : !vm.ref, %op : i32, %param : i32, + %send_buffer_slot : i32, + %recv_buffer_slot : i32, %send_buffer : !vm.ref, + %recv_buffer : !vm.ref, %send_offset : i64, %send_length : i64, - %recv_buffer : !vm.ref, %recv_offset : i64, %recv_length : i64, %element_count : i64 @@ -309,6 +318,7 @@ vm.import private @command_buffer.dispatch.indirect( %command_buffer : !vm.ref, %executable : !vm.ref, %entry_point : i32, + %workgroups_buffer_slot : i32, %workgroups_buffer : !vm.ref, %workgroups_offset : i64 ) diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index 3ba86d90e10d..8c3c502531f0 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -58,8 +58,7 @@ LogicalResult appendImportModule(StringRef importModuleSrc, return success(); } -Value castToImportType(Value value, Type targetType, - ConversionPatternRewriter &rewriter) { +Value castToImportType(Value value, Type targetType, OpBuilder &builder) { auto sourceType = value.getType(); if (sourceType == targetType) return value; @@ -70,36 +69,35 @@ Value castToImportType(Value value, Type targetType, if (llvm::isa(sourceType) && llvm::isa(targetType) && sourceType.getIntOrFloatBitWidth() == targetType.getIntOrFloatBitWidth()) { - return rewriter.create(value.getLoc(), targetType, - value); + return builder.create(value.getLoc(), targetType, + value); } else if (sourceIsInteger && (targetType.isSignedInteger() || targetType.isSignlessInteger())) { if (targetType.getIntOrFloatBitWidth() > sourceType.getIntOrFloatBitWidth()) { - return rewriter.create(value.getLoc(), targetType, - value); + return builder.create(value.getLoc(), targetType, + value); } else { - return rewriter.create(value.getLoc(), targetType, - value); + return builder.create(value.getLoc(), targetType, + value); } } else if (sourceIsInteger && targetType.isUnsignedInteger()) { if (targetType.getIntOrFloatBitWidth() > sourceType.getIntOrFloatBitWidth()) { - return rewriter.create(value.getLoc(), targetType, - value); + return builder.create(value.getLoc(), targetType, + value); } else { - return rewriter.create(value.getLoc(), targetType, - value); + return builder.create(value.getLoc(), targetType, + value); } } else { return value; } } -Value castFromImportType(Value value, Type targetType, - ConversionPatternRewriter &rewriter) { +Value castFromImportType(Value value, Type targetType, OpBuilder &builder) { // Right now the to-import and from-import types are the same. - return castToImportType(value, targetType, rewriter); + return castToImportType(value, targetType, builder); } void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp) { @@ -118,15 +116,16 @@ size_t getSegmentSpanSize(Type spanType) { } } -std::optional> -rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, - ConversionPatternRewriter &rewriter) { +std::optional> rewriteAttrToOperands(Location loc, + Attribute attrValue, + Type inputType, + OpBuilder &builder) { if (auto intAttr = llvm::dyn_cast(attrValue)) { // NOTE: we intentionally go to std.constant ops so that the standard // conversions can do their job. If we want to remove the dependency // from standard ops in the future we could instead go directly to // one of the vm constant ops. - auto constValue = rewriter.createOrFold( + auto constValue = builder.create( loc, inputType, IntegerAttr::get(inputType, APInt(32, static_cast(intAttr.getInt())))); @@ -136,7 +135,7 @@ rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, SmallVector elementValues; elementValues.reserve(elementsAttr.getNumElements()); for (auto intAttr : elementsAttr.getValues()) { - elementValues.push_back(rewriter.createOrFold( + elementValues.push_back(builder.create( loc, elementsAttr.getType().getElementType(), cast(intAttr))); } @@ -146,7 +145,7 @@ rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, SmallVector allValues; for (auto elementAttr : arrayAttr) { auto flattenedValues = - rewriteAttrToOperands(loc, elementAttr, inputType, rewriter); + rewriteAttrToOperands(loc, elementAttr, inputType, builder); if (!flattenedValues) return std::nullopt; allValues.append(flattenedValues->begin(), flattenedValues->end()); @@ -154,7 +153,7 @@ rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, return allValues; } if (auto strAttr = llvm::dyn_cast(attrValue)) { - return {{rewriter.create(loc, strAttr)}}; + return {{builder.create(loc, strAttr)}}; } // This may be a custom dialect type. As we can't trivially access the storage @@ -176,7 +175,7 @@ rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, return; auto elementType = tupleTypes[ordinal++]; auto flattenedValues = - rewriteAttrToOperands(loc, elementAttr, elementType, rewriter); + rewriteAttrToOperands(loc, elementAttr, elementType, builder); if (!flattenedValues) { anyFailed = true; return; @@ -192,7 +191,7 @@ rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, if (anyFailed) return; auto flattenedValues = - rewriteAttrToOperands(loc, elementAttr, inputType, rewriter); + rewriteAttrToOperands(loc, elementAttr, inputType, builder); if (!flattenedValues) { anyFailed = true; return; diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h index 3e809062b0d0..b2f0a8f74c60 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.h @@ -30,20 +30,19 @@ LogicalResult appendImportModule(StringRef importModuleSrc, namespace detail { size_t getSegmentSpanSize(Type spanType); -std::optional> -rewriteAttrToOperands(Location loc, Attribute attrValue, Type inputType, - ConversionPatternRewriter &rewriter); +std::optional> rewriteAttrToOperands(Location loc, + Attribute attrValue, + Type inputType, + OpBuilder &builder); } // namespace detail // Casts |value| to |targetType| ala static_cast for when the declared type // differs from the type provided by the input dialect. -Value castToImportType(Value value, Type targetType, - ConversionPatternRewriter &rewriter); +Value castToImportType(Value value, Type targetType, OpBuilder &builder); // Casts |value| to |targetType| ala static_cast for when the declared return // type of an import does not match the required output type. -Value castFromImportType(Value value, Type targetType, - ConversionPatternRewriter &rewriter); +Value castFromImportType(Value value, Type targetType, OpBuilder &builder); // Copies known attributes from the |importOp| to the |callOp|. // This allows for passes to quickly query the properties of the import such as @@ -56,8 +55,7 @@ void copyImportAttrs(IREE::VM::ImportOp importOp, Operation *callOp); template std::optional> rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, - const TypeConverter &typeConverter, - ConversionPatternRewriter &rewriter) { + const TypeConverter &typeConverter, OpBuilder &builder) { auto *operation = op.getOperation(); bool isOpVariadic = importOp.isVariadic(); OperationState state{ @@ -76,7 +74,7 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, auto inputName = importOp.getFuncArgumentName(input.index()); if (auto attrValue = op->getAttr(inputName)) { auto flattenedAttrs = detail::rewriteAttrToOperands( - op.getLoc(), attrValue, inputType, rewriter); + op.getLoc(), attrValue, inputType, builder); if (!flattenedAttrs) return std::nullopt; state.addOperands(*flattenedAttrs); @@ -101,11 +99,11 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, } for (auto [newOperand, inputType] : llvm::zip_equal(newOperands, inputTupleType.getTypes())) { - state.addOperands(castToImportType(newOperand, inputType, rewriter)); + state.addOperands(castToImportType(newOperand, inputType, builder)); } } else { for (auto &operand : newOperands) { - state.addOperands(castToImportType(operand, inputType, rewriter)); + state.addOperands(castToImportType(operand, inputType, builder)); } } @@ -121,16 +119,16 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, "segment_sizes", DenseIntElementsAttr::get( VectorType::get({static_cast(segmentSizes.size())}, - rewriter.getIntegerType(16)), + builder.getIntegerType(16)), segmentSizes)); state.addAttribute("segment_types", - rewriter.getArrayAttr(llvm::map_to_vector( + builder.getArrayAttr(llvm::map_to_vector( importType.getInputs(), [&](Type type) { return cast(TypeAttr::get(type)); }))); } - auto *callOp = rewriter.create(state); + auto *callOp = builder.create(state); copyImportAttrs(importOp, callOp); SmallVector results; @@ -139,7 +137,7 @@ rewriteToCall(T op, Adaptor adaptor, IREE::VM::ImportOp importOp, targetType = typeConverter.convertType(targetType); if (!targetType) return std::nullopt; - results.push_back(castFromImportType(result, targetType, rewriter)); + results.push_back(castFromImportType(result, targetType, builder)); } return results; } diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index bd5adee0b625..b87a8cbd09e0 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -47,18 +47,18 @@ EXPORT_FN("channel.rank_and_count", iree_hal_module_channel_rank_and_count, r, i EXPORT_FN("channel.split", iree_hal_module_channel_split, riii, r) EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_begin_debug_group, rr, v) -EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriirIIrIII, v) -EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, rrIrII, v) +EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v) +EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v) EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r) EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v) -EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rrirI, v) +EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirI, v) EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v) EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v) -EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIii, v) +EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v) EXPORT_FN("command_buffer.finalize", iree_hal_module_command_buffer_finalize, r, v) EXPORT_FN("command_buffer.push_constants", iree_hal_module_command_buffer_push_constants, rriCiD, v) EXPORT_FN("command_buffer.push_descriptor_set", iree_hal_module_command_buffer_push_descriptor_set, rriCiirIID, v) -EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrII, v) +EXPORT_FN("command_buffer.update_buffer", iree_hal_module_command_buffer_update_buffer, rrIrIIi, v) EXPORT_FN("descriptor_set_layout.create", iree_hal_module_descriptor_set_layout_create, riCiiiD, r) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 52b3fe386991..c0db04bde782 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -759,72 +759,76 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_execution_barrier, // IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_fill_buffer, // iree_hal_module_state_t, // - rrIIii, v) { + rrIIiii, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); - iree_hal_buffer_t* target_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &target_buffer)); iree_device_size_t target_offset = iree_hal_cast_device_size(args->i2); iree_device_size_t length = iree_hal_cast_device_size(args->i3); - uint32_t pattern = (uint32_t)args->i4; - uint32_t pattern_length = (uint32_t)args->i5; + uint32_t target_buffer_slot = (uint32_t)args->i4; + iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref( + target_buffer_slot, target_offset, length); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_check_deref_or_null(args->r1, &target_ref.buffer)); + uint32_t pattern = (uint32_t)args->i5; + uint32_t pattern_length = (uint32_t)args->i6; - iree_hal_buffer_ref_t target_ref = - iree_hal_make_buffer_ref(target_buffer, target_offset, length); return iree_hal_command_buffer_fill_buffer(command_buffer, target_ref, &pattern, pattern_length); } IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_update_buffer, // iree_hal_module_state_t, // - rrIrII, v) { + rrIrIIi, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); iree_vm_buffer_t* source_buffer = NULL; IREE_RETURN_IF_ERROR(iree_vm_buffer_check_deref(args->r1, &source_buffer)); iree_host_size_t source_offset = iree_hal_cast_host_size(args->i2); - iree_hal_buffer_t* target_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer)); iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4); iree_device_size_t length = iree_hal_cast_device_size(args->i5); + uint32_t target_buffer_slot = (uint32_t)args->i6; + iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref( + target_buffer_slot, target_offset, length); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_check_deref_or_null(args->r3, &target_ref.buffer)); iree_const_byte_span_t source_span = iree_const_byte_span_empty(); IREE_RETURN_IF_ERROR(iree_vm_buffer_map_ro( source_buffer, source_offset, (iree_host_size_t)length, 1, &source_span)); - iree_hal_buffer_ref_t target_ref = - iree_hal_make_buffer_ref(target_buffer, target_offset, length); return iree_hal_command_buffer_update_buffer(command_buffer, source_span.data, /*source_offset=*/0, target_ref); } IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_copy_buffer, // iree_hal_module_state_t, // - rrIrII, v) { + riirIrII, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); - iree_hal_buffer_t* source_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r1, &source_buffer)); - iree_device_size_t source_offset = iree_hal_cast_device_size(args->i2); - iree_hal_buffer_t* target_buffer = NULL; - IREE_RETURN_IF_ERROR(iree_hal_buffer_check_deref(args->r3, &target_buffer)); - iree_device_size_t target_offset = iree_hal_cast_device_size(args->i4); - iree_device_size_t length = iree_hal_cast_device_size(args->i5); + uint32_t source_buffer_slot = (uint32_t)args->i1; + uint32_t target_buffer_slot = (uint32_t)args->i2; + iree_device_size_t source_offset = iree_hal_cast_device_size(args->i4); + iree_device_size_t target_offset = iree_hal_cast_device_size(args->i6); + iree_device_size_t length = iree_hal_cast_device_size(args->i7); + iree_hal_buffer_ref_t source_ref = iree_hal_make_indirect_buffer_ref( + source_buffer_slot, source_offset, length); + iree_hal_buffer_ref_t target_ref = iree_hal_make_indirect_buffer_ref( + target_buffer_slot, target_offset, length); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_check_deref_or_null(args->r3, &source_ref.buffer)); + IREE_RETURN_IF_ERROR( + iree_hal_buffer_check_deref_or_null(args->r5, &target_ref.buffer)); - iree_hal_buffer_ref_t source_ref = - iree_hal_make_buffer_ref(source_buffer, source_offset, length); - iree_hal_buffer_ref_t target_ref = - iree_hal_make_buffer_ref(target_buffer, target_offset, length); return iree_hal_command_buffer_copy_buffer(command_buffer, source_ref, target_ref); } IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_collective, // iree_hal_module_state_t, // - rriirIIrIII, v) { + rriiiirrIIIII, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); @@ -832,17 +836,19 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_collective, // IREE_RETURN_IF_ERROR(iree_hal_channel_check_deref(args->r1, &channel)); iree_hal_collective_op_t op = {.packed = args->i2}; uint32_t param = args->i3; - iree_hal_buffer_ref_t send_ref = - iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i5), - iree_hal_cast_device_size(args->i6)); + uint32_t send_buffer_slot = (uint32_t)args->i4; + uint32_t recv_buffer_slot = (uint32_t)args->i5; + iree_hal_buffer_ref_t send_ref = iree_hal_make_indirect_buffer_ref( + send_buffer_slot, iree_hal_cast_device_size(args->i8), + iree_hal_cast_device_size(args->i9)); IREE_RETURN_IF_ERROR( - iree_hal_buffer_check_deref_or_null(args->r4, &send_ref.buffer)); - iree_hal_buffer_ref_t recv_ref = - iree_hal_make_buffer_ref(NULL, iree_hal_cast_device_size(args->i8), - iree_hal_cast_device_size(args->i9)); + iree_hal_buffer_check_deref_or_null(args->r6, &send_ref.buffer)); + iree_hal_buffer_ref_t recv_ref = iree_hal_make_indirect_buffer_ref( + recv_buffer_slot, iree_hal_cast_device_size(args->i10), + iree_hal_cast_device_size(args->i11)); IREE_RETURN_IF_ERROR( iree_hal_buffer_check_deref_or_null(args->r7, &recv_ref.buffer)); - iree_device_size_t element_count = iree_hal_cast_device_size(args->i10); + iree_device_size_t element_count = iree_hal_cast_device_size(args->i12); return iree_hal_command_buffer_collective(command_buffer, channel, op, param, send_ref, recv_ref, element_count); @@ -919,20 +925,20 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, // IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, // iree_hal_module_state_t, // - rrirI, v) { + rriirI, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); iree_hal_executable_t* executable = NULL; IREE_RETURN_IF_ERROR(iree_hal_executable_check_deref(args->r1, &executable)); uint32_t entry_point = (uint32_t)args->i2; - iree_hal_buffer_t* workgroups_buffer = NULL; + uint32_t workgroups_buffer_slot = (uint32_t)args->i3; + iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i5); + iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_indirect_buffer_ref( + workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t)); IREE_RETURN_IF_ERROR( - iree_hal_buffer_check_deref(args->r3, &workgroups_buffer)); - iree_device_size_t workgroups_offset = iree_hal_cast_device_size(args->i4); + iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer)); - iree_hal_buffer_ref_t workgroups_ref = iree_hal_make_buffer_ref( - workgroups_buffer, workgroups_offset, 3 * sizeof(uint32_t)); return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable, entry_point, workgroups_ref); } diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index 1d5744da868b..c89a8b74fb8b 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -45,7 +45,7 @@ IREE_VM_ABI_DEFINE_SHIM(riii, v); IREE_VM_ABI_DEFINE_SHIM(riiIi, r); IREE_VM_ABI_DEFINE_SHIM(rIiiI, r); IREE_VM_ABI_DEFINE_SHIM(riIiirII, r); -IREE_VM_ABI_DEFINE_SHIM(rriirIIrIII, v); +IREE_VM_ABI_DEFINE_SHIM(rriiiirrIIIII, v); IREE_VM_ABI_DEFINE_SHIM(rrrrCrD, r); IREE_VM_ABI_DEFINE_SHIM(ririi, v); IREE_VM_ABI_DEFINE_SHIM(rr, i); @@ -60,10 +60,12 @@ IREE_VM_ABI_DEFINE_SHIM(rriCiD, v); IREE_VM_ABI_DEFINE_SHIM(rriiCID, v); IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v); IREE_VM_ABI_DEFINE_SHIM(rriiii, v); -IREE_VM_ABI_DEFINE_SHIM(rrIIii, v); +IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v); IREE_VM_ABI_DEFINE_SHIM(rrirCID, v); IREE_VM_ABI_DEFINE_SHIM(rrirI, v); -IREE_VM_ABI_DEFINE_SHIM(rrIrII, v); +IREE_VM_ABI_DEFINE_SHIM(rriirI, v); +IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v); +IREE_VM_ABI_DEFINE_SHIM(riirIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrIii, v); IREE_VM_ABI_DEFINE_SHIM(rrrIii, v); IREE_VM_ABI_DEFINE_SHIM(rIrriiiI, r); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index 195daf8e66a0..1d5a3aa4b9e2 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -355,18 +355,20 @@ IREE_VM_ABI_FIXED_STRUCT(riIiirII, { int64_t i7; }); -IREE_VM_ABI_FIXED_STRUCT(rriirIIrIII, { +IREE_VM_ABI_FIXED_STRUCT(rriiiirrIIIII, { iree_vm_ref_t r0; iree_vm_ref_t r1; int32_t i2; int32_t i3; - iree_vm_ref_t r4; - int64_t i5; - int64_t i6; + int32_t i4; + int32_t i5; + iree_vm_ref_t r6; iree_vm_ref_t r7; int64_t i8; int64_t i9; int64_t i10; + int64_t i11; + int64_t i12; }); IREE_VM_ABI_FIXED_STRUCT(rriiii, { @@ -378,13 +380,14 @@ IREE_VM_ABI_FIXED_STRUCT(rriiii, { int32_t i5; }); -IREE_VM_ABI_FIXED_STRUCT(rrIIii, { +IREE_VM_ABI_FIXED_STRUCT(rrIIiii, { iree_vm_ref_t r0; iree_vm_ref_t r1; int64_t i2; int64_t i3; int32_t i4; int32_t i5; + int32_t i6; }); IREE_VM_ABI_FIXED_STRUCT(rrirI, { @@ -395,13 +398,34 @@ IREE_VM_ABI_FIXED_STRUCT(rrirI, { int64_t i4; }); -IREE_VM_ABI_FIXED_STRUCT(rrIrII, { +IREE_VM_ABI_FIXED_STRUCT(rriirI, { + iree_vm_ref_t r0; + iree_vm_ref_t r1; + int32_t i2; + int32_t i3; + iree_vm_ref_t r4; + int64_t i5; +}); + +IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, { iree_vm_ref_t r0; iree_vm_ref_t r1; int64_t i2; iree_vm_ref_t r3; int64_t i4; int64_t i5; + int32_t i6; +}); + +IREE_VM_ABI_FIXED_STRUCT(riirIrII, { + iree_vm_ref_t r0; + int32_t i1; + int32_t i2; + iree_vm_ref_t r3; + int64_t i4; + iree_vm_ref_t r5; + int64_t i6; + int64_t i7; }); IREE_VM_ABI_FIXED_STRUCT(rrIii, { @@ -670,7 +694,7 @@ IREE_VM_ABI_DECLARE_SHIM(riii, v); IREE_VM_ABI_DECLARE_SHIM(riiIi, r); IREE_VM_ABI_DECLARE_SHIM(rIiiI, r); IREE_VM_ABI_DECLARE_SHIM(riIiirII, r); -IREE_VM_ABI_DECLARE_SHIM(rriirIIrIII, v); +IREE_VM_ABI_DECLARE_SHIM(rriiiirrIIIII, v); IREE_VM_ABI_DECLARE_SHIM(rrrrCrD, r); IREE_VM_ABI_DECLARE_SHIM(ririi, v); IREE_VM_ABI_DECLARE_SHIM(rr, i); @@ -685,10 +709,12 @@ IREE_VM_ABI_DECLARE_SHIM(rriCiD, v); IREE_VM_ABI_DECLARE_SHIM(rriiCID, v); IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v); IREE_VM_ABI_DECLARE_SHIM(rriiii, v); -IREE_VM_ABI_DECLARE_SHIM(rrIIii, v); +IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v); IREE_VM_ABI_DECLARE_SHIM(rrirCID, v); IREE_VM_ABI_DECLARE_SHIM(rrirI, v); -IREE_VM_ABI_DECLARE_SHIM(rrIrII, v); +IREE_VM_ABI_DECLARE_SHIM(rriirI, v); +IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v); +IREE_VM_ABI_DECLARE_SHIM(riirIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrIii, v); IREE_VM_ABI_DECLARE_SHIM(rrrIii, v); IREE_VM_ABI_DECLARE_SHIM(rIrriiiI, r); From 3ea135701d1c1882b767def6964518ddaf5850da Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Wed, 24 Jul 2024 12:07:45 -0700 Subject: [PATCH 4/5] Adding `iree_hal_dispatch_flags_t` to dispatch operations. This is currently unused but may be useful for specifying scheduling behavior or something else. --- .../Codegen/SPIRV/test/link_executables.mlir | 20 +++++++++---------- .../Codegen/VMVX/test/link_executables.mlir | 12 +++++------ .../HALToVM/ConvertCommandBufferOps.cpp | 8 ++++++++ .../HALToVM/test/command_buffer_ops.mlir | 14 +++++++++---- .../HAL/Conversion/StreamToHAL/Patterns.cpp | 4 +++- .../iree/compiler/Dialect/HAL/IR/HALAttrs.td | 8 ++++++++ .../iree/compiler/Dialect/HAL/IR/HALOps.td | 8 ++++++-- .../HAL/IR/test/command_buffer_ops.mlir | 6 ++++++ .../Transforms/DumpExecutableBenchmarks.cpp | 4 +++- .../Transforms/test/repeat_dispatches.mlir | 10 +++++----- .../compiler/Dialect/HAL/hal.imports.mlir | 6 ++++-- .../Dialect/VM/Conversion/ImportUtils.cpp | 14 ++++++------- experimental/rocm/direct_command_buffer.c | 5 +++-- experimental/webgpu/command_buffer.c | 5 +++-- runtime/src/iree/hal/command_buffer.c | 13 ++++++------ runtime/src/iree/hal/command_buffer.h | 16 +++++++++++---- .../src/iree/hal/command_buffer_validation.c | 5 +++-- .../src/iree/hal/command_buffer_validation.h | 5 +++-- .../hal/cts/command_buffer_dispatch_test.h | 3 ++- .../cts/command_buffer_push_constants_test.h | 3 ++- .../hal/drivers/cuda/graph_command_buffer.c | 5 +++-- .../hal/drivers/cuda/stream_command_buffer.c | 5 +++-- .../hal/drivers/hip/graph_command_buffer.c | 5 +++-- .../hal/drivers/hip/stream_command_buffer.c | 5 +++-- .../drivers/local_task/task_command_buffer.c | 5 +++-- .../hal/drivers/metal/direct_command_buffer.m | 4 ++-- .../drivers/vulkan/direct_command_buffer.cc | 5 +++-- .../iree/hal/local/inline_command_buffer.c | 7 ++++--- .../iree/hal/utils/deferred_command_buffer.c | 14 +++++++++---- runtime/src/iree/modules/hal/exports.inl | 4 ++-- runtime/src/iree/modules/hal/module.c | 12 ++++++----- runtime/src/iree/vm/shims.c | 4 ++-- runtime/src/iree/vm/shims.h | 10 ++++++---- tools/iree-benchmark-executable-main.c | 2 +- 34 files changed, 162 insertions(+), 94 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir index 7d2977e517b6..bee65573c994 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/link_executables.mlir @@ -94,9 +94,9 @@ func.func @basic_linking() -> () attributes { %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return } util.initializer { @@ -111,9 +111,9 @@ util.initializer { %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@spirv::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) util.return } @@ -304,10 +304,10 @@ func.func @two_target_environments() -> () { %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@spirv::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@spirv::@dispatch_2) : index %dispatch_3_ordinal = hal.executable.export.ordinal target(@dispatch_3::@spirv::@dispatch_3) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_3_exe : !hal.executable)[%dispatch_3_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return } diff --git a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir index af83f1ba0733..2baedf6a305b 100644 --- a/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir +++ b/compiler/src/iree/compiler/Codegen/VMVX/test/link_executables.mlir @@ -87,9 +87,9 @@ func.func @basic_linking() -> () attributes { %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) return } util.initializer { @@ -104,9 +104,9 @@ util.initializer { %dispatch_0_ordinal = hal.executable.export.ordinal target(@dispatch_0::@vmvx::@dispatch_0) : index %dispatch_1_ordinal = hal.executable.export.ordinal target(@dispatch_1::@vmvx::@dispatch_1) : index %dispatch_2_ordinal = hal.executable.export.ordinal target(@dispatch_2::@vmvx::@dispatch_2) : index - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_0_exe : !hal.executable)[%dispatch_0_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_1_exe : !hal.executable)[%dispatch_1_ordinal] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%dispatch_2_exe : !hal.executable)[%dispatch_2_ordinal] workgroups([%c1, %c1, %c1]) flags(None) util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp index 71be493be24e..cb4179f60ed2 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/ConvertCommandBufferOps.cpp @@ -396,6 +396,13 @@ class CommandBufferDispatchIndirectOpConversion auto importType = importOp.getFunctionType(); auto [workgroupsBufferSlot, workgroupsBuffer] = splitBufferSlot(op.getLoc(), adaptor.getWorkgroupsBuffer(), rewriter); + auto flags = adaptor.getFlagsAttr() + ? rewriter + .create( + op.getLoc(), adaptor.getFlagsAttr().getInt()) + .getResult() + : rewriter.create(op.getLoc()) + .getResult(); SmallVector callOperands = { adaptor.getCommandBuffer(), adaptor.getExecutable(), @@ -405,6 +412,7 @@ class CommandBufferDispatchIndirectOpConversion workgroupsBuffer, castToImportType(adaptor.getWorkgroupsOffset(), rewriter.getI64Type(), rewriter), + flags, }; auto callOp = rewriter.replaceOpWithNewOp( op, SymbolRefAttr::get(importOp), importType.getResults(), diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir index 61ca923bb164..2df69590c1c8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/HALToVM/test/command_buffer_ops.mlir @@ -334,15 +334,17 @@ util.func public @command_buffer_dispatch( %cmd: !hal.command_buffer, %executable: !hal.executable ) { - // CHECK: %[[ORDINAL:.+]] = vm.const.i32 123 + // CHECK-DAG: %[[ORDINAL:.+]] = vm.const.i32 123 %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index %c200 = arith.constant 200 : index %c300 = arith.constant 300 : index - // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %c100, %c200, %c300, %[[FLAGS]]) hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups([%c100, %c200, %c300]) + flags(None) util.return } @@ -361,10 +363,12 @@ util.func public @command_buffer_dispatch_indirect( %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index // CHECK-DAG: %[[UNUSED_SLOT:.+]] = vm.const.i32.zero - // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[UNUSED_SLOT]], %[[BUFFER]], %c100, %[[FLAGS]]) hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups(%buffer : !hal.buffer)[%c100] + flags(None) util.return } @@ -383,9 +387,11 @@ util.func public @command_buffer_dispatch_indirect_indirect( %ordinal = arith.constant 123 : index %c100 = arith.constant 100 : index // CHECK-DAG: %[[NULL_BUFFER:.+]] = vm.const.ref.zero : !vm.ref - // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100) + // CHECK-DAG: %[[FLAGS:.+]] = vm.const.i64.zero + // CHECK: vm.call @hal.command_buffer.dispatch.indirect(%[[CMD]], %[[EXECUTABLE]], %[[ORDINAL]], %[[BUFFER_SLOT]], %[[NULL_BUFFER]], %c100, %[[FLAGS]]) hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable : !hal.executable)[%ordinal] workgroups(%buffer_slot : index)[%c100] + flags(None) util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp index 274b348f753b..74c0fc99afc8 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Conversion/StreamToHAL/Patterns.cpp @@ -731,9 +731,11 @@ struct CmdDispatchOpPattern entryPointAttr.getRootReference().getValue()); Value ordinal = caseBuilder.create( loc, caseBuilder.getIndexType(), entryPointAttr); + auto flags = caseBuilder.getAttr( + IREE::HAL::DispatchFlags::None); caseBuilder.create( loc, commandBuffer, executable, ordinal, caseWorkgroupCount[0], - caseWorkgroupCount[1], caseWorkgroupCount[2]); + caseWorkgroupCount[1], caseWorkgroupCount[2], flags); caseBuilder.create(loc); } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td index e2d56d213729..9d85020b2302 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALAttrs.td @@ -186,6 +186,14 @@ def HAL_DescriptorSetLayoutFlagsAttr : let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; } +def HAL_DispatchFlags_None : I64BitEnumAttrCase<"None", 0x0000>; +def HAL_DispatchFlagsAttr : + I64BitEnumAttr<"DispatchFlags", "valid dispatch flags", [ + HAL_DispatchFlags_None, + ]> { + let cppNamespace = "::mlir::iree_compiler::IREE::HAL"; +} + def HAL_ExecutionStage_None : I32BitEnumAttrCase<"None", 0x0000>; def HAL_ExecutionStage_CommandIssue : I32BitEnumAttrCase<"CommandIssue", 0x0001>; def HAL_ExecutionStage_CommandProcess : I32BitEnumAttrCase<"CommandProcess", 0x0002>; diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td index d0abb7104ab3..599c1ffecdc5 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/HALOps.td @@ -1502,7 +1502,8 @@ def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { HAL_Ordinal:$entry_point, HAL_Dim:$workgroup_x, HAL_Dim:$workgroup_y, - HAL_Dim:$workgroup_z + HAL_Dim:$workgroup_z, + HAL_DispatchFlagsAttr:$flags ); let assemblyFormat = [{ @@ -1514,6 +1515,7 @@ def HAL_CommandBufferDispatchOp : HAL_Op<"command_buffer.dispatch"> { $workgroup_y `,` $workgroup_z `]` `)` + `flags` `(` $flags `)` attr-dict-with-keyword }]; } @@ -1530,7 +1532,8 @@ def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indire HAL_Executable:$executable, HAL_Ordinal:$entry_point, AnyTypeOf<[Index, HAL_BufferType]>:$workgroups_buffer, - HAL_DeviceSize:$workgroups_offset + HAL_DeviceSize:$workgroups_offset, + HAL_DispatchFlagsAttr:$flags ); let assemblyFormat = [{ @@ -1539,6 +1542,7 @@ def HAL_CommandBufferDispatchIndirectOp : HAL_Op<"command_buffer.dispatch.indire `` `[` $entry_point `]` `workgroups` `(` $workgroups_buffer `:` type($workgroups_buffer) `)` `` `[` $workgroups_offset `]` + `flags` `(` $flags `)` attr-dict-with-keyword }]; } diff --git a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir index 5598e39a1a7e..77d56e362d36 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/IR/test/command_buffer_ops.mlir @@ -266,9 +266,11 @@ util.func public @command_buffer_dispatch( // CHECK: hal.command_buffer.dispatch<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups([%[[X]], %[[Y]], %[[Z]]]) + // CHECK-SAME: flags("None") hal.command_buffer.dispatch<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups([%x, %y, %z]) + flags("None") util.return } @@ -296,9 +298,11 @@ util.func public @command_buffer_dispatch_indirect( // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups(%[[BUFFER]] : !hal.buffer)[%[[OFFSET]]] + // CHECK-SAME: flags("None") hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups(%buffer : !hal.buffer)[%offset] + flags("None") util.return } @@ -326,8 +330,10 @@ util.func public @command_buffer_dispatch_indirect_indirect( // CHECK: hal.command_buffer.dispatch.indirect<%[[CMD]] : !hal.command_buffer> // CHECK-SAME: target(%[[EXECUTABLE]] : !hal.executable)[%[[ORDINAL]] // CHECK-SAME: workgroups(%[[BUFFER_SLOT]] : index)[%[[OFFSET]]] + // CHECK-SAME: flags("None") hal.command_buffer.dispatch.indirect<%cmd : !hal.command_buffer> target(%executable: !hal.executable)[%ordinal] workgroups(%buffer_slot : index)[%offset] + flags("None") util.return } diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp index 4ec14bfd86f4..c2487b84ca2b 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/DumpExecutableBenchmarks.cpp @@ -353,10 +353,12 @@ static void appendDispatchBenchmark(IREE::Stream::AffinityAttr affinityAttr, loc, indexSet.get(0), batchSizeArg, indexSet.get(1), ValueRange{}, [&](OpBuilder &forBuilder, Location loc, Value iv, ValueRange iters) { // Dispatch. + auto flags = forBuilder.getAttr( + IREE::HAL::DispatchFlags::None); forBuilder.create( loc, commandBuffer, executable, ordinal, workgroupCountOp.getWorkgroupX(), workgroupCountOp.getWorkgroupY(), - workgroupCountOp.getWorkgroupZ()); + workgroupCountOp.getWorkgroupZ(), flags); // Barrier following the dispatch to block the next dispatch. auto sourceStage = IREE::HAL::ExecutionStageBitfield::CommandRetire | diff --git a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir index e7b80e11a3f8..a139ecedef03 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/Transforms/test/repeat_dispatches.mlir @@ -13,12 +13,12 @@ util.func public @duplicate_dispatches(%cmd1 : !hal.command_buffer, %cmd2 : !hal %c1 = arith.constant 1 : index %c2 = arith.constant 2 : index %c3 = arith.constant 3 : index - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None) hal.command_buffer.execution_barrier<%cmd1 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c1] workgroups([%c2, %c2, %c2]) flags(None) - hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) - hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) + hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c2] workgroups([%c1, %c1, %c1]) flags(None) + hal.command_buffer.dispatch<%cmd2 : !hal.command_buffer> target(%exe : !hal.executable)[%c3] workgroups([%c2, %c2, %c2]) flags(None) hal.command_buffer.execution_barrier<%cmd2 : !hal.command_buffer> source("Dispatch|CommandRetire") target("CommandIssue|Dispatch") flags("None") util.return @@ -59,7 +59,7 @@ util.func public @nested_dispatch(%cmd1 : !hal.command_buffer, %idx : index) { %c1 = arith.constant 1 : index scf.index_switch %idx case 0 { - hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) + hal.command_buffer.dispatch<%cmd1 : !hal.command_buffer> target(%exe : !hal.executable)[%c0] workgroups([%c1, %c1, %c1]) flags(None) scf.yield } default { diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index 0a923c93ee0d..d68b86523bb6 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -309,7 +309,8 @@ vm.import private @command_buffer.dispatch( %entry_point : i32, %workgroup_x : i32, %workgroup_y : i32, - %workgroup_z : i32 + %workgroup_z : i32, + %flags : i64 ) // Dispatches an execution request with the dispatch parameters loaded from the @@ -320,7 +321,8 @@ vm.import private @command_buffer.dispatch.indirect( %entry_point : i32, %workgroups_buffer_slot : i32, %workgroups_buffer : !vm.ref, - %workgroups_offset : i64 + %workgroups_offset : i64, + %flags : i64 ) //===----------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp index 8c3c502531f0..fa837d41c992 100644 --- a/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp +++ b/compiler/src/iree/compiler/Dialect/VM/Conversion/ImportUtils.cpp @@ -127,11 +127,11 @@ std::optional> rewriteAttrToOperands(Location loc, // one of the vm constant ops. auto constValue = builder.create( loc, inputType, - IntegerAttr::get(inputType, - APInt(32, static_cast(intAttr.getInt())))); + IntegerAttr::get(inputType, APInt(inputType.getIntOrFloatBitWidth(), + intAttr.getValue().getSExtValue()))); return {{constValue}}; - } - if (auto elementsAttr = llvm::dyn_cast(attrValue)) { + } else if (auto elementsAttr = + llvm::dyn_cast(attrValue)) { SmallVector elementValues; elementValues.reserve(elementsAttr.getNumElements()); for (auto intAttr : elementsAttr.getValues()) { @@ -140,8 +140,7 @@ std::optional> rewriteAttrToOperands(Location loc, cast(intAttr))); } return elementValues; - } - if (auto arrayAttr = llvm::dyn_cast(attrValue)) { + } else if (auto arrayAttr = llvm::dyn_cast(attrValue)) { SmallVector allValues; for (auto elementAttr : arrayAttr) { auto flattenedValues = @@ -151,8 +150,7 @@ std::optional> rewriteAttrToOperands(Location loc, allValues.append(flattenedValues->begin(), flattenedValues->end()); } return allValues; - } - if (auto strAttr = llvm::dyn_cast(attrValue)) { + } else if (auto strAttr = llvm::dyn_cast(attrValue)) { return {{builder.create(loc, strAttr)}}; } diff --git a/experimental/rocm/direct_command_buffer.c b/experimental/rocm/direct_command_buffer.c index fe165e5c39de..42b476baa815 100644 --- a/experimental/rocm/direct_command_buffer.c +++ b/experimental/rocm/direct_command_buffer.c @@ -412,7 +412,8 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_push_descriptor_set( static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_rocm_direct_command_buffer_t* command_buffer = iree_hal_rocm_direct_command_buffer_cast(base_command_buffer); // Lookup kernel parameters used for side-channeling additional launch @@ -463,7 +464,7 @@ static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch( static iree_status_t iree_hal_rocm_direct_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need rocm implementation"); } diff --git a/experimental/webgpu/command_buffer.c b/experimental/webgpu/command_buffer.c index d57dee4020b3..de89e4feabbe 100644 --- a/experimental/webgpu/command_buffer.c +++ b/experimental/webgpu/command_buffer.c @@ -884,7 +884,8 @@ static iree_status_t iree_hal_webgpu_command_buffer_prepare_dispatch( static iree_status_t iree_hal_webgpu_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_webgpu_command_buffer_t* command_buffer = iree_hal_webgpu_command_buffer_cast(base_command_buffer); @@ -900,7 +901,7 @@ static iree_status_t iree_hal_webgpu_command_buffer_dispatch( static iree_status_t iree_hal_webgpu_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_webgpu_command_buffer_t* command_buffer = iree_hal_webgpu_command_buffer_cast(base_command_buffer); diff --git a/runtime/src/iree/hal/command_buffer.c b/runtime/src/iree/hal/command_buffer.c index 38619a1d42f8..7f3785ddbd8c 100644 --- a/runtime/src/iree/hal/command_buffer.c +++ b/runtime/src/iree/hal/command_buffer.c @@ -558,7 +558,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_descriptor_set( IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { IREE_ASSERT_ARGUMENT(command_buffer); IREE_ASSERT_ARGUMENT(executable); if ((workgroup_x | workgroup_y | workgroup_z) == 0) { @@ -574,7 +575,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_command_buffer_dispatch_validation( command_buffer, VALIDATION_STATE(command_buffer), executable, - entry_point, workgroup_x, workgroup_y, workgroup_z)); + entry_point, workgroup_x, workgroup_y, workgroup_z, flags)); }); #if IREE_HAL_VERBOSE_TRACING_ENABLE // TODO(benvanik): add a tracing.h helper that does the snprintf directly @@ -594,7 +595,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( #endif // IREE_HAL_VERBOSE_TRACING_ENABLE iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch)( command_buffer, executable, entry_point, workgroup_x, workgroup_y, - workgroup_z); + workgroup_z, flags); IREE_TRACE_ZONE_END(z0); return status; } @@ -602,7 +603,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_ASSERT_ARGUMENT(command_buffer); IREE_ASSERT_ARGUMENT(executable); IREE_TRACE_ZONE_BEGIN(z0); @@ -610,10 +611,10 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( IREE_RETURN_AND_END_ZONE_IF_ERROR( z0, iree_hal_command_buffer_dispatch_indirect_validation( command_buffer, VALIDATION_STATE(command_buffer), executable, - entry_point, workgroups_ref)); + entry_point, workgroups_ref, flags)); }); iree_status_t status = _VTABLE_DISPATCH(command_buffer, dispatch_indirect)( - command_buffer, executable, entry_point, workgroups_ref); + command_buffer, executable, entry_point, workgroups_ref, flags); IREE_TRACE_ZONE_END(z0); return status; } diff --git a/runtime/src/iree/hal/command_buffer.h b/runtime/src/iree/hal/command_buffer.h index 38ac03d05f9b..c9c6037eb746 100644 --- a/runtime/src/iree/hal/command_buffer.h +++ b/runtime/src/iree/hal/command_buffer.h @@ -384,6 +384,12 @@ IREE_API_EXPORT iree_string_view_t iree_hal_collective_op_format( IREE_API_EXPORT iree_device_size_t iree_hal_collective_element_byte_count( iree_hal_collective_element_type_t element_type); +// Bitfield specifying flags controlling a dispatch operation. +enum iree_hal_dispatch_flag_bits_t { + IREE_HAL_DISPATCH_FLAG_NONE = 0, +}; +typedef uint64_t iree_hal_dispatch_flags_t; + // An RGBA color. typedef struct iree_hal_label_color_t { uint8_t r; @@ -751,7 +757,8 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_push_descriptor_set( IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); // Dispatches an execution request with deferred workgroup counts. // This is the same as iree_hal_command_buffer_dispatch but the workgroup counts @@ -765,7 +772,7 @@ IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch( IREE_API_EXPORT iree_status_t iree_hal_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); //===----------------------------------------------------------------------===// // Validation support @@ -922,12 +929,13 @@ typedef struct iree_hal_command_buffer_vtable_t { iree_status_t(IREE_API_PTR* dispatch)( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); iree_status_t(IREE_API_PTR* dispatch_indirect)( iree_hal_command_buffer_t* command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); } iree_hal_command_buffer_vtable_t; IREE_HAL_ASSERT_VTABLE_LAYOUT(iree_hal_command_buffer_vtable_t); diff --git a/runtime/src/iree/hal/command_buffer_validation.c b/runtime/src/iree/hal/command_buffer_validation.c index 87f3a4bb9e67..b27433c3f35d 100644 --- a/runtime/src/iree/hal/command_buffer_validation.c +++ b/runtime/src/iree/hal/command_buffer_validation.c @@ -603,7 +603,8 @@ iree_status_t iree_hal_command_buffer_dispatch_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_dispatch_bindings( @@ -615,7 +616,7 @@ iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_RETURN_IF_ERROR(iree_hal_command_buffer_validate_categories( command_buffer, validation_state, IREE_HAL_COMMAND_CATEGORY_DISPATCH)); diff --git a/runtime/src/iree/hal/command_buffer_validation.h b/runtime/src/iree/hal/command_buffer_validation.h index 036d66658131..82ab1c5c7ad6 100644 --- a/runtime/src/iree/hal/command_buffer_validation.h +++ b/runtime/src/iree/hal/command_buffer_validation.h @@ -142,13 +142,14 @@ iree_status_t iree_hal_command_buffer_dispatch_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z); + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags); iree_status_t iree_hal_command_buffer_dispatch_indirect_validation( iree_hal_command_buffer_t* command_buffer, iree_hal_command_buffer_validation_state_t* validation_state, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref); + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags); iree_status_t iree_hal_command_buffer_binding_table_validation( iree_hal_command_buffer_t* command_buffer, diff --git a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h index 4d712073fe7f..6d19793af8ec 100644 --- a/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h +++ b/runtime/src/iree/hal/cts/command_buffer_dispatch_test.h @@ -154,7 +154,8 @@ TEST_P(CommandBufferDispatchTest, DispatchAbs) { IREE_ASSERT_OK(iree_hal_command_buffer_dispatch( command_buffer, executable_, /*entry_point=*/0, - /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1)); + /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1, + IREE_HAL_DISPATCH_FLAG_NONE)); IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier( command_buffer, /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH | diff --git a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h index af99ee1842db..06fa7470e378 100644 --- a/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h +++ b/runtime/src/iree/hal/cts/command_buffer_push_constants_test.h @@ -120,7 +120,8 @@ TEST_F(CommandBufferPushConstantsTest, DispatchWithPushConstants) { IREE_ASSERT_OK(iree_hal_command_buffer_dispatch( command_buffer, executable_, /*entry_point=*/0, - /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1)); + /*workgroup_x=*/1, /*workgroup_y=*/1, /*workgroup_z=*/1, + IREE_HAL_DISPATCH_FLAG_NONE)); IREE_ASSERT_OK(iree_hal_command_buffer_execution_barrier( command_buffer, /*source_stage_mask=*/IREE_HAL_EXECUTION_STAGE_DISPATCH | diff --git a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c index c747c3c699b9..c53428af1ce6 100644 --- a/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c @@ -747,7 +747,8 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_push_descriptor_set( static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_cuda_graph_command_buffer_t* command_buffer = iree_hal_cuda_graph_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -873,7 +874,7 @@ static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch( static iree_status_t iree_hal_cuda_graph_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect dispatch not yet implemented"); } diff --git a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c index 8f9cd2d074e4..3369f3b405cb 100644 --- a/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c @@ -533,7 +533,8 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_push_descriptor_set( static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_cuda_stream_command_buffer_t* command_buffer = iree_hal_cuda_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -646,7 +647,7 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch( static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need cuda implementation of dispatch indirect"); } diff --git a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c index 65b2c4c9be28..ae66cfd2110a 100644 --- a/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/graph_command_buffer.c @@ -771,7 +771,8 @@ static iree_status_t iree_hal_hip_graph_command_buffer_push_descriptor_set( static iree_status_t iree_hal_hip_graph_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_hip_graph_command_buffer_t* command_buffer = iree_hal_hip_graph_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -882,7 +883,7 @@ static iree_status_t iree_hal_hip_graph_command_buffer_dispatch( static iree_status_t iree_hal_hip_graph_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "indirect dispatch not yet implemented"); } diff --git a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c index a250299ca62b..0f087275524d 100644 --- a/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c +++ b/runtime/src/iree/hal/drivers/hip/stream_command_buffer.c @@ -525,7 +525,8 @@ static iree_status_t iree_hal_hip_stream_command_buffer_push_descriptor_set( static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_hip_stream_command_buffer_t* command_buffer = iree_hal_hip_stream_command_buffer_cast(base_command_buffer); IREE_TRACE_ZONE_BEGIN(z0); @@ -626,7 +627,7 @@ static iree_status_t iree_hal_hip_stream_command_buffer_dispatch( static iree_status_t iree_hal_hip_stream_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { return iree_make_status(IREE_STATUS_UNIMPLEMENTED, "need hip implementation of dispatch indirect"); } diff --git a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c index 95c503e6df0c..50d56c8d010e 100644 --- a/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c +++ b/runtime/src/iree/hal/drivers/local_task/task_command_buffer.c @@ -973,7 +973,8 @@ static iree_status_t iree_hal_task_command_buffer_build_dispatch( static iree_status_t iree_hal_task_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_task_command_buffer_t* command_buffer = iree_hal_task_command_buffer_cast(base_command_buffer); IREE_RETURN_IF_ERROR(iree_hal_resource_set_insert( @@ -987,7 +988,7 @@ static iree_status_t iree_hal_task_command_buffer_dispatch( static iree_status_t iree_hal_task_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_task_command_buffer_t* command_buffer = iree_hal_task_command_buffer_cast(base_command_buffer); diff --git a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m index 50d01e7fc177..eaed4f539652 100644 --- a/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m +++ b/runtime/src/iree/hal/drivers/metal/direct_command_buffer.m @@ -1073,7 +1073,7 @@ static iree_status_t iree_hal_metal_command_segment_record_dispatch( static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, uint32_t workgroup_count_x, uint32_t workgroup_count_y, - uint32_t workgroup_count_z) { + uint32_t workgroup_count_z, iree_hal_dispatch_flags_t flags) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_metal_dispatch_segment_t* segment = NULL; @@ -1090,7 +1090,7 @@ static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch( static iree_status_t iree_hal_metal_command_buffer_prepare_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, - int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref) { + int32_t entry_point, iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { IREE_TRACE_ZONE_BEGIN(z0); iree_hal_metal_dispatch_segment_t* segment = NULL; diff --git a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc index b66be80befcb..8bb94139f3fa 100644 --- a/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc +++ b/runtime/src/iree/hal/drivers/vulkan/direct_command_buffer.cc @@ -725,7 +725,8 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_push_descriptor_set( static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); @@ -764,7 +765,7 @@ static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch( static iree_status_t iree_hal_vulkan_direct_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_vulkan_direct_command_buffer_t* command_buffer = iree_hal_vulkan_direct_command_buffer_cast(base_command_buffer); diff --git a/runtime/src/iree/hal/local/inline_command_buffer.c b/runtime/src/iree/hal/local/inline_command_buffer.c index f69f9c2973dc..3de7c601e312 100644 --- a/runtime/src/iree/hal/local/inline_command_buffer.c +++ b/runtime/src/iree/hal/local/inline_command_buffer.c @@ -442,7 +442,8 @@ static iree_status_t iree_hal_inline_command_buffer_push_descriptor_set( static iree_status_t iree_hal_inline_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_inline_command_buffer_t* command_buffer = iree_hal_inline_command_buffer_cast(base_command_buffer); @@ -559,7 +560,7 @@ typedef union iree_hal_vec3_t { static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { // TODO(benvanik): track mapping so we can properly map/unmap/flush/etc. iree_hal_buffer_mapping_t buffer_mapping = {{0}}; IREE_RETURN_IF_ERROR(iree_hal_buffer_map_range( @@ -570,7 +571,7 @@ static iree_status_t iree_hal_inline_command_buffer_dispatch_indirect( *(const iree_hal_vec3_t*)buffer_mapping.contents.data; return iree_hal_inline_command_buffer_dispatch( base_command_buffer, executable, entry_point, workgroup_count.x, - workgroup_count.y, workgroup_count.z); + workgroup_count.y, workgroup_count.z, flags); } //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/hal/utils/deferred_command_buffer.c b/runtime/src/iree/hal/utils/deferred_command_buffer.c index a4b805a38ad0..49ec3347a584 100644 --- a/runtime/src/iree/hal/utils/deferred_command_buffer.c +++ b/runtime/src/iree/hal/utils/deferred_command_buffer.c @@ -771,12 +771,14 @@ typedef struct iree_hal_cmd_dispatch_t { uint32_t workgroup_x; uint32_t workgroup_y; uint32_t workgroup_z; + iree_hal_dispatch_flags_t flags; } iree_hal_cmd_dispatch_t; static iree_status_t iree_hal_deferred_command_buffer_dispatch( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z) { + uint32_t workgroup_x, uint32_t workgroup_y, uint32_t workgroup_z, + iree_hal_dispatch_flags_t flags) { iree_hal_deferred_command_buffer_t* command_buffer = iree_hal_deferred_command_buffer_cast(base_command_buffer); iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list; @@ -790,6 +792,7 @@ static iree_status_t iree_hal_deferred_command_buffer_dispatch( cmd->workgroup_x = workgroup_x; cmd->workgroup_y = workgroup_y; cmd->workgroup_z = workgroup_z; + cmd->flags = flags; return iree_ok_status(); } @@ -799,7 +802,7 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch( const iree_hal_cmd_dispatch_t* cmd) { return iree_hal_command_buffer_dispatch( target_command_buffer, cmd->executable, cmd->entry_point, - cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z); + cmd->workgroup_x, cmd->workgroup_y, cmd->workgroup_z, cmd->flags); } //===----------------------------------------------------------------------===// @@ -811,12 +814,13 @@ typedef struct iree_hal_cmd_dispatch_indirect_t { iree_hal_executable_t* executable; int32_t entry_point; iree_hal_buffer_ref_t workgroups_ref; + iree_hal_dispatch_flags_t flags; } iree_hal_cmd_dispatch_indirect_t; static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect( iree_hal_command_buffer_t* base_command_buffer, iree_hal_executable_t* executable, int32_t entry_point, - iree_hal_buffer_ref_t workgroups_ref) { + iree_hal_buffer_ref_t workgroups_ref, iree_hal_dispatch_flags_t flags) { iree_hal_deferred_command_buffer_t* command_buffer = iree_hal_deferred_command_buffer_cast(base_command_buffer); iree_hal_cmd_list_t* cmd_list = &command_buffer->cmd_list; @@ -834,6 +838,7 @@ static iree_status_t iree_hal_deferred_command_buffer_dispatch_indirect( cmd->executable = executable; cmd->entry_point = entry_point; cmd->workgroups_ref = workgroups_ref; + cmd->flags = flags; return iree_ok_status(); } @@ -845,7 +850,8 @@ static iree_status_t iree_hal_deferred_command_buffer_apply_dispatch_indirect( IREE_RETURN_IF_ERROR(iree_hal_buffer_binding_table_resolve_ref( binding_table, cmd->workgroups_ref, &workgroups_ref)); return iree_hal_command_buffer_dispatch_indirect( - target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref); + target_command_buffer, cmd->executable, cmd->entry_point, workgroups_ref, + cmd->flags); } //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/modules/hal/exports.inl b/runtime/src/iree/modules/hal/exports.inl index b87a8cbd09e0..f6f96f2b810d 100644 --- a/runtime/src/iree/modules/hal/exports.inl +++ b/runtime/src/iree/modules/hal/exports.inl @@ -50,8 +50,8 @@ EXPORT_FN("command_buffer.begin_debug_group", iree_hal_module_command_buffer_beg EXPORT_FN("command_buffer.collective", iree_hal_module_command_buffer_collective, rriiiirrIIIII, v) EXPORT_FN("command_buffer.copy_buffer", iree_hal_module_command_buffer_copy_buffer, riirIrII, v) EXPORT_FN("command_buffer.create", iree_hal_module_command_buffer_create, riiIi, r) -EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiii, v) -EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirI, v) +EXPORT_FN("command_buffer.dispatch", iree_hal_module_command_buffer_dispatch, rriiiiI, v) +EXPORT_FN("command_buffer.dispatch.indirect", iree_hal_module_command_buffer_dispatch_indirect, rriirII, v) EXPORT_FN("command_buffer.end_debug_group", iree_hal_module_command_buffer_end_debug_group, r, v) EXPORT_FN("command_buffer.execution_barrier", iree_hal_module_command_buffer_execution_barrier, riii, v) EXPORT_FN("command_buffer.fill_buffer", iree_hal_module_command_buffer_fill_buffer, rrIIiii, v) diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index c0db04bde782..777bcb3fbfbe 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -907,7 +907,7 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_push_descriptor_set, // IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, // iree_hal_module_state_t, // - rriiii, v) { + rriiiiI, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); @@ -917,15 +917,16 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch, // uint32_t workgroup_x = (uint32_t)args->i3; uint32_t workgroup_y = (uint32_t)args->i4; uint32_t workgroup_z = (uint32_t)args->i5; + iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6; return iree_hal_command_buffer_dispatch(command_buffer, executable, entry_point, workgroup_x, workgroup_y, - workgroup_z); + workgroup_z, flags); } IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, // iree_hal_module_state_t, // - rriirI, v) { + rriirII, v) { iree_hal_command_buffer_t* command_buffer = NULL; IREE_RETURN_IF_ERROR( iree_hal_command_buffer_check_deref(args->r0, &command_buffer)); @@ -938,9 +939,10 @@ IREE_VM_ABI_EXPORT(iree_hal_module_command_buffer_dispatch_indirect, // workgroups_buffer_slot, workgroups_offset, 3 * sizeof(uint32_t)); IREE_RETURN_IF_ERROR( iree_hal_buffer_check_deref_or_null(args->r4, &workgroups_ref.buffer)); + iree_hal_dispatch_flags_t flags = (iree_hal_dispatch_flags_t)args->i6; - return iree_hal_command_buffer_dispatch_indirect(command_buffer, executable, - entry_point, workgroups_ref); + return iree_hal_command_buffer_dispatch_indirect( + command_buffer, executable, entry_point, workgroups_ref, flags); } //===----------------------------------------------------------------------===// diff --git a/runtime/src/iree/vm/shims.c b/runtime/src/iree/vm/shims.c index c89a8b74fb8b..5bd69a7dd5d4 100644 --- a/runtime/src/iree/vm/shims.c +++ b/runtime/src/iree/vm/shims.c @@ -59,11 +59,11 @@ IREE_VM_ABI_DEFINE_SHIM(rrCrIID, v); IREE_VM_ABI_DEFINE_SHIM(rriCiD, v); IREE_VM_ABI_DEFINE_SHIM(rriiCID, v); IREE_VM_ABI_DEFINE_SHIM(rriCiirIID, v); -IREE_VM_ABI_DEFINE_SHIM(rriiii, v); +IREE_VM_ABI_DEFINE_SHIM(rriiiiI, v); IREE_VM_ABI_DEFINE_SHIM(rrIIiii, v); IREE_VM_ABI_DEFINE_SHIM(rrirCID, v); IREE_VM_ABI_DEFINE_SHIM(rrirI, v); -IREE_VM_ABI_DEFINE_SHIM(rriirI, v); +IREE_VM_ABI_DEFINE_SHIM(rriirII, v); IREE_VM_ABI_DEFINE_SHIM(rrIrIIi, v); IREE_VM_ABI_DEFINE_SHIM(riirIrII, v); IREE_VM_ABI_DEFINE_SHIM(rrIii, v); diff --git a/runtime/src/iree/vm/shims.h b/runtime/src/iree/vm/shims.h index 1d5a3aa4b9e2..b47428ced911 100644 --- a/runtime/src/iree/vm/shims.h +++ b/runtime/src/iree/vm/shims.h @@ -371,13 +371,14 @@ IREE_VM_ABI_FIXED_STRUCT(rriiiirrIIIII, { int64_t i12; }); -IREE_VM_ABI_FIXED_STRUCT(rriiii, { +IREE_VM_ABI_FIXED_STRUCT(rriiiiI, { iree_vm_ref_t r0; iree_vm_ref_t r1; int32_t i2; int32_t i3; int32_t i4; int32_t i5; + int64_t i6; }); IREE_VM_ABI_FIXED_STRUCT(rrIIiii, { @@ -398,13 +399,14 @@ IREE_VM_ABI_FIXED_STRUCT(rrirI, { int64_t i4; }); -IREE_VM_ABI_FIXED_STRUCT(rriirI, { +IREE_VM_ABI_FIXED_STRUCT(rriirII, { iree_vm_ref_t r0; iree_vm_ref_t r1; int32_t i2; int32_t i3; iree_vm_ref_t r4; int64_t i5; + int64_t i6; }); IREE_VM_ABI_FIXED_STRUCT(rrIrIIi, { @@ -708,11 +710,11 @@ IREE_VM_ABI_DECLARE_SHIM(rrCrIID, v); IREE_VM_ABI_DECLARE_SHIM(rriCiD, v); IREE_VM_ABI_DECLARE_SHIM(rriiCID, v); IREE_VM_ABI_DECLARE_SHIM(rriCiirIID, v); -IREE_VM_ABI_DECLARE_SHIM(rriiii, v); +IREE_VM_ABI_DECLARE_SHIM(rriiiiI, v); IREE_VM_ABI_DECLARE_SHIM(rrIIiii, v); IREE_VM_ABI_DECLARE_SHIM(rrirCID, v); IREE_VM_ABI_DECLARE_SHIM(rrirI, v); -IREE_VM_ABI_DECLARE_SHIM(rriirI, v); +IREE_VM_ABI_DECLARE_SHIM(rriirII, v); IREE_VM_ABI_DECLARE_SHIM(rrIrIIi, v); IREE_VM_ABI_DECLARE_SHIM(riirIrII, v); IREE_VM_ABI_DECLARE_SHIM(rrIii, v); diff --git a/tools/iree-benchmark-executable-main.c b/tools/iree-benchmark-executable-main.c index f3a0b1449d33..c603cb852949 100644 --- a/tools/iree-benchmark-executable-main.c +++ b/tools/iree-benchmark-executable-main.c @@ -277,7 +277,7 @@ static iree_status_t iree_benchmark_executable_run( IREE_RETURN_IF_ERROR(iree_hal_command_buffer_dispatch( command_buffer, args->executable, FLAG_entry_point, args->workgroup_count[0], args->workgroup_count[1], - args->workgroup_count[2])); + args->workgroup_count[2], IREE_HAL_DISPATCH_FLAG_NONE)); IREE_RETURN_IF_ERROR(iree_hal_command_buffer_execution_barrier( command_buffer, IREE_HAL_EXECUTION_STAGE_COMMAND_RETIRE, IREE_HAL_EXECUTION_STAGE_COMMAND_ISSUE, From ad9634e36b2e7fff349bf95db051c69691dab678 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 18 Jul 2024 12:36:25 -0700 Subject: [PATCH 5/5] Bumping HAL module version to 0.3. This is a breaking change between compiler and runtime and versions will not be compatible. Compiled artifacts using the HAL must be recompiled. --- .../iree/compiler/Dialect/HAL/hal.imports.mlir | 15 +++++---------- runtime/src/iree/modules/hal/module.c | 4 ++-- 2 files changed, 7 insertions(+), 12 deletions(-) diff --git a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir index d68b86523bb6..66f8dd7af602 100644 --- a/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir +++ b/compiler/src/iree/compiler/Dialect/HAL/hal.imports.mlir @@ -33,7 +33,6 @@ vm.import private @allocator.allocate( %buffer_usage : i32, %allocation_size : i64 ) -> !vm.ref -attributes {minimum_version = 1 : i32} // Imports a host byte buffer into a device visible buffer. // If try!=0 then returns null if the given memory type cannot be mapped. @@ -48,7 +47,6 @@ vm.import private @allocator.import( %offset : i64, %length : i64 ) -> !vm.ref -attributes {minimum_version = 1 : i32} //===----------------------------------------------------------------------===// // iree_hal_buffer_t @@ -202,6 +200,9 @@ vm.import private @command_buffer.create( %queue_affinity : i64, %binding_capacity : i32 ) -> !vm.ref +attributes { + minimum_version = 3 : i32 // command buffer API version +} // Finalizes recording into the command buffer and prepares it for submission. // No more commands can be recorded afterward. @@ -448,16 +449,10 @@ vm.import private @device.queue.flush( //===----------------------------------------------------------------------===// vm.import private @devices.count() -> i32 -attributes { - minimum_version = 2 : i32, - nosideeffects -} +attributes {nosideeffects} vm.import private @devices.get(%index : i32) -> !vm.ref -attributes { - minimum_version = 2 : i32, - nosideeffects -} +attributes {nosideeffects} //===----------------------------------------------------------------------===// // iree_hal_executable_t diff --git a/runtime/src/iree/modules/hal/module.c b/runtime/src/iree/modules/hal/module.c index 777bcb3fbfbe..fad75d067d66 100644 --- a/runtime/src/iree/modules/hal/module.c +++ b/runtime/src/iree/modules/hal/module.c @@ -32,8 +32,8 @@ // Module type definitions //===----------------------------------------------------------------------===// -#define IREE_HAL_MODULE_VERSION_0_2 0x00000002u -#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_2 +#define IREE_HAL_MODULE_VERSION_0_3 0x00000003u +#define IREE_HAL_MODULE_VERSION_LATEST IREE_HAL_MODULE_VERSION_0_3 typedef struct iree_hal_module_t { iree_allocator_t host_allocator;