Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,7 @@ cc_library(
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:tensor_map",
"//xla/stream_executor:trace_command_buffer_factory",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:multi_gpu_barrier_kernel",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tsl/platform:env",
Expand Down Expand Up @@ -600,6 +601,7 @@ cc_library(
"//xla/stream_executor:command_buffer",
"//xla/stream_executor:device_address",
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/tsl/platform:env",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:logging",
Expand Down
121 changes: 105 additions & 16 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Include order — gpu_command_buffer.h is inserted between device_address.h and device_address_handle.h, breaking alphabetical order within the xla/stream_executor/ group. It should come after device_address_handle.h (or be grouped with the gpu/ headers at line 106).

Suggested change
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"

#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h"
Expand Down Expand Up @@ -241,6 +242,7 @@ TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd,
absl::flat_hash_set<BufferAllocation::Index> allocs_indices;
for (auto& buffer : buffers) {
allocs_indices.insert(buffer.slice().index());
buffer_slices_.push_back(buffer.slice());
}
allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end());
}
Expand All @@ -249,13 +251,25 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
const BufferAllocations* buffer_allocation, se::StreamExecutor* executor,
se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace,
se::StreamPriority priority) {
// Collect memory addresses for relevant allocations.
static const bool profile_trace = [] {
const char* env = std::getenv("XLA_PROFILE_CMD_BUFFER");
return env != nullptr && std::string(env) == "1";
}();

// Collect memory addresses for relevant allocations (for cache comparison).
absl::InlinedVector<se::DeviceAddressBase, 4> allocs;
allocs.reserve(allocs_indices_.size());
for (auto& index : allocs_indices_) {
allocs.emplace_back(buffer_allocation->GetDeviceAddress(index));
}

// Collect slice-level addresses (for kernel node patching).
absl::InlinedVector<se::DeviceAddressBase, 4> slice_addrs;
slice_addrs.reserve(buffer_slices_.size());
for (auto& slice : buffer_slices_) {
slice_addrs.emplace_back(buffer_allocation->GetDeviceAddress(slice));
}

// Moves entry at `i` position to front and moves entries in `[0, i)` range
// one element to the right. Returns reference to the first entry.
auto shift_right = [&](size_t i) -> Entry& {
Expand All @@ -271,41 +285,117 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
return entries_[0] = std::move(entry);
};

static const bool skip_retrace = [] {
const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE");
bool val = env != nullptr && std::string(env) == "1";
if (val) {
LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node "
"params will be patched in-place to avoid re-tracing";
}
return val;
}();
Comment on lines +288 to +296
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip-retrace feature is gated via raw std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"), while the flattening feature in the same PR uses a proper DebugOptions proto flag (xla_gpu_graph_enable_node_flattening). Using two different configuration mechanisms within the same feature set is inconsistent and makes the feature harder to discover/manage. Consider using DebugOptions for both, keeping them aligned with XLA's standard configuration pattern.


for (size_t i = 0; i < capacity_; ++i) {
// Found entry for a given allocations, move it to front and return a
// pointer to cached command buffer.
if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) &&
entries_[i].command_buffer)) {
VLOG(6) << "Command buffer trace cache hit for command "
<< trace_cmd_->ToString(0);
if (profile_trace) {
LOG(WARNING) << "TraceCache HIT cmd=" << trace_cmd_->ToString(0);
}
return shift_right(i).command_buffer.get();
}

// Create a new entry by calling a user-provided tracing function, move it
// to front and return a pointer to cached command buffer.
// When skip_retrace is enabled, patch kernel nodes in the existing
// cached graph directly instead of re-tracing via stream capture.
if (skip_retrace && entries_[i].command_buffer != nullptr &&
!entries_[i].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[i].command_buffer.get());
if (gpu_cmd) {
uint64_t patch_start = 0;
if (profile_trace) patch_start = tsl::Env::Default()->NowMicros();

auto status = gpu_cmd->UpdateKernelNodes(
entries_[i].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
if (profile_trace) {
uint64_t patch_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache PATCH cmd=" << trace_cmd_->ToString(0)
<< " time=" << (patch_end - patch_start) << "us";
}
return shift_right(i).command_buffer.get();
}
VLOG(3) << "Kernel node patch failed for " << trace_cmd_->ToString(0)
<< ": " << status << ", falling back to retrace";
}
}

if (entries_[i].command_buffer == nullptr) {
uint64_t trace_start = 0;
if (profile_trace) trace_start = tsl::Env::Default()->NowMicros();

TF_ASSIGN_OR_RETURN(
entries_[i].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
if (priority != se::StreamPriority::Default) {
TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority));
}
VLOG(6) << "Command buffer trace cache create new item for command "
<< trace_cmd_->ToString(0);
if (profile_trace) {
uint64_t trace_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache NEW cmd=" << trace_cmd_->ToString(0)
<< " time=" << (trace_end - trace_start) << "us";
}
return shift_right(i).command_buffer.get();
}
}

// Create a new entry by calling a user-provided tracing function, replace
// the last entry with it, move it to front and return a pointer to cached
// command buffer.
// All slots occupied and no match. Try in-place update of the last slot.
if (skip_retrace && !entries_[capacity_ - 1].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[capacity_ - 1].command_buffer.get());
if (gpu_cmd) {
uint64_t patch_start = 0;
if (profile_trace) patch_start = tsl::Env::Default()->NowMicros();

auto status = gpu_cmd->UpdateKernelNodes(
entries_[capacity_ - 1].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(),
allocs.end());
entries_[capacity_ - 1].recorded_slice_addrs.assign(
slice_addrs.begin(), slice_addrs.end());
if (profile_trace) {
uint64_t patch_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache PATCH(evict) cmd="
<< trace_cmd_->ToString(0)
<< " time=" << (patch_end - patch_start) << "us";
}
return shift_right(capacity_ - 1).command_buffer.get();
}
}
}

uint64_t trace_start = 0;
if (profile_trace) trace_start = tsl::Env::Default()->NowMicros();

TF_ASSIGN_OR_RETURN(
entries_[capacity_ - 1].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end());
VLOG(6) << "Command buffer trace cache does replacement for command "
<< trace_cmd_->ToString(0);
entries_[capacity_ - 1].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());

if (profile_trace) {
uint64_t trace_end = tsl::Env::Default()->NowMicros();
LOG(WARNING) << "TraceCache RETRACE cmd=" << trace_cmd_->ToString(0)
<< " time=" << (trace_end - trace_start) << "us";
}

return shift_right(capacity_ - 1).command_buffer.get();
}

Expand All @@ -322,6 +412,7 @@ TracedCommandBufferCmd::RecordTracedCommand(
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer,
absl::FunctionRef<absl::Status(se::Stream*)> trace) {

auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>(
this, command_buffer, [&] {
const auto& debug_options = xla::GetDebugOptionsFromFlags();
Expand All @@ -335,8 +426,6 @@ TracedCommandBufferCmd::RecordTracedCommand(
traced_cmd->GetOrTraceCommandBuffer(
execute_params.buffer_allocations, execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace, priority()));

VLOG(5) << "Record traced command into command buffer: " << command_buffer;
return Handle(
std::move(record_action),
[&](absl::Span<const se::CommandBuffer::Command* const> dependencies) {
Expand Down
6 changes: 6 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ class TracedCommandBuffer : public CommandState {
se::StreamPriority priority = se::StreamPriority::Default);

private:
// Unique allocation indices for cache-level comparison.
std::vector<BufferAllocation::Index> allocs_indices_;

// Full slice descriptors for resolving slice-level addresses during patching.
std::vector<BufferAllocation::Slice> buffer_slices_;

struct Entry {
std::vector<se::DeviceAddressBase> recorded_allocs;
// Slice-level addresses for kernel node patching.
std::vector<se::DeviceAddressBase> recorded_slice_addrs;
std::unique_ptr<se::CommandBuffer> command_buffer;
};
const Command* trace_cmd_;
Expand Down
Loading
Loading