Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/backends/gpu/runtime/command_buffer_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ ConvertThunksToCommandBuffer(
const DebugOptions& debug_options) {
bool enable_loop_unroll = debug_options.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping =
debug_options.xla_gpu_enable_command_buffer_va_remapping();
debug_options.xla_gpu_enable_command_buffer_va_remapping() ||
debug_options.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
61 changes: 44 additions & 17 deletions xla/backends/gpu/runtime/command_buffer_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,28 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) {
/*additional_compute_streams=*/{}, params.execution_scoped_state,
/*mock_collectives=*/false);

// If command buffer is in `kCreate` state it means that command buffer
// sequence was never recorded into it. We initialize all command buffers
// before execution, because command buffers when instantiated will allocate
// memory on device and this might lead to deadlocks when we have concurrent
// NCCL operations in flight.
//
// If commands require initialization (and VA remapping is not enabled), we
// also record them into the command buffer before execution. This is required
// to guarantee that collective commands are recorded on all participating
// ranks to avoid deadlocks.
if (cmd_buffer->warmup_done && (cmd_buffer->command_buffer->state() ==
se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()))) {
bool warmup = cmd_buffer->warmup_done;
auto state = cmd_buffer->command_buffer->state();
bool will_record = warmup && (state == se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()));
LOG(INFO) << absl::StrFormat(
"CommandBufferThunk::Initialize: warmup_done=%d state=%d "
"va_remapping=%d requires_init=%d will_record=%d",
warmup, static_cast<int>(state), enable_command_buffer_va_remapping_,
commands_.requires_initialization(), will_record);

// Log the addresses that will be used for recording
if (will_record) {
for (auto idx : commands_.allocs_indices()) {
auto addr = execute_params.buffer_allocations->GetDeviceAddress(idx);
LOG(INFO) << absl::StrFormat(
" Initialize record addr[%d]: %p size=%d", idx, addr.opaque(),
addr.size());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) in hot path should be VLOG

Multiple LOG(INFO) statements are added to Initialize and ExecuteOnStream that fire on every iteration. These will produce massive log output in production and add measurable overhead. The existing code already uses VLOG(2)/VLOG(3) for this purpose.

These should be converted to VLOG(2) or VLOG(3) before merging, or removed if they were only needed during development.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The LOG(INFO) statements in Initialize and ExecuteOnStream have been converted to VLOG(3).

}
}

