Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
248 changes: 241 additions & 7 deletions xla/backends/gpu/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,13 @@ class CollectiveOpsTestE2E : public CollectiveOpsE2ETestBase {
GetDebugOptionsForTest().xla_gpu_enable_cublaslt();
}

// Ref: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4
bool IsNanooFp8Only() {
return Capability().IsRocm() &&
Capability().rocm_compute_capability()->has_nanoo_fp8_support() &&
!Capability().rocm_compute_capability()->has_ocp_fp8_support();
}

void CollectiveOpsVerifyF8Matmul(absl::string_view hlo_text,
const DebugOptions& options) {
if (!HasFp8Support()) {
Expand Down Expand Up @@ -1288,8 +1295,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
}

// Run with reference config.
TF_ASSERT_OK_AND_ASSIGN(auto ref_module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto ref_module,
ParseAndReturnVerifiedModule(hlo_text, config));
ASSERT_OK_AND_ASSIGN(auto ref_executable,
CreateExecutable(std::move(ref_module),
/*run_hlo_passes=*/true));
Expand All @@ -1314,8 +1322,9 @@ class CollectiveOpsTestE2EWindowedNonWindowed : public CollectiveOpsTestE2E {
debug_options.set_xla_gpu_multi_streamed_windowed_einsum(true);
debug_options.set_xla_gpu_experimental_enable_alltoall_windowed_einsum(
enable_a2a_rewrite);
TF_ASSERT_OK_AND_ASSIGN(auto module,
ParseAndReturnVerifiedModule(hlo_text, config));
TF_ASSERT_OK_AND_ASSIGN(
auto module,
ParseAndReturnVerifiedModule(hlo_text, config));

TF_ASSERT_OK_AND_ASSIGN(
ExecutionResult execution_result,
Expand Down Expand Up @@ -1357,6 +1366,9 @@ ENTRY main.12 {
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed, WindowedEinsumE2EAllGatherF8) {
if (IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires OCP FP8 (f8e4m3fn) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

Expand Down Expand Up @@ -1394,6 +1406,9 @@ ENTRY main {

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherReshapeF8) {
if (IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires OCP FP8 (f8e4m3fn) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[2,24,192]{2,1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

Expand Down Expand Up @@ -1432,6 +1447,9 @@ ENTRY main {

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherMultiConsumerF8) {
if (IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires OCP FP8 (f8e4m3fn) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule windowed_einsum_e2e_all_gather_multi_consumer_f8, entry_computation_layout={(f8e4m3fn[2,16,48]{2,1,0}, f8e4m3fn[48,192]{1,0}, f8e4m3fn[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

Expand Down Expand Up @@ -1475,6 +1493,9 @@ ENTRY main {

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EReduceScatterF8) {
if (IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires OCP FP8 (f8e4m3fn) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fn[2,16,192]{2,1,0}, f8e4m3fn[192,48]{1,0}, bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

Expand Down Expand Up @@ -1510,6 +1531,161 @@ ENTRY main {
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

// FNUZ FP8 (f8e4m3fnuz/f8e5m2fnuz) variants of the windowed einsum tests
// for platforms that only support FNUZ/NANOO FP8, not OCP FP8.
// Ref: https://rocm.docs.amd.com/projects/HIP/en/latest/reference/low_fp_types.html#id4

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherFnuzF8) {
if (!IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires NANOO FP8 (f8e4m3fnuz) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fnuz[2,16,48]{2,1,0}, f8e4m3fnuz[48,192]{1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main {
lhs = f8e4m3fnuz[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
rhs = f8e4m3fnuz[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
scale_lhs = bf16[] parameter(2)
scale_rhs = bf16[] parameter(3)
scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={}
scale_rhs_bcast = bf16[48,192]{1,0} broadcast(scale_rhs), dimensions={}
lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs)
rhs_bf16 = bf16[48,192]{1,0} convert(rhs)
lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16)
rhs_scaled = bf16[48,192]{1,0} multiply(scale_rhs_bcast, rhs_bf16)
dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0}
ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]}
} // main
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);

DebugOptions opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherReshapeFnuzF8) {
if (!IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires NANOO FP8 (f8e4m3fnuz) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule windowed_einsum_e2e_all_gather_reshape_fnuz_f8, entry_computation_layout={(f8e4m3fnuz[2,16,48]{2,1,0}, f8e4m3fnuz[2,24,192]{2,1,0}, bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main {
lhs = f8e4m3fnuz[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
rhs = f8e4m3fnuz[2,24,192]{2,1,0} parameter(1), sharding={devices=[1,1,4]<=[4]}
scale_lhs = bf16[] parameter(2)
scale_rhs = bf16[] parameter(3)
scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_rhs), dimensions={}
scale_rhs_bcast = bf16[2,24,192]{2,1,0} broadcast(scale_lhs), dimensions={}
lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs)
rhs_bf16 = bf16[2,24,192]{2,1,0} convert(rhs)
lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16)
rhs_scaled = bf16[2,24,192]{2,1,0} multiply(scale_rhs_bcast, rhs_bf16)
rhs_reshaped = bf16[48,192]{1,0} reshape(rhs_scaled)
dot = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs_reshaped), lhs_contracting_dims={2}, rhs_contracting_dims={0}
ROOT custom-call = bf16[2,16,192]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,1,4]<=[4]}
} // main
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);

DebugOptions opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllGatherMultiConsumerFnuzF8) {
if (!IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires NANOO FP8 (f8e4m3fnuz) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule windowed_einsum_e2e_all_gather_multi_consumer_fnuz_f8, entry_computation_layout={(f8e4m3fnuz[2,16,48]{2,1,0}, f8e4m3fnuz[48,192]{1,0}, f8e4m3fnuz[48,192]{1,0}, bf16[], bf16[], bf16[])->bf16[2,16,192]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main {
lhs = f8e4m3fnuz[2,16,48]{2,1,0} parameter(0), sharding={devices=[1,4,1]<=[4]}
rhs0 = f8e4m3fnuz[48,192]{1,0} parameter(1), sharding={devices=[1,4]<=[4]}
scale_lhs = bf16[] parameter(3)
scale_rhs0 = bf16[] parameter(4)
scale_lhs_bcast = bf16[2,16,48]{2,1,0} broadcast(scale_lhs), dimensions={}
scale_rhs0_bcast = bf16[48,192]{1,0} broadcast(scale_rhs0), dimensions={}
lhs_bf16 = bf16[2,16,48]{2,1,0} convert(lhs)
rhs0_bf16 = bf16[48,192]{1,0} convert(rhs0)
lhs_scaled = bf16[2,16,48]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16)
rhs0_scaled = bf16[48,192]{1,0} multiply(scale_rhs0_bcast, rhs0_bf16)
dot0 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs0_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0}
rhs1 = f8e4m3fnuz[48,192]{1,0} parameter(2), sharding={devices=[1,4]<=[4]}
scale_rhs1 = bf16[] parameter(5)
scale_rhs1_bcast = bf16[48,192]{1,0} broadcast(scale_rhs1), dimensions={}
rhs1_bf16 = bf16[48,192]{1,0} convert(rhs1)
rhs1_scaled = bf16[48,192]{1,0} multiply(scale_rhs1_bcast, rhs1_bf16)
dot1 = bf16[2,16,192]{2,1,0} dot(lhs_scaled, rhs1_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0}
ROOT add = bf16[2,16,192]{2,1,0} add(dot0, dot1)
} // main
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);

DebugOptions opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EReduceScatterFnuzF8) {
if (!IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires NANOO FP8 (f8e4m3fnuz) support.";
}
absl::string_view kModuleReplicatedStr = R"(
HloModule pjit__unnamed_wrapped_function_, entry_computation_layout={(f8e4m3fnuz[2,16,192]{2,1,0}, f8e4m3fnuz[192,48]{1,0}, bf16[], bf16[])->bf16[2,16,48]{2,1,0}}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4

ENTRY main {
lhs = f8e4m3fnuz[2,16,192]{2,1,0} parameter(0), sharding={devices=[1,1,4]<=[4]}
rhs = f8e4m3fnuz[192,48]{1,0} parameter(1), sharding={devices=[4,1]<=[4]}
scale_lhs = bf16[] parameter(2)
scale_rhs = bf16[] parameter(3)
scale_lhs_bcast = bf16[2,16,192]{2,1,0} broadcast(scale_lhs), dimensions={}
scale_rhs_bcast = bf16[192,48]{1,0} broadcast(scale_rhs), dimensions={}
lhs_bf16 = bf16[2,16,192]{2,1,0} convert(lhs)
rhs_bf16 = bf16[192,48]{1,0} convert(rhs)
lhs_scaled = bf16[2,16,192]{2,1,0} multiply(scale_lhs_bcast, lhs_bf16)
rhs_scaled = bf16[192,48]{1,0} multiply(scale_rhs_bcast, rhs_bf16)
dot = bf16[2,16,48]{2,1,0} dot(lhs_scaled, rhs_scaled), lhs_contracting_dims={2}, rhs_contracting_dims={0}
ROOT custom-call = bf16[2,16,48]{2,1,0} custom-call(dot), custom_call_target="Sharding", sharding={devices=[1,4,1]<=[4]}
} // main
)";

CollectiveOpsCompareWindowedNonWindowed(kModuleReplicatedStr,
/*disable_dot_merger=*/true);

DebugOptions opts = GetDebugOptionsForTest();
opts.set_xla_gpu_threshold_for_windowed_einsum_mib(0);
opts.set_xla_gpu_multi_streamed_windowed_einsum(true);
opts.set_xla_gpu_graph_min_graph_size(200);
opts.set_xla_gpu_enable_triton_gemm(false);
opts.add_xla_disable_hlo_passes("dot-merger");
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

TEST_F(CollectiveOpsTestE2EWindowedNonWindowed,
WindowedEinsumE2EAllToAllDecompose) {
absl::string_view kModuleReplicatedStr = R"(
Expand Down Expand Up @@ -1606,10 +1782,11 @@ ENTRY entry {
}

TEST_F(CollectiveOpsTestE2E, CollectivePipelinerF8) {
// Verify that FP8 patterns are preserved when collectives are pipelined so
// the GEMM rewriter can create FP8 matmuls.
if (!HasFp8Support()) {
GTEST_SKIP() << "Test requires Hopper or newer architecture.";
GTEST_SKIP() << "Test requires FP8 support.";
}
if (IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires OCP FP8 (f8e4m3fn) support.";
}

absl::string_view kModuleReplicatedStr = R"(
Expand Down Expand Up @@ -1664,6 +1841,63 @@ ENTRY entry {
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

TEST_F(CollectiveOpsTestE2E, CollectivePipelinerFnuzF8) {
if (!IsNanooFp8Only()) {
GTEST_SKIP() << "Test requires NANOO FP8 (f8e4m3fnuz) support.";
}

absl::string_view kModuleReplicatedStr = R"(
HloModule module, entry_computation_layout={(bf16[128,128], bf16[32,128], bf16[], bf16[])->bf16[512,128]}, allow_spmd_sharding_propagation_to_parameters={false,false,false,false}, num_partitions=4
while_cond {
input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0)
loop_counter = s32[] get-tuple-element(input), index=0
c4 = s32[] constant(4)
ROOT compare = pred[] compare(loop_counter, c4), direction=LT
}
while_body {
input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) parameter(0)
loop_counter = s32[] get-tuple-element(input), index=0
lhs = bf16[128,128] get-tuple-element(input), index=1
rhs = bf16[32,128] get-tuple-element(input), index=2
partial_dot_output = bf16[512,128] get-tuple-element(input), index=5
lhs_f8 = f8e4m3fnuz[128,128] convert(lhs)
rhs_f8 = f8e4m3fnuz[32,128] convert(rhs)
lhs_bf16 = bf16[128,128] convert(lhs_f8)
rhs_bf16 = bf16[32,128] convert(rhs_f8)
scale_lhs = bf16[] get-tuple-element(input), index=3
scale_rhs = bf16[] get-tuple-element(input), index=4
scale_lhs_bcast = bf16[128,128] broadcast(scale_lhs), dimensions={}
scale_rhs_bcast = bf16[32,128] broadcast(scale_rhs), dimensions={}
lhs_scaled = bf16[128,128] multiply(lhs_bf16, scale_lhs_bcast)
rhs_scaled = bf16[32,128] multiply(rhs_bf16, scale_rhs_bcast)
rhs_scaled_all_gathered = bf16[128,128] all-gather(rhs_scaled), channel_id=1, use_global_device_ids=true, dimensions={0}, replica_groups={{0,1,2,3}}
dot = bf16[128,128] dot(lhs_scaled, rhs_scaled_all_gathered), lhs_contracting_dims={1}, rhs_contracting_dims={1}
c0 = s32[] constant(0)
size = s32[] constant(128)
iteration_offset = s32[] multiply(loop_counter, size)
updated_dot_output = bf16[512,128] dynamic-update-slice(partial_dot_output, dot, iteration_offset, c0)
c1 = s32[] constant(1)
loop_counter_plus_one = s32[] add(loop_counter, c1)
ROOT tuple = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(loop_counter_plus_one, lhs, rhs, scale_lhs, scale_rhs, updated_dot_output)
}
ENTRY entry {
c0 = s32[] constant(0)
lhs = bf16[128,128] parameter(0)
rhs = bf16[32,128] parameter(1)
scale_lhs = bf16[] parameter(2)
scale_rhs = bf16[] parameter(3)
result_buffer = bf16[512,128] constant(0.)
while_input = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) tuple(c0, lhs, rhs, scale_lhs, scale_rhs, result_buffer)
while = (s32[], bf16[128,128], bf16[32,128], bf16[], bf16[], bf16[512,128]) while(while_input), condition=while_cond, body=while_body
ROOT dot_output = bf16[512,128] get-tuple-element(while), index=5
}
)";

auto opts = GetDebugOptionsForTest();
opts.set_xla_gpu_enable_triton_gemm(false);
CollectiveOpsVerifyF8Matmul(kModuleReplicatedStr, opts);
}

// E2E tests comparing the results with and without pipelining of collectives.
class CollectiveOpsTestE2EPipelinedNonPipelined : public CollectiveOpsTestE2E {
public:
Expand Down
Loading