Skip to content

Phambinh/rocm vmm allocator#767

Open
phambinhfin wants to merge 4 commits intomainfrom
phambinh/rocm-vmm-allocator
Open

Phambinh/rocm vmm allocator#767
phambinhfin wants to merge 4 commits intomainfrom
phambinh/rocm-vmm-allocator

Conversation

@phambinhfin
Copy link
Copy Markdown

@phambinhfin phambinhfin commented Mar 30, 2026

📝 Summary of Changes

Add VMM (Virtual Memory Management) support for ROCm in XLA, using HIP's native VMM APIs (hipMemCreate, hipMemAddressReserve, hipMemMap, hipMemSetAccess, hipStreamWriteValue64). This adds 12 new files and modifies 7 existing files across xla/stream_executor/rocm/ and xla/pjrt/gpu/.

Four layers following the existing XLA VMM architecture:

  • Layer 1RocmRawMemoryAllocation: RAII wrapper for hipMemCreate/hipMemRelease
  • Layer 2RocmMemoryReservation: RAII wrapper for hipMemAddressReserve/hipMemMap/hipMemSetAccess/hipMemUnmap
  • Layer 3RocmVmmAllocator: All-in-one allocator (create + reserve + map + setAccess)
  • Layer 4RocmDeviceAddressVmmAllocator: GPU timeline-based deferred deallocation via hipStreamWriteValue64 + coherent host memory

PJRT integration wires RocmDeviceAddressVmmAllocator into the GPU client so XLA_PYTHON_CLIENT_ALLOCATOR=vmm works on ROCm (previously returned "VMM allocator is only supported with CUDA").

🎯 Justification

The VMM allocator separates virtual address reservation from physical memory allocation. This is a prerequisite for update-free command buffer execution — stable virtual addresses allow command buffers (HIP graphs) to be replayed without re-recording when buffer allocation addresses change between executions (see upstream openxla#39672).

Additionally, VMM provides per-device access control via hipMemSetAccess, enabling fine-grained control over which GPUs can access a given allocation — replacing the all-or-nothing hipDeviceMallocFinegrained approach.

🚀 Kind of Contribution

✨ New Feature, 🧪 Tests

📊 Benchmark (for Performance Improvements)

Not applicable — this is infrastructure for future command buffer optimizations. Performance benchmarks will be provided when the VA remapping execution path is enabled on ROCm.

🧪 Unit Tests

4 test targets (26 tests total):

  • rocm_raw_memory_allocation_test (3 tests) — CreateAllocation, AddressReflectsHandle, SizeIsAtLeastRequested
  • rocm_memory_reservation_test (6 tests) — CreateReservation, MapToWrongType, MapToSingleAllocation, ScopedMappingUnmapsOnDestruction, MapToMultipleAllocations, TwoReservationsDifferentAddresses
  • rocm_vmm_allocator_test (3 tests) — AllocateAndFree, AllocateZeroBytes, MemcpyRoundTrip
  • rocm_device_address_vmm_allocator_test (14 tests) — Covers allocate/deallocate, memory read/write, stream accessors, deferred deallocation, VA reuse, destructor safety, and error handling

All tests use GTEST_SKIP() when a ROCm executor is not available, allowing graceful skip in sandboxed CI environments.

🧪 Execution Tests

  • VmmAllocatorCanBeSet — Verifies that the PJRT client correctly instantiates DeviceAddressVmmAllocator when GpuAllocatorConfig::Kind::kVmm is set (shared test for CUDA and ROCm)
  • VmmAllocatorE2ETest — End-to-end test: creates a VMM-backed PJRT client, compiles and executes an HLO add program, asserts correctness of the result
  • 2 multi-GPU tests in rocm_device_address_vmm_allocator_test — Verify cross-device allocation isolation (requires 2 GPUs)

All execution tests run with at most 2 GPUs.

@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Review Summary

This PR ports the CUDA VMM (Virtual Memory Management) allocator to ROCm/HIP, adding 12 new files with a clean four-layer design (physical memory, virtual address reservation, simple VMM allocator, timeline-based VMM allocator) plus PJRT integration changes. The implementation closely mirrors the CUDA equivalent and enables per-device access control and L2 cache utilization for P2P GPU memory on AMD hardware.

Overall: Well-structured port with good test coverage. A few items to address:

  • P2P access gating (main issue): The ROCm VMM allocator calls hipMemSetAccess for all peer devices unconditionally, unlike the CUDA version which checks CanEnablePeerAccessTo() first. This produces spurious warnings and unnecessary syscalls on systems without full P2P connectivity.
  • Signal memory semantics: timeline_dev_ptr stores a host pointer (not a device pointer) on ROCm due to HIP signal memory semantics — worth a clarifying comment.
  • Minor items: unused is_rdma_supported parameter, duplicated BuildAllocationProperties helper, deleted move ops on RocmMemoryReservation without explanation.

See inline comments for details.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from 86144ed to 9dd51da Compare March 30, 2026 12:30
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
absl::StrFormat("Failed to query available memory from device %i",
executor->device_ordinal()));
}
uint64_t pa_budget = total_memory * memory_fraction;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No bounds check on memory_fraction. total_memory * memory_fraction uses implicit doubleuint64_t conversion. If memory_fraction exceeds 1.0 (e.g. misconfiguration), the product could overflow or produce an unexpectedly large budget. Consider adding a DCHECK that memory_fraction is in [0, 1], or at minimum clamping the value. The CUDA version has the same gap.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. A DCHECK now validates that memory_fraction is in [0, 1].

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Re-review Summary