if (will_record) {
VLOG(3) << "Initialize command buffer on device #"
<< params.executor->device_ordinal()
<< " by recoding command buffer cmd sequence"
Expand Down Expand Up @@ -271,22 +279,41 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) {

// warm up iteration, run through thunks if they are present.
if (!cmd_buffer->warmup_done && thunks_) {
VLOG(2) << "Executing warm up iteration of command buffer thunk";
LOG(INFO) << "CommandBufferThunk: WARMUP - running sequential thunks";
for (auto idx : commands_.allocs_indices()) {
auto addr = params.buffer_allocations->GetDeviceAddress(idx);
LOG(INFO) << absl::StrFormat(
" Warmup addr[%d]: %p size=%d", idx, addr.opaque(), addr.size());
}
TF_RETURN_IF_ERROR(thunks_->ExecuteOnStream(params));
cmd_buffer->warmup_done = true;
return absl::OkStatus();
}

auto updated_allocs = cmd_buffer->UpdateBufferAllocations(commands_, params);

// Determine whether to (re-)record the command buffer and whether this is a
// first-time initialization recording (VA remapping path).
bool is_first_record =
enable_command_buffer_va_remapping_ &&
cmd_buffer->command_buffer->state() == se::CommandBuffer::State::kCreate;
bool needs_update = !enable_command_buffer_va_remapping_ &&
(!updated_allocs.empty() || commands_.force_update());

LOG(INFO) << absl::StrFormat(
"CommandBufferThunk::ExecuteOnStream: va_remapping=%d updated_allocs=%d "
"is_first_record=%d needs_update=%d num_executions=%d state=%d",
enable_command_buffer_va_remapping_, updated_allocs.size(),
is_first_record, needs_update, cmd_buffer->num_executions,
static_cast<int>(cmd_buffer->command_buffer->state()));

// Log addresses on first few executions
if (cmd_buffer->num_executions < 3) {
for (auto idx : commands_.allocs_indices()) {
auto addr = params.buffer_allocations->GetDeviceAddress(idx);
LOG(INFO) << absl::StrFormat(
" Execute addr[%d]: %p size=%d", idx, addr.opaque(), addr.size());
}
}

if (is_first_record || needs_update) {
XLA_VLOG_DEVICE(3, executor->device_ordinal())
<< "Create/Update command buffer"
Expand Down
17 changes: 17 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_enable_pdl(true);
opts.set_xla_gpu_enable_command_buffer_va_remapping(false);
opts.set_xla_gpu_enable_circular_vmm_pool(false);
opts.set_xla_gpu_circular_vmm_pool_slots(1);
return opts;
}

Expand Down Expand Up @@ -3012,6 +3014,21 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Enable VA remapping for command buffer thunks. When enabled, command "
"buffer thunks use fixed virtual addresses across executions, allowing "
"the command buffer to be recorded once and replayed without updates."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_circular_vmm_pool",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_circular_vmm_pool),
debug_options->xla_gpu_enable_circular_vmm_pool(),
"Enable circular VMM pool for command buffer thunks. Pre-allocates N "
"physical memory slots with permanent VA mappings, using GPU timeline "
"signaling for safe slot reuse. Eliminates per-iteration map/unmap "
"overhead entirely after startup."));
flag_list->push_back(tsl::Flag(
"xla_gpu_circular_vmm_pool_slots",
int32_setter_for(
&DebugOptions::set_xla_gpu_circular_vmm_pool_slots),
debug_options->xla_gpu_circular_vmm_pool_slots(),
"Number of slots in the circular VMM pool (default 2)."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: default value / help text mismatch

The actual default is set to 1 (line 507: set_xla_gpu_circular_vmm_pool_slots(1)), but the flag help text here says "(default 2)".

With 1 slot, AcquireNextSlot will spin-wait for GPU completion on every iteration (no overlap), negating the primary benefit of the circular pool. Either fix the help text to say "default 1" or change the default to 2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially resolved. The flag help text in debug_options_flags.cc now correctly says "(default 1)", but the proto comment in xla.proto still says "(default 2)" -- see xla_gpu_circular_vmm_pool_slots field comment. Please update the proto comment to match the actual default of 1.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- fully addressed in this revision. Both the flag help text in debug_options_flags.cc and the proto comment in xla.proto now correctly say "(default 1)".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: both the flag help text (line 3034) and proto comment now correctly say "(default 1)", matching the actual default value.

Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ cc_library(
name = "gpu_executable",
srcs = ["gpu_executable.cc"],
hdrs = ["gpu_executable.h"],
local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
deps = [
":alias_info",
":backend_configs_cc",
Expand Down Expand Up @@ -804,7 +805,9 @@ cc_library(
"@tsl//tsl/platform:random",
"@tsl//tsl/profiler/lib:scoped_annotation",
"@tsl//tsl/profiler/lib:traceme",
],
] + if_rocm_is_configured([
"//xla/stream_executor/rocm:circular_vmm_pool",
]),
)

tf_proto_library(
Expand Down
171 changes: 165 additions & 6 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/stream_executor/vmm_device_address_allocator.h"
#if TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/circular_vmm_pool.h"
#endif
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/env_time.h"
#include "xla/tsl/platform/logging.h"
Expand Down Expand Up @@ -1515,6 +1518,151 @@ absl::Status GpuExecutable::ExecuteThunksWithVaRemapping(
return absl::OkStatus();
}

absl::Status GpuExecutable::ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
se::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done) {
#if TENSORFLOW_USE_ROCM
CircularPoolState* pool_state = nullptr;
{
absl::MutexLock lock(&circular_pool_mutex_);
pool_state = &circular_pools_[executor];
}
Comment on lines +1529 to +1531
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: thread safety — pool_state accessed without lock

The mutex circular_pool_mutex_ is held only while looking up/inserting the CircularPoolState entry. After that, pool_state->pool, pool_state->iteration_count++, etc. are accessed without the lock.

If two threads call ExecuteThunksWithCircularVmmPool for the same executor concurrently (e.g., multi-threaded inference), they can race on iteration_count++ (non-atomic, no lock) — two iterations could use the same slot concurrently, corrupting data.

Consider either:

  1. Holding the mutex for the entire function body (but watch for deadlocks with GPU calls), or
  2. Making iteration_count an std::atomic<uint64_t> and adding a per-pool mutex for the initialization path.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. iteration_count is now std::atomic<uint64_t> with fetch_add(1), and pool initialization uses double-checked locking with a per-pool mutex.


// Build sets: which buffers go into pool, which need data copying.
// Pool ALL non-constant command buffer allocations (params + temps).
// Constants keep BFC addresses (loaded from module globals, stable).
absl::btree_set<BufferAllocation::Index> pool_indexes;
absl::btree_set<BufferAllocation::Index> copy_indexes;
if (buffer_assignment_) {
for (BufferAllocation::Index idx : command_buffer_allocation_indexes_) {
const auto& alloc = buffer_assignment_->GetAllocation(idx);
if (alloc.is_constant() || alloc.size() == 0) continue;
pool_indexes.insert(idx);
if (alloc.is_entry_computation_parameter() || alloc.maybe_live_out()) {
copy_indexes.insert(idx);
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance: pool_indexes and copy_indexes recomputed every iteration

These sets are derived from buffer_assignment_ and command_buffer_allocation_indexes_, which don't change between iterations. They should be computed once during pool initialization and cached in CircularPoolState, rather than rebuilt on every call.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. pool_indexes and copy_indexes are now cached in CircularPoolState with an indexes_cached flag and double-checked locking, computed once at init.

}

// Initialize pool on first use.
if (pool_state->pool == nullptr) {
int num_slots = has_module()
? module_config().debug_options().xla_gpu_circular_vmm_pool_slots()
: 1;

if (pool_indexes.empty()) {
return ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr,
module_name_, unique_id, *thunk_executor_, executable_source,
run_options, buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_);
}

std::vector<uint64_t> buffer_sizes;
buffer_sizes.reserve(pool_indexes.size());
for (BufferAllocation::Index idx : pool_indexes) {
buffer_sizes.push_back(buffer_allocations.GetDeviceAddress(idx).size());
}

TF_ASSIGN_OR_RETURN(
auto pool,
se::gpu::CircularVmmPool::Create(executor, buffer_sizes, num_slots));

LOG(INFO) << absl::StrFormat(
"CircularVmmPool: created %d slots for module %s on device %d "
"(%d command buffer allocations)",
num_slots, module_name_, executor->device_ordinal(),
command_buffer_allocation_indexes_.size());

pool_state->pool = std::move(pool);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: pool initialization race -- pool == nullptr checked without lock

If two threads call ExecuteThunksWithCircularVmmPool concurrently for the same executor before the pool is initialized, both will see pool_state->pool == nullptr and both will create a pool. The second std::move(pool) assignment overwrites the first, and any iteration already in-flight on the first pool will use stale state.

This should use the same double-checked locking pattern as indexes_cached (once that is fixed to use atomics), or simply hold pool_state->mu during initialization:

if (pool_state->pool == nullptr) {
  absl::MutexLock lock(&pool_state->mu);
  if (pool_state->pool == nullptr) {
    // ... create pool ...
  }
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Pool initialization is now guarded by atomic initialized + per-pool mutex double-checked locking. The pool creation, index computation, and initialized flag are all protected within the same critical section.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: pool initialization is now inside an atomic double-checked locking pattern with a per-pool mutex, eliminating the race.

}

auto* pool = static_cast<se::gpu::CircularVmmPool*>(pool_state->pool.get());
uint64_t iteration = pool_state->iteration_count++;
int slot_idx = iteration % pool->num_slots();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] slot_idx is only used in the VLOG(3) on line 1601. In builds where VLOG(3) compiles out, this may trigger an unused-variable warning. Consider moving this computation inside the VLOG statement or guarding it with VLOG_IS_ON(3).


// Acquire next slot — non-blocking check of GPU timeline counter.
TF_ASSIGN_OR_RETURN(auto slot_addresses, pool->AcquireNextSlot(iteration));

LOG(INFO) << absl::StrFormat(
"CircularVmmPool iter=%d slot=%d/%d: %d pool addrs",
iteration, slot_idx, pool->num_slots(), slot_addresses.size());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) per-iteration — should be VLOG

This logs on every iteration. Should be VLOG(3) or similar to avoid flooding production logs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Per-iteration logging in ExecuteThunksWithCircularVmmPool now uses VLOG(3). The remaining LOG(INFO) calls are one-time pool initialization messages.


// Build remapped buffer allocations: all pooled buffers use pool VA
// addresses; constants and non-command-buffer buffers keep BFC addresses.
// For params/live-out, copy data from BFC into pool VA before execution.
std::vector<se::DeviceAddressBase> mapped_buffers;
mapped_buffers.reserve(buffer_allocations.size());
int slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);

// Copy param data from BFC into pool before execution. This is needed
// because the graph uses stable pool VA addresses, but the actual data
// lives at BFC addresses which may change. For params, the data itself
// may also change (e.g., optimizer weight updates in training).
if (copy_indexes.contains(i) && !bfc_addr.is_null() &&
bfc_addr.size() > 0) {
se::DeviceAddressBase pool_dst(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&pool_dst, bfc_addr, bfc_addr.size()));
}
mapped_buffers.push_back(pool_addr);
} else {
mapped_buffers.push_back(buffer_allocations.GetDeviceAddress(i));
}
}

