From 34bb84ce92eb4ba955d33e19677114cfeda5dbf5 Mon Sep 17 00:00:00 2001 From: Aleksei Nurmukhametov Date: Mon, 30 Mar 2026 14:46:03 +0100 Subject: [PATCH] [ROCm] Add waves_per_eu support to Triton GEMM config Add the ROCm-specific `waves_per_eu` hint to the Triton GEMM config. This parameter specifies the minimum number of wavefronts per execution unit on AMD GPUs. The LLVM backend uses this to: 1) limit the number of SGPRs and VGPRs available per wave, which affects register allocation; 2) set register pressure thresholds for the instruction scheduler. The parameter flows through: TritonGemmKey proto -> TritonGemmConfig struct -> BlockLevelFusionConfig proto -> BlockLevelParameters -> xtile_compiler, where it is applied as the "amdgpu-waves-per-eu" LLVM function attribute on the kernel, matching Triton's own AMD backend behavior. Default value of 0 means no restriction. The autotuner search space is extended with values {0,1,2,4} for ROCm targets. --- xla/autotuning.proto | 1 + xla/backends/gpu/autotuner/triton/BUILD | 1 + .../gpu/autotuner/triton/dot_search_space.cc | 37 +++++++++ .../gpu/autotuner/triton/dot_search_space.h | 5 ++ .../autotuner/triton/dot_search_space_test.cc | 39 +++++++++ .../tests/fusion_emitter_device_test.cc | 81 +++++++++++++++++++ .../gpu/codegen/triton/xtile_compiler.cc | 10 +++ .../transforms/convert_triton_gemm_config.cc | 1 + .../convert_triton_gemm_config_test.cc | 38 +++++++++ xla/hlo/ir/backend_config_test.cc | 2 +- xla/service/gpu/BUILD | 1 + .../gpu/autotuning/autotune_cache_key.h | 2 +- xla/service/gpu/backend_configs.proto | 3 + xla/service/gpu/matmul_utils.cc | 8 +- xla/service/gpu/matmul_utils.h | 10 ++- xla/service/gpu/matmul_utils_test.cc | 36 +++++++++ .../gpu/model/block_level_parameters.h | 3 + .../gpu/model/block_level_parameters_test.cc | 4 + 18 files changed, 275 insertions(+), 7 deletions(-) diff --git a/xla/autotuning.proto b/xla/autotuning.proto index ceac0079baaaa..7a93181a40b9a 100644 --- a/xla/autotuning.proto +++ b/xla/autotuning.proto @@ -85,6 +85,7 @@ message AutotuneResult { int64 num_ctas = 7; bool is_tma_allowed = 8; bool is_warp_specialization_allowed = 9; + int64 waves_per_eu = 10; // LINT.ThenChange(//tensorflow/compiler/xla/service/gpu/matmul_utils.h) } diff --git a/xla/backends/gpu/autotuner/triton/BUILD b/xla/backends/gpu/autotuner/triton/BUILD index 7390891dc8e0b..41774b1a099bc 100644 --- a/xla/backends/gpu/autotuner/triton/BUILD +++ b/xla/backends/gpu/autotuner/triton/BUILD @@ -59,6 +59,7 @@ xla_cc_test( "//xla/stream_executor:device_description", "//xla/stream_executor:device_description_proto_cc", "//xla/stream_executor/cuda:cuda_compute_capability", + "//xla/stream_executor/rocm:rocm_compute_capability", "//xla/tsl/platform:statusor", "@com_google_absl//absl/log:check", "@com_google_absl//absl/status:statusor", diff --git a/xla/backends/gpu/autotuner/triton/dot_search_space.cc b/xla/backends/gpu/autotuner/triton/dot_search_space.cc index 43f3a9e91c392..fc9d663cc6b03 100644 --- a/xla/backends/gpu/autotuner/triton/dot_search_space.cc +++ b/xla/backends/gpu/autotuner/triton/dot_search_space.cc @@ -161,6 +161,9 @@ std::vector TritonDotFusionSearchSpace::GenerateConfigs( ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddWarpSpecializationParameter); } + if (device_description_.gpu_compute_capability().IsRocm()) { + ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddWavesPerEuParameter); + } std::vector result; result.reserve(configs.size()); @@ -738,4 +741,38 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs( VLOG(10) << "Eliminated " << num_configs - configs.size() << " configs."; } +void TritonDotFusionSearchSpace::AddWavesPerEuParameter( + const ConfigWithNotes& config, + std::vector& updated_configs) const { + // Hints the LLVM backend to reduce VGPR usage so that the given number of + // wavefronts can run concurrently on each Execution Unit (EU). On MI300X + // each EU has 512 VGPRs allocated in blocks of 16, giving: + // + // Num VGPRs Waves/EU Waves/CU + // <= 64 8 32 + // <= 96 5 20 + // <= 128 4 16 + // <= 168 3 12 + // <= 256 2 8 + // > 256 1 4 + // + // 0 means no restriction (the default). Higher values force tighter register + // budgets, potentially increasing occupancy at the cost of more spilling. + // + // Actual occupancy also depends on LDS (shared memory) and num_warps: + // occ = min(floor(occ_vgpr * 4 / num_warps), + // floor(65536 / lds_bytes)) * num_warps / 4 + // where occ_vgpr is from the table above. Neither lds_bytes nor + // the VGPR count are known at search space construction time, so we + // enumerate a few values and let the autotuner pick the best one. + static constexpr int kWavesPerEuValues[] = {0, 1, 2, 4}; + for (int waves : kWavesPerEuValues) { + ConfigWithNotes new_config = config; + new_config.config.waves_per_eu = waves; + VLOG(10) << "Adding waves_per_eu parameter: config = " + << new_config.ToString(); + updated_configs.push_back(new_config); + } +} + } // namespace xla::gpu diff --git a/xla/backends/gpu/autotuner/triton/dot_search_space.h b/xla/backends/gpu/autotuner/triton/dot_search_space.h index ed46ec5b15679..b80d54cf0c977 100644 --- a/xla/backends/gpu/autotuner/triton/dot_search_space.h +++ b/xla/backends/gpu/autotuner/triton/dot_search_space.h @@ -221,6 +221,11 @@ class TritonDotFusionSearchSpace { const ConfigWithNotes& config, std::vector& updated_configs) const; + // Extend the passed configs with waves_per_eu values. + void AddWavesPerEuParameter( + const ConfigWithNotes& config, + std::vector& updated_configs) const; + // The order of these fields is important: the values of those defined earlier // are used to compute the values of later ones. se::DeviceDescription device_description_; diff --git a/xla/backends/gpu/autotuner/triton/dot_search_space_test.cc b/xla/backends/gpu/autotuner/triton/dot_search_space_test.cc index df67d9c7a31e8..66a20a4ad55d5 100644 --- a/xla/backends/gpu/autotuner/triton/dot_search_space_test.cc +++ b/xla/backends/gpu/autotuner/triton/dot_search_space_test.cc @@ -33,6 +33,7 @@ limitations under the License. #include "xla/stream_executor/cuda/cuda_compute_capability.h" #include "xla/stream_executor/device_description.h" #include "xla/stream_executor/device_description.pb.h" +#include "xla/stream_executor/rocm/rocm_compute_capability.h" #include "xla/tsl/platform/statusor.h" namespace xla::gpu { @@ -45,6 +46,7 @@ void PrintTo(const TritonGemmConfig& config, std::ostream* os) { namespace { using ::testing::AllOf; +using ::testing::AnyOf; using ::testing::Contains; using ::testing::ElementsAre; using ::testing::ElementsAreArray; @@ -92,6 +94,10 @@ template auto NumCtasIs(MatcherType matcher) { return Field("num_ctas", &TritonGemmConfig::num_ctas, matcher); } +template +auto WavesPerEuIs(MatcherType matcher) { + return Field("waves_per_eu", &TritonGemmConfig::waves_per_eu, matcher); +} auto IsValidConfig() { return AllOf(BlockMIs(Ge(1)), BlockNIs(Ge(1)), BlockKIs(Ge(1)), @@ -660,5 +666,38 @@ TEST_F(DotSearchSpaceTest, EnsuresSplitKAndBlockKAreCompatible) { } } +TEST_F(DotSearchSpaceTest, CudaDoesNotGenerateWavesPerEuConfigs) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + + EXPECT_THAT(search_space.GenerateConfigs(), + AllOf(Not(IsEmpty()), Each(WavesPerEuIs(Eq(0))))); +} + +class RocmDotSearchSpaceTest : public DefaultDeviceDotSearchSpaceTest { + protected: + RocmDotSearchSpaceTest() { + // MI300X-like parameters. + device_description_.set_registers_per_block_limit(64 * 1024); + device_description_.set_core_count(304); + device_description_.set_threads_per_block_limit(1024); + device_description_.set_threads_per_warp(64); + device_description_.set_shared_memory_per_block_optin(64 * 1024); + device_description_.set_gpu_compute_capability( + se::GpuComputeCapability(se::RocmComputeCapability("gfx942"))); + } +}; + +TEST_F(RocmDotSearchSpaceTest, GeneratesWavesPerEuConfigs) { + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr module, + GetDefaultDotModule()); + TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get()); + std::vector configs = search_space.GenerateConfigs(); + + EXPECT_THAT(configs, AllOf(Not(IsEmpty()), Contains(WavesPerEuIs(Ge(1))), + Each(WavesPerEuIs(AnyOf(0, 1, 2, 4))))); +} + } // namespace } // namespace xla::gpu diff --git a/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc b/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc index ddc7dec2f8caf..b81e36438603b 100644 --- a/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc +++ b/xla/backends/gpu/codegen/triton/tests/fusion_emitter_device_test.cc @@ -2074,6 +2074,87 @@ TEST_F(TritonEmitterTest, RocmWarpSizeIsSetCorrectly) { EXPECT_THAT(RunFileCheck(triton_passes_log, kPattern_n), true); } +TEST_F(TritonEmitterTest, RocmWavesPerEuAttributeIsSet) { + if (GpuComputeCapability().IsCuda()) { + GTEST_SKIP() << "waves_per_eu is ROCm-specific"; + } + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(GetDotAlgorithmHlo( + F16, F16, PrecisionConfig::ALG_UNSET))); + + const HloFusionInstruction* triton_fusion = Cast( + verified_module->entry_computation()->root_instruction()); + + llvm::LLVMContext llvm_ctx; + mlir::MLIRContext mlir_context; + llvm::Triple target_triple(amdgpu::TargetTriple()); + std::string data_layout(amdgpu::DataLayout()); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {{16, 64}}; + block_level_parameters.num_warps = 1; + block_level_parameters.waves_per_eu = 4; + + se::DeviceDescription dev_info = TestGpuDeviceInfo::AMDMI210DeviceInfo(); + + TF_ASSERT_OK_AND_ASSIGN( + TritonWrapperResult result, + TritonWrapper( + "test_fn", triton_fusion, + se::GpuComputeCapability{se::RocmComputeCapability("gfx90a")}, + dev_info, block_level_parameters, target_triple, data_layout, + llvm_ctx, mlir_context)); + + ASSERT_NE(result.llvm_module, nullptr); + auto* fn = result.llvm_module->getFunction("test_fn"); + ASSERT_NE(fn, nullptr) + << "Kernel function 'test_fn' not found in LLVM module"; + auto attr = fn->getFnAttribute("amdgpu-waves-per-eu"); + ASSERT_TRUE(attr.isStringAttribute()); + EXPECT_EQ(attr.getValueAsString().str(), "4, 4"); +} + +TEST_F(TritonEmitterTest, RocmWavesPerEuZeroOmitsAttribute) { + if (GpuComputeCapability().IsCuda()) { + GTEST_SKIP() << "waves_per_eu is ROCm-specific"; + } + + TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr verified_module, + ParseAndReturnVerifiedModule(GetDotAlgorithmHlo( + F16, F16, PrecisionConfig::ALG_UNSET))); + + const HloFusionInstruction* triton_fusion = Cast( + verified_module->entry_computation()->root_instruction()); + + llvm::LLVMContext llvm_ctx; + mlir::MLIRContext mlir_context; + llvm::Triple target_triple(amdgpu::TargetTriple()); + std::string data_layout(amdgpu::DataLayout()); + + BlockLevelParameters block_level_parameters; + block_level_parameters.output_tile_sizes = {{16, 64}}; + block_level_parameters.num_warps = 1; + block_level_parameters.waves_per_eu = 0; + + se::DeviceDescription dev_info = TestGpuDeviceInfo::AMDMI210DeviceInfo(); + + TF_ASSERT_OK_AND_ASSIGN( + TritonWrapperResult result, + TritonWrapper( + "test_fn", triton_fusion, + se::GpuComputeCapability{se::RocmComputeCapability("gfx90a")}, + dev_info, block_level_parameters, target_triple, data_layout, + llvm_ctx, mlir_context)); + + ASSERT_NE(result.llvm_module, nullptr); + auto* fn = result.llvm_module->getFunction("test_fn"); + ASSERT_NE(fn, nullptr) + << "Kernel function 'test_fn' not found in LLVM module"; + EXPECT_FALSE(fn->hasFnAttribute("amdgpu-waves-per-eu")) + << "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute"; +} + struct ScaleDotTestParams { std::string lhs_type; std::string rhs_type; diff --git a/xla/backends/gpu/codegen/triton/xtile_compiler.cc b/xla/backends/gpu/codegen/triton/xtile_compiler.cc index 6803eb0530ffc..56be35375bdaa 100644 --- a/xla/backends/gpu/codegen/triton/xtile_compiler.cc +++ b/xla/backends/gpu/codegen/triton/xtile_compiler.cc @@ -525,6 +525,16 @@ absl::StatusOr CompileTritonToLLVM( VerifyModule(*ll_triton_module); } + // Apply ROCm-specific waves_per_eu attribute if set. + if (gpu_cc.IsRocm() && block_level_parameters.waves_per_eu > 0) { + if (auto* fn = ll_triton_module->getFunction(kernel_name)) { + std::string waves_attr = + absl::StrCat(block_level_parameters.waves_per_eu, ", ", + block_level_parameters.waves_per_eu); + fn->addFnAttr("amdgpu-waves-per-eu", waves_attr); + } + } + // Integrate LLVM matmul kernel into XLA's LLVM module. captured_nvvm_annotations = xgt::ExtractNvvmAnnotations(ll_triton_module.get()); diff --git a/xla/backends/gpu/transforms/convert_triton_gemm_config.cc b/xla/backends/gpu/transforms/convert_triton_gemm_config.cc index f5e4a336fb8d4..7b04afef6f2c5 100644 --- a/xla/backends/gpu/transforms/convert_triton_gemm_config.cc +++ b/xla/backends/gpu/transforms/convert_triton_gemm_config.cc @@ -282,6 +282,7 @@ absl::StatusOr FindBlockLevelParameters( params.is_tma_allowed = config.is_tma_allowed; params.is_warp_specialization_allowed = config.is_warp_specialization_allowed; + params.waves_per_eu = config.waves_per_eu; return params; } VLOG(4) << "mapped_dot_tile_sizes: " diff --git a/xla/backends/gpu/transforms/convert_triton_gemm_config_test.cc b/xla/backends/gpu/transforms/convert_triton_gemm_config_test.cc index e7e8e1cfd778a..c4ce7809613da 100644 --- a/xla/backends/gpu/transforms/convert_triton_gemm_config_test.cc +++ b/xla/backends/gpu/transforms/convert_triton_gemm_config_test.cc @@ -153,5 +153,43 @@ ENTRY entry { )")); } +TEST_F(ConvertTritonGemmConfigTest, WavesPerEuPassthrough) { + absl::string_view hlo = R"( +dot { + lhs = f32[8192,512] parameter(0) + rhs = f32[512,512] parameter(1) + ROOT dot = f32[8192,512] dot(lhs, rhs), + lhs_contracting_dims={1}, rhs_contracting_dims={0} +} + +ENTRY entry { + p0 = f32[8192,512] parameter(0) + p1 = f32[512,512] parameter(1) + ROOT fusion = f32[8192,512] fusion(p0, p1), + kind=kCustom, calls=dot, backend_config={ + "fusion_backend_config": { + "kind":"__triton_gemm", "triton_gemm_config": { + "block_m":"64", "block_n":"256", "block_k":"32", + "split_k":"1", "num_stages":"5", "num_warps":"4", "num_ctas":"1", + "waves_per_eu":"4" + } + } + } +})"; + + std::unique_ptr module = RunConvertTritonGemmConfig(hlo); + EXPECT_TRUE(*RunFileCheck(module->ToString(), R"( + CHECK: ROOT{{.*}}fusion( + CHECK-SAME: "block_level_fusion_config" + CHECK-SAME: "waves_per_eu":4 +)")); + const HloInstruction* fusion = nullptr; + ASSERT_THAT(module->entry_computation()->root_instruction(), + GmockMatch(match::Fusion(&fusion))); + EXPECT_FALSE(fusion->backend_config() + ->fusion_backend_config() + .has_triton_gemm_config()); +} + } // namespace } // namespace xla::gpu diff --git a/xla/hlo/ir/backend_config_test.cc b/xla/hlo/ir/backend_config_test.cc index 6e47a558c2d56..459234003a363 100644 --- a/xla/hlo/ir/backend_config_test.cc +++ b/xla/hlo/ir/backend_config_test.cc @@ -37,7 +37,7 @@ const int kNumRepetitions = 100; // since the == operator does not canonicalize the raw strings before comparing // them. constexpr absl::string_view kRawString = - R"({"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1","is_tma_allowed":false,"is_warp_specialization_allowed":false}},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"})"; + R"({"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1","is_tma_allowed":false,"is_warp_specialization_allowed":false,"waves_per_eu":"0"}},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"})"; template void RunThreaded(Input input, CheckFn check_fn) { diff --git a/xla/service/gpu/BUILD b/xla/service/gpu/BUILD index d4b2fb6991732..1701f7af86d06 100644 --- a/xla/service/gpu/BUILD +++ b/xla/service/gpu/BUILD @@ -1202,6 +1202,7 @@ xla_cc_test( srcs = ["matmul_utils_test.cc"], deps = [ ":matmul_utils", + "//xla:autotuning_proto_cc", "//xla:shape_util", "//xla/hlo/ir:hlo", "//xla/hlo/parser:hlo_parser", diff --git a/xla/service/gpu/autotuning/autotune_cache_key.h b/xla/service/gpu/autotuning/autotune_cache_key.h index b9e6a2bd5ab8d..af56f20e03285 100644 --- a/xla/service/gpu/autotuning/autotune_cache_key.h +++ b/xla/service/gpu/autotuning/autotune_cache_key.h @@ -34,7 +34,7 @@ class AutotuneCacheKey { // changes that may affect the autotuning results. // To prevent accidental merges of concurrent increments, update the comment // to explain why the version is bumped. - static constexpr int kCurrentVersion = 32; // Triton integration 1.23. + static constexpr int kCurrentVersion = 33; // Add waves_per_eu support. AutotuneCacheKey(const se::DeviceDescription& device_description, const HloInstruction& instruction, diff --git a/xla/service/gpu/backend_configs.proto b/xla/service/gpu/backend_configs.proto index d90f351c339e1..5d5600c4820cc 100644 --- a/xla/service/gpu/backend_configs.proto +++ b/xla/service/gpu/backend_configs.proto @@ -220,6 +220,9 @@ message BlockLevelFusionConfig { // Allow/disallow automatic warp specialization. bool is_warp_specialization_allowed = 7; + + // Number of waves per execution unit (0 = no restriction). + int32 waves_per_eu = 8; } message DynamicMemcpyConfig { diff --git a/xla/service/gpu/matmul_utils.cc b/xla/service/gpu/matmul_utils.cc index fb77d332b63d2..361e88b25ef3f 100644 --- a/xla/service/gpu/matmul_utils.cc +++ b/xla/service/gpu/matmul_utils.cc @@ -1080,11 +1080,13 @@ absl::StatusOr AsBlasLtEpilogue( TF_RET_CHECK(proto.num_stages() > 0); TF_RET_CHECK(proto.num_warps() > 0); TF_RET_CHECK(proto.num_ctas() > 0); + TF_RET_CHECK(proto.waves_per_eu() >= 0); return TritonGemmConfig( proto.block_m(), proto.block_n(), proto.block_k(), proto.split_k(), proto.num_stages(), proto.num_warps(), proto.num_ctas(), - proto.is_tma_allowed(), proto.is_warp_specialization_allowed()); + proto.is_tma_allowed(), proto.is_warp_specialization_allowed(), + proto.waves_per_eu()); } AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { @@ -1098,6 +1100,7 @@ AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const { key.set_num_ctas(num_ctas); key.set_is_tma_allowed(is_tma_allowed); key.set_is_warp_specialization_allowed(is_warp_specialization_allowed); + key.set_waves_per_eu(waves_per_eu); return key; } @@ -1107,7 +1110,8 @@ std::string TritonGemmConfig::ToString() const { ",split_k:", split_k, ",num_stages:", num_stages, ",num_warps:", num_warps, ",num_ctas:", num_ctas, ",is_tma_allowed:", is_tma_allowed, - ",is_warp_specialization_allowed:", is_warp_specialization_allowed, "}"); + ",is_warp_specialization_allowed:", is_warp_specialization_allowed, + ",waves_per_eu:", waves_per_eu, "}"); } absl::StatusOr IsMatrixMultiplicationTooSmallForRewriting( diff --git a/xla/service/gpu/matmul_utils.h b/xla/service/gpu/matmul_utils.h index 416f90b5f9574..573918a17b08d 100644 --- a/xla/service/gpu/matmul_utils.h +++ b/xla/service/gpu/matmul_utils.h @@ -219,7 +219,8 @@ struct TritonGemmConfig { constexpr TritonGemmConfig(int block_m, int block_n, int block_k, int split_k, int num_stages, int num_warps, int num_ctas = 1, bool is_tma_allowed = false, - bool is_warp_specialization_allowed = false) + bool is_warp_specialization_allowed = false, + int waves_per_eu = 0) : block_m(block_m), block_n(block_n), block_k(block_k), @@ -228,7 +229,8 @@ struct TritonGemmConfig { num_warps(num_warps), num_ctas(num_ctas), is_tma_allowed(is_tma_allowed), - is_warp_specialization_allowed(is_warp_specialization_allowed) {} + is_warp_specialization_allowed(is_warp_specialization_allowed), + waves_per_eu(waves_per_eu) {} // LINT.IfChange int block_m = 0; int block_n = 0; @@ -242,6 +244,8 @@ struct TritonGemmConfig { bool is_tma_allowed = false; // Allow/disallow automatic warp specialization. bool is_warp_specialization_allowed = false; + // Number of waves per execution unit (0 = no restriction). + int waves_per_eu = 0; // LINT.ThenChange(//tensorflow/compiler/xla/autotuning.proto) // When adding new members, please update all methods, such as ToTuple, @@ -255,7 +259,7 @@ struct TritonGemmConfig { auto ToTuple() const { return std::make_tuple(block_m, block_n, block_k, split_k, num_stages, num_warps, num_ctas, is_tma_allowed, - is_warp_specialization_allowed); + is_warp_specialization_allowed, waves_per_eu); } public: diff --git a/xla/service/gpu/matmul_utils_test.cc b/xla/service/gpu/matmul_utils_test.cc index d8c0655a194da..195bb40baf68f 100644 --- a/xla/service/gpu/matmul_utils_test.cc +++ b/xla/service/gpu/matmul_utils_test.cc @@ -17,11 +17,13 @@ limitations under the License. #include #include +#include #include #include #include "absl/status/status_matchers.h" #include "absl/strings/string_view.h" +#include "xla/autotuning.pb.h" #include "xla/hlo/ir/hlo_module.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/testlib/hlo_hardware_independent_test_base.h" @@ -336,6 +338,40 @@ ENTRY DotFunc { absl_testing::IsOkAndHolds(false)); } +TEST(TritonGemmConfigTest, ProtoRoundTripPreservesWavesPerEu) { + TritonGemmConfig config(/*block_m=*/64, /*block_n=*/128, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/2, /*num_warps=*/4, + /*num_ctas=*/1, /*is_tma_allowed=*/false, + /*is_warp_specialization_allowed=*/false, + /*waves_per_eu=*/4); + + AutotuneResult::TritonGemmKey key = config.ToProto(); + EXPECT_EQ(key.waves_per_eu(), 4); + + TF_ASSERT_OK_AND_ASSIGN(TritonGemmConfig restored, + TritonGemmConfig::FromProto(key)); + EXPECT_EQ(restored.block_m, 64); + EXPECT_EQ(restored.block_n, 128); + EXPECT_EQ(restored.block_k, 32); + EXPECT_EQ(restored.split_k, 1); + EXPECT_EQ(restored.num_stages, 2); + EXPECT_EQ(restored.num_warps, 4); + EXPECT_EQ(restored.num_ctas, 1); + EXPECT_EQ(restored.is_tma_allowed, false); + EXPECT_EQ(restored.is_warp_specialization_allowed, false); + EXPECT_EQ(restored.waves_per_eu, 4); +} + +TEST(TritonGemmConfigTest, ToStringIncludesWavesPerEu) { + TritonGemmConfig config(/*block_m=*/64, /*block_n=*/128, /*block_k=*/32, + /*split_k=*/1, /*num_stages=*/2, /*num_warps=*/4, + /*num_ctas=*/1, /*is_tma_allowed=*/false, + /*is_warp_specialization_allowed=*/false, + /*waves_per_eu=*/4); + std::string str = config.ToString(); + EXPECT_NE(str.find("waves_per_eu:4"), std::string::npos); +} + } // namespace } // namespace gpu } // namespace xla diff --git a/xla/service/gpu/model/block_level_parameters.h b/xla/service/gpu/model/block_level_parameters.h index f764b1c8b1823..4f18db4088fb2 100644 --- a/xla/service/gpu/model/block_level_parameters.h +++ b/xla/service/gpu/model/block_level_parameters.h @@ -38,6 +38,7 @@ struct BlockLevelParameters { int64_t global_scratch_memory_size = 0; bool is_tma_allowed = false; bool is_warp_specialization_allowed = false; + int waves_per_eu = 0; // Returns a BlockLevelParameters struct from a BlockLevelFusionConfig proto. static BlockLevelParameters FromBlockLevelFusionConfig( @@ -49,6 +50,7 @@ struct BlockLevelParameters { result.is_tma_allowed = config.is_tma_allowed(); result.is_warp_specialization_allowed = config.is_warp_specialization_allowed(); + result.waves_per_eu = config.waves_per_eu(); result.output_tile_sizes.reserve(config.output_tiles_size()); for (const auto& tile : config.output_tiles()) { result.output_tile_sizes.push_back( @@ -70,6 +72,7 @@ struct BlockLevelParameters { config.set_num_stages(num_stages); config.set_is_tma_allowed(is_tma_allowed); config.set_is_warp_specialization_allowed(is_warp_specialization_allowed); + config.set_waves_per_eu(waves_per_eu); return config; } }; diff --git a/xla/service/gpu/model/block_level_parameters_test.cc b/xla/service/gpu/model/block_level_parameters_test.cc index 9452e82afed4c..807a49e7287b2 100644 --- a/xla/service/gpu/model/block_level_parameters_test.cc +++ b/xla/service/gpu/model/block_level_parameters_test.cc @@ -37,6 +37,7 @@ TEST(BlockLevelParametersTest, block_level_fusion_config.set_num_stages(14); block_level_fusion_config.set_is_tma_allowed(true); block_level_fusion_config.set_is_warp_specialization_allowed(true); + block_level_fusion_config.set_waves_per_eu(4); BlockLevelParameters block_level_parameters = BlockLevelParameters::FromBlockLevelFusionConfig( @@ -48,6 +49,7 @@ TEST(BlockLevelParametersTest, EXPECT_THAT(block_level_parameters.num_stages, 14); EXPECT_THAT(block_level_parameters.is_tma_allowed, true); EXPECT_THAT(block_level_parameters.is_warp_specialization_allowed, true); + EXPECT_THAT(block_level_parameters.waves_per_eu, 4); } TEST(BlockLevelParametersTest, @@ -59,6 +61,7 @@ TEST(BlockLevelParametersTest, block_level_parameters.num_stages = 14; block_level_parameters.is_tma_allowed = true; block_level_parameters.is_warp_specialization_allowed = true; + block_level_parameters.waves_per_eu = 4; BlockLevelFusionConfig block_level_fusion_config = block_level_parameters.ToBlockLevelFusionConfig(); @@ -71,6 +74,7 @@ TEST(BlockLevelParametersTest, EXPECT_THAT(block_level_fusion_config.num_stages(), 14); EXPECT_THAT(block_level_fusion_config.is_tma_allowed(), true); EXPECT_THAT(block_level_fusion_config.is_warp_specialization_allowed(), true); + EXPECT_THAT(block_level_fusion_config.waves_per_eu(), 4); } } // namespace