-
Notifications
You must be signed in to change notification settings - Fork 8
Skip command buffer re-tracing via in-place HIP graph node patching #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 5 commits
4be2e4a
da7699f
a981e72
0a0fd1b
2acd22f
3700911
bc2b27f
a75943c
0182f04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,7 @@ limitations under the License. | |
| #include "xla/status_macros.h" | ||
| #include "xla/stream_executor/command_buffer.h" | ||
| #include "xla/stream_executor/device_address.h" | ||
| #include "xla/stream_executor/gpu/gpu_command_buffer.h" | ||
| #include "xla/stream_executor/device_address_handle.h" | ||
| #include "xla/stream_executor/dnn.h" | ||
| #include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h" | ||
|
|
@@ -241,6 +242,7 @@ TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd, | |
| absl::flat_hash_set<BufferAllocation::Index> allocs_indices; | ||
| for (auto& buffer : buffers) { | ||
| allocs_indices.insert(buffer.slice().index()); | ||
| buffer_slices_.push_back(buffer.slice()); | ||
| } | ||
| allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); | ||
| } | ||
|
|
@@ -249,13 +251,20 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer( | |
| const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, | ||
| se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace, | ||
| se::StreamPriority priority) { | ||
| // Collect memory addresses for relevant allocations. | ||
| // Collect memory addresses for relevant allocations (for cache comparison). | ||
| absl::InlinedVector<se::DeviceAddressBase, 4> allocs; | ||
| allocs.reserve(allocs_indices_.size()); | ||
| for (auto& index : allocs_indices_) { | ||
| allocs.emplace_back(buffer_allocation->GetDeviceAddress(index)); | ||
| } | ||
|
|
||
| // Collect slice-level addresses (for kernel node patching). | ||
| absl::InlinedVector<se::DeviceAddressBase, 4> slice_addrs; | ||
| slice_addrs.reserve(buffer_slices_.size()); | ||
| for (auto& slice : buffer_slices_) { | ||
| slice_addrs.emplace_back(buffer_allocation->GetDeviceAddress(slice)); | ||
| } | ||
|
|
||
| // Moves entry at `i` position to front and moves entries in `[0, i)` range | ||
| // one element to the right. Returns reference to the first entry. | ||
| auto shift_right = [&](size_t i) -> Entry& { | ||
|
|
@@ -271,23 +280,53 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer( | |
| return entries_[0] = std::move(entry); | ||
| }; | ||
|
|
||
| static const bool skip_retrace = [] { | ||
| const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"); | ||
| bool val = env != nullptr && std::string(env) == "1"; | ||
| if (val) { | ||
| LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node " | ||
| "params will be patched in-place to avoid re-tracing"; | ||
| } | ||
| return val; | ||
| }(); | ||
|
Comment on lines
+288
to
+296
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The skip-retrace feature is gated via raw |
||
|
|
||
| for (size_t i = 0; i < capacity_; ++i) { | ||
| // Found entry for a given allocations, move it to front and return a | ||
| // pointer to cached command buffer. | ||
| if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) && | ||
| entries_[i].command_buffer)) { | ||
| VLOG(6) << "Command buffer trace cache hit for command " | ||
| << trace_cmd_->ToString(); | ||
| return shift_right(i).command_buffer.get(); | ||
| } | ||
|
|
||
| // Create a new entry by calling a user-provided tracing function, move it | ||
| // to front and return a pointer to cached command buffer. | ||
| // When skip_retrace is enabled, patch kernel nodes in the existing | ||
| // cached graph directly instead of re-tracing via stream capture. | ||
| if (skip_retrace && entries_[i].command_buffer != nullptr && | ||
| !entries_[i].recorded_slice_addrs.empty()) { | ||
| auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>( | ||
| entries_[i].command_buffer.get()); | ||
| if (gpu_cmd) { | ||
| auto status = gpu_cmd->UpdateKernelNodes( | ||
| entries_[i].recorded_slice_addrs, slice_addrs); | ||
| if (status.ok()) { | ||
| entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
| VLOG(6) << "Patched kernel nodes in-place for command " | ||
| << trace_cmd_->ToString(); | ||
| return shift_right(i).command_buffer.get(); | ||
| } | ||
| VLOG(3) << "Kernel node patch failed for " << trace_cmd_->ToString() | ||
| << ": " << status << ", falling back to retrace"; | ||
| } | ||
| } | ||
|
|
||
| if (entries_[i].command_buffer == nullptr) { | ||
| TF_ASSIGN_OR_RETURN( | ||
| entries_[i].command_buffer, | ||
| se::TraceCommandBufferFactory::Create(executor, stream, trace)); | ||
| entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
| if (priority != se::StreamPriority::Default) { | ||
| TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority)); | ||
| } | ||
|
|
@@ -297,15 +336,33 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer( | |
| } | ||
| } | ||
|
|
||
| // Create a new entry by calling a user-provided tracing function, replace | ||
| // the last entry with it, move it to front and return a pointer to cached | ||
| // command buffer. | ||
| // All slots occupied and no match. Try in-place update of the last slot. | ||
| if (skip_retrace && !entries_[capacity_ - 1].recorded_slice_addrs.empty()) { | ||
| auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>( | ||
| entries_[capacity_ - 1].command_buffer.get()); | ||
| if (gpu_cmd) { | ||
| auto status = gpu_cmd->UpdateKernelNodes( | ||
| entries_[capacity_ - 1].recorded_slice_addrs, slice_addrs); | ||
| if (status.ok()) { | ||
| entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), | ||
| allocs.end()); | ||
| entries_[capacity_ - 1].recorded_slice_addrs.assign( | ||
| slice_addrs.begin(), slice_addrs.end()); | ||
| VLOG(6) << "Patched kernel nodes (evict slot) for command " | ||
| << trace_cmd_->ToString(); | ||
| return shift_right(capacity_ - 1).command_buffer.get(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| VLOG(6) << "Command buffer trace cache does replacement for command " | ||
| << trace_cmd_->ToString(); | ||
| TF_ASSIGN_OR_RETURN( | ||
| entries_[capacity_ - 1].command_buffer, | ||
| se::TraceCommandBufferFactory::Create(executor, stream, trace)); | ||
| entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| VLOG(6) << "Command buffer trace cache does replacement for command " | ||
| << trace_cmd_->ToString(); | ||
| entries_[capacity_ - 1].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
| return shift_right(capacity_ - 1).command_buffer.get(); | ||
| } | ||
|
|
||
|
|
@@ -322,6 +379,7 @@ TracedCommandBufferCmd::RecordTracedCommand( | |
| const RecordParams& record_params, RecordAction record_action, | ||
| se::CommandBuffer* command_buffer, | ||
| absl::FunctionRef<absl::Status(se::Stream*)> trace) { | ||
|
|
||
| auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>( | ||
| this, command_buffer, [&] { | ||
| const auto& debug_options = xla::GetDebugOptionsFromFlags(); | ||
|
|
@@ -335,8 +393,6 @@ TracedCommandBufferCmd::RecordTracedCommand( | |
| traced_cmd->GetOrTraceCommandBuffer( | ||
| execute_params.buffer_allocations, execute_params.stream->parent(), | ||
| execute_params.command_buffer_trace_stream, trace, priority())); | ||
|
|
||
| VLOG(5) << "Record traced command into command buffer: " << command_buffer; | ||
| return Handle( | ||
| std::move(record_action), | ||
| [&](absl::Span<const se::CommandBuffer::Command* const> dependencies) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -117,6 +117,56 @@ class GpuCommandBuffer : public CommandBuffer { | |
| std::unique_ptr<CommandBuffer> command_buffer; | ||
| }; | ||
|
|
||
| // A command representing a set of nodes that were extracted ("flattened") | ||
| // from a traced child graph into the parent graph. Stores all flattened | ||
| // node handles so they can be updated individually. | ||
| enum class FlatNodeKind : uint8_t { | ||
| kKernel, | ||
| kKernelExtra, | ||
| kMemcpy, | ||
| kMemset, | ||
| kEmpty, | ||
| }; | ||
|
|
||
| struct FlatNodeInfo { | ||
| FlatNodeKind kind; | ||
| void* func = nullptr; // kernel function pointer (for kKernel/kKernelExtra) | ||
| }; | ||
|
|
||
| // Records where a known device pointer appears inside a kernel node's | ||
| // packed argument buffer (the `extra` HIP_LAUNCH_PARAM blob). | ||
| struct ArgPatchEntry { | ||
| size_t node_index; // index into node_handles/node_infos | ||
| size_t byte_offset; // offset within the packed arg buffer | ||
| int buffer_use_index; // index into the buffer_uses() vector | ||
| }; | ||
|
|
||
| struct GpuFlattenedCommand : public CommandBuffer::Command { | ||
| std::vector<GraphNodeHandle> node_handles; | ||
| std::vector<FlatNodeInfo> node_infos; | ||
|
|
||
| // Deep-copied extra arg buffers for kernel nodes that use the | ||
| // HIP_LAUNCH_PARAM packed-buffer launch convention (e.g. rocBLAS). | ||
| // Kept alive for the lifetime of the parent graph. | ||
| std::vector<std::unique_ptr<uint8_t[]>> extra_arg_buffers; | ||
| std::vector<std::unique_ptr<size_t>> extra_arg_sizes; | ||
|
|
||
| // Patch table for skip-retrace updates: records where each known | ||
| // buffer address appears inside the packed args of flattened | ||
| // extra-style kernel nodes. | ||
| std::vector<ArgPatchEntry> patch_table; | ||
|
|
||
| // Per-node deep-copied arg buffer and its size (only for extra-style | ||
| // kernel nodes that have patch entries). | ||
| struct NodeArgBuffer { | ||
| std::unique_ptr<uint8_t[]> data; | ||
|
Comment on lines
+148
to
+162
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| size_t size = 0; | ||
| }; | ||
| std::vector<NodeArgBuffer> node_arg_buffers; | ||
|
|
||
| bool has_patch_table = false; | ||
| }; | ||
|
|
||
| GpuCommandBuffer(Mode mode, StreamExecutor* executor); | ||
|
|
||
| // Bring CreateLaunch and UpdateLaunch template functions into scope. | ||
|
|
@@ -207,6 +257,57 @@ class GpuCommandBuffer : public CommandBuffer { | |
|
|
||
| absl::Span<const std::unique_ptr<Command>> commands() const; | ||
|
|
||
| // Dumps kernel node info from this command buffer's graph for debugging. | ||
| virtual void DumpGraphKernelNodes(absl::string_view label) {} | ||
|
|
||
| // Updates kernel nodes inside this command buffer's graph in place, | ||
| // replacing old_addresses with new_addresses at the corresponding positions. | ||
| // This avoids re-tracing the graph when only buffer addresses change. | ||
| virtual absl::Status UpdateKernelNodes( | ||
| absl::Span<const DeviceAddressBase> old_addresses, | ||
| absl::Span<const DeviceAddressBase> new_addresses) { | ||
| return absl::UnimplementedError( | ||
| "UpdateKernelNodes not supported on this platform"); | ||
| } | ||
|
|
||
| // Extracts nodes from a nested (traced) command buffer and adds them | ||
| // directly to this parent graph as individual kernel/memcpy/memset nodes, | ||
| // avoiding child graph re-tracing. Returns a single GpuFlattenedCommand | ||
| // that owns all the flattened node handles. | ||
| virtual absl::StatusOr<const Command*> FlattenChildGraphNodes( | ||
| const CommandBuffer& nested, | ||
| absl::Span<const Command* const> dependencies) { | ||
| return absl::UnimplementedError( | ||
| "FlattenChildGraphNodes not supported on this platform"); | ||
| } | ||
|
|
||
| // Updates all kernel nodes inside a previously flattened command by scanning | ||
| // their arguments and replacing old device pointers with new ones. | ||
| virtual absl::Status UpdateFlattenedChildNodes( | ||
| const Command* command, const CommandBuffer& nested) { | ||
| return absl::UnimplementedError( | ||
| "UpdateFlattenedChildNodes not supported on this platform"); | ||
| } | ||
|
|
||
| // Builds a patch table on a flattened command: scans each extra-style | ||
| // kernel node's packed argument buffer to find byte offsets of known | ||
| // buffer addresses. Must be called once after FlattenChildGraphNodes. | ||
| virtual absl::Status BuildPatchTable( | ||
| const Command* command, | ||
| absl::Span<const DeviceAddressBase> known_addresses) { | ||
| return absl::UnimplementedError( | ||
| "BuildPatchTable not supported on this platform"); | ||
| } | ||
|
|
||
| // Directly patches buffer addresses in flattened nodes using the | ||
| // previously built patch table. Skips re-tracing entirely. | ||
| virtual absl::Status PatchFlattenedNodes( | ||
| const Command* command, | ||
| absl::Span<const DeviceAddressBase> new_addresses) { | ||
| return absl::UnimplementedError( | ||
| "PatchFlattenedNodes not supported on this platform"); | ||
| } | ||
|
|
||
| protected: | ||
| // We track the total number of allocated and alive executable graphs in the | ||
| // process to track the command buffers resource usage. Executable graph | ||
|
|
@@ -280,6 +381,15 @@ class GpuCommandBuffer : public CommandBuffer { | |
| // Returns OK status if the command buffer can be updated. | ||
| virtual absl::Status CheckCanBeUpdated() = 0; | ||
|
|
||
| // Appends a new command to the command buffer. | ||
| template <typename T> | ||
| const Command* AppendCommand(T command) { | ||
| commands_.push_back(std::make_unique<T>(std::move(command))); | ||
| VLOG(5) << "AppendCommand: " | ||
| << reinterpret_cast<const void*>(commands_.back().get()); | ||
| return commands_.back().get(); | ||
| } | ||
|
|
||
| private: | ||
| absl::StatusOr<const Command*> CreateCase( | ||
| DeviceAddress<uint8_t> index, bool index_is_bool, | ||
|
|
@@ -290,15 +400,6 @@ class GpuCommandBuffer : public CommandBuffer { | |
| bool index_is_bool, | ||
| std::vector<UpdateCommands> update_branches); | ||
|
|
||
| // Appends a new command to the command buffer. | ||
| template <typename T> | ||
| const Command* AppendCommand(T command) { | ||
| commands_.push_back(std::make_unique<T>(std::move(command))); | ||
| VLOG(5) << "AppendCommand: " | ||
| << reinterpret_cast<const void*>(commands_.back().get()); | ||
| return commands_.back().get(); | ||
| } | ||
|
|
||
| // Converts a list of command dependencies to a list of graph node handles. | ||
| std::vector<GraphNodeHandle> ToGraphNodeDependencies( | ||
| absl::Span<const Command* const> dependencies); | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Include order —
gpu_command_buffer.his inserted betweendevice_address.handdevice_address_handle.h, breaking alphabetical order within thexla/stream_executor/group. It should come afterdevice_address_handle.h(or be grouped with thegpu/headers at line 106).