-
Notifications
You must be signed in to change notification settings - Fork 8
[ROCm] Fix FP8 collective ops E2E tests for MI300X (NANOO FP8) #776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,6 +27,7 @@ limitations under the License. | |
| #include "absl/status/statusor.h" | ||
| #include "absl/strings/match.h" | ||
| #include "absl/strings/str_cat.h" | ||
| #include "absl/strings/str_replace.h" | ||
| #include "absl/strings/string_view.h" | ||
| #include "absl/types/span.h" | ||
| #include "xla/array.h" | ||
|
|
@@ -87,6 +88,23 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase { | |
| GetDebugOptionsForTest().xla_gpu_enable_cublaslt(); | ||
| } | ||
|
|
||
| // MI300X (gfx942) uses NANOO FP8 types (f8e4m3fnuz/f8e5m2fnuz), while | ||
| // CUDA and MI350+ use OCP/IEEE types (f8e4m3fn/f8e5m2). The HLO test | ||
| // strings are written with OCP types; this replaces them with NANOO types | ||
| // when running on MI300X so the GEMM rewriter can produce FP8 custom calls. | ||
| // Note: the input HLO must not already contain FNUZ type strings, as the | ||
| // substring replacement of "f8e4m3fn" would also match inside "f8e4m3fnuz". | ||
| std::string ReplaceFp8Types(absl::string_view hlo_text) { | ||
| const auto& cap = Capability(); | ||
| if (cap.IsRocm() && | ||
| cap.rocm_compute_capability()->has_nanoo_fp8_support() && | ||
| !cap.rocm_compute_capability()->has_ocp_fp8_support()) { | ||
| return absl::StrReplaceAll( | ||
| hlo_text, {{"f8e4m3fn", "f8e4m3fnuz"}, {"f8e5m2", "f8e5m2fnuz"}}); | ||
|
||
| } | ||
| return std::string(hlo_text); | ||
| } | ||
|
|
||
| void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text, | ||
| const DebugOptions& options) { | ||
| if (!HasFp8Support()) { | ||
|
|
@@ -99,11 +117,12 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase { | |
| << " devices (" << device_count() << " available)"; | ||
| } | ||
|
|
||
| std::string replaced_hlo = ReplaceFp8Types(hlo_text); | ||
| HloModuleConfig config = GetModuleConfigForTest( | ||
| /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); | ||
| config.set_debug_options(options); | ||
| TF_ASSERT_OK_AND_ASSIGN(auto module, | ||
| ParseAndReturnVerifiedModule(hlo_text, config)); | ||
| ParseAndReturnVerifiedModule(replaced_hlo, config)); | ||
|
|
||
| TF_ASSERT_OK_AND_ASSIGN(auto executable, | ||
| CreateExecutable(std::move(module), | ||
|
|
@@ -1277,6 +1296,7 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { | |
| << " devices (" << device_count() << " available)"; | ||
| } | ||
|
|
||
| std::string replaced_hlo = ReplaceFp8Types(hlo_text); | ||
| HloModuleConfig config = GetModuleConfigForTest( | ||
| /*replica_count=*/kNumReplicas, /*num_partitions=*/kNumPartitions); | ||
|
|
||
|
|
@@ -1288,8 +1308,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { | |
| } | ||
|
|
||
| // Run with reference config. | ||
| TF_ASSERT_OK_AND_ASSIGN(auto ref_module, | ||
| ParseAndReturnVerifiedModule(hlo_text, config)); | ||
| TF_ASSERT_OK_AND_ASSIGN( | ||
| auto ref_module, | ||
| ParseAndReturnVerifiedModule(replaced_hlo, config)); | ||
| ASSERT_OK_AND_ASSIGN(auto ref_executable, | ||
| CreateExecutable(std::move(ref_module), | ||
| /*run_hlo_passes=*/true)); | ||
|
|
@@ -1314,8 +1335,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E { | |
| debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true); | ||
| debug_options.set_xla_gpu_experimental_enable_alltoall_windowed_einsum( | ||
| enable_a2a_rewrite); | ||
| TF_ASSERT_OK_AND_ASSIGN(auto module, | ||
| ParseAndReturnVerifiedModule(hlo_text, config)); | ||
| TF_ASSERT_OK_AND_ASSIGN( | ||
| auto module, | ||
| ParseAndReturnVerifiedModule(replaced_hlo, config)); | ||
|
|
||
| TF_ASSERT_OK_AND_ASSIGN( | ||
| ExecutionResult execution_result, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you want to make a statement, please put the docs as the reference https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4