All 6 previous review findings have been addressed — resolved threads posted as replies. Nice work on the updates (peer access gating, documentation comments, parameter naming).

4 new findings posted inline:

  1. Resource leak on error after hipMemMap (rocm_vmm_allocator.cc) — missing cleanup if post-map operations fail
  2. No bounds check on memory_fraction (rocm_device_address_vmm_allocator.cc) — implicit double→uint64_t conversion without validation
  3. No runtime capability check (rocm_device_address_vmm_allocator.cc) — unlike CUDA, no runtime probe for VMM/signal memory support on the specific GPU
  4. address() semantics (rocm_raw_memory_allocation.cc) — handle-as-address encoding deserves a clarifying comment

Finding #1 is the most actionable (potential resource leak). The others are defensive improvements and documentation nits.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from 9dd51da to e3ad555 Compare March 30, 2026 12:48
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Re-review Summary

All 10 previously flagged issues have been addressed in this revision:

  • 4 newly resolved: resource leak after hipMemMap, DCHECK on memory_fraction, runtime VMM capability probe, address() documentation
  • 6 previously resolved: P2P access gating, is_rdma_supported docs, BuildAllocationProperties duplication docs, move-disabled rationale, timeline_dev_ptr naming, hipFree correctness

No new issues found. The ROCm VMM allocator implementation looks solid — clean four-layer architecture matching the CUDA counterpart with appropriate ROCm-specific adaptations.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@draganmladjenovic draganmladjenovic self-requested a review March 30, 2026 13:29
@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from e3ad555 to 1ee9570 Compare March 30, 2026 13:45
@phambinhfin phambinhfin added the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
Comment on lines +49 to +50
RocmRawMemoryAllocation& operator=(const RocmRawMemoryAllocation&) = delete;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: RocmRawMemoryAllocation deletes copy operations but does not explicitly delete move operations. Since a user-declared destructor exists, the implicit move constructor/assignment are not generated -- so this is safe. However, RocmMemoryReservation explicitly deletes moves with a comment explaining why. For consistency and clarity, consider adding = delete for move operations here as well.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. Move operations are now explicitly deleted with = delete.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- move operations are now explicitly deleted with = delete.

if (executor == nullptr) {
// Skips remote devices.
continue;
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The // Skips remote devices. comment was removed here. This comment explained why executor == nullptr is expected (remote devices have no local executor). Removing it reduces context for future readers. Was this intentional, or should the comment be preserved?

pjrt_se_client->allocator()),
nullptr);
}
#endif // GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This #endif closes a block that includes both #if GOOGLE_CUDA and #elif TENSORFLOW_USE_ROCM (line 3623), so the comment should be // GOOGLE_CUDA || TENSORFLOW_USE_ROCM to match the one at line 3677.

