Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/backends/gpu/runtime/command_buffer_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ ConvertThunksToCommandBuffer(
const DebugOptions& debug_options) {
bool enable_loop_unroll = debug_options.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping =
debug_options.xla_gpu_enable_command_buffer_va_remapping();
debug_options.xla_gpu_enable_command_buffer_va_remapping() ||
debug_options.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
35 changes: 19 additions & 16 deletions xla/backends/gpu/runtime/command_buffer_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) {
/*additional_compute_streams=*/{}, params.execution_scoped_state,
/*mock_collectives=*/false);

// If command buffer is in `kCreate` state it means that command buffer
// sequence was never recorded into it. We initialize all command buffers
// before execution, because command buffers when instantiated will allocate
// memory on device and this might lead to deadlocks when we have concurrent
// NCCL operations in flight.
//
// If commands require initialization (and VA remapping is not enabled), we
// also record them into the command buffer before execution. This is required
// to guarantee that collective commands are recorded on all participating
// ranks to avoid deadlocks.
if (cmd_buffer->warmup_done && (cmd_buffer->command_buffer->state() ==
se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()))) {
bool warmup = cmd_buffer->warmup_done;
auto state = cmd_buffer->command_buffer->state();
bool will_record = warmup && (state == se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()));
VLOG(3) << absl::StrFormat(
"CommandBufferThunk::Initialize: warmup_done=%d state=%d "
"va_remapping=%d requires_init=%d will_record=%d",
warmup, static_cast<int>(state), enable_command_buffer_va_remapping_,
commands_.requires_initialization(), will_record);

