Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion xla/backends/gpu/runtime/command_buffer_conversion_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,8 @@ ConvertThunksToCommandBuffer(
const DebugOptions& debug_options) {
bool enable_loop_unroll = debug_options.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping =
debug_options.xla_gpu_enable_command_buffer_va_remapping();
debug_options.xla_gpu_enable_command_buffer_va_remapping() ||
debug_options.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
35 changes: 19 additions & 16 deletions xla/backends/gpu/runtime/command_buffer_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -198,20 +198,18 @@ absl::Status CommandBufferThunk::Initialize(const InitializeParams& params) {
/*additional_compute_streams=*/{}, params.execution_scoped_state,
/*mock_collectives=*/false);

// If command buffer is in `kCreate` state it means that command buffer
// sequence was never recorded into it. We initialize all command buffers
// before execution, because command buffers when instantiated will allocate
// memory on device and this might lead to deadlocks when we have concurrent
// NCCL operations in flight.
//
// If commands require initialization (and VA remapping is not enabled), we
// also record them into the command buffer before execution. This is required
// to guarantee that collective commands are recorded on all participating
// ranks to avoid deadlocks.
if (cmd_buffer->warmup_done && (cmd_buffer->command_buffer->state() ==
se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()))) {
bool warmup = cmd_buffer->warmup_done;
auto state = cmd_buffer->command_buffer->state();
bool will_record = warmup && (state == se::CommandBuffer::State::kCreate ||
(!enable_command_buffer_va_remapping_ &&
commands_.requires_initialization()));
VLOG(3) << absl::StrFormat(
"CommandBufferThunk::Initialize: warmup_done=%d state=%d "
"va_remapping=%d requires_init=%d will_record=%d",
warmup, static_cast<int>(state), enable_command_buffer_va_remapping_,
commands_.requires_initialization(), will_record);

if (will_record) {
VLOG(3) << "Initialize command buffer on device #"
<< params.executor->device_ordinal()
<< " by recoding command buffer cmd sequence"
Expand Down Expand Up @@ -279,14 +277,19 @@ absl::Status CommandBufferThunk::ExecuteOnStream(const ExecuteParams& params) {

auto updated_allocs = cmd_buffer->UpdateBufferAllocations(commands_, params);

// Determine whether to (re-)record the command buffer and whether this is a
// first-time initialization recording (VA remapping path).
bool is_first_record =
enable_command_buffer_va_remapping_ &&
cmd_buffer->command_buffer->state() == se::CommandBuffer::State::kCreate;
bool needs_update = !enable_command_buffer_va_remapping_ &&
(!updated_allocs.empty() || commands_.force_update());

VLOG(3) << absl::StrFormat(
"CommandBufferThunk::ExecuteOnStream: va_remapping=%d updated_allocs=%d "
"is_first_record=%d needs_update=%d num_executions=%d state=%d",
enable_command_buffer_va_remapping_, updated_allocs.size(),
is_first_record, needs_update, cmd_buffer->num_executions,
static_cast<int>(cmd_buffer->command_buffer->state()));

if (is_first_record || needs_update) {
XLA_VLOG_DEVICE(3, executor->device_ordinal())
<< "Create/Update command buffer"
Expand Down
17 changes: 17 additions & 0 deletions xla/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ DebugOptions DefaultDebugOptionsIgnoringFlags() {

opts.set_xla_gpu_enable_pdl(true);
opts.set_xla_gpu_enable_command_buffer_va_remapping(false);
opts.set_xla_gpu_enable_circular_vmm_pool(false);
opts.set_xla_gpu_circular_vmm_pool_slots(1);
return opts;
}

Expand Down Expand Up @@ -3012,6 +3014,21 @@ void MakeDebugOptionsFlags(std::vector<tsl::Flag>* flag_list,
"Enable VA remapping for command buffer thunks. When enabled, command "
"buffer thunks use fixed virtual addresses across executions, allowing "
"the command buffer to be recorded once and replayed without updates."));
flag_list->push_back(tsl::Flag(
"xla_gpu_enable_circular_vmm_pool",
bool_setter_for(
&DebugOptions::set_xla_gpu_enable_circular_vmm_pool),
debug_options->xla_gpu_enable_circular_vmm_pool(),
"Enable circular VMM pool for command buffer thunks. Pre-allocates N "
"physical memory slots with permanent VA mappings, using GPU timeline "
"signaling for safe slot reuse. Eliminates per-iteration map/unmap "
"overhead entirely after startup."));
flag_list->push_back(tsl::Flag(
"xla_gpu_circular_vmm_pool_slots",
int32_setter_for(
&DebugOptions::set_xla_gpu_circular_vmm_pool_slots),
debug_options->xla_gpu_circular_vmm_pool_slots(),
"Number of slots in the circular VMM pool (default 1)."));
} // NOLINT(readability/fn_size)

// Allocates flag_values and flag_objects; this function must not be called more
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: default value / help text mismatch

The actual default is set to 1 (line 507: set_xla_gpu_circular_vmm_pool_slots(1)), but the flag help text here says "(default 2)".

With 1 slot, AcquireNextSlot will spin-wait for GPU completion on every iteration (no overlap), negating the primary benefit of the circular pool. Either fix the help text to say "default 1" or change the default to 2.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Partially resolved. The flag help text in debug_options_flags.cc now correctly says "(default 1)", but the proto comment in xla.proto still says "(default 2)" -- see xla_gpu_circular_vmm_pool_slots field comment. Please update the proto comment to match the actual default of 1.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- fully addressed in this revision. Both the flag help text in debug_options_flags.cc and the proto comment in xla.proto now correctly say "(default 1)".

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved: both the flag help text (line 3034) and proto comment now correctly say "(default 1)", matching the actual default value.

Expand Down
5 changes: 4 additions & 1 deletion xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,7 @@ cc_library(
name = "gpu_executable",
srcs = ["gpu_executable.cc"],
hdrs = ["gpu_executable.h"],
local_defines = if_rocm_is_configured(["TENSORFLOW_USE_ROCM=1"]),
deps = [
":alias_info",
":backend_configs_cc",
Expand Down Expand Up @@ -804,7 +805,9 @@ cc_library(
"@tsl//tsl/platform:random",
"@tsl//tsl/profiler/lib:scoped_annotation",
"@tsl//tsl/profiler/lib:traceme",
],
] + if_rocm_is_configured([
"//xla/stream_executor/rocm:circular_vmm_pool",
]),
)

tf_proto_library(
Expand Down
184 changes: 179 additions & 5 deletions xla/service/gpu/gpu_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/stream_executor/sycl/sycl_platform_id.h"
#include "xla/stream_executor/vmm_device_address_allocator.h"
#if TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/circular_vmm_pool.h"
#endif
#include "xla/tsl/platform/env.h"
#include "xla/tsl/platform/env_time.h"
#include "xla/tsl/platform/logging.h"
Expand Down Expand Up @@ -1515,6 +1518,166 @@ absl::Status GpuExecutable::ExecuteThunksWithVaRemapping(
return absl::OkStatus();
}

absl::Status GpuExecutable::ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
se::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done) {
#if TENSORFLOW_USE_ROCM
CircularPoolState* pool_state = nullptr;
{
absl::MutexLock lock(&circular_pool_mutex_);
pool_state = &circular_pools_[executor];
}
Comment on lines +1529 to +1531
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: thread safety — pool_state accessed without lock

The mutex circular_pool_mutex_ is held only while looking up/inserting the CircularPoolState entry. After that, pool_state->pool, pool_state->iteration_count++, etc. are accessed without the lock.

If two threads call ExecuteThunksWithCircularVmmPool for the same executor concurrently (e.g., multi-threaded inference), they can race on iteration_count++ (non-atomic, no lock) — two iterations could use the same slot concurrently, corrupting data.

Consider either:

  1. Holding the mutex for the entire function body (but watch for deadlocks with GPU calls), or
  2. Making iteration_count an std::atomic<uint64_t> and adding a per-pool mutex for the initialization path.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. iteration_count is now std::atomic<uint64_t> with fetch_add(1), and pool initialization uses double-checked locking with a per-pool mutex.


// One-time initialization: compute buffer classification and create pool.
// Uses std::atomic<bool> initialized + mutex for safe double-checked locking.
if (!pool_state->initialized.load(std::memory_order_acquire)) {
absl::MutexLock lock(&pool_state->mu);
if (!pool_state->initialized.load(std::memory_order_relaxed)) {
if (buffer_assignment_) {
for (BufferAllocation::Index idx : command_buffer_allocation_indexes_) {
const auto& alloc = buffer_assignment_->GetAllocation(idx);
if (alloc.is_constant() || alloc.size() == 0) continue;
pool_state->pool_indexes.insert(idx);
if (alloc.is_entry_computation_parameter() ||
alloc.maybe_live_out()) {
pool_state->copy_indexes.insert(idx);
}
}
}

if (!pool_state->pool_indexes.empty()) {
int num_slots = has_module()
? module_config().debug_options().xla_gpu_circular_vmm_pool_slots()
: 1;

std::vector<uint64_t> buffer_sizes;
buffer_sizes.reserve(pool_state->pool_indexes.size());
for (BufferAllocation::Index idx : pool_state->pool_indexes) {
buffer_sizes.push_back(
buffer_allocations.GetDeviceAddress(idx).size());
}

TF_ASSIGN_OR_RETURN(
auto pool,
se::gpu::CircularVmmPool::Create(executor, buffer_sizes,
num_slots));

LOG(INFO) << absl::StrFormat(
"CircularVmmPool: created %d slots for module %s on device %d "
"(%d command buffer allocations)",
num_slots, module_name_, executor->device_ordinal(),
command_buffer_allocation_indexes_.size());

pool_state->pool = std::move(pool);
}

pool_state->initialized.store(true, std::memory_order_release);
}
}

// After initialization, pool_indexes/copy_indexes are immutable — safe to
// read without lock.
const auto& pool_indexes = pool_state->pool_indexes;
const auto& copy_indexes = pool_state->copy_indexes;

if (pool_state->pool == nullptr) {
return ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr,
module_name_, unique_id, *thunk_executor_, executable_source,
run_options, buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_);
}

auto* pool = pool_state->pool.get();
uint64_t iteration = pool_state->iteration_count.fetch_add(1);
int slot_idx = iteration % pool->num_slots();
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Nit] slot_idx is only used in the VLOG(3) on line 1601. In builds where VLOG(3) compiles out, this may trigger an unused-variable warning. Consider moving this computation inside the VLOG statement or guarding it with VLOG_IS_ON(3).


// Acquire next slot — non-blocking check of GPU timeline counter.
TF_ASSIGN_OR_RETURN(auto slot_addresses, pool->AcquireNextSlot(iteration));

VLOG(3) << absl::StrFormat(
"CircularVmmPool iter=%d slot=%d/%d: %d pool addrs",
iteration, slot_idx, pool->num_slots(), slot_addresses.size());

// Build remapped buffer allocations: all pooled buffers use pool VA
// addresses; constants and non-command-buffer buffers keep BFC addresses.
// For params/live-out, copy data from BFC into pool VA before execution.
std::vector<se::DeviceAddressBase> mapped_buffers;
mapped_buffers.reserve(buffer_allocations.size());
int slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);

