Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xla/autotuning.proto
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ message AutotuneResult {
int64 num_ctas = 7;
bool is_tma_allowed = 8;
bool is_warp_specialization_allowed = 9;
int64 waves_per_eu = 10;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: waves_per_eu is declared as int64 here, but as int32 in backend_configs.proto (BlockLevelFusionConfig) and as int in the C++ structs (TritonGemmConfig, BlockLevelParameters). Since the max meaningful value is single digits ({0,1,2,4}), consider using int32 to match backend_configs.proto and the pattern used by num_ctas/num_stages there.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine, I exactly follow num_ctas/num_stages pattern.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — you're right, num_ctas, num_stages, and num_warps all use int64 in this proto. Consistent with the existing pattern.

// LINT.ThenChange(//tensorflow/compiler/xla/service/gpu/matmul_utils.h)
}

Expand Down
1 change: 1 addition & 0 deletions xla/backends/gpu/autotuner/triton/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ xla_cc_test(
"//xla/stream_executor:device_description",
"//xla/stream_executor:device_description_proto_cc",
"//xla/stream_executor/cuda:cuda_compute_capability",
"//xla/stream_executor/rocm:rocm_compute_capability",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
Expand Down
37 changes: 37 additions & 0 deletions xla/backends/gpu/autotuner/triton/dot_search_space.cc
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,9 @@ std::vector<TritonGemmConfig> TritonDotFusionSearchSpace::GenerateConfigs(
ExtendConfigs(configs,
&TritonDotFusionSearchSpace::AddWarpSpecializationParameter);
}
if (device_description_.gpu_compute_capability().IsRocm()) {
ExtendConfigs(configs, &TritonDotFusionSearchSpace::AddWavesPerEuParameter);
}

std::vector<TritonGemmConfig> result;
result.reserve(configs.size());
Expand Down Expand Up @@ -738,4 +741,38 @@ void TritonDotFusionSearchSpace::EliminateLowOccupancyConfigs(
VLOG(10) << "Eliminated " << num_configs - configs.size() << " configs.";
}

void TritonDotFusionSearchSpace::AddWavesPerEuParameter(
const ConfigWithNotes& config,
std::vector<ConfigWithNotes>& updated_configs) const {
// Hints the LLVM backend to reduce VGPR usage so that the given number of
// wavefronts can run concurrently on each Execution Unit (EU). On MI300X
// each EU has 512 VGPRs allocated in blocks of 16, giving:
//
// Num VGPRs Waves/EU Waves/CU
// <= 64 8 32
// <= 96 5 20
// <= 128 4 16
// <= 168 3 12
// <= 256 2 8
// > 256 1 4
//
// 0 means no restriction (the default). Higher values force tighter register
// budgets, potentially increasing occupancy at the cost of more spilling.
//
// Actual occupancy also depends on LDS (shared memory) and num_warps:
// occ = min(floor(occ_vgpr * 4 / num_warps),
// floor(65536 / lds_bytes)) * num_warps / 4
// where occ_vgpr is from the table above. Neither lds_bytes nor
// the VGPR count are known at search space construction time, so we
// enumerate a few values and let the autotuner pick the best one.
static constexpr int kWavesPerEuValues[] = {0, 1, 2, 4};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the comment based on both links here. I think it is better to be self-consistent and I believe that code comment will outlive any URL.

for (int waves : kWavesPerEuValues) {
ConfigWithNotes new_config = config;
new_config.config.waves_per_eu = waves;
VLOG(10) << "Adding waves_per_eu parameter: config = "
<< new_config.ToString();
updated_configs.push_back(new_config);
}
}

} // namespace xla::gpu
5 changes: 5 additions & 0 deletions xla/backends/gpu/autotuner/triton/dot_search_space.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,11 @@ class TritonDotFusionSearchSpace {
const ConfigWithNotes& config,
std::vector<ConfigWithNotes>& updated_configs) const;

// Extend the passed configs with waves_per_eu values.
void AddWavesPerEuParameter(
const ConfigWithNotes& config,
std::vector<ConfigWithNotes>& updated_configs) const;

// The order of these fields is important: the values of those defined earlier
// are used to compute the values of later ones.
se::DeviceDescription device_description_;
Expand Down
39 changes: 39 additions & 0 deletions xla/backends/gpu/autotuner/triton/dot_search_space_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_compute_capability.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_description.pb.h"
#include "xla/stream_executor/rocm/rocm_compute_capability.h"
#include "xla/tsl/platform/statusor.h"

