-
Notifications
You must be signed in to change notification settings - Fork 8
Skip command buffer re-tracing via in-place HIP graph node patching #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4be2e4a
da7699f
a981e72
0a0fd1b
2acd22f
3700911
bc2b27f
a75943c
0182f04
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -99,6 +99,7 @@ limitations under the License. | |
| #include "xla/status_macros.h" | ||
| #include "xla/stream_executor/command_buffer.h" | ||
| #include "xla/stream_executor/device_address.h" | ||
| #include "xla/stream_executor/gpu/gpu_command_buffer.h" | ||
| #include "xla/stream_executor/device_address_handle.h" | ||
| #include "xla/stream_executor/dnn.h" | ||
| #include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h" | ||
|
|
@@ -241,6 +242,7 @@ TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd, | |
| absl::flat_hash_set<BufferAllocation::Index> allocs_indices; | ||
| for (auto& buffer : buffers) { | ||
| allocs_indices.insert(buffer.slice().index()); | ||
| buffer_slices_.push_back(buffer.slice()); | ||
| } | ||
| allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end()); | ||
| } | ||
|
|
@@ -249,13 +251,25 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer( | |
| const BufferAllocations* buffer_allocation, se::StreamExecutor* executor, | ||
| se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace, | ||
| se::StreamPriority priority) { | ||
| // Collect memory addresses for relevant allocations. | ||
| static const bool profile_trace = [] { | ||
| const char* env = std::getenv("XLA_PROFILE_CMD_BUFFER"); | ||
| return env != nullptr && std::string(env) == "1"; | ||
| }(); | ||
|
|
||
| // Collect memory addresses for relevant allocations (for cache comparison). | ||
| absl::InlinedVector<se::DeviceAddressBase, 4> allocs; | ||
| allocs.reserve(allocs_indices_.size()); | ||
| for (auto& index : allocs_indices_) { | ||
| allocs.emplace_back(buffer_allocation->GetDeviceAddress(index)); | ||
| } | ||
|
|
||
| // Collect slice-level addresses (for kernel node patching). | ||
| absl::InlinedVector<se::DeviceAddressBase, 4> slice_addrs; | ||
| slice_addrs.reserve(buffer_slices_.size()); | ||
| for (auto& slice : buffer_slices_) { | ||
| slice_addrs.emplace_back(buffer_allocation->GetDeviceAddress(slice)); | ||
| } | ||
|
|
||
| // Moves entry at `i` position to front and moves entries in `[0, i)` range | ||
| // one element to the right. Returns reference to the first entry. | ||
| auto shift_right = [&](size_t i) -> Entry& { | ||
|
|
@@ -271,41 +285,117 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer( | |
| return entries_[0] = std::move(entry); | ||
| }; | ||
|
|
||
| static const bool skip_retrace = [] { | ||
| const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"); | ||
| bool val = env != nullptr && std::string(env) == "1"; | ||
| if (val) { | ||
| LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node " | ||
| "params will be patched in-place to avoid re-tracing"; | ||
| } | ||
| return val; | ||
| }(); | ||
|
Comment on lines
+288
to
+296
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The skip-retrace feature is gated via raw |
||
|
|
||
| for (size_t i = 0; i < capacity_; ++i) { | ||
| // Found entry for a given allocations, move it to front and return a | ||
| // pointer to cached command buffer. | ||
| if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) && | ||
| entries_[i].command_buffer)) { | ||
| VLOG(6) << "Command buffer trace cache hit for command " | ||
| << trace_cmd_->ToString(0); | ||
| if (profile_trace) { | ||
| LOG(WARNING) << "TraceCache HIT cmd=" << trace_cmd_->ToString(0); | ||
| } | ||
| return shift_right(i).command_buffer.get(); | ||
| } | ||
|
|
||
| // Create a new entry by calling a user-provided tracing function, move it | ||
| // to front and return a pointer to cached command buffer. | ||
| // When skip_retrace is enabled, patch kernel nodes in the existing | ||
| // cached graph directly instead of re-tracing via stream capture. | ||
| if (skip_retrace && entries_[i].command_buffer != nullptr && | ||
| !entries_[i].recorded_slice_addrs.empty()) { | ||
| auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>( | ||
| entries_[i].command_buffer.get()); | ||
| if (gpu_cmd) { | ||
| uint64_t patch_start = 0; | ||
| if (profile_trace) patch_start = tsl::Env::Default()->NowMicros(); | ||
|
|
||
| auto status = gpu_cmd->UpdateKernelNodes( | ||
| entries_[i].recorded_slice_addrs, slice_addrs); | ||
| if (status.ok()) { | ||
| entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
| if (profile_trace) { | ||
| uint64_t patch_end = tsl::Env::Default()->NowMicros(); | ||
| LOG(WARNING) << "TraceCache PATCH cmd=" << trace_cmd_->ToString(0) | ||
| << " time=" << (patch_end - patch_start) << "us"; | ||
| } | ||
| return shift_right(i).command_buffer.get(); | ||
| } | ||
| VLOG(3) << "Kernel node patch failed for " << trace_cmd_->ToString(0) | ||
| << ": " << status << ", falling back to retrace"; | ||
| } | ||
| } | ||
|
|
||
| if (entries_[i].command_buffer == nullptr) { | ||
| uint64_t trace_start = 0; | ||
| if (profile_trace) trace_start = tsl::Env::Default()->NowMicros(); | ||
|
|
||
| TF_ASSIGN_OR_RETURN( | ||
| entries_[i].command_buffer, | ||
| se::TraceCommandBufferFactory::Create(executor, stream, trace)); | ||
| entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
| if (priority != se::StreamPriority::Default) { | ||
| TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority)); | ||
| } | ||
| VLOG(6) << "Command buffer trace cache create new item for command " | ||
| << trace_cmd_->ToString(0); | ||
| if (profile_trace) { | ||
| uint64_t trace_end = tsl::Env::Default()->NowMicros(); | ||
| LOG(WARNING) << "TraceCache NEW cmd=" << trace_cmd_->ToString(0) | ||
| << " time=" << (trace_end - trace_start) << "us"; | ||
| } | ||
| return shift_right(i).command_buffer.get(); | ||
| } | ||
| } | ||
|
|
||
| // Create a new entry by calling a user-provided tracing function, replace | ||
| // the last entry with it, move it to front and return a pointer to cached | ||
| // command buffer. | ||
| // All slots occupied and no match. Try in-place update of the last slot. | ||
| if (skip_retrace && !entries_[capacity_ - 1].recorded_slice_addrs.empty()) { | ||
| auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>( | ||
| entries_[capacity_ - 1].command_buffer.get()); | ||
| if (gpu_cmd) { | ||
| uint64_t patch_start = 0; | ||
| if (profile_trace) patch_start = tsl::Env::Default()->NowMicros(); | ||
|
|
||
| auto status = gpu_cmd->UpdateKernelNodes( | ||
| entries_[capacity_ - 1].recorded_slice_addrs, slice_addrs); | ||
| if (status.ok()) { | ||
| entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), | ||
| allocs.end()); | ||
| entries_[capacity_ - 1].recorded_slice_addrs.assign( | ||
| slice_addrs.begin(), slice_addrs.end()); | ||
| if (profile_trace) { | ||
| uint64_t patch_end = tsl::Env::Default()->NowMicros(); | ||
| LOG(WARNING) << "TraceCache PATCH(evict) cmd=" | ||
| << trace_cmd_->ToString(0) | ||
| << " time=" << (patch_end - patch_start) << "us"; | ||
| } | ||
| return shift_right(capacity_ - 1).command_buffer.get(); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| uint64_t trace_start = 0; | ||
| if (profile_trace) trace_start = tsl::Env::Default()->NowMicros(); | ||
|
|
||
| TF_ASSIGN_OR_RETURN( | ||
| entries_[capacity_ - 1].command_buffer, | ||
| se::TraceCommandBufferFactory::Create(executor, stream, trace)); | ||
| entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end()); | ||
| VLOG(6) << "Command buffer trace cache does replacement for command " | ||
| << trace_cmd_->ToString(0); | ||
| entries_[capacity_ - 1].recorded_slice_addrs.assign(slice_addrs.begin(), | ||
| slice_addrs.end()); | ||
|
|
||
| if (profile_trace) { | ||
| uint64_t trace_end = tsl::Env::Default()->NowMicros(); | ||
| LOG(WARNING) << "TraceCache RETRACE cmd=" << trace_cmd_->ToString(0) | ||
| << " time=" << (trace_end - trace_start) << "us"; | ||
| } | ||
|
|
||
| return shift_right(capacity_ - 1).command_buffer.get(); | ||
| } | ||
|
|
||
|
|
@@ -322,6 +412,7 @@ TracedCommandBufferCmd::RecordTracedCommand( | |
| const RecordParams& record_params, RecordAction record_action, | ||
| se::CommandBuffer* command_buffer, | ||
| absl::FunctionRef<absl::Status(se::Stream*)> trace) { | ||
|
|
||
| auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>( | ||
| this, command_buffer, [&] { | ||
| const auto& debug_options = xla::GetDebugOptionsFromFlags(); | ||
|
|
@@ -335,8 +426,6 @@ TracedCommandBufferCmd::RecordTracedCommand( | |
| traced_cmd->GetOrTraceCommandBuffer( | ||
| execute_params.buffer_allocations, execute_params.stream->parent(), | ||
| execute_params.command_buffer_trace_stream, trace, priority())); | ||
|
|
||
| VLOG(5) << "Record traced command into command buffer: " << command_buffer; | ||
| return Handle( | ||
| std::move(record_action), | ||
| [&](absl::Span<const se::CommandBuffer::Command* const> dependencies) { | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Include order —
gpu_command_buffer.his inserted betweendevice_address.handdevice_address_handle.h, breaking alphabetical order within thexla/stream_executor/group. It should come afterdevice_address_handle.h(or be grouped with thegpu/headers at line 106).