Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/backends/gpu/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ cc_library(
"//xla/stream_executor:stream_executor_h",
"//xla/stream_executor:tensor_map",
"//xla/stream_executor:trace_command_buffer_factory",
"//xla/stream_executor/gpu:gpu_command_buffer",
"//xla/stream_executor/gpu:multi_gpu_barrier_kernel",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tsl/platform:env",
Expand Down
80 changes: 68 additions & 12 deletions xla/backends/gpu/runtime/command_buffer_cmd.cc
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ limitations under the License.
#include "xla/status_macros.h"
#include "xla/stream_executor/command_buffer.h"
#include "xla/stream_executor/device_address.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Include order — gpu_command_buffer.h is inserted between device_address.h and device_address_handle.h, breaking alphabetical order within the xla/stream_executor/ group. It should come after device_address_handle.h (or be grouped with the gpu/ headers at line 106).

Suggested change
#include "xla/stream_executor/gpu/gpu_command_buffer.h"
#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/gpu/gpu_command_buffer.h"

#include "xla/stream_executor/device_address_handle.h"
#include "xla/stream_executor/dnn.h"
#include "xla/stream_executor/gpu/multi_gpu_barrier_kernel.h"
Expand Down Expand Up @@ -241,6 +242,7 @@ TracedCommandBuffer::TracedCommandBuffer(const Command* trace_cmd,
absl::flat_hash_set<BufferAllocation::Index> allocs_indices;
for (auto& buffer : buffers) {
allocs_indices.insert(buffer.slice().index());
buffer_slices_.push_back(buffer.slice());
}
allocs_indices_.assign(allocs_indices.begin(), allocs_indices.end());
}
Expand All @@ -249,13 +251,20 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
const BufferAllocations* buffer_allocation, se::StreamExecutor* executor,
se::Stream* stream, absl::FunctionRef<absl::Status(se::Stream*)> trace,
se::StreamPriority priority) {
// Collect memory addresses for relevant allocations.
// Collect memory addresses for relevant allocations (for cache comparison).
absl::InlinedVector<se::DeviceAddressBase, 4> allocs;
allocs.reserve(allocs_indices_.size());
for (auto& index : allocs_indices_) {
allocs.emplace_back(buffer_allocation->GetDeviceAddress(index));
}

// Collect slice-level addresses (for kernel node patching).
absl::InlinedVector<se::DeviceAddressBase, 4> slice_addrs;
slice_addrs.reserve(buffer_slices_.size());
for (auto& slice : buffer_slices_) {
slice_addrs.emplace_back(buffer_allocation->GetDeviceAddress(slice));
}

// Moves entry at `i` position to front and moves entries in `[0, i)` range
// one element to the right. Returns reference to the first entry.
auto shift_right = [&](size_t i) -> Entry& {
Expand All @@ -271,23 +280,53 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
return entries_[0] = std::move(entry);
};

static const bool skip_retrace = [] {
const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE");
bool val = env != nullptr && std::string(env) == "1";
if (val) {
LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node "
"params will be patched in-place to avoid re-tracing";
}
return val;
}();
Comment on lines +288 to +296
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The skip-retrace feature is gated via raw std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"), while the flattening feature in the same PR uses a proper DebugOptions proto flag (xla_gpu_graph_enable_node_flattening). Using two different configuration mechanisms within the same feature set is inconsistent and makes the feature harder to discover/manage. Consider using DebugOptions for both, keeping them aligned with XLA's standard configuration pattern.


for (size_t i = 0; i < capacity_; ++i) {
// Found entry for a given allocations, move it to front and return a
// pointer to cached command buffer.
if (ABSL_PREDICT_TRUE(absl::c_equal(entries_[i].recorded_allocs, allocs) &&
entries_[i].command_buffer)) {
VLOG(6) << "Command buffer trace cache hit for command "
<< trace_cmd_->ToString();
return shift_right(i).command_buffer.get();
}

// Create a new entry by calling a user-provided tracing function, move it
// to front and return a pointer to cached command buffer.
// When skip_retrace is enabled, patch kernel nodes in the existing
// cached graph directly instead of re-tracing via stream capture.
if (skip_retrace && entries_[i].command_buffer != nullptr &&
!entries_[i].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[i].command_buffer.get());
if (gpu_cmd) {
auto status = gpu_cmd->UpdateKernelNodes(
entries_[i].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
VLOG(6) << "Patched kernel nodes in-place for command "
<< trace_cmd_->ToString();
return shift_right(i).command_buffer.get();
}
VLOG(3) << "Kernel node patch failed for " << trace_cmd_->ToString()
<< ": " << status << ", falling back to retrace";
}
}

if (entries_[i].command_buffer == nullptr) {
TF_ASSIGN_OR_RETURN(
entries_[i].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[i].recorded_allocs.assign(allocs.begin(), allocs.end());
entries_[i].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
if (priority != se::StreamPriority::Default) {
TF_RETURN_IF_ERROR(entries_[i].command_buffer->SetPriority(priority));
}
Expand All @@ -297,15 +336,33 @@ absl::StatusOr<se::CommandBuffer*> TracedCommandBuffer::GetOrTraceCommandBuffer(
}
}

// Create a new entry by calling a user-provided tracing function, replace
// the last entry with it, move it to front and return a pointer to cached
// command buffer.
// All slots occupied and no match. Try in-place update of the last slot.
if (skip_retrace && !entries_[capacity_ - 1].recorded_slice_addrs.empty()) {
auto* gpu_cmd = dynamic_cast<se::gpu::GpuCommandBuffer*>(
entries_[capacity_ - 1].command_buffer.get());
if (gpu_cmd) {
auto status = gpu_cmd->UpdateKernelNodes(
entries_[capacity_ - 1].recorded_slice_addrs, slice_addrs);
if (status.ok()) {
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(),
allocs.end());
entries_[capacity_ - 1].recorded_slice_addrs.assign(
slice_addrs.begin(), slice_addrs.end());
VLOG(6) << "Patched kernel nodes (evict slot) for command "
<< trace_cmd_->ToString();
return shift_right(capacity_ - 1).command_buffer.get();
}
}
}

VLOG(6) << "Command buffer trace cache does replacement for command "
<< trace_cmd_->ToString();
TF_ASSIGN_OR_RETURN(
entries_[capacity_ - 1].command_buffer,
se::TraceCommandBufferFactory::Create(executor, stream, trace));
entries_[capacity_ - 1].recorded_allocs.assign(allocs.begin(), allocs.end());
VLOG(6) << "Command buffer trace cache does replacement for command "
<< trace_cmd_->ToString();
entries_[capacity_ - 1].recorded_slice_addrs.assign(slice_addrs.begin(),
slice_addrs.end());
return shift_right(capacity_ - 1).command_buffer.get();
}

