Skip to content

Commit

Permalink
[hip][cuda] Update event allocation and collection. (iree-org#17603)
Browse files Browse the repository at this point in the history
The existing system was not sufficient for graphs, as they can be run
out of order and have different behavior for event recording.

This does not entirely solve the problem for re-use, if we ever want to
simultaneously submit more than one graph at a time, but is much closer.

---------

Signed-off-by: Andrew Woloszyn <andrew.woloszyn@gmail.com>
  • Loading branch information
AWoloszyn authored Jul 15, 2024
1 parent 7dafb0e commit c322d28
Show file tree
Hide file tree
Showing 18 changed files with 823 additions and 298 deletions.
35 changes: 27 additions & 8 deletions runtime/src/iree/hal/drivers/cuda/graph_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ typedef struct iree_hal_cuda_graph_command_buffer_t {

// Per-stream CUDA tracing context.
iree_hal_cuda_tracing_context_t* tracing_context;
iree_hal_cuda_tracing_context_event_list_t tracing_event_list;

// A resource set to maintain references to all resources used within the
// command buffer.
Expand Down Expand Up @@ -96,10 +97,11 @@ static void iree_cuda_graph_command_buffer_trace_zone_begin_external(
&command_buffer->cu_graph_nodes[command_buffer->graph_node_count++];
size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0;
IREE_CUDA_GRAPH_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, tracing_event_node,
command_buffer->cu_graph, &command_buffer->cu_barrier_node,
dependency_count, file_name, file_name_length, line, function_name,
function_name_length, name, name_length);
command_buffer->tracing_context, &command_buffer->tracing_event_list,
tracing_event_node, command_buffer->cu_graph,
&command_buffer->cu_barrier_node, dependency_count, file_name,
file_name_length, line, function_name, function_name_length, name,
name_length);

// Move the barrier forward to make sure that the tracing event is recorded
// before work starts.
Expand All @@ -121,10 +123,10 @@ static void iree_cuda_graph_command_buffer_trace_zone_end(
size_t dependency_count = command_buffer->cu_barrier_node ? 1 : 0;
IREE_ASSERT_GT(dependency_count, 0,
"ending a zone should at least depend on the beginning");
IREE_CUDA_GRAPH_TRACE_ZONE_END(command_buffer->tracing_context,
tracing_event_node, command_buffer->cu_graph,
&command_buffer->cu_barrier_node,
dependency_count);
IREE_CUDA_GRAPH_TRACE_ZONE_END(
command_buffer->tracing_context, &command_buffer->tracing_event_list,
tracing_event_node, command_buffer->cu_graph,
&command_buffer->cu_barrier_node, dependency_count);

// We need to wait on the tracing end before other work starts.
// GPU tracing zones are first-in, last-out.
Expand Down Expand Up @@ -191,6 +193,8 @@ iree_status_t iree_hal_cuda_graph_command_buffer_create(
command_buffer->host_allocator = host_allocator;
command_buffer->symbols = cuda_symbols;
command_buffer->tracing_context = tracing_context;
command_buffer->tracing_event_list.head = NULL;
command_buffer->tracing_event_list.tail = NULL;
iree_arena_initialize(block_pool, &command_buffer->arena);
command_buffer->cu_context = context;
command_buffer->cu_graph = NULL;
Expand Down Expand Up @@ -224,6 +228,9 @@ static void iree_hal_cuda_graph_command_buffer_destroy(
iree_allocator_t host_allocator = command_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_cuda_tracing_free(command_buffer->tracing_context,
&command_buffer->tracing_event_list);

// Drop any pending collective batches before we tear things down.
iree_hal_collective_batch_clear(&command_buffer->collective_batch);

Expand Down Expand Up @@ -261,6 +268,18 @@ CUgraphExec iree_hal_cuda_graph_command_buffer_handle(
return command_buffer->cu_graph_exec;
}

void iree_hal_cuda_graph_tracing_notify_submitted_commands(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_cuda_graph_command_buffer_t* command_buffer =
iree_hal_cuda_graph_command_buffer_cast(base_command_buffer);
if (!command_buffer->tracing_context) {
return;
}

iree_hal_cuda_tracing_notify_submitted(command_buffer->tracing_context,
&command_buffer->tracing_event_list);
}

// Flushes any pending batched collective operations.
// Must be called before any other non-collective nodes are added to the graph
// or a barrier is encountered.
Expand Down
5 changes: 5 additions & 0 deletions runtime/src/iree/hal/drivers/cuda/graph_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,11 @@ bool iree_hal_cuda_graph_command_buffer_isa(
CUgraphExec iree_hal_cuda_graph_command_buffer_handle(
iree_hal_command_buffer_t* command_buffer);

// This is to be called after the given |command_buffer| has been submitted
// in order to notify the tracing system that there are events to collect.
void iree_hal_cuda_graph_tracing_notify_submitted_commands(
iree_hal_command_buffer_t* command_buffer);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
10 changes: 6 additions & 4 deletions runtime/src/iree/hal/drivers/cuda/nccl_channel.c
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ static iree_status_t iree_hal_cuda_nccl_submit_batch_entry(
iree_status_t iree_hal_cuda_nccl_submit_batch(
const iree_hal_cuda_nccl_dynamic_symbols_t* symbols,
iree_hal_cuda_tracing_context_t* tracing_context,
iree_hal_cuda_tracing_context_event_list_t* tracing_event_list,
const iree_hal_collective_batch_t* batch, CUstream stream) {
IREE_ASSERT_ARGUMENT(symbols);
IREE_ASSERT_ARGUMENT(batch);
Expand All @@ -558,9 +559,9 @@ iree_status_t iree_hal_cuda_nccl_submit_batch(
iree_string_view_t collective_str =
iree_hal_collective_op_format(&entry->op, &string_temp);
IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
tracing_context, stream, __FILE__, strlen(__FILE__), (uint32_t)__LINE__,
__FUNCTION__, strlen(__FUNCTION__), collective_str.data,
collective_str.size);
tracing_context, tracing_event_list, stream, __FILE__, strlen(__FILE__),
(uint32_t)__LINE__, __FUNCTION__, strlen(__FUNCTION__),
collective_str.data, collective_str.size);
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE

Expand All @@ -577,7 +578,8 @@ iree_status_t iree_hal_cuda_nccl_submit_batch(
// End all zones we began above - note that these are just simply nested so
// order doesn't matter so long as we end the right number of zones.
for (iree_host_size_t i = 0; i < batch->count; ++i) {
IREE_CUDA_STREAM_TRACE_ZONE_END(tracing_context, stream);
IREE_CUDA_STREAM_TRACE_ZONE_END(tracing_context, tracing_event_list,
stream);
}
#endif // IREE_TRACING_FEATURES & IREE_TRACING_FEATURE_INSTRUMENTATION_DEVICE

Expand Down
1 change: 1 addition & 0 deletions runtime/src/iree/hal/drivers/cuda/nccl_channel.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ iree_status_t iree_hal_cuda_nccl_channel_create(
iree_status_t iree_hal_cuda_nccl_submit_batch(
const iree_hal_cuda_nccl_dynamic_symbols_t* nccl_symbols,
iree_hal_cuda_tracing_context_t* tracing_context,
iree_hal_cuda_tracing_context_event_list_t* tracing_event_list,
const iree_hal_collective_batch_t* batch, CUstream stream);

#ifdef __cplusplus
Expand Down
12 changes: 11 additions & 1 deletion runtime/src/iree/hal/drivers/cuda/pending_queue_actions.c
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "iree/hal/drivers/cuda/event_pool.h"
#include "iree/hal/drivers/cuda/event_semaphore.h"
#include "iree/hal/drivers/cuda/graph_command_buffer.h"
#include "iree/hal/drivers/cuda/stream_command_buffer.h"
#include "iree/hal/drivers/utils/semaphore.h"
#include "iree/hal/utils/deferred_command_buffer.h"
#include "iree/hal/utils/resource_set.h"
Expand Down Expand Up @@ -729,12 +730,17 @@ static iree_status_t iree_hal_cuda_pending_queue_actions_issue_execution(
action->payload.execution.binding_tables
? action->payload.execution.binding_tables[i]
: iree_hal_buffer_binding_table_empty();
if (iree_hal_cuda_graph_command_buffer_isa(command_buffer)) {
if (iree_hal_cuda_stream_command_buffer_isa(command_buffer)) {
// Notify that the commands were "submitted" so we can
// make sure to clean up our trace events.
iree_hal_cuda_stream_notify_submitted_commands(command_buffer);
} else if (iree_hal_cuda_graph_command_buffer_isa(command_buffer)) {
CUgraphExec exec =
iree_hal_cuda_graph_command_buffer_handle(command_buffer);
IREE_CUDA_RETURN_AND_END_ZONE_IF_ERROR(
z0, symbols, cuGraphLaunch(exec, action->dispatch_cu_stream),
"cuGraphLaunch");
iree_hal_cuda_graph_tracing_notify_submitted_commands(command_buffer);
} else {
iree_hal_command_buffer_t* stream_command_buffer = NULL;
iree_hal_command_buffer_mode_t mode =
Expand All @@ -753,6 +759,10 @@ static iree_status_t iree_hal_cuda_pending_queue_actions_issue_execution(
IREE_RETURN_AND_END_ZONE_IF_ERROR(
z0, iree_hal_deferred_command_buffer_apply(
command_buffer, stream_command_buffer, binding_table));
iree_hal_cuda_stream_notify_submitted_commands(stream_command_buffer);
// The stream_command_buffer is going to be retained by
// the action->resource_set and deleted after the action
// completes.
iree_hal_resource_release(stream_command_buffer);
}
}
Expand Down
43 changes: 33 additions & 10 deletions runtime/src/iree/hal/drivers/cuda/stream_command_buffer.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ typedef struct iree_hal_cuda_stream_command_buffer_t {

// Per-stream CUDA tracing context.
iree_hal_cuda_tracing_context_t* tracing_context;
iree_hal_cuda_tracing_context_event_list_t tracing_event_list;

CUstream cu_stream;

Expand Down Expand Up @@ -98,6 +99,8 @@ iree_status_t iree_hal_cuda_stream_command_buffer_create(
command_buffer->cuda_symbols = cuda_symbols;
command_buffer->nccl_symbols = nccl_symbols;
command_buffer->tracing_context = tracing_context;
command_buffer->tracing_event_list.head = NULL;
command_buffer->tracing_event_list.tail = NULL;
command_buffer->cu_stream = stream;
iree_arena_initialize(block_pool, &command_buffer->arena);

Expand All @@ -122,6 +125,9 @@ static void iree_hal_cuda_stream_command_buffer_destroy(
iree_allocator_t host_allocator = command_buffer->host_allocator;
IREE_TRACE_ZONE_BEGIN(z0);

iree_hal_cuda_tracing_free(command_buffer->tracing_context,
&command_buffer->tracing_event_list);

iree_hal_collective_batch_deinitialize(&command_buffer->collective_batch);
iree_hal_resource_set_free(command_buffer->resource_set);
iree_arena_deinitialize(&command_buffer->arena);
Expand All @@ -136,6 +142,18 @@ bool iree_hal_cuda_stream_command_buffer_isa(
&iree_hal_cuda_stream_command_buffer_vtable);
}

void iree_hal_cuda_stream_notify_submitted_commands(
iree_hal_command_buffer_t* base_command_buffer) {
iree_hal_cuda_stream_command_buffer_t* command_buffer =
iree_hal_cuda_stream_command_buffer_cast(base_command_buffer);
if (!command_buffer->tracing_context) {
return;
}

iree_hal_cuda_tracing_notify_submitted(command_buffer->tracing_context,
&command_buffer->tracing_event_list);
}

// Flushes any pending batched collective operations.
// Must be called before any other non-collective nodes are added to the graph
// or a barrier is encountered.
Expand All @@ -151,7 +169,8 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_flush_collectives(
IREE_TRACE_ZONE_BEGIN(z0);
iree_status_t status = iree_hal_cuda_nccl_submit_batch(
command_buffer->nccl_symbols, command_buffer->tracing_context,
&command_buffer->collective_batch, command_buffer->cu_stream);
&command_buffer->tracing_event_list, &command_buffer->collective_batch,
command_buffer->cu_stream);
iree_hal_collective_batch_clear(&command_buffer->collective_batch);
IREE_TRACE_ZONE_END(z0);
return status;
Expand All @@ -164,7 +183,8 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_begin(
(void)command_buffer;

IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, command_buffer->cu_stream,
command_buffer->tracing_context, &command_buffer->tracing_event_list,
command_buffer->cu_stream,
/*file_name=*/NULL, 0, /*line=*/0, "iree_hal_cuda_stream_command_buffer",
strlen("iree_hal_cuda_stream_command_buffer"), /*name=*/NULL, 0);

Expand Down Expand Up @@ -200,6 +220,7 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_end(
&command_buffer->collective_batch);

IREE_CUDA_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,
&command_buffer->tracing_event_list,
command_buffer->cu_stream);

IREE_TRACE_ZONE_END(z0);
Expand All @@ -215,10 +236,10 @@ static void iree_hal_cuda_stream_command_buffer_begin_debug_group(
(void)command_buffer;

IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, command_buffer->cu_stream,
location ? location->file.data : NULL, location ? location->file.size : 0,
location ? location->line : 0, /*func_name=*/NULL, 0, label.data,
label.size);
command_buffer->tracing_context, &command_buffer->tracing_event_list,
command_buffer->cu_stream, location ? location->file.data : NULL,
location ? location->file.size : 0, location ? location->line : 0,
/*func_name=*/NULL, 0, label.data, label.size);

// TODO: pass along to CUPTI if available.
}
Expand All @@ -232,6 +253,7 @@ static void iree_hal_cuda_stream_command_buffer_end_debug_group(
// TODO: pass along to CUPTI if available.

IREE_CUDA_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,
&command_buffer->tracing_event_list,
command_buffer->cu_stream);
}

Expand Down Expand Up @@ -528,10 +550,10 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch(
executable, entry_point, &kernel_info));

IREE_CUDA_STREAM_TRACE_ZONE_BEGIN_EXTERNAL(
command_buffer->tracing_context, command_buffer->cu_stream,
kernel_info.source_filename.data, kernel_info.source_filename.size,
kernel_info.source_line, kernel_info.function_name.data,
kernel_info.function_name.size,
command_buffer->tracing_context, &command_buffer->tracing_event_list,
command_buffer->cu_stream, kernel_info.source_filename.data,
kernel_info.source_filename.size, kernel_info.source_line,
kernel_info.function_name.data, kernel_info.function_name.size,
/*name=*/NULL, 0);

IREE_RETURN_AND_END_ZONE_IF_ERROR(
Expand Down Expand Up @@ -614,6 +636,7 @@ static iree_status_t iree_hal_cuda_stream_command_buffer_dispatch(
"cuLaunchKernel");

IREE_CUDA_STREAM_TRACE_ZONE_END(command_buffer->tracing_context,
&command_buffer->tracing_event_list,
command_buffer->cu_stream);

IREE_TRACE_ZONE_END(z0);
Expand Down
6 changes: 6 additions & 0 deletions runtime/src/iree/hal/drivers/cuda/stream_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ iree_status_t iree_hal_cuda_stream_command_buffer_create(
bool iree_hal_cuda_stream_command_buffer_isa(
iree_hal_command_buffer_t* command_buffer);

// This is to be called after a command buffer has been submitted
// in order to notify the tracing system that there are events
// to collect.
void iree_hal_cuda_stream_notify_submitted_commands(
iree_hal_command_buffer_t* base_command_buffer);

#ifdef __cplusplus
} // extern "C"
#endif // __cplusplus
Expand Down
Loading

0 comments on commit c322d28

Please sign in to comment.