BufferAllocations remapped_buffer_allocations(
mapped_buffers, buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());

TF_RETURN_IF_ERROR(ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr, module_name_,
unique_id, *thunk_executor_, executable_source, run_options,
remapped_buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_));

// Copy live-out results back from pool to BFC so the output appears at
// the expected BFC address for downstream consumers.
slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
if (copy_indexes.contains(i) && buffer_assignment_) {
const auto& alloc = buffer_assignment_->GetAllocation(i);
if (alloc.maybe_live_out()) {
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);
if (!bfc_addr.is_null() && bfc_addr.size() > 0) {
se::DeviceAddressBase bfc_dst(bfc_addr.opaque(), bfc_addr.size());
se::DeviceAddressBase pool_src(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&bfc_dst, pool_src, bfc_addr.size()));
}
}
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: live-out copy-back not synchronized when block_host_until_done is true

If block_host_until_done is true, ExecuteThunksImpl synchronizes the stream before returning. These D2D copy-back operations are enqueued after that sync, meaning the caller may read stale data at the BFC address.

Consider adding a stream synchronization after the copy-back when block_host_until_done is true:

if (block_host_until_done) {
  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The copy-back D2D memcpy calls now happen before BlockHostUntilDone, so the sync correctly covers both the thunk execution and the copy-back.

}

// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
Comment on lines +1666 to +1672
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] ReleaseSlot (which enqueues hipStreamWriteValue64) is called after BlockHostUntilDone. This means the GPU-side timeline signal write is not covered by the host synchronization — if the caller destroys the pool (or the process exits) immediately after this function returns with block_host_until_done=true, the signal write may not have executed yet.

Consider swapping the order so ReleaseSlot is enqueued before the sync:

Suggested change
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

This ensures the timeline signal is included in the synchronization fence.


return absl::OkStatus();
#else
return absl::UnimplementedError(
"Circular VMM pool is only supported on ROCm.");
#endif
}

absl::Status GpuExecutable::ExecuteThunks(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options) {
Expand Down Expand Up @@ -1587,21 +1735,32 @@ absl::Status GpuExecutable::ExecuteThunks(

se::StreamExecutor* executor = run_options->stream()->parent();

bool has_cmd_buffer_allocs = !command_buffer_allocation_indexes_.empty();

// Check if circular VMM pool is enabled (takes priority over VA remapping).
bool enable_circular_vmm_pool =
has_cmd_buffer_allocs && has_module() &&
module_config().debug_options().xla_gpu_enable_circular_vmm_pool();

// Check if command buffer VA remapping is enabled.
bool enable_command_buffer_va_remapping =
(command_buffer_allocation_indexes_.size() > 0) && has_module() &&
!enable_circular_vmm_pool && has_cmd_buffer_allocs && has_module() &&
module_config()
.debug_options()
.xla_gpu_enable_command_buffer_va_remapping() &&
dynamic_cast<se::DeviceAddressVmmAllocator*>(memory_allocator) != nullptr;

XLA_VLOG_DEVICE(3, executor->device_ordinal()) << absl::StreamFormat(
"ExecuteThunks: command_buffer_allocation_indexes_.size()=%d "
"enable_command_buffer_va_remapping=%d",
command_buffer_allocation_indexes_.size(),
LOG(INFO) << absl::StreamFormat(
"ExecuteThunks: cmd_buffer_allocs=%d circular_vmm_pool=%d "
"va_remapping=%d",
command_buffer_allocation_indexes_.size(), enable_circular_vmm_pool,
enable_command_buffer_va_remapping);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cleanup: LOG(INFO) in hot path — should be VLOG

This fires on every invocation of ExecuteThunks. The original code used XLA_VLOG_DEVICE(3, ...). Should be restored to VLOG(3) or XLA_VLOG_DEVICE to avoid flooding production logs.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The ExecuteThunks logging now uses XLA_VLOG_DEVICE(3, ...) instead of LOG(INFO).


if (enable_command_buffer_va_remapping) {
if (enable_circular_vmm_pool) {
TF_RETURN_IF_ERROR(ExecuteThunksWithCircularVmmPool(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
} else if (enable_command_buffer_va_remapping) {
TF_RETURN_IF_ERROR(ExecuteThunksWithVaRemapping(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
Expand Down
18 changes: 18 additions & 0 deletions xla/service/gpu/gpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -430,6 +430,24 @@ class GpuExecutable : public Executable {
absl::node_hash_map<stream_executor::StreamExecutor*, VaRanges>
module_va_ranges_ ABSL_GUARDED_BY(va_ranges_mutex_);

// Circular VMM pool: pre-allocated slots with permanent VA mappings and
// GPU timeline signaling for safe slot reuse. ROCm-only.
absl::Status ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
stream_executor::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done);

struct CircularPoolState {
std::shared_ptr<void> pool;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: std::shared_ptr<void> for type erasure is fragile

This works correctly because unique_ptr<CircularVmmPool> converting to shared_ptr<void> captures the right deleter. However, it sacrifices type safety and requires static_cast<CircularVmmPool*> at every use site.

Since the header already has a #if TENSORFLOW_USE_ROCM include for CircularVmmPool in the .cc file, consider forward-declaring CircularVmmPool here and using std::unique_ptr<se::gpu::CircularVmmPool> directly (with a custom deleter or forward-declared destructor). This eliminates the type-erased casts.

uint64_t iteration_count = 0;
// Track last-seen BFC addresses to skip redundant D2D memcpy.
absl::flat_hash_map<BufferAllocation::Index, void*> last_param_addrs;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused field: last_param_addrs is never read or written

This map is declared with a comment about skipping redundant D2D memcpy, but it's never used anywhere in the implementation. Either remove it or implement the optimization it's intended for.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The last_param_addrs field has been removed.

};
absl::Mutex circular_pool_mutex_;
absl::node_hash_map<stream_executor::StreamExecutor*, CircularPoolState>
circular_pools_ ABSL_GUARDED_BY(circular_pool_mutex_);

GpuExecutable(const GpuExecutable&) = delete;
GpuExecutable& operator=(const GpuExecutable&) = delete;

Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,11 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCommandBufferThunk(

bool enable_loop_unroll = ir_emitter_context_->debug_options()
.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping = ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping();
bool enable_va_remapping =
ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping() ||
ir_emitter_context_->debug_options()
.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
Loading
Loading