// Copy param data from BFC into pool before execution. This is needed
// because the graph uses stable pool VA addresses, but the actual data
// lives at BFC addresses which may change. For params, the data itself
// may also change (e.g., optimizer weight updates in training).
if (copy_indexes.contains(i) && !bfc_addr.is_null() &&
bfc_addr.size() > 0) {
se::DeviceAddressBase pool_dst(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&pool_dst, bfc_addr, bfc_addr.size()));
}
mapped_buffers.push_back(pool_addr);
} else {
mapped_buffers.push_back(buffer_allocations.GetDeviceAddress(i));
}
}

BufferAllocations remapped_buffer_allocations(
mapped_buffers, buffer_allocations.device_ordinal(),
buffer_allocations.memory_allocator());

TF_RETURN_IF_ERROR(ExecuteThunksImpl(
has_module() ? &module_config().debug_options() : nullptr, module_name_,
unique_id, *thunk_executor_, executable_source, run_options,
remapped_buffer_allocations, block_host_until_done,
execution_stream_ids_, collective_memory_cache_));

// Copy live-out results back from pool to BFC so the output appears at
// the expected BFC address for downstream consumers.
slot_addr_idx = 0;
for (BufferAllocation::Index i = 0;
i < static_cast<BufferAllocation::Index>(buffer_allocations.size());
++i) {
if (pool_indexes.contains(i)) {
auto pool_addr = slot_addresses[slot_addr_idx++];
if (copy_indexes.contains(i) && buffer_assignment_) {
const auto& alloc = buffer_assignment_->GetAllocation(i);
if (alloc.maybe_live_out()) {
auto bfc_addr = buffer_allocations.GetDeviceAddress(i);
if (!bfc_addr.is_null() && bfc_addr.size() > 0) {
se::DeviceAddressBase bfc_dst(bfc_addr.opaque(), bfc_addr.size());
se::DeviceAddressBase pool_src(pool_addr.opaque(), bfc_addr.size());
TF_RETURN_IF_ERROR(run_options->stream()->MemcpyD2D(
&bfc_dst, pool_src, bfc_addr.size()));
}
}
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential bug: live-out copy-back not synchronized when block_host_until_done is true

If block_host_until_done is true, ExecuteThunksImpl synchronizes the stream before returning. These D2D copy-back operations are enqueued after that sync, meaning the caller may read stale data at the BFC address.

Consider adding a stream synchronization after the copy-back when block_host_until_done is true:

if (block_host_until_done) {
  TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- addressed in this revision. The copy-back D2D memcpy calls now happen before BlockHostUntilDone, so the sync correctly covers both the thunk execution and the copy-back.

}

if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
Comment on lines +1666 to +1672
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Medium] ReleaseSlot (which enqueues hipStreamWriteValue64) is called after BlockHostUntilDone. This means the GPU-side timeline signal write is not covered by the host synchronization — if the caller destroys the pool (or the process exits) immediately after this function returns with block_host_until_done=true, the signal write may not have executed yet.

Consider swapping the order so ReleaseSlot is enqueued before the sync:

Suggested change
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
// GPU signals slot completion so the CPU knows when this slot is safe to
// reuse (non-blocking write via hipStreamWriteValue64).
TF_RETURN_IF_ERROR(pool->ReleaseSlot(run_options->stream(), iteration));
if (block_host_until_done) {
TF_RETURN_IF_ERROR(run_options->stream()->BlockHostUntilDone());
}

This ensures the timeline signal is included in the synchronization fence.


return absl::OkStatus();
#else
return absl::UnimplementedError(
"Circular VMM pool is only supported on ROCm.");
#endif
}

absl::Status GpuExecutable::ExecuteThunks(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options) {
Expand Down Expand Up @@ -1587,21 +1750,32 @@ absl::Status GpuExecutable::ExecuteThunks(

se::StreamExecutor* executor = run_options->stream()->parent();

bool has_cmd_buffer_allocs = !command_buffer_allocation_indexes_.empty();

// Check if circular VMM pool is enabled (takes priority over VA remapping).
bool enable_circular_vmm_pool =
has_cmd_buffer_allocs && has_module() &&
module_config().debug_options().xla_gpu_enable_circular_vmm_pool();

// Check if command buffer VA remapping is enabled.
bool enable_command_buffer_va_remapping =
(command_buffer_allocation_indexes_.size() > 0) && has_module() &&
!enable_circular_vmm_pool && has_cmd_buffer_allocs && has_module() &&
module_config()
.debug_options()
.xla_gpu_enable_command_buffer_va_remapping() &&
dynamic_cast<se::DeviceAddressVmmAllocator*>(memory_allocator) != nullptr;

XLA_VLOG_DEVICE(3, executor->device_ordinal()) << absl::StreamFormat(
"ExecuteThunks: command_buffer_allocation_indexes_.size()=%d "
"enable_command_buffer_va_remapping=%d",
command_buffer_allocation_indexes_.size(),
"ExecuteThunks: cmd_buffer_allocs=%d circular_vmm_pool=%d "
"va_remapping=%d",
command_buffer_allocation_indexes_.size(), enable_circular_vmm_pool,
enable_command_buffer_va_remapping);

if (enable_command_buffer_va_remapping) {
if (enable_circular_vmm_pool) {
TF_RETURN_IF_ERROR(ExecuteThunksWithCircularVmmPool(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
} else if (enable_command_buffer_va_remapping) {
TF_RETURN_IF_ERROR(ExecuteThunksWithVaRemapping(
buffer_allocations, run_options, executor, unique_id, executable_source,
block_host_until_done));
Expand Down
28 changes: 28 additions & 0 deletions xla/service/gpu/gpu_executable.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.
#ifndef XLA_SERVICE_GPU_GPU_EXECUTABLE_H_
#define XLA_SERVICE_GPU_GPU_EXECUTABLE_H_

#include <atomic>
#include <cstdint>
#include <deque>
#include <memory>
Expand Down Expand Up @@ -68,6 +69,10 @@ limitations under the License.
#include "xla/stream_executor/stream_executor.h"
#include "xla/xla.pb.h"

namespace stream_executor::gpu {
class CircularVmmPool;
} // namespace stream_executor::gpu

namespace xla {
namespace gpu {

Expand Down Expand Up @@ -430,6 +435,29 @@ class GpuExecutable : public Executable {
absl::node_hash_map<stream_executor::StreamExecutor*, VaRanges>
module_va_ranges_ ABSL_GUARDED_BY(va_ranges_mutex_);

// Circular VMM pool: pre-allocated slots with permanent VA mappings and
// GPU timeline signaling for safe slot reuse. ROCm-only.
absl::Status ExecuteThunksWithCircularVmmPool(
const BufferAllocations& buffer_allocations,
const ServiceExecutableRunOptions* run_options,
stream_executor::StreamExecutor* executor, int64_t unique_id,
Thunk::ExecutableSource executable_source, bool block_host_until_done);

struct CircularPoolState {
absl::Mutex mu;
// Typed pool pointer; destructor handled by shared_ptr deleter captured
// from the original unique_ptr<CircularVmmPool>.
std::shared_ptr<stream_executor::gpu::CircularVmmPool> pool
ABSL_GUARDED_BY(mu);
std::atomic<uint64_t> iteration_count{0};
absl::btree_set<BufferAllocation::Index> pool_indexes ABSL_GUARDED_BY(mu);
absl::btree_set<BufferAllocation::Index> copy_indexes ABSL_GUARDED_BY(mu);
Comment on lines +450 to +454
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: ABSL_GUARDED_BY(mu) annotations are inaccurate after initialization

pool, pool_indexes, and copy_indexes are annotated ABSL_GUARDED_BY(mu), but after the one-time initialization they are read without holding mu in ExecuteThunksWithCircularVmmPool (lines 1582-1585, 1593 of gpu_executable.cc). This is safe because the atomic double-checked locking on initialized guarantees the fields are immutable post-init, but the annotations will produce false positives with Clang's thread safety analysis (-Wthread-safety).

Consider either:

  1. Removing the ABSL_GUARDED_BY(mu) annotations and adding a comment that these fields are write-once under mu and immutable thereafter (reads protected by the initialized acquire fence), or
  2. Keeping the lock held for reads as well (unnecessary overhead for the hot path).

Comment on lines +450 to +454
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Low] pool, pool_indexes, and copy_indexes are annotated ABSL_GUARDED_BY(mu), but after initialization they are read without holding mu in the hot path (lines 1582-1593 of gpu_executable.cc). This is functionally safe because these are write-once fields protected by the initialized atomic, but the annotations are inconsistent and will produce false positives if Clang's thread safety analysis is enabled.

Consider either:

  1. Removing ABSL_GUARDED_BY(mu) (since the fields are effectively immutable post-init and the atomic<bool> initialized provides the barrier), or
  2. Adding ABSL_NO_THREAD_SAFETY_ANALYSIS to the hot-path reader function.

std::atomic<bool> initialized{false};
};
absl::Mutex circular_pool_mutex_;
absl::node_hash_map<stream_executor::StreamExecutor*, CircularPoolState>
circular_pools_ ABSL_GUARDED_BY(circular_pool_mutex_);

GpuExecutable(const GpuExecutable&) = delete;
GpuExecutable& operator=(const GpuExecutable&) = delete;

Expand Down
7 changes: 5 additions & 2 deletions xla/service/gpu/thunk_emitter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,11 @@ absl::StatusOr<ThunkSequence> ThunkEmitter::EmitCommandBufferThunk(

bool enable_loop_unroll = ir_emitter_context_->debug_options()
.xla_gpu_command_buffer_unroll_loops();
bool enable_va_remapping = ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping();
bool enable_va_remapping =
ir_emitter_context_->debug_options()
.xla_gpu_enable_command_buffer_va_remapping() ||
ir_emitter_context_->debug_options()
.xla_gpu_enable_circular_vmm_pool();
TF_ASSIGN_OR_RETURN(
CommandExecutor cmd_executor,
ConvertToCommands(
Expand Down
Loading
Loading