namespace xla::gpu {
Expand All @@ -45,6 +46,7 @@ void PrintTo(const TritonGemmConfig& config, std::ostream* os) {
namespace {

using ::testing::AllOf;
using ::testing::AnyOf;
using ::testing::Contains;
using ::testing::ElementsAre;
using ::testing::ElementsAreArray;
Expand Down Expand Up @@ -92,6 +94,10 @@ template <typename MatcherType>
auto NumCtasIs(MatcherType matcher) {
return Field("num_ctas", &TritonGemmConfig::num_ctas, matcher);
}
template <typename MatcherType>
auto WavesPerEuIs(MatcherType matcher) {
return Field("waves_per_eu", &TritonGemmConfig::waves_per_eu, matcher);
}

auto IsValidConfig() {
return AllOf(BlockMIs(Ge(1)), BlockNIs(Ge(1)), BlockKIs(Ge(1)),
Expand Down Expand Up @@ -660,5 +666,38 @@ TEST_F(DotSearchSpaceTest, EnsuresSplitKAndBlockKAreCompatible) {
}
}

TEST_F(DotSearchSpaceTest, CudaDoesNotGenerateWavesPerEuConfigs) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
GetDefaultDotModule());
TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get());

EXPECT_THAT(search_space.GenerateConfigs(),
AllOf(Not(IsEmpty()), Each(WavesPerEuIs(Eq(0)))));
}

class RocmDotSearchSpaceTest : public DefaultDeviceDotSearchSpaceTest {
protected:
RocmDotSearchSpaceTest() {
// MI300X-like parameters.
device_description_.set_registers_per_block_limit(64 * 1024);
device_description_.set_core_count(304);
device_description_.set_threads_per_block_limit(1024);
device_description_.set_threads_per_warp(64);
device_description_.set_shared_memory_per_block_optin(64 * 1024);
device_description_.set_gpu_compute_capability(
se::GpuComputeCapability(se::RocmComputeCapability("gfx942")));
}
};

TEST_F(RocmDotSearchSpaceTest, GeneratesWavesPerEuConfigs) {
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<VerifiedHloModule> module,
GetDefaultDotModule());
TritonDotFusionSearchSpace search_space = MakeSearchSpace(module.get());
std::vector<TritonGemmConfig> configs = search_space.GenerateConfigs();

EXPECT_THAT(configs, AllOf(Not(IsEmpty()), Contains(WavesPerEuIs(Ge(1))),
Each(WavesPerEuIs(AnyOf(0, 1, 2, 4)))));
}

} // namespace
} // namespace xla::gpu
Original file line number Diff line number Diff line change
Expand Up @@ -2074,6 +2074,87 @@ TEST_F(TritonEmitterTest, RocmWarpSizeIsSetCorrectly) {
EXPECT_THAT(RunFileCheck(triton_passes_log, kPattern_n), true);
}

TEST_F(TritonEmitterTest, RocmWavesPerEuAttributeIsSet) {
if (GpuComputeCapability().IsCuda()) {
GTEST_SKIP() << "waves_per_eu is ROCm-specific";
}

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> verified_module,
ParseAndReturnVerifiedModule(GetDotAlgorithmHlo(
F16, F16, PrecisionConfig::ALG_UNSET)));

const HloFusionInstruction* triton_fusion = Cast<HloFusionInstruction>(
verified_module->entry_computation()->root_instruction());

llvm::LLVMContext llvm_ctx;
mlir::MLIRContext mlir_context;
llvm::Triple target_triple(amdgpu::TargetTriple());
std::string data_layout(amdgpu::DataLayout());

BlockLevelParameters block_level_parameters;
block_level_parameters.output_tile_sizes = {{16, 64}};
block_level_parameters.num_warps = 1;
block_level_parameters.waves_per_eu = 4;

se::DeviceDescription dev_info = TestGpuDeviceInfo::AMDMI210DeviceInfo();

TF_ASSERT_OK_AND_ASSIGN(
TritonWrapperResult result,
TritonWrapper(
"test_fn", triton_fusion,
se::GpuComputeCapability{se::RocmComputeCapability("gfx90a")},
dev_info, block_level_parameters, target_triple, data_layout,
llvm_ctx, mlir_context));

ASSERT_NE(result.llvm_module, nullptr);
auto* fn = result.llvm_module->getFunction("test_fn");
ASSERT_NE(fn, nullptr)
<< "Kernel function 'test_fn' not found in LLVM module";
auto attr = fn->getFnAttribute("amdgpu-waves-per-eu");
ASSERT_TRUE(attr.isStringAttribute());
EXPECT_EQ(attr.getValueAsString().str(), "4, 4");
}

TEST_F(TritonEmitterTest, RocmWavesPerEuZeroOmitsAttribute) {
if (GpuComputeCapability().IsCuda()) {
GTEST_SKIP() << "waves_per_eu is ROCm-specific";
}

TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> verified_module,
ParseAndReturnVerifiedModule(GetDotAlgorithmHlo(
F16, F16, PrecisionConfig::ALG_UNSET)));