if (will_record) {
VLOG(3) << "Initialize command buffer on device #"
<< params.executor->device_ordinal()
<< " by recoding command buffer cmd sequence"
Expand Down Expand Up @@ -279,14 +277,19 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) {

auto updated_allocs = cmd_buffer->UpdateBufferAllocations(commands_, params);

// Determine whether to (re-)record the command buffer and whether this is a
// first-time initialization recording (VA remapping path).
bool is_first_record =
enable_command_buffer_va_remapping_ &&
cmd_buffer->command_buffer->state() == se::CommandBuffer::State::kCreate;
bool needs_update = !enable_command_buffer_va_remapping_ &&
(!updated_allocs.empty() || commands_.force_update());

VLOG(3) << absl::StrFormat(
"CommandBufferThunk::ExecuteOnStream: va_remapping=%d updated_allocs=%d "
"is_first_record=%d needs_update=%d num_executions=%d state=%d",
enable_command_buffer_va_remapping_, updated_allocs.size(),
is_first_record, needs_update, cmd_buffer->num_executions,
static_cast<int>(cmd_buffer->command_buffer->state()));

if (is_first_record || needs_update) {
XLA_VLOG_DEVICE(3, executor->device_ordinal())
<< "Create/Update command buffer"
Expand Down
17 changes: 17 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_enable_pdl(true);
opts.set_xla_gpu_enable_command_buffer_va_remapping(false);
opts.set_xla_gpu_enable_circular_vmm_pool(false);
opts.set_xla_gpu_circular_vmm_pool_slots(1);
return opts;
}

Expand Down Expand Up @@ -3012,6 +3014,21 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Enable VA remapping for command buffer thunks. When enabled, command "
"buffer thunks use fixed virtual addresses across executions, allowing "
"the command buffer to be recorded once and replayed without updates."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_circular_vmm_pool",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_circular_vmm_pool),
debug_options->xla_gpu_enable_circular_vmm_pool(),
"Enable circular VMM pool for command buffer thunks. Pre-allocates N "
"physical memory slots with permanent VA mappings, using GPU timeline "
"signaling for safe slot reuse. Eliminates per-iteration map/unmap "
"overhead entirely after startup."));
flag_list->push_back(tsl::Flag(
"xla_gpu_circular_vmm_pool_slots",
int32_setter_for(
&DebugOptions::set_xla_gpu_circular_vmm_pool_slots),
debug_options->xla_gpu_circular_vmm_pool_slots(),
"Number of slots in the circular VMM pool (default 1)."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: default value / help text mismatch

The actual default is set to 1 (line 507: set_xla_gpu_circular_vmm_pool_slots(1)), but the flag help text here says "(default 2)".

With 1 slot, AcquireNextSlot will spin-wait for GPU completion on every iteration (no overlap), negating the primary benefit of the circular pool. Either fix the help text to say "default 1" or change the default to 2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially resolved. The flag help text in debug_options_flags.cc now correctly says "(default 1)", but the proto comment in xla.proto still says "(default 2)" -- see xla_gpu_circular_vmm_pool_slots field comment. Please update the proto comment to match the actual default of 1.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- fully addressed in this revision. Both the flag help text in debug_options_flags.cc and the proto comment in xla.proto now correctly say "(default 1)".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: both the flag help text (line 3034) and proto comment now correctly say "(default 1)", matching the actual default value.

Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ cc_library(
name = "gpu_executable",
srcs = ["gpu_executable.cc"],
hdrs = ["gpu_executable.h"],
local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
deps = [
":alias_info",
":backend_configs_cc",
Expand Down Expand Up @@ -804,7 +805,9 @@ cc_library(
"@tsl//tsl/platform:random",
"@tsl//tsl/profiler/lib:scoped_annotation",
"@tsl//tsl/profiler/lib:traceme",
],
] + if_rocm_is_configured([
"//xla/stream_executor/rocm:circular_vmm_pool",
]),
)

tf_proto_library(
Expand Down
178 changes: 173 additions & 5 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/stream_executor/vmm_device_address_allocator.h"
#if TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/circular_vmm_pool.h"
#endif
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/env_time.h"
#include "xla/tsl/platform/logging.h"
Expand Down Expand Up @@ -1515,6 +1518,160 @@ absl::Status GpuExecutable::ExecuteThunksWithVaRemapping(
return absl::OkStatus();
}

absl::Status GpuExecutable::ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
se::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done) {
#if TENSORFLOW_USE_ROCM
CircularPoolState* pool_state = nullptr;
{
absl::MutexLock lock(&circular_pool_mutex_);
pool_state = &circular_pools_[executor];
}
Comment on lines +1529 to +1531
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: thread safety — pool_state accessed without lock

The mutex circular_pool_mutex_ is held only while looking up/inserting the CircularPoolState entry. After that, pool_state->pool, pool_state->iteration_count++, etc. are accessed without the lock.

If two threads call ExecuteThunksWithCircularVmmPool for the same executor concurrently (e.g., multi-threaded inference), they can race on iteration_count++ (non-atomic, no lock) — two iterations could use the same slot concurrently, corrupting data.

Consider either:

  1. Holding the mutex for the entire function body (but watch for deadlocks with GPU calls), or
  2. Making iteration_count an std::atomic<uint64_t> and adding a per-pool mutex for the initialization path.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. iteration_count is now std::atomic<uint64_t> with fetch_add(1), and pool initialization uses double-checked locking with a per-pool mutex.


// Cache buffer classification on first call (doesn't change between iters).
if (!pool_state->indexes_cached) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: data race in double-checked locking on indexes_cached

indexes_cached is a plain bool (declared at gpu_executable.h:449), but it is read here outside the lock. Another thread may be writing true to it (at line 1548) under the lock. Reading a non-atomic variable concurrently with a write is undefined behavior in the C++ memory model.

Either:

  1. Change indexes_cached to std::atomic<bool> and use load(std::memory_order_acquire) / store(true, std::memory_order_release) for the double-checked locking pattern, or
  2. Remove the outer check and always acquire the lock (the lock is uncontended after the first call, so the overhead is negligible).

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. indexes_cached replaced by std::atomic initialized with proper acquire/release memory ordering for safe double-checked locking.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: indexes_cached has been replaced by std::atomic<bool> initialized with proper acquire/release ordering, fixing the data race.

absl::MutexLock lock(&pool_state->mu);
if (!pool_state->indexes_cached) {
if (buffer_assignment_) {
for (BufferAllocation::Index idx : command_buffer_allocation_indexes_) {
const auto& alloc = buffer_assignment_->GetAllocation(idx);
if (alloc.is_constant() || alloc.size() == 0) continue;
pool_state->pool_indexes.insert(idx);
if (alloc.is_entry_computation_parameter() ||
alloc.maybe_live_out()) {
pool_state->copy_indexes.insert(idx);
}
}
}
pool_state->indexes_cached = true;
}
}
const auto& pool_indexes = pool_state->pool_indexes;
const auto& copy_indexes = pool_state->copy_indexes;

// Initialize pool on first use.
if (pool_state->pool == nullptr) {
int num_slots = has_module()
? module_config().debug_options().xla_gpu_circular_vmm_pool_slots()
: 1;

if (pool_indexes.empty()) {
return ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr,
module_name_, unique_id, *thunk_executor_, executable_source,
run_options, buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_);
}

std::vector<uint64_t> buffer_sizes;
buffer_sizes.reserve(pool_indexes.size());
for (BufferAllocation::Index idx : pool_indexes) {
buffer_sizes.push_back(buffer_allocations.GetDeviceAddress(idx).size());
}

TF_ASSIGN_OR_RETURN(
auto pool,
se::gpu::CircularVmmPool::Create(executor, buffer_sizes, num_slots));

LOG(INFO) << absl::StrFormat(
"CircularVmmPool: created %d slots for module %s on device %d "
"(%d command buffer allocations)",
num_slots, module_name_, executor->device_ordinal(),
command_buffer_allocation_indexes_.size());

pool_state->pool = std::move(pool);
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: pool initialization race -- pool == nullptr checked without lock

If two threads call ExecuteThunksWithCircularVmmPool concurrently for the same executor before the pool is initialized, both will see pool_state->pool == nullptr and both will create a pool. The second std::move(pool) assignment overwrites the first, and any iteration already in-flight on the first pool will use stale state.

This should use the same double-checked locking pattern as indexes_cached (once that is fixed to use atomics), or simply hold pool_state->mu during initialization:

if (pool_state->pool == nullptr) {
  absl::MutexLock lock(&pool_state->mu);
  if (pool_state->pool == nullptr) {
    // ... create pool ...
  }
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. Pool initialization is now guarded by atomic initialized + per-pool mutex double-checked locking. The pool creation, index computation, and initialized flag are all protected within the same critical section.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: pool initialization is now inside an atomic double-checked locking pattern with a per-pool mutex, eliminating the race.

}

auto* pool = static_cast<se::gpu::CircularVmmPool*>(pool_state->pool.get());
uint64_t iteration = pool_state->iteration_count.fetch_add(1);
int slot_idx = iteration % pool->num_slots();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] slot_idx is only used in the VLOG(3) on line 1601. In builds where VLOG(3) compiles out, this may trigger an unused-variable warning. Consider moving this computation inside the VLOG statement or guarding it with VLOG_IS_ON(3).


// Acquire next slot — non-blocking check of GPU timeline counter.
TF_ASSIGN_OR_RETURN(auto slot_addresses, pool->AcquireNextSlot(iteration));

VLOG(3) << absl::StrFormat(
"CircularVmmPool iter=%d slot=%d/%d: %d pool addrs",
iteration, slot_idx, pool->num_slots(), slot_addresses.size());

// Build remapped buffer allocations: all pooled buffers use pool VA
// addresses; constants and non-command-buffer buffers keep BFC addresses.
// For params/live-out, copy data from BFC into pool VA before execution.
std::vector<se::DeviceAddressBase> mapped_buffers;
mapped_buffers.reserve(buffer_allocations.size());
int slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);

