Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 24 additions & 5 deletions xla/backends/gpu/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include "absl/status/statusor.h"
#include "absl/strings/match.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_replace.h"
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "xla/array.h"
Expand Down Expand Up @@ -87,6 +88,20 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
}

// MI300X (gfx942) uses NANOO FP8 types (f8e4m3fnuz/f8e5m2fnuz), while
// CUDA and MI350+ use OCP/IEEE types (f8e4m3fn/f8e5m2). The HLO test
// strings are written with OCP types; this replaces them with NANOO types
// when running on MI300X so the GEMM rewriter can produce FP8 custom calls.
std::string ReplaceFp8Types(absl::string_view hlo_text) {
if (Capability().IsRocm() &&
Capability().rocm_compute_capability()->has_nanoo_fp8_support() &&
!Capability().rocm_compute_capability()->has_ocp_fp8_support()) {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ultra-nit: Capability() is called three times in succession. A local reference would read slightly cleaner:

Suggested change
std::string ReplaceFp8Types(absl::string_view hlo_text) {
if (Capability().IsRocm() &&
Capability().rocm_compute_capability()->has_nanoo_fp8_support() &&
!Capability().rocm_compute_capability()->has_ocp_fp8_support()) {
const auto& cap = Capability();
if (cap.IsRocm() &&
cap.rocm_compute_capability()->has_nanoo_fp8_support() &&
!cap.rocm_compute_capability()->has_ocp_fp8_support()) {

Negligible perf-wise since this runs once per test, just a readability suggestion.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. Capability() is now cached in a local reference.

return absl::StrReplaceAll(
hlo_text, {{"f8e4m3fn", "f8e4m3fnuz"}, {"f8e5m2", "f8e5m2fnuz"}});
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The {"f8e4m3fn", "f8e4m3fnuz"} replacement pair has a subtle prefix-match risk: if an HLO string ever contained a pre-existing f8e4m3fnuz token, absl::StrReplaceAll would match the f8e4m3fn prefix first and produce f8e4m3fnuzuz.

This is safe today because no HLO strings in this file contain fnuz types, and the replacement is gated on NANOO-only hardware. However, the gemm_rewriter_fp8_test.cc reference implementation avoids this by using placeholder tokens (<<F8E4M3>>, <<F8E5M2>>) instead of substring replacement.

Consider adding a brief comment noting this constraint, or switching to the placeholder approach if more FP8 tests are added to this file.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — addressed in this revision. A comment noting the prefix-match constraint has been added.

}
return std::string(hlo_text);
}

void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
const DebugOptions& options) {
if (!HasFp8Support()) {
Expand All @@ -99,11 +114,12 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
<< " devices (" << device_count() << " available)";
}

std::string replaced_hlo = ReplaceFp8Types(hlo_text);
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);
config.set_debug_options(options);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
ParseAndReturnVerifiedModule(replaced_hlo, config));

TF_ASSERT_OK_AND_ASSIGN(auto executable,
CreateExecutable(std::move(module),
Expand Down Expand Up @@ -1277,6 +1293,7 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
<< " devices (" << device_count() << " available)";
}

std::string replaced_hlo = ReplaceFp8Types(hlo_text);
HloModuleConfig config = GetModuleConfigForTest(
/*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions);

Expand All @@ -1288,8 +1305,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
}

// Run with reference config.
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto ref_module,
ParseAndReturnVerifiedModule(replaced_hlo, config));
ASSERT_OK_AND_ASSIGN(auto ref_executable,
CreateExecutable(std::move(ref_module),
/*run_hlo_passes=*/true));
Expand All @@ -1314,8 +1332,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true);
debug_options.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(
enable_a2a_rewrite);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(replaced_hlo, config));

TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
Expand Down
Loading