Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ cc_library(
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:tensor_map",
"//xla/stream_executor:trace_command_buffer_factory",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:multi_gpu_barrier_kernel",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tsl/platform:env",
Expand Down Expand Up @@ -600,6 +601,7 @@ cc_library(
"//xla/stream_executor:command_buffer",
"//xla/stream_executor:device_address",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
Expand Down
121 changes: 105 additions & 16 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Include order — gpu_command_buffer.h is inserted between device_address.h and device_address_handle.h, breaking alphabetical order within the xla/stream_executor/ group. It should come after device_address_handle.h (or be grouped with the gpu/ headers at line 106).

Suggested change
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"

#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h"
Expand Down Expand Up @@ -241,6 +242,7 @@ TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd,
absl::flat_hash_set<BufferAllocation::Index> allocs_indices;
for (auto& buffer : buffers) {
allocs_indices.insert(buffer.slice().index());
buffer_slices_.push_back(buffer.slice());
}
allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end());
}
Expand All @@ -249,13 +251,25 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
const BufferAllocations* buffer_allocation, se::StreamExecutor* executor,
se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace,
se::StreamPriority priority) {
// Collect memory addresses for relevant allocations.
static const bool profile_trace = [] {
const char* env = std::getenv("XLA_PROFILE_CMD_BUFFER");
return env != nullptr && std::string(env) == "1";
}();

// Collect memory addresses for relevant allocations (for cache comparison).
absl::InlinedVector<se::DeviceAddressBase, 4> allocs;
allocs.reserve(allocs_indices_.size());
for (auto& index : allocs_indices_) {
allocs.emplace_back(buffer_allocation->GetDeviceAddress(index));
}

// Collect slice-level addresses (for kernel node patching).
absl::InlinedVector<se::DeviceAddressBase, 4> slice_addrs;
slice_addrs.reserve(buffer_slices_.size());
for (auto& slice : buffer_slices_) {
slice_addrs.emplace_back(buffer_allocation->GetDeviceAddress(slice));
}

// Moves entry at `i` position to front and moves entries in `[0, i)` range
// one element to the right. Returns reference to the first entry.
auto shift_right = [&](size_t i) -> Entry& {
Expand All @@ -271,41 +285,117 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
return entries_[0] = std::move(entry);
};

static const bool skip_retrace = [] {
const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE");
bool val = env != nullptr && std::string(env) == "1";
if (val) {
LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node "
"params will be patched in-place to avoid re-tracing";
}
return val;
}();
Comment on lines +288 to +296
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip-retrace feature is gated via raw std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"), while the flattening feature in the same PR uses a proper DebugOptions proto flag (xla_gpu_graph_enable_node_flattening). Using two different configuration mechanisms within the same feature set is inconsistent and makes the feature harder to discover/manage. Consider using DebugOptions for both, keeping them aligned with XLA's standard configuration pattern.


for (size_t i = 0; i < capacity_; ++i) {
// Found entry for a given allocations, move it to front and return a
// pointer to cached command buffer.
if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) &&
entries_[i].command_buffer)) {
VLOG(6) << "Command buffer trace cache hit for command "
<< trace_cmd_->ToString(0);
if (profile_trace) {
LOG(WARNING) << "TraceCache HIT cmd=" << trace_cmd_->ToString(0);
}
return shift_right(i).command_buffer.get();
}

// Create a new entry by calling a user-provided tracing function, move it
// to front and return a pointer to cached command buffer.
// When skip_retrace is enabled, patch kernel nodes in the existing
// cached graph directly instead of re-tracing via stream capture.
if (skip_retrace && entries_[i].command_buffer != nullptr &&
!entries_[i].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[i].command_buffer.get());
if (gpu_cmd) {
uint64_t patch_start = 0;
if (profile_trace) patch_start = tsl::Env::Default()->NowMicros();

auto status = gpu_cmd->UpdateKernelNodes(
entries_[i].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
if (profile_trace) {
uint64_t patch_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache PATCH cmd=" << trace_cmd_->ToString(0)
<< " time=" << (patch_end - patch_start) << "us";
}
return shift_right(i).command_buffer.get();
}
VLOG(3) << "Kernel node patch failed for " << trace_cmd_->ToString(0)
<< ": " << status << ", falling back to retrace";
}
}

