Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 27 additions & 5 deletions xla/backends/gpu/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
Expand Down Expand Up @@ -87,6 +88,23 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
}

// MI300X (gfx942) uses NANOO FP8 types (f8e4m3fnuz/f8e5m2fnuz), while
// CUDA and MI350+ use OCP/IEEE types (f8e4m3fn/f8e5m2). The HLO test
// strings are written with OCP types; this replaces them with NANOO types
// when running on MI300X so the GEMM rewriter can produce FP8 custom calls.
// Note: the input HLO must not already contain FNUZ type strings, as the
// substring replacement of "f8e4m3fn" would also match inside "f8e4m3fnuz".
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you want to make a statement, please put the docs as the reference https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4

std::string ReplaceFp8Types(absl::string_view hlo_text) {
const auto& cap = Capability();
if (cap.IsRocm() &&
cap.rocm_compute_capability()->has_nanoo_fp8_support() &&
!cap.rocm_compute_capability()->has_ocp_fp8_support()) {
return absl::StrReplaceAll(
hlo_text, {{"f8e4m3fn", "f8e4m3fnuz"}, {"f8e5m2", "f8e5m2fnuz"}});
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The {"f8e4m3fn", "f8e4m3fnuz"} replacement pair has a subtle prefix-match risk: if an HLO string ever contained a pre-existing f8e4m3fnuz token, absl::StrReplaceAll would match the f8e4m3fn prefix first and produce f8e4m3fnuzuz.

This is safe today because no HLO strings in this file contain fnuz types, and the replacement is gated on NANOO-only hardware. However, the gemm_rewriter_fp8_test.cc reference implementation avoids this by using placeholder tokens (<<F8E4M3>>, <<F8E5M2>>) instead of substring replacement.

Consider adding a brief comment noting this constraint, or switching to the placeholder approach if more FP8 tests are added to this file.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. A comment noting the prefix-match constraint has been added.

}
return std::string(hlo_text);
}

void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
const DebugOptions& options) {
if (!HasFp8Support()) {
Expand All @@ -99,11 +117,12 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
<< " devices (" << device_count() << " available)";
}

std::string replaced_hlo = ReplaceFp8Types(hlo_text);
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(replaced_hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(std::move(module),
Expand Down Expand Up @@ -1277,6 +1296,7 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
<< " devices (" << device_count() << " available)";
}

std::string replaced_hlo = ReplaceFp8Types(hlo_text);
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);

Expand All @@ -1288,8 +1308,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
}

// Run with reference config.
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto ref_module,
ParseAndReturnVerifiedModule(replaced_hlo, config));
ASSERT_OK_AND_ASSIGN(auto ref_executable,
CreateExecutable(std::move(ref_module),
/*run_hlo_passes=*/true));
Expand All @@ -1314,8 +1335,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true);
debug_options.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(
enable_a2a_rewrite);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(replaced_hlo, config));

TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
Expand Down
Loading