Expand All @@ -322,6 +379,7 @@ TracedCommandBufferCmd::RecordTracedCommand(
const RecordParams& record_params, RecordAction record_action,
se::CommandBuffer* command_buffer,
absl::FunctionRef<absl::Status(se::Stream*)> trace) {

auto traced_cmd = record_params.state.GetOrCreate<TracedCommandBuffer>(
this, command_buffer, [&] {
const auto& debug_options = xla::GetDebugOptionsFromFlags();
Expand All @@ -335,8 +393,6 @@ TracedCommandBufferCmd::RecordTracedCommand(
traced_cmd->GetOrTraceCommandBuffer(
execute_params.buffer_allocations, execute_params.stream->parent(),
execute_params.command_buffer_trace_stream, trace, priority()));

VLOG(5) << "Record traced command into command buffer: " << command_buffer;
return Handle(
std::move(record_action),
[&](absl::Span<const se::CommandBuffer::Command* const> dependencies) {
Expand Down
6 changes: 6 additions & 0 deletions xla/backends/gpu/runtime/command_buffer_cmd.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,16 @@ class TracedCommandBuffer : public CommandState {
se::StreamPriority priority = se::StreamPriority::Default);

private:
// Unique allocation indices for cache-level comparison.
std::vector<BufferAllocation::Index> allocs_indices_;

// Full slice descriptors for resolving slice-level addresses during patching.
std::vector<BufferAllocation::Slice> buffer_slices_;

struct Entry {
std::vector<se::DeviceAddressBase> recorded_allocs;
// Slice-level addresses for kernel node patching.
std::vector<se::DeviceAddressBase> recorded_slice_addrs;
std::unique_ptr<se::CommandBuffer> command_buffer;
};
const Command* trace_cmd_;
Expand Down
8 changes: 8 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3000,6 +3000,14 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
bool_setter_for(&DebugOptions::set_xla_dump_buffer_assignment_analysis),
debug_options->xla_dump_buffer_assignment_analysis(),
"Dump BufferAssignment analysis."));
flag_list->push_back(tsl::Flag(
"xla_gpu_graph_enable_node_flattening",
bool_setter_for(
&DebugOptions::set_xla_gpu_graph_enable_node_flattening),
debug_options->xla_gpu_graph_enable_node_flattening(),
"Flatten traced child graph nodes into the parent HIP graph as "
"individual nodes, enabling per-node address updates instead of "
"full child graph re-tracing."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Expand Down
4 changes: 4 additions & 0 deletions xla/stream_executor/gpu/gpu_command_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,10 @@ GpuCommandBuffer::ToGraphNodeDependencies(
handles.push_back(gpu_command->conditional_node.handle);
} else if (auto* gpu_command = dynamic_cast<const GpuChildCommand*>(dep)) {
handles.push_back(gpu_command->handle);
} else if (auto* flat = dynamic_cast<const GpuFlattenedCommand*>(dep)) {
if (!flat->node_handles.empty()) {
handles.push_back(flat->node_handles.back());
}
} else {
LOG(FATAL) << "Unsupported command type"; // Crash OK
}
Expand Down
119 changes: 110 additions & 9 deletions xla/stream_executor/gpu/gpu_command_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,56 @@ class GpuCommandBuffer : public CommandBuffer {
std::unique_ptr<CommandBuffer> command_buffer;
};

// A command representing a set of nodes that were extracted ("flattened")
// from a traced child graph into the parent graph. Stores all flattened
// node handles so they can be updated individually.
enum class FlatNodeKind : uint8_t {
kKernel,
kKernelExtra,
kMemcpy,
kMemset,
kEmpty,
};

struct FlatNodeInfo {
FlatNodeKind kind;
void* func = nullptr; // kernel function pointer (for kKernel/kKernelExtra)
};

// Records where a known device pointer appears inside a kernel node's
// packed argument buffer (the `extra` HIP_LAUNCH_PARAM blob).
struct ArgPatchEntry {
size_t node_index; // index into node_handles/node_infos
size_t byte_offset; // offset within the packed arg buffer
int buffer_use_index; // index into the buffer_uses() vector
};

struct GpuFlattenedCommand : public CommandBuffer::Command {
std::vector<GraphNodeHandle> node_handles;
std::vector<FlatNodeInfo> node_infos;

// Deep-copied extra arg buffers for kernel nodes that use the
// HIP_LAUNCH_PARAM packed-buffer launch convention (e.g. rocBLAS).
// Kept alive for the lifetime of the parent graph.
std::vector<std::unique_ptr<uint8_t[]>> extra_arg_buffers;
std::vector<std::unique_ptr<size_t>> extra_arg_sizes;

// Patch table for skip-retrace updates: records where each known
// buffer address appears inside the packed args of flattened
// extra-style kernel nodes.
std::vector<ArgPatchEntry> patch_table;

// Per-node deep-copied arg buffer and its size (only for extra-style
// kernel nodes that have patch entries).
struct NodeArgBuffer {
std::unique_ptr<uint8_t[]> data;
Comment on lines +148 to +162
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

extra_arg_buffers (line 155) appears to be unused dead code — it is declared but never written to anywhere in this PR. Only node_arg_buffers (line 164) is populated in FlattenChildGraphNodes. Consider removing extra_arg_buffers if it's leftover from an earlier design iteration.

size_t size = 0;
};
std::vector<NodeArgBuffer> node_arg_buffers;

bool has_patch_table = false;
};

GpuCommandBuffer(Mode mode, StreamExecutor* executor);

// Bring CreateLaunch and UpdateLaunch template functions into scope.
Expand Down Expand Up @@ -207,6 +257,57 @@ class GpuCommandBuffer : public CommandBuffer {

absl::Span<const std::unique_ptr<Command>> commands() const;

// Dumps kernel node info from this command buffer's graph for debugging.
virtual void DumpGraphKernelNodes(absl::string_view label) {}

// Updates kernel nodes inside this command buffer's graph in place,
// replacing old_addresses with new_addresses at the corresponding positions.
// This avoids re-tracing the graph when only buffer addresses change.
virtual absl::Status UpdateKernelNodes(
absl::Span<const DeviceAddressBase> old_addresses,
absl::Span<const DeviceAddressBase> new_addresses) {
return absl::UnimplementedError(
"UpdateKernelNodes not supported on this platform");
}

// Extracts nodes from a nested (traced) command buffer and adds them
// directly to this parent graph as individual kernel/memcpy/memset nodes,
// avoiding child graph re-tracing. Returns a single GpuFlattenedCommand
// that owns all the flattened node handles.
virtual absl::StatusOr<const Command*> FlattenChildGraphNodes(
const CommandBuffer& nested,
absl::Span<const Command* const> dependencies) {
return absl::UnimplementedError(
"FlattenChildGraphNodes not supported on this platform");
}

// Updates all kernel nodes inside a previously flattened command by scanning
// their arguments and replacing old device pointers with new ones.
virtual absl::Status UpdateFlattenedChildNodes(
const Command* command, const CommandBuffer& nested) {
return absl::UnimplementedError(
"UpdateFlattenedChildNodes not supported on this platform");
}

// Builds a patch table on a flattened command: scans each extra-style
// kernel node's packed argument buffer to find byte offsets of known
// buffer addresses. Must be called once after FlattenChildGraphNodes.
virtual absl::Status BuildPatchTable(
const Command* command,
absl::Span<const DeviceAddressBase> known_addresses) {
return absl::UnimplementedError(
"BuildPatchTable not supported on this platform");
}

// Directly patches buffer addresses in flattened nodes using the
// previously built patch table. Skips re-tracing entirely.
virtual absl::Status PatchFlattenedNodes(
const Command* command,
absl::Span<const DeviceAddressBase> new_addresses) {
return absl::UnimplementedError(
"PatchFlattenedNodes not supported on this platform");
}

protected:
// We track the total number of allocated and alive executable graphs in the
// process to track the command buffers resource usage. Executable graph
Expand Down Expand Up @@ -280,6 +381,15 @@ class GpuCommandBuffer : public CommandBuffer {
// Returns OK status if the command buffer can be updated.
virtual absl::Status CheckCanBeUpdated() = 0;

// Appends a new command to the command buffer.
template <typename T>
const Command* AppendCommand(T command) {
commands_.push_back(std::make_unique<T>(std::move(command)));
VLOG(5) << "AppendCommand: "
<< reinterpret_cast<const void*>(commands_.back().get());
return commands_.back().get();
}

private:
absl::StatusOr<const Command*> CreateCase(
DeviceAddress<uint8_t> index, bool index_is_bool,
Expand All @@ -290,15 +400,6 @@ class GpuCommandBuffer : public CommandBuffer {
bool index_is_bool,
std::vector<UpdateCommands> update_branches);

// Appends a new command to the command buffer.
template <typename T>
const Command* AppendCommand(T command) {
commands_.push_back(std::make_unique<T>(std::move(command)));
VLOG(5) << "AppendCommand: "
<< reinterpret_cast<const void*>(commands_.back().get());
return commands_.back().get();
}

// Converts a list of command dependencies to a list of graph node handles.
std::vector<GraphNodeHandle> ToGraphNodeDependencies(
absl::Span<const Command* const> dependencies);
Expand Down
1 change: 1 addition & 0 deletions xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1063,6 +1063,7 @@ cc_library(
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/functional:any_invocable",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
Expand Down
Loading
Loading