if (entries_[i].command_buffer == nullptr) {
uint64_t trace_start = 0;
if (profile_trace) trace_start = tsl::Env::Default()->NowMicros();

TF_ASSIGN_OR_RETURN(
entries_[i].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
if (priority != se::StreamPriority::Default) {
TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority));
}
VLOG(6) << "Command buffer trace cache create new item for command "
<< trace_cmd_->ToString(0);
if (profile_trace) {
uint64_t trace_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache NEW cmd=" << trace_cmd_->ToString(0)
<< " time=" << (trace_end - trace_start) << "us";
}
return shift_right(i).command_buffer.get();
}
}

// Create a new entry by calling a user-provided tracing function, replace
// the last entry with it, move it to front and return a pointer to cached
// command buffer.
// All slots occupied and no match. Try in-place update of the last slot.
if (skip_retrace && !entries_[capacity_ - 1].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[capacity_ - 1].command_buffer.get());
if (gpu_cmd) {
uint64_t patch_start = 0;
if (profile_trace) patch_start = tsl::Env::Default()->NowMicros();

auto status = gpu_cmd->UpdateKernelNodes(
entries_[capacity_ - 1].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(),
allocs.end());
entries_[capacity_ - 1].recorded_slice_addrs.assign(
slice_addrs.begin(), slice_addrs.end());
if (profile_trace) {
uint64_t patch_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache PATCH(evict) cmd="
<< trace_cmd_->ToString(0)
<< " time=" << (patch_end - patch_start) << "us";
}
return shift_right(capacity_ - 1).command_buffer.get();
}
}
}

uint64_t trace_start = 0;
if (profile_trace) trace_start = tsl::Env::Default()->NowMicros();

TF_ASSIGN_OR_RETURN(
entries_[capacity_ - 1].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end());
VLOG(6) << "Command buffer trace cache does replacement for command "
<< trace_cmd_->ToString(0);
entries_[capacity_ - 1].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());

if (profile_trace) {
uint64_t trace_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache RETRACE cmd=" << trace_cmd_->ToString(0)
<< " time=" << (trace_end - trace_start) << "us";
}

return shift_right(capacity_ - 1).command_buffer.get();
}

Expand All @@ -322,6 +412,7 @@ TracedCommandBufferCmd::RecordTracedCommand(
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer,
absl::FunctionRef<absl::Status(se::Stream*)> trace) {

auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>(
this, command_buffer, [&] {
const auto& debug_options = xla::GetDebugOptionsFromFlags();
Expand All @@ -335,8 +426,6 @@ TracedCommandBufferCmd::RecordTracedCommand(
traced_cmd->GetOrTraceCommandBuffer(
execute_params.buffer_allocations, execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace, priority()));

VLOG(5) << "Record traced command into command buffer: " << command_buffer;
return Handle(
std::move(record_action),
[&](absl::Span<const se::CommandBuffer::Command* const> dependencies) {
Expand Down
6 changes: 6 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ class TracedCommandBuffer : public CommandState {
se::StreamPriority priority = se::StreamPriority::Default);

private:
// Unique allocation indices for cache-level comparison.
std::vector<BufferAllocation::Index> allocs_indices_;

// Full slice descriptors for resolving slice-level addresses during patching.
std::vector<BufferAllocation::Slice> buffer_slices_;

struct Entry {
std::vector<se::DeviceAddressBase> recorded_allocs;
// Slice-level addresses for kernel node patching.
std::vector<se::DeviceAddressBase> recorded_slice_addrs;
std::unique_ptr<se::CommandBuffer> command_buffer;
};
const Command* trace_cmd_;
Expand Down
Loading
Loading