// Copy param data from BFC into pool before execution. This is needed
// because the graph uses stable pool VA addresses, but the actual data
// lives at BFC addresses which may change. For params, the data itself
// may also change (e.g., optimizer weight updates in training).
if (copy_indexes.contains(i) && !bfc_addr.is_null() &&
bfc_addr.size() > 0) {
se::DeviceAddressBase pool_dst(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&pool_dst, bfc_addr, bfc_addr.size()));
}
mapped_buffers.push_back(pool_addr);
} else {
mapped_buffers.push_back(buffer_allocations.GetDeviceAddress(i));
}
}

BufferAllocations remapped_buffer_allocations(
mapped_buffers, buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());

TF_RETURN_IF_ERROR(ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr, module_name_,
unique_id, *thunk_executor_, executable_source, run_options,
remapped_buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_));

// Copy live-out results back from pool to BFC so the output appears at
// the expected BFC address for downstream consumers.
slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
if (copy_indexes.contains(i) && buffer_assignment_) {
const auto& alloc = buffer_assignment_->GetAllocation(i);
if (alloc.maybe_live_out()) {
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);
if (!bfc_addr.is_null() && bfc_addr.size() > 0) {
se::DeviceAddressBase bfc_dst(bfc_addr.opaque(), bfc_addr.size());
se::DeviceAddressBase pool_src(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&bfc_dst, pool_src, bfc_addr.size()));
}
}
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: live-out copy-back not synchronized when block_host_until_done is true

If block_host_until_done is true, ExecuteThunksImpl synchronizes the stream before returning. These D2D copy-back operations are enqueued after that sync, meaning the caller may read stale data at the BFC address.

Consider adding a stream synchronization after the copy-back when block_host_until_done is true:

if (block_host_until_done) {
  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The copy-back D2D memcpy calls now happen before BlockHostUntilDone, so the sync correctly covers both the thunk execution and the copy-back.

}

