Skip command buffer re-tracing via in-place HIP graph node patching#785
Skip command buffer re-tracing via in-place HIP graph node patching#785phambinhfin wants to merge 9 commits intomainfrom
Conversation
Remove two overly-conservative checks from rocm_command_buffer.cc (ported from phambinh/full_vmm_solution branch): 1. Trace(): Remove rejection of empty traced HIP graphs, which caused segfaults when custom calls don't launch GPU ops. 2. PrepareFinalization(): Remove empty node insertion into empty graphs, which could crash with conditional nodes. Baseline for command buffer performance profiling on main branch.
Flatten traced child graph nodes into the parent HIP graph as individual kernel/memcpy/memset nodes instead of embedding them as child graph nodes. This enables per-node parameter updates via hipGraphExecKernelNodeSetParams (~1 us/node) instead of full child graph re-tracing (~60-70 us per sub-graph), providing up to 14.6x faster update throughput. Key changes: - Add FlattenChildGraphNodes/UpdateFlattenedChildNodes virtual methods to GpuCommandBuffer with ROCm implementation - Add GpuFlattenedCommand type to track flattened node handles - Modify RecordTracedCommand to use flattening path when xla_gpu_graph_enable_node_flattening flag is set - Add HIP graph introspection wrappers (hipGraphGetEdges, hipGraphNodeGetDependencies, hipGraphKernelNodeGetParams, etc.) - Move AppendCommand to protected for subclass access - Add xla_gpu_graph_enable_node_flattening proto flag (field 502) Activate via: --xla_gpu_graph_enable_node_flattening=true
- Add GpuFlattenedCommand handling in ToGraphNodeDependencies to prevent crash when flattened commands are used as dependencies - Add graceful fallback to child graph path when kernel nodes use HIP's opaque `extra` arg-packing (hipGraphKernelNodeGetParams returns kernelParams=null, extra=non-null), which can't be re-packed for a new graph node - Register xla_gpu_graph_enable_node_flattening flag in debug_options_flags.cc - Add XLA_GPU_GRAPH_ENABLE_NODE_FLATTENING env var as alternative activation mechanism when jaxlib doesn't have the proto flag Known limitation: ROCm/HIP stores all kernel args internally via the opaque `extra` mechanism. hipGraphKernelNodeGetParams does not populate `kernelParams`, so individual kernel node flattening cannot re-pack args for a new parent graph. All traced commands currently fall back to the child graph path on ROCm.
…al format hipGraphKernelNodeGetParams on ROCm 7.2 returns kernel arguments via the internal `extra` pointer (not as HIP_LAUNCH_PARAM_* arrays and not as kernelParams) for kernels captured via hipModuleLaunchKernel. This internal format cannot be used with hipGraphAddKernelNode to create equivalent nodes in a different graph context. Changes: - Detect kernels with opaque extra args (kernelParams=null, extra=non-null) and return InternalError to trigger graceful fallback to the standard child-graph path - Improve fallback in RecordTracedCommand to catch any error status (not just kUnimplemented) and fall back to CreateChildCommand - Add VLOG diagnostics for kernel node parameters during flattening C++ benchmarks confirm the HIP graph node update APIs work correctly when kernelParams is properly populated (14.6x faster than re-tracing). The limitation is specifically in how hipGraphKernelNodeGetParams returns arguments for module-loaded kernels on ROCm 7.2.
Status: Work in progress - functional and correct, performance optimization ongoing
== Problem ==
XLA's TracedCommandBuffer re-traces HIP graphs (via hipStreamBeginCapture/
hipStreamEndCapture) every time buffer addresses change due to BFC allocator
reassignment. This causes severe performance degradation especially with
command buffers enabled for GEMM/CublasLt operations.
== Approach ==
Instead of re-tracing, patch kernel nodes directly in the cached child graph
using hipGraphKernelNodeSetParams. This avoids the expensive stream capture
entirely when only buffer addresses change.
== How it works ==
1. TracedCommandBuffer now tracks slice-level addresses (not just allocation
indices) per cache entry
2. On cache miss, instead of re-tracing via TraceCommandBufferFactory::Create,
we call UpdateKernelNodes() on the existing cached graph
3. UpdateKernelNodes scans each kernel node's arguments:
- For 'extra' packed buffers (rocBLAS/hipBLASLt): decode HIP_LAUNCH_PARAM
buffer, scan at pointer-aligned offsets for old addresses, replace with new
- For kernelParams arrays: scan pointer values directly
4. hipGraphKernelNodeSetParams commits changes to the graph definition
5. The existing UpdateChildCommand path propagates to the executable graph
== Enable ==
export XLA_GPU_GRAPH_SKIP_RETRACE=1
== Results (Llama FSDP 8-layer, 8x MI308X) ==
Without command buffers (CB off): 41.6 ms/step
With CB (no collectives), baseline: 97.4 ms/step (after warmup)
With CB (no collectives), skip-retrace: 97.2 ms/step (after warmup)
-> During warmup: 896ms baseline vs 864ms skip-retrace (3.5% improvement)
-> 728 traces reduced to only initial creates, 1276 patches avoided retraces
With CB + collectives, baseline: 53,862 ms/step
With CB + collectives, skip-retrace: 52,810 ms/step (~2% improvement)
-> Collectives mode now runs correctly (previously problematic)
-> Remaining cost is from collective command re-tracing (not GEMMs)
Correctness: verified element-wise gradient match across all configurations.
Loss values stable and matching baseline throughout.
== Files changed ==
- command_buffer_cmd.cc/h: TracedCommandBuffer tracks buffer_slices_ and
recorded_slice_addrs; GetOrTraceCommandBuffer attempts UpdateKernelNodes
before falling back to retrace
- gpu_command_buffer.h: Added UpdateKernelNodes virtual method and
DumpGraphKernelNodes for debugging
- rocm_command_buffer.cc/h: Implemented UpdateKernelNodes (scans extra/
kernelParams, patches addresses, calls hipGraphKernelNodeSetParams)
- rocm_driver_wrapper.h: Added hipGraphKernelNodeSetParams wrapper
== Next steps ==
- Extend UpdateKernelNodes approach to collective command types to reduce
the 53s/step cost with collectives enabled
- Investigate VA remapping as alternative to avoid address changes entirely
- Profile to identify which specific retrace operations dominate with collectives
| #include "xla/status_macros.h" | ||
| #include "xla/stream_executor/command_buffer.h" | ||
| #include "xla/stream_executor/device_address.h" | ||
| #include "xla/stream_executor/gpu/gpu_command_buffer.h" |
There was a problem hiding this comment.
nit: Include order — gpu_command_buffer.h is inserted between device_address.h and device_address_handle.h, breaking alphabetical order within the xla/stream_executor/ group. It should come after device_address_handle.h (or be grouped with the gpu/ headers at line 106).
| #include "xla/stream_executor/gpu/gpu_command_buffer.h" | |
| #include "xla/stream_executor/device_address_handle.h" | |
| #include "xla/stream_executor/gpu/gpu_command_buffer.h" |
| static const bool skip_retrace = [] { | ||
| const char* env = std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"); | ||
| bool val = env != nullptr && std::string(env) == "1"; | ||
| if (val) { | ||
| LOG(INFO) << "XLA_GPU_GRAPH_SKIP_RETRACE enabled: kernel node " | ||
| "params will be patched in-place to avoid re-tracing"; | ||
| } | ||
| return val; | ||
| }(); |
There was a problem hiding this comment.
The skip-retrace feature is gated via raw std::getenv("XLA_GPU_GRAPH_SKIP_RETRACE"), while the flattening feature in the same PR uses a proper DebugOptions proto flag (xla_gpu_graph_enable_node_flattening). Using two different configuration mechanisms within the same feature set is inconsistent and makes the feature harder to discover/manage. Consider using DebugOptions for both, keeping them aligned with XLA's standard configuration pattern.
| } | ||
| } else if (kp.kernelParams != nullptr) { | ||
| for (int a = 0; a < 64 && kp.kernelParams[a] != nullptr; ++a) { | ||
| uintptr_t val; | ||
| memcpy(&val, kp.kernelParams[a], sizeof(uintptr_t)); | ||
| auto it = old_addr_map.find(val); | ||
| if (it != old_addr_map.end()) { | ||
| uintptr_t new_val = reinterpret_cast<uintptr_t>( | ||
| new_addresses[it->second].opaque()); | ||
| memcpy(kp.kernelParams[a], &new_val, sizeof(uintptr_t)); | ||
| modified = true; | ||
| } | ||
| } | ||
| if (modified) { | ||
| TF_RETURN_IF_ERROR(ToStatus( |
There was a problem hiding this comment.
Correctness risk: The kernelParams patching reads sizeof(uintptr_t) bytes from each kernelParams[a], which works for pointer arguments but will misinterpret scalar arguments (int, float, etc.) as addresses. If a scalar value happens to have the same bit pattern as an old buffer address, it will be silently overwritten with a new address, causing data corruption.
The same risk exists in the extra/packed-buffer code path above (scanning at pointer-aligned offsets and matching raw values), though it's somewhat mitigated there by scanning a known packed buffer.
Additionally, the hardcoded limit of 64 here is inconsistent with the limit of 16 in DumpGraphKernelNodes at line 399 — a kernel with 17–63 params would be correctly patched but incompletely dumped, making debugging misleading. Consider using a named constant for both.
| auto* flat_cmd = const_cast<GpuFlattenedCommand*>( | ||
| dynamic_cast<const GpuFlattenedCommand*>(command)); | ||
| if (!flat_cmd) { |
There was a problem hiding this comment.
const_cast on a const Command* parameter is a code smell. Both BuildPatchTable and PatchFlattenedNodes take const Command* but mutate the object through const_cast. If mutation is intended, the virtual method signatures in gpu_command_buffer.h should take non-const Command*, rather than casting away const (which is UB if the object was originally declared const).
| auto* gpu_cmd = static_cast<const GpuCommand*>(d); | ||
| dep_handles.push_back(gpu_cmd->handle); | ||
| } |
There was a problem hiding this comment.
Unsafe cast: static_cast<const GpuCommand*>(d) is incorrect when d could be a GpuFlattenedCommand, GpuChildCommand, GpuCaseCommand, or GpuWhileCommand — none of which inherit from GpuCommand. This is undefined behavior. The existing ToGraphNodeDependencies helper (gpu_command_buffer.cc:108-128) handles this correctly with dynamic_cast chains. Consider reusing that helper or at minimum using dynamic_cast with a null check here.
| continue; | ||
| } | ||
|
|
There was a problem hiding this comment.
Silently skipping unsupported node types with continue can lead to incorrect dependency resolution. The flattened command's node_handles/node_infos will have fewer entries than the child graph, and if a downstream node depends on the skipped node, its dependencies won't be resolved in child_to_parent, causing it to fall back to dep_handles (external deps) instead of the correct internal predecessor. This could produce incorrect execution ordering or race conditions.
Consider either returning an error for unsupported types, or inserting an empty node to preserve the dependency chain.
| // Deep-copied extra arg buffers for kernel nodes that use the | ||
| // HIP_LAUNCH_PARAM packed-buffer launch convention (e.g. rocBLAS). | ||
| // Kept alive for the lifetime of the parent graph. | ||
| std::vector<std::unique_ptr<uint8_t[]>> extra_arg_buffers; | ||
| std::vector<std::unique_ptr<size_t>> extra_arg_sizes; | ||
|
|
||
| // Patch table for skip-retrace updates: records where each known | ||
| // buffer address appears inside the packed args of flattened | ||
| // extra-style kernel nodes. | ||
| std::vector<ArgPatchEntry> patch_table; | ||
|
|
||
| // Per-node deep-copied arg buffer and its size (only for extra-style | ||
| // kernel nodes that have patch entries). | ||
| struct NodeArgBuffer { | ||
| std::unique_ptr<uint8_t[]> data; |
There was a problem hiding this comment.
extra_arg_buffers (line 155) appears to be unused dead code — it is declared but never written to anywhere in this PR. Only node_arg_buffers (line 164) is populated in FlattenChildGraphNodes. Consider removing extra_arg_buffers if it's leftover from an earlier design iteration.
| // allows per-node address updates via hipGraphExecKernelNodeSetParams | ||
| // instead of full child graph re-tracing on buffer address changes. | ||
| optional bool xla_gpu_graph_enable_node_flattening = 502; |
There was a problem hiding this comment.
This flag is registered in debug_options_flags.cc but is never actually read anywhere in the codebase — the flattening methods (FlattenChildGraphNodes, UpdateFlattenedChildNodes, BuildPatchTable, PatchFlattenedNodes) are defined but never called from any code path. This makes the entire flattening infrastructure dead code in the current state of the PR. Is this intentional for a WIP, with the caller integration coming in a follow-up?
| size_t num_root_nodes = 0; | ||
| TF_RETURN_IF_ERROR( | ||
| ToStatus(wrap::hipGraphGetRootNodes(graph_, nullptr, &num_root_nodes), | ||
| "Failed to get HIP graph root node count")); | ||
|
|
||
| if (num_root_nodes == 0) { | ||
| return absl::InternalError( | ||
| "Traced HIP graph is empty. Traced function (custom call) did not " | ||
| "launch any HIP operations on the captured HIP stream. Instantiating " | ||
| "empty child nodes leads to crashes."); | ||
| } | ||
|
|
There was a problem hiding this comment.
This removes the safety check that guarded against empty traced graphs. The original code explicitly warned: "Instantiating empty child nodes leads to crashes." The corresponding PrepareFinalization empty-node insertion is also removed below. Could you explain why these removals are safe? If the underlying HIP issue has been fixed, it would be good to note that in a comment.
Review SummaryThis WIP PR adds a skip-retrace optimization for HIP graph command buffers that patches kernel node arguments in-place when buffer addresses change, plus a node-flattening mechanism to extract child graph nodes into the parent graph for individual updates. ~700 lines of new ROCm-specific code. Key concerns (see inline comments):
🤖 Generated with Claude Code |
Previously, UpdateKernelNodes only patched hipGraphNodeTypeKernel nodes. This extends the patching to also handle: - hipGraphNodeTypeMemcpy: patch srcPtr.ptr and dstPtr.ptr in hipMemcpy3DParms - hipGraphNodeTypeMemset: patch dst pointer in hipMemsetParams This is needed for traced sub-graphs that contain memory transfer operations alongside kernel launches (e.g., rocBLAS workspace copies, DNN scratch buffer fills). Without this, the skip-retrace mechanism would return a graph with stale memcpy/memset addresses, causing incorrect results or crashes. Added HIP API wrappers: - hipGraphMemcpyNodeSetParams - hipGraphMemsetNodeSetParams Verified correct loss values on Llama FSDP benchmark (8xMI308X): - 7 benchmark steps completed, all loss values 10.7748-10.7753 - Step times: 657-754ms (consistent with baseline)
… patching Switch ROCm CreateKernelNode and UpdateKernelNode from kernelParams-style to extra-style (HIP_LAUNCH_PARAM_BUFFER_POINTER) argument passing. This avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams returns dangling kernelParams pointers, and enables the UpdateNodeAddresses fast-update path to bypass full command buffer re-recording when only buffer addresses change (BFC allocator recapture). Key changes: - CreateKernelNode packs arguments into an owned contiguous buffer and uses extra-style, caching the node handle and params for fast patching. - UpdateKernelNode also uses extra-style for consistency. - UpdateNodeAddresses patches owned arg buffers directly and calls hipGraphExecKernelNodeSetParams per modified node, avoiding the expensive hipGraphExecUpdate full-graph sync. - has_kernelparams_nodes_ flag now stays false for XLA-recorded kernels, allowing SupportsNodeAddressUpdate() to return true. - Added profiling instrumentation behind XLA_PROFILE_CMD_BUFFER env var.
Resolved ToString() -> ToString(0) signature change from upstream. Kept profiling instrumentation behind XLA_PROFILE_CMD_BUFFER env var.
| #include "xla/backends/gpu/runtime/command_buffer_thunk.h" | ||
|
|
||
| #include <algorithm> | ||
| #include <sys/mman.h> |
There was a problem hiding this comment.
<sys/mman.h> is a POSIX system header and should be placed after the C++ standard library headers (or in a separate group), per Google C++ style. It also breaks alphabetical order within the <c*> group. Additionally, this is a Linux-only header -- using mmap makes this code non-portable to Windows (relevant if XLA ever supports non-ROCm/non-CUDA targets). Consider using aligned_alloc or a platform-abstracted allocator instead.
| if (!cmd_buffer->prev_allocs) { | ||
| size_t alloc_bytes = | ||
| ((n * sizeof(se::DeviceAddressBase)) + 4095) & ~4095uL; | ||
| void* p = mmap(nullptr, alloc_bytes, PROT_READ | PROT_WRITE, | ||
| MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); | ||
| if (p != MAP_FAILED) { | ||
| cmd_buffer->prev_allocs = static_cast<se::DeviceAddressBase*>(p); | ||
| cmd_buffer->prev_allocs_capacity = | ||
| alloc_bytes / sizeof(se::DeviceAddressBase); | ||
| } | ||
| } | ||
| if (cmd_buffer->prev_allocs) { | ||
| size_t copy_n = std::min(n, cmd_buffer->prev_allocs_capacity); | ||
| memcpy(cmd_buffer->prev_allocs, cmd_buffer->recorded_allocs.data(), | ||
| copy_n * sizeof(se::DeviceAddressBase)); | ||
| cmd_buffer->prev_allocs_size = copy_n; | ||
| } | ||
| } |
There was a problem hiding this comment.
Memory leak: prev_allocs is allocated via mmap but never freed. ExecutorCommandBuffer has no destructor, so this memory is leaked when the object is destroyed. Either add a destructor that calls munmap(prev_allocs, prev_allocs_capacity * sizeof(se::DeviceAddressBase)), or use a simpler allocation strategy (e.g., std::vector or std::unique_ptr<se::DeviceAddressBase[]>) that cleans up automatically.
The comment says mmap is used "to avoid system allocator (heap) interference with the HIP graph runtime," but this rationale should be explained more concretely -- what specific interference occurs? If it is a real concern, the leak still needs to be fixed.
| bool did_fast_update = false; | ||
|
|
||
| if (can_fast_update && cmd_buffer->prev_allocs_size > 0) { | ||
| auto* gpu_cmd_buf = static_cast<se::gpu::GpuCommandBuffer*>( |
There was a problem hiding this comment.
Unsafe downcast: static_cast<se::gpu::GpuCommandBuffer*> is used instead of dynamic_cast. If command_buffer is not actually a GpuCommandBuffer (e.g., on a different platform or future backend), this is undefined behavior. The null check on line 323 (gpu_cmd_buf &&) only guards against a null command_buffer.get(), not against an incorrect type. The same issue appears at line 412. Use dynamic_cast with a null check for safety.
| @@ -277,49 +284,109 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { | |||
| return absl::OkStatus(); | |||
| } | |||
|
|
|||
| static const bool fast_update = [] { | |||
| const char* env = std::getenv("XLA_GPU_GRAPH_FAST_UPDATE"); | |||
| return env != nullptr && std::string(env) == "1"; | |||
| }(); | |||
|
|
|||
| uint64_t t_alloc_start = 0, t_alloc_end = 0; | |||
| uint64_t t_record_start = 0, t_record_end = 0; | |||
| uint64_t t_submit_start = 0, t_submit_end = 0; | |||
|
|
|||
| if (profile_steps) t_alloc_start = tsl::Env::Default()->NowMicros(); | |||
|
|
|||
| auto updated_allocs = cmd_buffer->UpdateBufferAllocations(commands_, params); | |||
|
|
|||
| // Determine whether to (re-)record the command buffer and whether this is a | |||
| // first-time initialization recording (VA remapping path). | |||
| if (profile_steps) t_alloc_end = tsl::Env::Default()->NowMicros(); | |||
There was a problem hiding this comment.
Three new std::getenv-based feature flags are introduced in this file (XLA_PROFILE_CMD_BUFFER, XLA_GPU_GRAPH_FAST_UPDATE), adding to the XLA_GPU_GRAPH_SKIP_RETRACE in command_buffer_cmd.cc. These should use XLA's DebugOptions proto mechanism for consistency with the rest of the codebase. Environment variables are harder to discover, undocumented, and cannot be set per-compilation (only per-process). The static const bool + std::getenv pattern also means the value is read once at first invocation and cannot be changed, unlike DebugOptions which can be set per-HLO module.
| LOG(WARNING) << "CmdBufProfile dev=" << dev | ||
| << " alloc_check=" << (t_alloc_end - t_alloc_start) << "us" | ||
| << " record=" << (t_record_end - t_record_start) << "us" | ||
| << " submit=" << (t_submit_end - t_submit_start) << "us" | ||
| << " total=" << (t_submit_end - t_alloc_start) << "us" | ||
| << " updated=" << needs_update | ||
| << " fast_update=" << fast_update | ||
| << " can_fast=" << can_fast_update | ||
| << " prev_sz=" << cmd_buffer->prev_allocs_size | ||
| << " num_cmds=" << commands_.size() | ||
| << " num_allocs_changed=" << num_allocs_changed; | ||
| } |
There was a problem hiding this comment.
Profiling output uses LOG(WARNING) for what is purely informational/diagnostic data. This will pollute warning logs in production even when intentionally enabled. Consider using LOG(INFO) or VLOG(1) instead. The same pattern appears in command_buffer_cmd.cc (lines 306, 331, 347) and earlier in this file (line 324, 344, 358, 378).
| "Failed to set memset node params after patching")); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return absl::OkStatus(); | ||
| } | ||
|
|
There was a problem hiding this comment.
The PatchGraphNodes helper function (used by UpdateNodeAddresses) lacks a recursion depth limit. It recurses into child graphs via hipGraphNodeTypeGraph nodes. If the graph structure were cyclic (which shouldn't happen but could due to a HIP bug), or very deeply nested, this could cause a stack overflow. Consider adding a maximum depth guard (e.g., if (depth > 32) return InternalError(...)).
|
|
||
| bool RocmCommandBuffer::SupportsNodeAddressUpdate() const { | ||
| return exec_ != nullptr && graph_ != nullptr && !has_kernelparams_nodes_; | ||
| } | ||
|
|
||
| absl::StatusOr<bool> RocmCommandBuffer::UpdateNodeAddresses( | ||
| absl::Span<const DeviceAddressBase> old_addresses, | ||
| absl::Span<const DeviceAddressBase> new_addresses) { | ||
| if (exec_ == nullptr || graph_ == nullptr) { | ||
| return false; | ||
| } | ||
|
|
||
| if (has_kernelparams_nodes_) return false; | ||
|
|
||
| size_t common_size = std::min(old_addresses.size(), new_addresses.size()); | ||
| absl::flat_hash_map<uintptr_t, size_t> old_addr_map; | ||
| for (size_t i = 0; i < common_size; ++i) { | ||
| auto old_val = reinterpret_cast<uintptr_t>(old_addresses[i].opaque()); | ||
| auto new_val = reinterpret_cast<uintptr_t>(new_addresses[i].opaque()); | ||
| if (old_val != 0 && old_val != new_val) { | ||
| old_addr_map[old_val] = i; | ||
| } | ||
| } | ||
|
|
||
| if (old_addr_map.empty()) return true; |
There was a problem hiding this comment.
In UpdateNodeAddresses, memcpy/memset nodes are patched via the graph definition (hipGraphMemcpyNodeSetParams) and then pushed to the executable graph (hipGraphExecMemcpyNodeSetParams1D). However, the definition-graph update uses hipGraphMemcpyNodeSetParams with the full 3D params, while the exec update uses hipGraphExecMemcpyNodeSetParams1D which assumes 1D. If the original memcpy was actually a 2D/3D copy, the exec-level update will be incorrect -- it flattens height/depth into a single size, changing the memory layout.
Additionally, hipMemcpyDeviceToDevice is hardcoded on line 548, but the original copy kind could be host-to-device, device-to-host, etc.
| // Pack all argument values into a contiguous buffer and use extra-style | ||
| // launch (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead of kernelParams. | ||
| // This avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams | ||
| // returns dangling kernelParams pointers, and enables UpdateNodeAddresses. | ||
| auto arg_addrs = args.argument_addresses(); | ||
| size_t num_args = arg_addrs.size(); | ||
|
|
||
| // Each argument is a pointer-sized value; pack them contiguously. | ||
| size_t buf_size = num_args * sizeof(void*); | ||
| auto packed_buf = std::make_unique<uint8_t[]>(buf_size); | ||
| for (size_t i = 0; i < num_args; ++i) { | ||
| memcpy(packed_buf.get() + i * sizeof(void*), arg_addrs[i], sizeof(void*)); | ||
| } |
There was a problem hiding this comment.
The CreateKernelNode change packs all kernel arguments as pointer-sized values (sizeof(void*) each), but not all kernel arguments are necessarily pointer-sized. For example, a kernel taking an int32_t argument would have sizeof(int32_t) == 4 bytes, but this code packs it as 8 bytes (on 64-bit), adding 4 bytes of uninitialized padding. The comment on line 1305 acknowledges the assumption ("Each argument is a pointer-sized value") but this is not guaranteed by the KernelArgs API.
The same assumption is made in UpdateKernelNode below. If XLA's kernel argument packing does indeed guarantee pointer-sized values, this should be documented/asserted at the KernelArgs level rather than assumed here.
| #include "xla/backends/gpu/runtime/command_buffer_thunk.h" | ||
|
|
||
| #include <algorithm> | ||
| #include <sys/mman.h> |
There was a problem hiding this comment.
issue (portability): <sys/mman.h> is a POSIX-specific header placed between C++ standard library headers, violating Google C++ include ordering (C system headers should be grouped separately). More importantly, this is platform-generic code shared across backends — this include and the mmap call at line 420 will fail to compile on Windows. Either guard with #ifdef __linux__ / platform ifdefs, or use a portable allocator (std::aligned_alloc, operator new).
| if (!cmd_buffer->prev_allocs) { | ||
| size_t alloc_bytes = | ||
| ((n * sizeof(se::DeviceAddressBase)) + 4095) & ~4095uL; | ||
| void* p = mmap(nullptr, alloc_bytes, PROT_READ | PROT_WRITE, | ||
| MAP_PRIVATE | MAP_ANONYMOUS, -1, 0); | ||
| if (p != MAP_FAILED) { | ||
| cmd_buffer->prev_allocs = static_cast<se::DeviceAddressBase*>(p); | ||
| cmd_buffer->prev_allocs_capacity = | ||
| alloc_bytes / sizeof(se::DeviceAddressBase); | ||
| } | ||
| } | ||
| if (cmd_buffer->prev_allocs) { | ||
| size_t copy_n = std::min(n, cmd_buffer->prev_allocs_capacity); | ||
| memcpy(cmd_buffer->prev_allocs, cmd_buffer->recorded_allocs.data(), | ||
| copy_n * sizeof(se::DeviceAddressBase)); | ||
| cmd_buffer->prev_allocs_size = copy_n; | ||
| } | ||
| } |
There was a problem hiding this comment.
bug (memory leak): prev_allocs is allocated via mmap(MAP_PRIVATE | MAP_ANONYMOUS) but is never freed with munmap. When the ExecutorCommandBuffer is destroyed, this memory leaks. The comment says this avoids "system allocator (heap) interference with the HIP graph runtime" but provides no justification for why std::vector or std::unique_ptr<T[]> would interfere. The mmap approach also skips the RAII pattern expected by XLA conventions. If mmap is truly required, the ExecutorCommandBuffer destructor should call munmap.
| bool did_fast_update = false; | ||
|
|
||
| if (can_fast_update && cmd_buffer->prev_allocs_size > 0) { | ||
| auto* gpu_cmd_buf = static_cast<se::gpu::GpuCommandBuffer*>( |
There was a problem hiding this comment.
nit (safety): static_cast<se::gpu::GpuCommandBuffer*> is undefined behavior if the actual runtime type is not GpuCommandBuffer. In command_buffer_cmd.cc (same PR), you use dynamic_cast for the same downcast (line 311). This should be dynamic_cast with a null check for consistency and safety — especially since the supports check on the next line already handles the null case.
| absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) { | ||
| static const bool profile_steps = [] { | ||
| const char* env = std::getenv("XLA_PROFILE_CMD_BUFFER"); | ||
| return env != nullptr && std::string(env) == "1"; | ||
| }(); | ||
|
|
||
| // We might end up with empty command sequence if all of the captured fusions |
There was a problem hiding this comment.
suggestion (consistency): This PR introduces three std::getenv-based feature flags (XLA_PROFILE_CMD_BUFFER, XLA_GPU_GRAPH_FAST_UPDATE here, plus XLA_GPU_GRAPH_SKIP_RETRACE in command_buffer_cmd.cc), while the flattening feature uses DebugOptions (xla_gpu_graph_enable_node_flattening). These should all use DebugOptions for consistency — env vars are harder to discover, not documented in the proto, and bypass XLA's standard configuration surface.
| LOG(WARNING) << "CmdBufProfile dev=" << dev | ||
| << " alloc_check=" << (t_alloc_end - t_alloc_start) << "us" | ||
| << " record=" << (t_record_end - t_record_start) << "us" | ||
| << " submit=" << (t_submit_end - t_submit_start) << "us" | ||
| << " total=" << (t_submit_end - t_alloc_start) << "us" | ||
| << " updated=" << needs_update | ||
| << " fast_update=" << fast_update | ||
| << " can_fast=" << can_fast_update | ||
| << " prev_sz=" << cmd_buffer->prev_allocs_size | ||
| << " num_cmds=" << commands_.size() | ||
| << " num_allocs_changed=" << num_allocs_changed; | ||
| } |
There was a problem hiding this comment.
nit: All profiling messages use LOG(WARNING), which will pollute warning-level logs in production when XLA_PROFILE_CMD_BUFFER=1 is set. This can trigger warning-level log monitoring/alerting. Profiling output should use VLOG(1) or LOG(INFO) instead.
| "Failed to set memset node params after patching")); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return absl::OkStatus(); | ||
| } | ||
|
|
There was a problem hiding this comment.
issue (robustness): PatchGraphNodes recurses into child graphs via hipGraphChildGraphNodeGetGraph but has no depth limit. A graph with deeply nested children (or a cycle due to HIP API bug) would cause a stack overflow. Consider adding a max-depth guard (e.g., if (depth > 32) return InternalError(...)).
|
|
||
| bool RocmCommandBuffer::SupportsNodeAddressUpdate() const { | ||
| return exec_ != nullptr && graph_ != nullptr && !has_kernelparams_nodes_; | ||
| } | ||
|
|
||
| absl::StatusOr<bool> RocmCommandBuffer::UpdateNodeAddresses( | ||
| absl::Span<const DeviceAddressBase> old_addresses, | ||
| absl::Span<const DeviceAddressBase> new_addresses) { | ||
| if (exec_ == nullptr || graph_ == nullptr) { | ||
| return false; | ||
| } | ||
|
|
||
| if (has_kernelparams_nodes_) return false; | ||
|
|
||
| size_t common_size = std::min(old_addresses.size(), new_addresses.size()); | ||
| absl::flat_hash_map<uintptr_t, size_t> old_addr_map; | ||
| for (size_t i = 0; i < common_size; ++i) { | ||
| auto old_val = reinterpret_cast<uintptr_t>(old_addresses[i].opaque()); | ||
| auto new_val = reinterpret_cast<uintptr_t>(new_addresses[i].opaque()); | ||
| if (old_val != 0 && old_val != new_val) { | ||
| old_addr_map[old_val] = i; | ||
| } | ||
| } | ||
|
|
||
| if (old_addr_map.empty()) return true; |
There was a problem hiding this comment.
bug (correctness): The memcpy size computation width * height * depth doesn't account for element size or pitch, and the semantics of hipMemcpy3DParms.extent.width differ between 1D and 2D/3D copies. Additionally, hipMemcpyDeviceToDevice is always hardcoded via hipGraphExecMemcpyNodeSetParams1D, which ignores the original copy direction (the source memcpy could be D2H or H2D). This means host-to-device or device-to-host memcpy nodes will silently produce incorrect behavior after patching.
| // Pack all argument values into a contiguous buffer and use extra-style | ||
| // launch (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead of kernelParams. | ||
| // This avoids the HIP bug (ROCm/clr#138) where hipGraphKernelNodeGetParams | ||
| // returns dangling kernelParams pointers, and enables UpdateNodeAddresses. | ||
| auto arg_addrs = args.argument_addresses(); | ||
| size_t num_args = arg_addrs.size(); | ||
|
|
||
| // Each argument is a pointer-sized value; pack them contiguously. | ||
| size_t buf_size = num_args * sizeof(void*); | ||
| auto packed_buf = std::make_unique<uint8_t[]>(buf_size); | ||
| for (size_t i = 0; i < num_args; ++i) { | ||
| memcpy(packed_buf.get() + i * sizeof(void*), arg_addrs[i], sizeof(void*)); | ||
| } |
There was a problem hiding this comment.
issue (correctness assumption): The packing logic assumes every kernel argument is exactly sizeof(void*) bytes, copying sizeof(void*) from each argument_addresses() element. However, KernelArgsPackedArrayBase::argument_addresses() returns pointers to argument data of potentially different sizes (scalars, structs). If any argument is smaller than sizeof(void*), this reads beyond its storage; if larger, it truncates. While XLA-generated GPU kernels likely only pass device pointers, this assumption is undocumented and could break with library kernels. Consider adding a DCHECK or static assertion that all argument sizes equal sizeof(void*).
Re-review SummaryReviewed the updated diff. The 9 previous inline findings remain unaddressed — no changes needed on those threads. 8 new findings posted inline, primarily in
Previous findings (empty-graph safety check removal, |
Re-review SummaryReviewed latest changes including commit 711e998 ("Add trace cache for CollectiveCmd to fix COLLECTIVES graph capture regression"). The All 25 previously posted inline findings remain unaddressed — no code changes or replies have been made to resolve them. The key concerns are unchanged:
No new inline comments posted — all issues were already covered in previous review rounds. 🤖 Generated with Claude Code |
When COLLECTIVES are graph-captured, RCCL operations produce child graph nodes (hipGraphNodeTypeGraph). Previously UpdateNodeAddresses only patched kernel/memcpy/memset nodes in the parent graph, skipping child graphs entirely. This extends the fast-update path to: 1. Detect child graph nodes in the parent graph walk 2. Use PatchGraphNodes to recursively patch kernel/memcpy/memset nodes inside the child graph, including kernelParams-style nodes from stream capture (RCCL collectives) 3. Push the patched child graph to the exec graph via hipGraphExecChildGraphNodeSetParams Also extends PatchGraphNodes to handle kernelParams-style kernel nodes (not just extra-style), since stream-captured RCCL kernels use kernelParams where HIP owns the parameter storage.
79728ba to
0182f04
Compare
Problem
When XLA's BFC (Best-Fit with Coalescing) memory allocator reassigns device buffers between steps, command buffers must be updated with the new addresses. The current path re-records the entire command buffer from scratch:
This full re-record walks every command in the sequence, calls HIP APIs to recreate each graph node, finalizes the graph, and re-instantiates the executable. For large models with hundreds of graph nodes, this overhead is significant and happens on every training step because the BFC allocator frequently moves buffers.
Additionally, collectives (AllReduce, ReduceScatter, etc.) suffer a severe regression when captured into HIP graphs.
CollectiveCmd::RecordTracedCommandwas callingTraceCommandBufferFactory::Createon every Record invocation — performing a fullhipStreamBeginCapture→ RCCL →hipStreamEndCapturecycle each time. For non-power-of-2 element counts, this stream capture takes ~250ms due to RCCL's internal polling protocol, making COLLECTIVES in command buffers 60-150x slower than direct execution.Idea
1. Fast kernel node address patching
Instead of re-recording the entire graph, directly patch the device addresses inside the existing executable graph (
hipGraphExec_t):Switch kernel nodes from
kernelParams-style toextra-style argument passing. Instead ofhipKernelNodeParams.kernelParams(array of pointers to each argument), we pack all arguments into a single contiguous buffer and useHIP_LAUNCH_PARAM_BUFFER_POINTERvia theextrafield. This gives us an owned, mutable buffer containing all the device pointers for each kernel.Patch addresses in-place. When allocations change, scan the owned arg buffers for old device addresses and replace them with new ones. Then call
hipGraphExecKernelNodeSetParamsto push the updated args to the executable graph — no re-recording, no re-instantiation.Cache node handles and params at creation time. Each
CreateKernelNodecall stores thehipGraphNode_thandle, thehipKernelNodeParams, and the owned arg buffer. This eliminates the need to callhipGraphGetNodes+hipGraphNodeGetType+hipGraphKernelNodeGetParamsduring the update — we iterate our cached list directly.Why
extra-style?The
kernelParamsapproach has a known HIP bug (ROCm/clr#138):hipGraphKernelNodeGetParamsreturns dangling pointers forkernelParams-style nodes, making it impossible to read or patch argument values from the graph. By switching toextra-style withHIP_LAUNCH_PARAM_BUFFER_POINTER, we own the argument buffer and can safely read/modify it.Update flow comparison
Before (full re-record, ~200us per command buffer):
After (in-place patch, ~165us and improving with optimization):
2. Collective trace cache (fixes COLLECTIVES regression)
CollectiveCmd::RecordTracedCommandwas callingTraceCommandBufferFactory::Createon every record — doing a fullhipStreamBeginCapture→ RCCL ncclAllReduce →hipStreamEndCaptureeach time. UnlikeTracedCommandBufferCmd(used by GEMM, CublasLt, etc.) which caches traced graphs viaTracedCommandBuffer, collectives had no cache at all.Fix: Use the same
TracedCommandBuffercache forCollectiveCmd. The cache stores traced command buffers keyed by buffer addresses. On cache hit, the previously-traced RCCL graph is reused directly. On miss, it traces once and caches.This is the same caching mechanism that VMM/VA-remapping achieves implicitly (addresses never change → graph never re-traced), but works without requiring VMM support.
What this PR contains
Core changes
rocm_command_buffer.cc/CreateKernelNode: Packs kernel arguments into an owned contiguous buffer usingextra-style (HIP_LAUNCH_PARAM_BUFFER_POINTER) instead ofkernelParams. Caches the node handle and params inOwnedKernelNode.rocm_command_buffer.cc/UpdateKernelNode: Also usesextra-style for consistency withCreateKernelNode.rocm_command_buffer.cc/UpdateNodeAddresses: New fast-update path. Iterates cachedkernel_nodes_, patches owned arg buffers, callshipGraphExecKernelNodeSetParamsper modified node. Also handles memcpy/memset nodes via graph walk.rocm_command_buffer.h: AddsOwnedKernelNodestruct (node handle + cached params + owned arg buffer),has_kernelparams_nodes_safety flag,SupportsNodeAddressUpdate(),UpdateNodeAddresses().command_buffer_thunk.cc: Adds the fast-update decision logic inExecuteOnStream. WhenXLA_GPU_GRAPH_FAST_UPDATE=1, attemptsUpdateNodeAddressesbefore falling back to fullRecord. Managesprev_allocsbuffer (mmap-backed) for tracking old addresses.command_buffer_cmd.cc/CollectiveCmd::RecordTracedCommand: Replaced uncachedTraceCommandBufferFactory::CreatewithTracedCommandBuffercache (same pattern asTracedCommandBufferCmd). Eliminates the 250ms per-record RCCL re-capture.Supporting changes
FlattenChildGraphNodes,UpdateFlattenedChildNodes,BuildPatchTable,PatchFlattenedNodes— flatten child graphs into the parent for better patching coverage.UpdateKernelNodes: In-place kernel node patching for traced command buffers.DumpGraphKernelNodes: Debug utility for inspecting graph node state.XLA_PROFILE_CMD_BUFFER=1env var enables detailed timing logs.Profiling results
Hardware: 8x AMD Instinct MI308X (gfx942)
Workload: 8-layer MLP, 4096 input dim, 41.9M params/replica
A. Single-device (
jit) — kernel node patchingFAST_UPDATE=0)FAST_UPDATE=1)B. Multi-device (
pmap8 GPU) — without COLLECTIVESFAST_UPDATE=0)FAST_UPDATE=1)C. Multi-device (
pmap8 GPU) — with COLLECTIVES (the main fix)FAST_UPDATE=0FAST_UPDATE=1D. Standalone allreduce (non-power-of-2 elements, graph-captured)
Root cause analysis
The 250ms regression was caused by
hipStreamBeginCapture→ RCCL →hipStreamEndCapturebeing repeated on every graph record. For non-power-of-2 allreduce sizes, RCCL's kernel protocol involves internal polling that takes ~250ms during stream capture. With the trace cache, this capture happens only once; subsequent records reuse the cached graph.Power-of-2 sizes were fast even without caching because RCCL uses a simpler single-pass MSCCL algorithm for those sizes (the captured graph already completes quickly). With the cache, all sizes are uniformly fast.
Correctness
All configurations produce identical final loss values:
How to enable
The collective trace cache is always active (no flag needed). The fast-update path (
XLA_GPU_GRAPH_FAST_UPDATE=1) automatically falls back to full Record when:SupportsNodeAddressUpdate()returns falseUpdateNodeAddressesfails for any reason