Suggested change
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. The #endif comment now reads // GOOGLE_CUDA || TENSORFLOW_USE_ROCM.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved -- #endif comment now reads // GOOGLE_CUDA || TENSORFLOW_USE_ROCM.

#include "xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h"
#elif TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/rocm_device_address_vmm_allocator.h"
#endif // GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This #endif now closes a block that includes #elif TENSORFLOW_USE_ROCM (line 101), so the comment should be updated for consistency with the fix at line 3642.

Suggested change
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

#include "xla/stream_executor/cuda/cuda_device_address_vmm_allocator.h"
#elif TENSORFLOW_USE_ROCM
#include "xla/stream_executor/rocm/rocm_device_address_vmm_allocator.h"
#endif // GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This #endif now closes a block that includes #elif TENSORFLOW_USE_ROCM (line 112), so the comment should be updated.

Suggested change
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

return absl::UnimplementedError(
"VMM allocator is only supported with CUDA.");
"VMM allocator is only supported with CUDA or ROCm.");
#endif // GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This #endif now closes a block that includes #elif TENSORFLOW_USE_ROCM (line 682), so the comment should be updated.

Suggested change
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

return absl::UnimplementedError(
"VMM allocator is only supported with CUDA.");
"VMM allocator is only supported with CUDA or ROCm.");
#endif // GOOGLE_CUDA
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: This #endif now closes a block that includes #elif TENSORFLOW_USE_ROCM (line 1546), so the comment should be updated.

Suggested change
#endif // GOOGLE_CUDA
#endif // GOOGLE_CUDA || TENSORFLOW_USE_ROCM

@claude
Copy link
Copy Markdown

claude bot commented Mar 30, 2026

Re-review Summary

Most of the 14 issues from the initial review have been addressed — thanks for the thorough fixes.

