[ROCm] Fix FP8 collective ops E2E tests for MI300X (NANOO FP8)#776
[ROCm] Fix FP8 collective ops E2E tests for MI300X (NANOO FP8)#776phambinhfin wants to merge 3 commits intomainfrom
Conversation
The 5 FP8 tests in collective_ops_e2e_test use f8e4m3fn (OCP/IEEE FP8), but MI300X (gfx942) only supports f8e4m3fnuz (NANOO FP8). The GEMM rewriter correctly rejects the FP8 rewrite for mismatched types, causing the tests to produce __cublas$lt$matmul instead of __cublas$lt$matmul$f8. Add a ReplaceFp8Types() helper that substitutes f8e4m3fn -> f8e4m3fnuz and f8e5m2 -> f8e5m2fnuz when running on NANOO-only hardware (MI300X), following the same pattern used in gemm_rewriter_fp8_test.cc. Tests fixed: - CollectiveOpsTestE2E.CollectivePipelinerF8 - CollectiveOpsTestE2EWindowedNonWindowed.WindowedEinsumE2EAllGatherF8 - CollectiveOpsTestE2EWindowedNonWindowed.WindowedEinsumE2EAllGatherReshapeF8 - CollectiveOpsTestE2EWindowedNonWindowed.WindowedEinsumE2EReduceScatterF8 - CollectiveOpsTestE2EWindowedNonWindowed.WindowedEinsumE2EAllGatherMultiConsumerF8
| Capability().rocm_compute_capability()->has_nanoo_fp8_support() && | ||
| !Capability().rocm_compute_capability()->has_ocp_fp8_support()) { | ||
| return absl::StrReplaceAll( | ||
| hlo_text, {{"f8e4m3fn", "f8e4m3fnuz"}, {"f8e5m2", "f8e5m2fnuz"}}); |
There was a problem hiding this comment.
nit: The {"f8e4m3fn", "f8e4m3fnuz"} replacement pair has a subtle prefix-match risk: if an HLO string ever contained a pre-existing f8e4m3fnuz token, absl::StrReplaceAll would match the f8e4m3fn prefix first and produce f8e4m3fnuzuz.
This is safe today because no HLO strings in this file contain fnuz types, and the replacement is gated on NANOO-only hardware. However, the gemm_rewriter_fp8_test.cc reference implementation avoids this by using placeholder tokens (<<F8E4M3>>, <<F8E5M2>>) instead of substring replacement.
Consider adding a brief comment noting this constraint, or switching to the placeholder approach if more FP8 tests are added to this file.
There was a problem hiding this comment.
Resolved — addressed in this revision. A comment noting the prefix-match constraint has been added.
| std::string ReplaceFp8Types(absl::string_view hlo_text) { | ||
| if (Capability().IsRocm() && | ||
| Capability().rocm_compute_capability()->has_nanoo_fp8_support() && | ||
| !Capability().rocm_compute_capability()->has_ocp_fp8_support()) { |
There was a problem hiding this comment.
ultra-nit: Capability() is called three times in succession. A local reference would read slightly cleaner:
| std::string ReplaceFp8Types(absl::string_view hlo_text) { | |
| if (Capability().IsRocm() && | |
| Capability().rocm_compute_capability()->has_nanoo_fp8_support() && | |
| !Capability().rocm_compute_capability()->has_ocp_fp8_support()) { | |
| const auto& cap = Capability(); | |
| if (cap.IsRocm() && | |
| cap.rocm_compute_capability()->has_nanoo_fp8_support() && | |
| !cap.rocm_compute_capability()->has_ocp_fp8_support()) { |
Negligible perf-wise since this runs once per test, just a readability suggestion.
There was a problem hiding this comment.
Resolved — addressed in this revision. Capability() is now cached in a local reference.
Review SummarySolid fix. The Two minor inline suggestions posted (prefix-match risk note, repeated 🤖 Generated with Claude Code |
- Store Capability() in a local const reference to avoid repeated calls - Add comment noting the substring replacement constraint (input HLO must not already contain FNUZ type strings)
Re-review SummaryPrevious review feedback has been addressed — both issues are now resolved:
No new issues found. The fix is clean, well-scoped, and follows the established pattern from |
There was a problem hiding this comment.
IIUC, this test is checking ocp-fp8
xla/xla/backends/gpu/tests/collective_ops_e2e_test.cc
Lines 1386 to 1387 in 8298142
and
xla/xla/backends/gpu/tests/collective_ops_e2e_test.cc
Lines 1423 to 1424 in 8298142
and
xla/xla/backends/gpu/tests/collective_ops_e2e_test.cc
Lines 1504 to 1505 in 8298142
So I don't think this is a right thing that we just modify the test case as nanoo fp8 to get a PASS. The proper way should let gfx942 to skip these ocp-fp8 tests (as it's not supported, and gfx950 can no impact) and write related nanoo fp8 hlo for gfx942.
| // strings are written with OCP types; this replaces them with NANOO types | ||
| // when running on MI300X so the GEMM rewriter can produce FP8 custom calls. | ||
| // Note: the input HLO must not already contain FNUZ type strings, as the | ||
| // substring replacement of "f8e4m3fn" would also match inside "f8e4m3fnuz". |
There was a problem hiding this comment.
If you want to make a statement, please put the docs as the reference https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4
e0b99d7 to
f390a1d
Compare
Instead of silently replacing OCP FP8 types (f8e4m3fn) with FNUZ types (f8e4m3fnuz) at runtime via string substitution, properly separate the tests by FP8 variant: - Skip OCP FP8 tests on NANOO-only platforms since they do not support OCP FP8 (ref: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4) - Add dedicated FNUZ FP8 test cases with native f8e4m3fnuz HLO for NANOO FP8 platforms - Remove ReplaceFp8Types() helper that changed test semantics - Add IsNanooFp8Only() helper (ROCm-only, no CUDA changes) New FNUZ FP8 tests: - CollectivePipelinerFnuzF8 - WindowedEinsumE2EAllGatherFnuzF8 - WindowedEinsumE2EAllGatherReshapeFnuzF8 - WindowedEinsumE2EAllGatherMultiConsumerFnuzF8 - WindowedEinsumE2EReduceScatterFnuzF8
f390a1d to
5cc99ea
Compare
Re-review SummaryPrevious Claude comments (prefix-match risk with The new approach is straightforward and correct: all five FNUZ HLO modules are faithful type substitutions ( |
|
@i-chaochen |
The 5 FP8 tests in collective_ops_e2e_test use f8e4m3fn (OCP/IEEE FP8), but MI300X (gfx942) only supports f8e4m3fnuz (NANOO FP8). The GEMM rewriter correctly rejects the FP8 rewrite for mismatched types, causing the tests to produce __cublas$lt$matmul instead of __cublas$lt$matmul$f8.
Add a ReplaceFp8Types() helper that substitutes f8e4m3fn -> f8e4m3fnuz and f8e5m2 -> f8e5m2fnuz when running on NANOO-only hardware (MI300X), following the same pattern used in gemm_rewriter_fp8_test.cc.
Input
Tests fixed: