[ROCm] Add waves_per_eu support to Triton GEMM config#769
[ROCm] Add waves_per_eu support to Triton GEMM config#769nurmukhametov wants to merge 1 commit intomainfrom
Conversation
7ff0909 to
9c4694e
Compare
|
On upstream this may require https://github.com/openxla/xla/blob/main/xla/service/gpu/autotuning/autotune_cache_key.h#L37 increment too. |
| int64 num_ctas = 7; | ||
| bool is_tma_allowed = 8; | ||
| bool is_warp_specialization_allowed = 9; | ||
| int64 waves_per_eu = 10; |
There was a problem hiding this comment.
nit: waves_per_eu is declared as int64 here, but as int32 in backend_configs.proto (BlockLevelFusionConfig) and as int in the C++ structs (TritonGemmConfig, BlockLevelParameters). Since the max meaningful value is single digits ({0,1,2,4}), consider using int32 to match backend_configs.proto and the pattern used by num_ctas/num_stages there.
There was a problem hiding this comment.
It is fine, I exactly follow num_ctas/num_stages pattern.
There was a problem hiding this comment.
Resolved — you're right, num_ctas, num_stages, and num_warps all use int64 in this proto. Consistent with the existing pattern.
| proto.is_tma_allowed(), proto.is_warp_specialization_allowed(), | ||
| proto.waves_per_eu()); |
There was a problem hiding this comment.
All other numeric config fields validated above have TF_RET_CHECK(... >= 1) guards. waves_per_eu is missing a validation check — a negative value from a corrupt/malicious cache entry would be silently accepted and passed to the LLVM attribute.
Consider adding:
| proto.is_tma_allowed(), proto.is_warp_specialization_allowed(), | |
| proto.waves_per_eu()); | |
| proto.waves_per_eu()); | |
| TF_RET_CHECK(config.waves_per_eu >= 0); | |
| return config; |
There was a problem hiding this comment.
Resolved — TF_RET_CHECK(proto.waves_per_eu() >= 0) added in this revision.
| // All waves_per_eu values should be from the expected set {0,1,2,4}. | ||
| EXPECT_THAT(configs, Each(WavesPerEuIs(Le(4)))); |
There was a problem hiding this comment.
nit: The assertion Le(4) would also pass for value 3, which is not in the expected set. Consider using AnyOf(0, 1, 2, 4) to precisely match kWavesPerEuValues:
| // All waves_per_eu values should be from the expected set {0,1,2,4}. | |
| EXPECT_THAT(configs, Each(WavesPerEuIs(Le(4)))); | |
| // waves_per_eu should be from the expected set {0,1,2,4}. | |
| Each(WavesPerEuIs(AnyOf(0, 1, 2, 4))))); |
There was a problem hiding this comment.
Resolved — now using AnyOf(0, 1, 2, 4) for precise matching.
|
|
||
| ASSERT_NE(result.llvm_module, nullptr); | ||
| for (const auto& fn : result.llvm_module->functions()) { | ||
| if (fn.getName() == "test_fn") { | ||
| auto attr = fn.getFnAttribute("amdgpu-waves-per-eu"); | ||
| EXPECT_FALSE(attr.isStringAttribute()) | ||
| << "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute"; | ||
| } | ||
| } |
There was a problem hiding this comment.
Unlike the non-zero test above which asserts EXPECT_TRUE(found_attr), this zero-value test doesn't verify the kernel function test_fn was actually found in the module. If test_fn is missing (e.g., due to a test infrastructure change), the test would pass vacuously. Consider adding an assertion that the function was located:
| ASSERT_NE(result.llvm_module, nullptr); | |
| for (const auto& fn : result.llvm_module->functions()) { | |
| if (fn.getName() == "test_fn") { | |
| auto attr = fn.getFnAttribute("amdgpu-waves-per-eu"); | |
| EXPECT_FALSE(attr.isStringAttribute()) | |
| << "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute"; | |
| } | |
| } | |
| if (func.getName() == "test_fn") { | |
| EXPECT_FALSE(func.hasFnAttribute("amdgpu-waves-per-eu")) | |
| << "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute"; | |
| found_attr = true; | |
| } | |
| } | |
| EXPECT_TRUE(found_attr) | |
| << "Kernel function 'test_fn' not found in LLVM module"; |
(reusing found_attr to mean "found the function" rather than "found the attribute")
There was a problem hiding this comment.
Resolved — the zero-value test now tracks found_fn and asserts the kernel function was located.
Review SummaryClean, well-structured PR that threads 4 inline comments posted:
No correctness issues found. The 4x search space expansion for ROCm is noted and appears intentional. 🤖 Generated with Claude Code |
Re-review SummaryAll 4 previous review comments have been addressed:
No new issues found. The |
I have incremented it here. I wonder if I need also to do something similar due to a backend_proto change? |
3dac08c to
b471c38
Compare
| void TritonDotFusionSearchSpace::AddWavesPerEuParameter( | ||
| const ConfigWithNotes& config, | ||
| std::vector<ConfigWithNotes>& updated_configs) const { | ||
| static constexpr int kWavesPerEuValues[] = {0, 1, 2, 4}; |
There was a problem hiding this comment.
could you add this one as the reference? https://rocm.docs.amd.com/en/latest/how-to/rocm-for-ai/inference-optimization/workload.html#auto-tunable-kernel-configurations
There was a problem hiding this comment.
There was a problem hiding this comment.
I have added the comment based on both links here. I think it is better to be self-consistent and I believe that code comment will outlive any URL.
| VerifyModule(*ll_triton_module); | ||
| } | ||
|
|
||
| // Apply ROCm-specific waves_per_eu attribute if set. |
There was a problem hiding this comment.
b8cfb81 to
22d19a8
Compare
Add the ROCm-specific `waves_per_eu` hint to the Triton GEMM config.
This parameter specifies the minimum number of wavefronts per execution
unit on AMD GPUs. The LLVM backend uses this to:
1) limit the number of SGPRs and VGPRs available per wave, which affects
register allocation;
2) set register pressure thresholds for the instruction scheduler.
The parameter flows through: TritonGemmKey proto -> TritonGemmConfig
struct -> BlockLevelFusionConfig proto -> BlockLevelParameters ->
xtile_compiler, where it is applied as the "amdgpu-waves-per-eu" LLVM
function attribute on the kernel, matching Triton's own AMD backend
behavior. Default value of 0 means no restriction.
The autotuner search space is extended with values {0,1,2,4} for ROCm
targets.
22d19a8 to
34bb84c
Compare
Add the ROCm-specific
waves_per_euoccupancy hint to the Triton GEMM autotuning and compilation pipeline. This parameter controls the number of wavefronts per execution unit on AMD GPUs, allowing the LLVM backend to optimize register allocation for target occupancy.The parameter flows through: TritonGemmKey proto -> TritonGemmConfig struct -> BlockLevelFusionConfig proto -> BlockLevelParameters -> xtile_compiler, where it is applied as the "amdgpu-waves-per-eu" LLVM function attribute on the kernel, matching Triton's own AMD backend behavior. Default value of 0 means no restriction.
The autotuner search space is extended with values {0,1,2,4} for ROCm targets.
Submission Checklist