const HloFusionInstruction* triton_fusion = Cast<HloFusionInstruction>(
verified_module->entry_computation()->root_instruction());

llvm::LLVMContext llvm_ctx;
mlir::MLIRContext mlir_context;
llvm::Triple target_triple(amdgpu::TargetTriple());
std::string data_layout(amdgpu::DataLayout());

BlockLevelParameters block_level_parameters;
block_level_parameters.output_tile_sizes = {{16, 64}};
block_level_parameters.num_warps = 1;
block_level_parameters.waves_per_eu = 0;

se::DeviceDescription dev_info = TestGpuDeviceInfo::AMDMI210DeviceInfo();

TF_ASSERT_OK_AND_ASSIGN(
TritonWrapperResult result,
TritonWrapper(
"test_fn", triton_fusion,
se::GpuComputeCapability{se::RocmComputeCapability("gfx90a")},
dev_info, block_level_parameters, target_triple, data_layout,
llvm_ctx, mlir_context));

ASSERT_NE(result.llvm_module, nullptr);
auto* fn = result.llvm_module->getFunction("test_fn");
ASSERT_NE(fn, nullptr)
<< "Kernel function 'test_fn' not found in LLVM module";
EXPECT_FALSE(fn->hasFnAttribute("amdgpu-waves-per-eu"))
<< "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute";
}

struct ScaleDotTestParams {
std::string lhs_type;
std::string rhs_type;
Expand Down
10 changes: 10 additions & 0 deletions xla/backends/gpu/codegen/triton/xtile_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,16 @@ absl::StatusOr<TritonWrapperResult> CompileTritonToLLVM(
VerifyModule(*ll_triton_module);
}

// Apply ROCm-specific waves_per_eu attribute if set.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (gpu_cc.IsRocm() && block_level_parameters.waves_per_eu > 0) {
if (auto* fn = ll_triton_module->getFunction(kernel_name)) {
std::string waves_attr =
absl::StrCat(block_level_parameters.waves_per_eu, ", ",
block_level_parameters.waves_per_eu);
fn->addFnAttr("amdgpu-waves-per-eu", waves_attr);
}
}

// Integrate LLVM matmul kernel into XLA's LLVM module.
captured_nvvm_annotations =
xgt::ExtractNvvmAnnotations(ll_triton_module.get());
Expand Down
1 change: 1 addition & 0 deletions xla/backends/gpu/transforms/convert_triton_gemm_config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ absl::StatusOr<BlockLevelParameters> FindBlockLevelParameters(
params.is_tma_allowed = config.is_tma_allowed;
params.is_warp_specialization_allowed =
config.is_warp_specialization_allowed;
params.waves_per_eu = config.waves_per_eu;
return params;
}
VLOG(4) << "mapped_dot_tile_sizes: "
Expand Down
38 changes: 38 additions & 0 deletions xla/backends/gpu/transforms/convert_triton_gemm_config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,43 @@ ENTRY entry {
)"));
}

TEST_F(ConvertTritonGemmConfigTest, WavesPerEuPassthrough) {
absl::string_view hlo = R"(
dot {
lhs = f32[8192,512] parameter(0)
rhs = f32[512,512] parameter(1)
ROOT dot = f32[8192,512] dot(lhs, rhs),
lhs_contracting_dims={1}, rhs_contracting_dims={0}
}

ENTRY entry {
p0 = f32[8192,512] parameter(0)
p1 = f32[512,512] parameter(1)
ROOT fusion = f32[8192,512] fusion(p0, p1),
kind=kCustom, calls=dot, backend_config={
"fusion_backend_config": {
"kind":"__triton_gemm", "triton_gemm_config": {
"block_m":"64", "block_n":"256", "block_k":"32",
"split_k":"1", "num_stages":"5", "num_warps":"4", "num_ctas":"1",
"waves_per_eu":"4"
}
}
}
})";

std::unique_ptr<VerifiedHloModule> module = RunConvertTritonGemmConfig(hlo);
EXPECT_TRUE(*RunFileCheck(module->ToString(), R"(
CHECK: ROOT{{.*}}fusion(
CHECK-SAME: "block_level_fusion_config"
CHECK-SAME: "waves_per_eu":4
)"));
const HloInstruction* fusion = nullptr;
ASSERT_THAT(module->entry_computation()->root_instruction(),
GmockMatch(match::Fusion(&fusion)));
EXPECT_FALSE(fusion->backend_config<GpuBackendConfig>()
->fusion_backend_config()
.has_triton_gemm_config());
}

} // namespace
} // namespace xla::gpu
2 changes: 1 addition & 1 deletion xla/hlo/ir/backend_config_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const int kNumRepetitions = 100;
// since the == operator does not canonicalize the raw strings before comparing
// them.
constexpr absl::string_view kRawString =
R"({"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1","is_tma_allowed":false,"is_warp_specialization_allowed":false}},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"})";
R"({"operation_queue_id":"0","wait_on_operation_queues":[],"fusion_backend_config":{"kind":"__triton_gemm","triton_gemm_config":{"block_m":"256","block_n":"256","block_k":"32","split_k":"1","num_stages":"1","num_warps":"16","num_ctas":"1","is_tma_allowed":false,"is_warp_specialization_allowed":false,"waves_per_eu":"0"}},"force_earliest_schedule":false,"reification_cost":[],"device_type":"DEVICE_TYPE_INVALID"})";

template <typename Input, typename CheckFn>
void RunThreaded(Input input, CheckFn check_fn) {
Expand Down
1 change: 1 addition & 0 deletions xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1202,6 +1202,7 @@ xla_cc_test(
srcs = ["matmul_utils_test.cc"],
deps = [
":matmul_utils",
"//xla:autotuning_proto_cc",
"//xla:shape_util",
"//xla/hlo/ir:hlo",
"//xla/hlo/parser:hlo_parser",
Expand Down
2 changes: 1 addition & 1 deletion xla/service/gpu/autotuning/autotune_cache_key.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class AutotuneCacheKey {
// changes that may affect the autotuning results.
// To prevent accidental merges of concurrent increments, update the comment
// to explain why the version is bumped.
static constexpr int kCurrentVersion = 32; // Triton integration 1.23.
static constexpr int kCurrentVersion = 33; // Add waves_per_eu support.

AutotuneCacheKey(const se::DeviceDescription& device_description,
const HloInstruction& instruction,
Expand Down
3 changes: 3 additions & 0 deletions xla/service/gpu/backend_configs.proto
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,9 @@ message BlockLevelFusionConfig {

// Allow/disallow automatic warp specialization.
bool is_warp_specialization_allowed = 7;

// Number of waves per execution unit (0 = no restriction).
int32 waves_per_eu = 8;
}

message DynamicMemcpyConfig {
Expand Down
8 changes: 6 additions & 2 deletions xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1080,11 +1080,13 @@ absl::StatusOr<se::gpu::BlasLt::Epilogue> AsBlasLtEpilogue(
TF_RET_CHECK(proto.num_stages() > 0);
TF_RET_CHECK(proto.num_warps() > 0);
TF_RET_CHECK(proto.num_ctas() > 0);
TF_RET_CHECK(proto.waves_per_eu() >= 0);

return TritonGemmConfig(
proto.block_m(), proto.block_n(), proto.block_k(), proto.split_k(),
proto.num_stages(), proto.num_warps(), proto.num_ctas(),
proto.is_tma_allowed(), proto.is_warp_specialization_allowed());
proto.is_tma_allowed(), proto.is_warp_specialization_allowed(),
proto.waves_per_eu());
Comment on lines +1088 to +1089
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other numeric config fields validated above have TF_RET_CHECK(... >= 1) guards. waves_per_eu is missing a validation check — a negative value from a corrupt/malicious cache entry would be silently accepted and passed to the LLVM attribute.

Consider adding:

Suggested change
proto.is_tma_allowed(), proto.is_warp_specialization_allowed(),
proto.waves_per_eu());
proto.waves_per_eu());
TF_RET_CHECK(config.waves_per_eu >= 0);
return config;

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — TF_RET_CHECK(proto.waves_per_eu() >= 0) added in this revision.

}

AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const {
Expand All @@ -1098,6 +1100,7 @@ AutotuneResult::TritonGemmKey TritonGemmConfig::ToProto() const {
key.set_num_ctas(num_ctas);
key.set_is_tma_allowed(is_tma_allowed);
key.set_is_warp_specialization_allowed(is_warp_specialization_allowed);
key.set_waves_per_eu(waves_per_eu);
return key;
}

Expand All @@ -1107,7 +1110,8 @@ std::string TritonGemmConfig::ToString() const {
",split_k:", split_k, ",num_stages:", num_stages,
",num_warps:", num_warps, ",num_ctas:", num_ctas,
",is_tma_allowed:", is_tma_allowed,
",is_warp_specialization_allowed:", is_warp_specialization_allowed, "}");
",is_warp_specialization_allowed:", is_warp_specialization_allowed,
",waves_per_eu:", waves_per_eu, "}");
}

absl::StatusOr<bool> IsMatrixMultiplicationTooSmallForRewriting(
Expand Down
Loading
Loading