if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
Comment on lines +1666 to +1672
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] ReleaseSlot (which enqueues hipStreamWriteValue64) is called after BlockHostUntilDone. This means the GPU-side timeline signal write is not covered by the host synchronization — if the caller destroys the pool (or the process exits) immediately after this function returns with block_host_until_done=true, the signal write may not have executed yet.

Consider swapping the order so ReleaseSlot is enqueued before the sync:

Suggested change
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

This ensures the timeline signal is included in the synchronization fence.


return absl::OkStatus();
#else
return absl::UnimplementedError(
"Circular VMM pool is only supported on ROCm.");
#endif
}

absl::Status GpuExecutable::ExecuteThunks(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options) {
Expand Down Expand Up @@ -1587,21 +1744,32 @@ absl::Status GpuExecutable::ExecuteThunks(

se::StreamExecutor* executor = run_options->stream()->parent();

bool has_cmd_buffer_allocs = !command_buffer_allocation_indexes_.empty();

// Check if circular VMM pool is enabled (takes priority over VA remapping).
bool enable_circular_vmm_pool =
has_cmd_buffer_allocs && has_module() &&
module_config().debug_options().xla_gpu_enable_circular_vmm_pool();

// Check if command buffer VA remapping is enabled.
bool enable_command_buffer_va_remapping =
(command_buffer_allocation_indexes_.size() > 0) && has_module() &&
!enable_circular_vmm_pool && has_cmd_buffer_allocs && has_module() &&
module_config()
.debug_options()
.xla_gpu_enable_command_buffer_va_remapping() &&
dynamic_cast<se::DeviceAddressVmmAllocator*>(memory_allocator) != nullptr;

XLA_VLOG_DEVICE(3, executor->device_ordinal()) << absl::StreamFormat(
"ExecuteThunks: command_buffer_allocation_indexes_.size()=%d "
"enable_command_buffer_va_remapping=%d",
command_buffer_allocation_indexes_.size(),
"ExecuteThunks: cmd_buffer_allocs=%d circular_vmm_pool=%d "
"va_remapping=%d",
command_buffer_allocation_indexes_.size(), enable_circular_vmm_pool,
enable_command_buffer_va_remapping);

if (enable_command_buffer_va_remapping) {
if (enable_circular_vmm_pool) {
TF_RETURN_IF_ERROR(ExecuteThunksWithCircularVmmPool(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
} else if (enable_command_buffer_va_remapping) {
TF_RETURN_IF_ERROR(ExecuteThunksWithVaRemapping(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
Expand Down
22 changes: 22 additions & 0 deletions xla/service/gpu/gpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_H_
#define XLA_SERVICE_GPU_GPU_EXECUTABLE_H_

#include <atomic>
#include <cstdint>
#include <deque>
#include <memory>
Expand Down Expand Up @@ -430,6 +431,27 @@ class GpuExecutable : public Executable {
absl::node_hash_map<stream_executor::StreamExecutor*, VaRanges>
module_va_ranges_ ABSL_GUARDED_BY(va_ranges_mutex_);

// Circular VMM pool: pre-allocated slots with permanent VA mappings and
// GPU timeline signaling for safe slot reuse. ROCm-only.
absl::Status ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
stream_executor::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done);

struct CircularPoolState {
std::shared_ptr<void> pool;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: std::shared_ptr<void> for type erasure is fragile

This works correctly because unique_ptr<CircularVmmPool> converting to shared_ptr<void> captures the right deleter. However, it sacrifices type safety and requires static_cast<CircularVmmPool*> at every use site.

Since the header already has a #if TENSORFLOW_USE_ROCM include for CircularVmmPool in the .cc file, consider forward-declaring CircularVmmPool here and using std::unique_ptr<se::gpu::CircularVmmPool> directly (with a custom deleter or forward-declared destructor). This eliminates the type-erased casts.

std::atomic<uint64_t> iteration_count{0};
absl::Mutex mu;
// Cached buffer classification (computed once at init, reused every iter).
absl::btree_set<BufferAllocation::Index> pool_indexes;
absl::btree_set<BufferAllocation::Index> copy_indexes;
bool indexes_cached = false;
};
absl::Mutex circular_pool_mutex_;
absl::node_hash_map<stream_executor::StreamExecutor*, CircularPoolState>
circular_pools_ ABSL_GUARDED_BY(circular_pool_mutex_);

GpuExecutable(const GpuExecutable&) = delete;
GpuExecutable& operator=(const GpuExecutable&) = delete;

Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,11 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCommandBufferThunk(

bool enable_loop_unroll = ir_emitter_context_->debug_options()
.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping = ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping();
bool enable_va_remapping =
ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping() ||
ir_emitter_context_->debug_options()
.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
Loading
Loading