Still open (unchanged, no action needed):

  • const_cast in rocm_vmm_allocator.cc:127 (known nit, mirrored from CUDA)
  • Cosmetic divergence in tfrt/utils.cc:672 (// Skips remote devices. comment removed from CUDA path but kept in ROCm path)

New nits (4 inline comments posted):

  • Four #endif // GOOGLE_CUDA comments that should read // GOOGLE_CUDA || TENSORFLOW_USE_ROCM now that the blocks include #elif TENSORFLOW_USE_ROCM:
    • se_gpu_pjrt_client_test.cc:103
    • tfrt/utils.cc:114
    • tfrt/utils.cc:701
    • se_gpu_pjrt_client.cc:1559

All four have inline code suggestions attached.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 30, 2026
@draganmladjenovic
Copy link
Copy Markdown

Ok lets start with premise:

ROCm currently uses hipExtMallocWithFlags(hipDeviceMallocFinegrained) for P2P GPU memory. This approach has a limitation:

No per-device access control — all GPUs get read/write access with no way to restrict individual devices
CUDA already solves this with VMM (Virtual Memory Management) APIs (cuMemCreate, cuMemAddressReserve, cuMemMap, cuMemSetAccess) which allow per-device access permissions. HIP has equivalent VMM APIs since ROCm 6.0 but they were not wired into XLA.
This PR ports the CUDA VMM allocator to ROCm/HIP, closing a critical gap between CUDA and ROCm for multi-GPU P2P memory management.

The above makes no sense whatsoever. We use P2P in handwritten all reduce kernel and it should be used there. It serves a particular purpose to allow coherent memory access between GPUs on MI200. Now I'm not sure if it is uncached or not on MI300 or other (cache level beyond L2 might handle coherency) and it does not mater in the end. It is not something to be replaced. I believe upstream VMM allocator is part of this effort to make command buffers work w/o need of retracing on pointer change. openxla#39672 That being said this should not go upstream until we see the bigger picture and change lands on the cuda side.

int ordinal = state.executor->device_ordinal();

// Runtime probe + query allocation granularity for this device.
// The compile-time gate (TF_ROCM_VERSION >= 60000) ensures the symbols exist,
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no compile-time gate here. Nor do we need one. We make no promise that the code will work on rocm 6.0

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed the compile-time gate comment. The code now just says // Query allocation granularity for this device.

github-actions bot pushed a commit that referenced this pull request Mar 30, 2026
Commit history for google/XNNPACK (03749037 -> e01449b8):
- 0845a579 Ken Unger: add rvv kernel for qs8/qu8-vcvt
- 2ff3d030 Ken Unger: Merge branch 'master' into qs8-vcvt
- be751db6 Ken Unger: add rvv kernels for qd8-f16-qc4w-gemm, qd8-f16-qc8w-gemm, qd8-f16-qc8w-igemm
- 838fa65c Ken Unger: rework rvv qs8-gemm and qs8_igemm scripts and generated kernels
- 04c9d006 Ken Unger: cleanup previous commit
- 231cf2f8 Ken Unger: fix my fix to f16-qb4w error check
- d13becc4 Ken Unger: Merge branch 'google:master' into qd8-f16-gemm-rvv
- 97ce3174 Volodymyr Kysenko: Use convert operations from simd wrappers in elementwise kernels.
- de3504fd Dillon Sharlet: Allow redundant reduction axes, and out of bounds reduction axes
- 83d6eeb0 Ken Unger: add or update f16-f32acc-rdsum, f16-f32acc-rdsum2, f16-f32acc-rsum, f16-f32acc-rsum2, f16-rminmax, f16-rdminmax, f32-rdsum, f32-rdsum2, f32-rsum2, s8-rdminmax, s8-rminmax, u8-rdminmax, u8-rminmax, f32-rdminmax
- c3ea126c Ken Unger: cleanup
- d23f6931 Dillon Sharlet: Treat axes as a set in get_tensor_shape when treating the axes as a set.
- c89e13a2 Quentin Khan: Fix error in `reshape_reduce_nd(...)` axis normalization.
- a0cdda2d Frank Barchard: NEONDOT qd8 GEMM microkernels add MR=7 and 8
- 65ab17aa Volodymyr Kysenko: Add saturating add and sub functions to simd wrappers.
- b1ba7db0 Volodymyr Kysenko: Use simd wrappers for saturating add and sub in elementwise kernels.
- dfdfaf08 Volodymyr Kysenko: Don't keep explicit list of types for the arch in the elementwise compiler.
- dde964b4 Dillon Sharlet: Update to QEMU 10.2.1
- 474163f8 Dillon Sharlet: Fix crash when rewriting reshape(reshape(x))
- 10eb0d18 Frank Barchard: QD8_F32_QC2W use GFNI to isolate 2 bits
- 95af55e5 Dillon Sharlet: Re-enable SME2 build on GitHub Actions
- 2ffb3d2e Ken Unger: cleanup per review comments
- 94713d86 Dillon Sharlet: Disable the bazel repository cache
- 995553ea Dillon Sharlet: Update rules_cc
- 925bce2e Dillon Sharlet: Fix fma emulation handling of NaN outputs
- 405bbc72 Dillon Sharlet: Add bfmmla based kernels
- 2976c65f Volodymyr Kysenko: Add f32->bf16 conversions to simd wrappers.
- ee91cc74 Dillon Sharlet: Disable bfmmla kernel when msan is enabled
- d197eb4f Dillon Sharlet: Remove unnecessary predicates from SME kernels
- daa07415 Volodymyr Kysenko: Add simd wrappers for floating-point division.
- 438557d6 Dillon Sharlet: Rearrange construction of masks for SME kernels
- 0a29c665 Volodymyr Kysenko: Use simd wrapper for division in elementwise kernels.
- 613c4ec9 Dillon Sharlet: Minor dot generator cleanups
- 0ff5d5e8 Dillon Sharlet: Report errors instead of asserting/crashing
- 0e878166 Dillon Sharlet: Actually disable bfmmla kernels for msan
- 570d5f70 Ken Unger: merge from master; fix rvv qs8-vcvt and qu8-vcvt
- 0320088 Misha Gutman: Fixed wrong kernel selection condition: kernel is avx, check was avx2.
- f1a5f31a Dillon Sharlet: Add transpose for 2-bit elements
- c8ee0dde Frank Barchard: qs8_qc2w AVXVNNI use GFNI to isolate 2 bits
- 02dd33a9 Volodymyr Kysenko: Add f32x8 to f16x8 conversion for ARM NEON FP16.
- 2e257ca7 Misha Gutman: Attempted to fix qd8_f16_qc2w flackiness by accounting for channel-wise zero point in tolerance calculcation.
- 0e7a25ec XNNPACK Team: Merge pull request openxla#8497 from ken-unger:qs8-vcvt
- 1f16c612 Dillon Sharlet: Automated code review of ynnpack/ by Gemini for potential bugs
- 811d0bd3 Volodymyr Kysenko: Add simd wrappers for saturating_convert and saturating_rounding_convert.
- 3681a464 Dillon Sharlet: Rename helpers to be more consistent with C++26
- f4b07fde Dillon Sharlet: Temporarily disable SVE transposes with msan
- 28099811 Volodymyr Kysenko: Use saturating casts from simd wrappers.
- 2295af59 Volodymyr Kysenko: Remove bfloat16 conversion patterns and implementations from x86 elementwise kernels.
- 780a46e4 Volodymyr Kysenko: Remove x86 slice patterns from YNNPACK kernels.
- 9cca4ee0 Volodymyr Kysenko: Clean-up unused cast patterns and implementations.
- c3d17cef Frank Barchard: gfni for constant 0xF0
- 4334fa7c Volodymyr Kysenko: Use fma from simd wrappers.
- c81f2f8d Dillon Sharlet: Insert convert ops for the output of dot and convert if we want to compute the result with a different type
- 4081527e Ken Unger: update to qemu 10.2.1
- 79b6566d Volodymyr Kysenko: Add left shift operator into simd wrappers.
- 1178cb6e Dillon Sharlet: Support converting the output of reductions to the requested type
- 6beca4d2 Volodymyr Kysenko: Use << operator from simd wrappers.
- 13df33b7 Dillon Sharlet: Globally disable SME/SVE if msan is enabled
- 2325db7b Gary Yi-Hung Chen: Fix: prefetch correct buffer in qp8 GEMM benchmark
- f2c4b6f9 Dillon Sharlet: Update dependencies.
- 96eb4f70 Frank Barchard: VNNI use python function to generate set1 or GFNI for a constant
- be5cb336 Dillon Sharlet: Use slinky::span as ynn::span
- 8f2b9d4c Richard Townsend: [gn] Cache everything for Github Actions
- 85154660 XNNPACK Team: Merge pull request openxla#9745 from yhng3010:fix_qp8_bench
- f411091a Frank Barchard: Fake VNNI multiply by -1 instead of 1
- 707f7fe2 Dillon Sharlet: Remove unnecessary packing dep
- deb87c02 XNNPACK Team: Merge pull request openxla#9739 from ken-unger:update-qemu
- ffef760d Matej Smycka: Validate num_reduction_axes against XNN_MAX_TENSOR_DIMS in reduce ops.
- 5f90740a Dillon Sharlet: Simplify some computed extents
- 082d3d4e Volodymyr Kysenko: Add select_greater_than intrinsic and rules.
- ce495a46 Volodymyr Kysenko: Enable sigmoid_fp32 kernels for AVX512F and ARM NEON.
- 76cd487c Dillon Sharlet: Only set params.unary if the node is an elementwise clamp.
- 2ff35992 Dillon Sharlet: Fix flags for XNNPACK compatibility shim for `xnn_define_quantized_tensor_value`
- 35fac83a XNNPACK Team: Merge pull request openxla#9754 from matejsmycka:fix-reduce-bounds-check
- 4414af48 Dillon Sharlet: Respect numerical consistency flag in sigmoid configs
- 80c9fa70 Ken Unger: Merge branch 'google:master' into reduce-rvv
- 97f3177f Frank Barchard: Add polyfill_test to test 2 bit _mm256_dpbusd_epi32_madd_kzp2
- 55f3a632 XNNPACK Team: Merge pull request openxla#9692 from ken-unger:reduce-rvv
- d93072e5 Quentin Khan: Analyse consumers and producers before logging a graph.
- d4afed87 Dillon Sharlet: Update dependencies.
- 628ba295 Dillon Sharlet: Reference count `ynn_subgraph` objects
- 8a91a54 Volodymyr Kysenko: Make FMA rewrite rules independent from the vector size.
- 61792c6e Dillon Sharlet: Minor cleanup of runtime/threadpool logic
- 867d5a34 Volodymyr Kysenko: Make saturating cast rules independent of natural vector size too.
- 9cc8ce0e Dillon Sharlet: Enable XNNPACK compatibility shim to parallelize with pthreadpool
- 265337f2 Frank Barchard: add polyfill tests for fake vnni functions for sse, avx and avx512
- 6ebd0eb8 Dillon Sharlet: Minor optimizations to `fuse_and_slice_leading_dims`
- 1ba27e22 tobias: Fix OOB read in xnn_get_heuristic_mr_gemm/igemm when batch_size is zero
- 1ec2e64a tobias: Fix OOB write in fully-connected reshape when input has zero dimensions
- 21018905 mohammadmseet-hue: Fix missing bounds checks and error handling in tensor APIs
- fab78cac mohammadmseet-hue: Fix subgraph optimizer bounds check to return success instead of error
- 8fad8695 XNNPACK Team: Merge pull request openxla#9778 from mohammadmseet-hue:fix/missing-bounds-checks
- ee9adae6 XNNPACK Team: Merge pull request openxla#9777 from TobiasWienand:fix/heuristic-mr-batch-size-zero
- 08154ed8 XNNPACK Team: Merge pull request openxla#9776 from TobiasWienand:fix/fully-connected-zero-dim-reshape
- 05c4905d Dillon Sharlet: Fix overly strict fully connected test when using YNNPACK
- b09bf03e Volodymyr Kysenko: Add basic build plumbing for WASM SIMD128 support to YNNPACK.
- d45f4523 Dillon Sharlet: Automated code review of src/subgraph/ by Gemini for potential bugs
- 4fd4e9b3 Dillon Sharlet: Update dependencies.
- a9f1654d Volodymyr Kysenko: Add more types into wasm simd wrappers.
- 679b750a Volodymyr Kysenko: Don't split vectors into natural-sized chunks.
- 038642f9 Volodymyr Kysenko: Add arithmetic operations to WASM simd wrappers.
- 7df7b146 Ken Unger: Merge branch 'google:master' into qd8-f16-gemm-rvv
- b2360490 XNNPACK Team: Merge pull request openxla#9639 from ken-unger:qd8-f16-gemm-rvv
- e5400dad Misha Gutman: Depthwise conv uses 4d filter, fixed wrong 3d assumption.
- 30c292b1 Dillon Sharlet: Move specification of elementwise kernels from the BUILD file to the generators
- cde3f678 Dillon Sharlet: Move specification of dot kernel block shapes from the BUILD file to the generators
- f94f7c7e Dillon Sharlet: Don't delete the operator when reshaping fails
- ff82f60c Volodymyr Kysenko: Refactor elementwise compiler to use tile_width consistently.
- 320da35a Volodymyr Kysenko: Add bitwise wasm simd wrappers.
- 9b1f3d69 Volodymyr Kysenko: Move common shift rules to common_rules.py
- b25825bb Volodymyr Kysenko: Remove offset from Load type + small other clean ups.
- a398d11e Volodymyr Kysenko: Add min/max wasm simd wrappers.
- 006d7a09 Dillon Sharlet: Avoid use-after-free due to growing nodes/values allocations during rewrites
- 9ca643bf Dillon Sharlet: Update dependencies.
- 5d096ba3 Volodymyr Kysenko: Add aligned loads/stores wasm simd wrappers.
- ff2db0e4 Richard Townsend: [gn] Strip "experimental" prefixes from Github Actions
- 96f529ce Richard Townsend: [gn] DEPS update for March, 2026
- 77596e7d Dillon Sharlet: Add `xnn_define_static_constant_pad_v2` and deprecate `xnn_define_static_constant_pad`
- e01449b8 Dillon Sharlet: XNNPACK does not copy scales in channelwise quantized tensor values, don't copy it in YNNPACK either
- 22ff9f38 Dillon Sharlet: Make reduce implementation more explicit
- 132a7b04 Dillon Sharlet: Don't try to use `_mm512_cvtne2ps_pbh` if we aren't targeting avx512bf16

Commit history for dsharlet/slinky (4a497c5b -> 4de79eb6):
- 80896cc0 Volodymyr Kysenko: Add include for stmt.h in pipeline.cc to fix clang-tidy (#767)
- 75e1f34a Dillon: Fix data races in thread pool benchmark (#772)
- 845af340 Dillon: Only take the min from the min, and max from the max, instead of the union (#773)
- b3159312 Alexander Shaposhnikov: Fix partial substitution. (#774)
- fd271732 Dillon: Fix missing <limits> include (#775)
- 0b6eb66c Dillon: Allow copying and assigning reference counted objects (#776)
- 3640b8ca Dillon: Avoid relying on `buffer_info` initialization (#777)
- 12a88800 Dillon: Fix simplify bug, we should not substitute src variables eagerly (#779)
- 20f3bf72 Dillon: Fix aliasing rewriting constrained strides of copy srcs (#780)
- 794723d3 Dillon: Skip clock overhead when there is no `task_size` (#778)
- b6b9cdfb Dillon: Add CMake build (#781)
- 4de79eb6 Dillon: Add SLINKY_ENABLE_TESTS option to the CMake build (#782)

PiperOrigin-RevId: 891848668

// hipDeviceptr_t is void* on HIP, so pointer arithmetic requires casting.
hipDeviceptr_t PtrAdd(hipDeviceptr_t base, size_t offset) {
return static_cast<char*>(base) + offset;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe just keep ptr_ as char* ?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — changed ptr_ from hipDeviceptr_t to char* and removed the PtrAdd helper. Pointer arithmetic now works directly: ptr_ + offset.

// is_rdma_supported is accepted for API compatibility with CudaVmmAllocator
// but not used: hipMemAllocationProp does not expose a gpuDirectRDMACapable
// flag like CUDA's CUmemAllocationProp::allocFlags.
static hipMemAllocationProp GetVmmAllocationProperties(
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inline it.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — inlined in all files

@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from 5e716ef to 3cd5657 Compare March 31, 2026 14:45
@phambinhfin
Copy link
Copy Markdown
Author

@draganmladjenovic Could you take a look again

@phambinhfin
Copy link
Copy Markdown
Author

phambinhfin commented Mar 31, 2026

From: this is a prerequisite for command buffer support — stable virtual addresses allow command buffers to be replayed without retracing when memory is reallocated (see https://github.com/openxla/xla/pull/39672).:

@draganmladjenovic , then do you mean that this would benefit for collective buffer as well since it does not need to create expensive retrace for each interator

@draganmladjenovic
Copy link
Copy Markdown

@phambinhfin I guess so. To be honest not sure how would it interact for lets say new symmetric memory for NCCL. I would refrain from pushing this upstream until we see the final result of this effort.

@phambinhfin phambinhfin requested a review from i-chaochen April 1, 2026 11:20
@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from 3cd5657 to 401eefb Compare April 1, 2026 13:13
@phambinhfin
Copy link
Copy Markdown
Author

phambinhfin commented Apr 1, 2026

Since the upstream VMM architecture is now finalized and merged, openxla#39393 . I pushed our ROCm VMM implementation so the Google reviewers can review it in advance while I study the next step (enabling VA remapping on ROCm using HIP graphs) and it tagged it as "Performance Improvement + New Feature," but haven't published benchmarks yet — their PR says "Benchmark results pending — to be provided by workload owners who enable the flag in their pipelines."

https://github.com/openxla/xla/pull/40236/changes

Port CUDA's VMM allocator to ROCm/HIP using HIP VMM APIs (hipMemCreate,
hipMemAddressReserve, hipMemMap, hipMemSetAccess, hipStreamWriteValue64).
This enables per-GPU memory access control on AMD GPUs, replacing the
all-or-nothing hipDeviceMallocFinegrained approach that grants every GPU
access and disables L2 cache.

Four layers matching the CUDA structure:
- RocmRawMemoryAllocation: RAII wrapper for hipMemCreate/hipMemRelease
- RocmMemoryReservation: RAII wrapper for virtual address reservation,
  mapping, access control, and unmapping
- RocmVmmAllocator: simple all-in-one allocator (create + reserve + map
  + setAccess in a single Allocate call)
- RocmDeviceAddressVmmAllocator: advanced allocator with GPU timeline-
  based deferred deallocation via hipStreamWriteValue64 and signal memory

Key differences from the CUDA implementation:
- hipMemGenericAllocationHandle_t is a pointer type (not integer)
- hipStreamWriteValue64 requires signal memory (hipMallocSignalMemory)
  instead of pinned host memory (cuMemHostAlloc)
- hipDeviceptr_t is void* requiring PtrAdd() helper for offset arithmetic
- All wrap:: calls use nullptr/0ULL for proper template type deduction
Add four test files mirroring the CUDA VMM test structure:

- rocm_raw_memory_allocation_test: CreateAllocation, AddressReflectsHandle,
  SizeIsAtLeastRequested
- rocm_memory_reservation_test: CreateReservation, MapToWrongType,
  MapToSingleAllocation, ScopedMappingUnmapsOnDestruction,
  MapToMultipleAllocations, TwoReservationsDifferentAddresses
- rocm_vmm_allocator_test: AllocateAndFree, AllocateZeroBytes,
  MemcpyRoundTrip (parameterized with RdmaEnabled/RdmaDisabled)
- rocm_device_address_vmm_allocator_test: 13 single-GPU tests covering
  allocate/deallocate, memory read/write, stream accessors, deferred
  deallocation, VA reuse, destructor safety, and error handling;
  2 multi-GPU tests for cross-device allocation isolation
Wire RocmDeviceAddressVmmAllocator into the PJRT GPU client so that
GpuAllocatorConfig::Kind::kVmm works on ROCm. Previously this returned
"VMM allocator is only supported with CUDA".

Users can now enable VMM on ROCm via:
  XLA_PYTHON_CLIENT_ALLOCATOR=vmm

Changes:
- se_gpu_pjrt_client.cc: add #elif TENSORFLOW_USE_ROCM branch in kVmm
  case and include section
- tfrt/utils.cc: same pattern for the TFRT client path
- se_gpu_pjrt_client_test.cc: add ROCm VmmAllocatorCanBeSet test and
  make VmmAllocatorE2ETest shared across CUDA and ROCm
- BUILD files: add rocm_device_address_vmm_allocator to if_rocm() deps
Two critical fixes discovered during VA remapping testing:

1. RocmEvent::Synchronize() — Added missing override. The base class
   Event::Synchronize() returns Unimplemented, causing any code that
   calls event->Synchronize() (including VA remapping's unmap event)
   to fail. Implements hipEventSynchronize matching CudaEvent.

2. RocmMemoryReservation::SetAccess() — Now grants read/write access
   to all P2P-capable peer devices, not just the owning device. Without
   this, any multi-GPU pmap/collective workload using VMM crashes with
   "Memory access fault by GPU node-N". The Layer 3 RocmVmmAllocator
   already had the correct P2P pattern; Layer 2 was missing it.

Tested: pmap matmul+allreduce on 2x MI300X with VMM allocator passes.
@phambinhfin phambinhfin force-pushed the phambinh/rocm-vmm-allocator branch from ec4d7d8 to 1794b3d Compare April 1, 2026 16:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants