Skip to content

[ROCm] Add waves_per_eu support to Triton GEMM config#769

Open
nurmukhametov wants to merge 1 commit intomainfrom
anurmukh/add-waves-per-eu-triton-gemm-config
Open

[ROCm] Add waves_per_eu support to Triton GEMM config#769
nurmukhametov wants to merge 1 commit intomainfrom
anurmukh/add-waves-per-eu-triton-gemm-config

Conversation

@nurmukhametov
Copy link
Copy Markdown
Member

Add the ROCm-specific waves_per_eu occupancy hint to the Triton GEMM autotuning and compilation pipeline. This parameter controls the number of wavefronts per execution unit on AMD GPUs, allowing the LLVM backend to optimize register allocation for target occupancy.

The parameter flows through: TritonGemmKey proto -> TritonGemmConfig struct -> BlockLevelFusionConfig proto -> BlockLevelParameters -> xtile_compiler, where it is applied as the "amdgpu-waves-per-eu" LLVM function attribute on the kernel, matching Triton's own AMD backend behavior. Default value of 0 means no restriction.

The autotuner search space is extended with values {0,1,2,4} for ROCm targets.

Submission Checklist

@nurmukhametov nurmukhametov force-pushed the anurmukh/add-waves-per-eu-triton-gemm-config branch from 7ff0909 to 9c4694e Compare April 1, 2026 11:33
@nurmukhametov nurmukhametov added the claude-review Request a Claude AI code review for this PR label Apr 1, 2026
@draganmladjenovic
Copy link
Copy Markdown

int64 num_ctas = 7;
bool is_tma_allowed = 8;
bool is_warp_specialization_allowed = 9;
int64 waves_per_eu = 10;
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: waves_per_eu is declared as int64 here, but as int32 in backend_configs.proto (BlockLevelFusionConfig) and as int in the C++ structs (TritonGemmConfig, BlockLevelParameters). Since the max meaningful value is single digits ({0,1,2,4}), consider using int32 to match backend_configs.proto and the pattern used by num_ctas/num_stages there.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fine, I exactly follow num_ctas/num_stages pattern.

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — you're right, num_ctas, num_stages, and num_warps all use int64 in this proto. Consistent with the existing pattern.

Comment on lines +1087 to +1088
proto.is_tma_allowed(), proto.is_warp_specialization_allowed(),
proto.waves_per_eu());
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All other numeric config fields validated above have TF_RET_CHECK(... >= 1) guards. waves_per_eu is missing a validation check — a negative value from a corrupt/malicious cache entry would be silently accepted and passed to the LLVM attribute.

Consider adding:

Suggested change
proto.is_tma_allowed(), proto.is_warp_specialization_allowed(),
proto.waves_per_eu());
proto.waves_per_eu());
TF_RET_CHECK(config.waves_per_eu >= 0);
return config;

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — TF_RET_CHECK(proto.waves_per_eu() >= 0) added in this revision.

Comment on lines +700 to +701
// All waves_per_eu values should be from the expected set {0,1,2,4}.
EXPECT_THAT(configs, Each(WavesPerEuIs(Le(4))));
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: The assertion Le(4) would also pass for value 3, which is not in the expected set. Consider using AnyOf(0, 1, 2, 4) to precisely match kWavesPerEuValues:

Suggested change
// All waves_per_eu values should be from the expected set {0,1,2,4}.
EXPECT_THAT(configs, Each(WavesPerEuIs(Le(4))));
// waves_per_eu should be from the expected set {0,1,2,4}.
Each(WavesPerEuIs(AnyOf(0, 1, 2, 4)))));

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — now using AnyOf(0, 1, 2, 4) for precise matching.

Comment on lines +2153 to +2161

ASSERT_NE(result.llvm_module, nullptr);
for (const auto& fn : result.llvm_module->functions()) {
if (fn.getName() == "test_fn") {
auto attr = fn.getFnAttribute("amdgpu-waves-per-eu");
EXPECT_FALSE(attr.isStringAttribute())
<< "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute";
}
}
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlike the non-zero test above which asserts EXPECT_TRUE(found_attr), this zero-value test doesn't verify the kernel function test_fn was actually found in the module. If test_fn is missing (e.g., due to a test infrastructure change), the test would pass vacuously. Consider adding an assertion that the function was located:

Suggested change
ASSERT_NE(result.llvm_module, nullptr);
for (const auto& fn : result.llvm_module->functions()) {
if (fn.getName() == "test_fn") {
auto attr = fn.getFnAttribute("amdgpu-waves-per-eu");
EXPECT_FALSE(attr.isStringAttribute())
<< "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute";
}
}
if (func.getName() == "test_fn") {
EXPECT_FALSE(func.hasFnAttribute("amdgpu-waves-per-eu"))
<< "waves_per_eu=0 should not set amdgpu-waves-per-eu attribute";
found_attr = true;
}
}
EXPECT_TRUE(found_attr)
<< "Kernel function 'test_fn' not found in LLVM module";

(reusing found_attr to mean "found the function" rather than "found the attribute")

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resolved — the zero-value test now tracks found_fn and asserts the kernel function was located.

@claude
Copy link
Copy Markdown

claude bot commented Apr 1, 2026

Review Summary

Clean, well-structured PR that threads waves_per_eu through the full Triton GEMM stack (proto → config → autotuner → LLVM IR). The propagation is complete across all required touchpoints and tests cover both the positive and zero-value cases.

4 inline comments posted:

  • Proto type inconsistency (int64 vs int32 elsewhere) — nit
  • Missing >= 0 validation in FromProto — suggested fix
  • Test assertion Le(4) is weaker than the intended set {0,1,2,4} — nit
  • Zero-value device test could pass vacuously if kernel function isn't found — suggested fix

No correctness issues found. The 4x search space expansion for ROCm is noted and appears intentional.

🤖 Generated with Claude Code

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 1, 2026
@nurmukhametov nurmukhametov added the claude-review Request a Claude AI code review for this PR label Apr 1, 2026
@claude
Copy link
Copy Markdown

claude bot commented Apr 1, 2026

Re-review Summary

All 4 previous review comments have been addressed:

  • Proto type consistency (int64 in autotuning.proto) — confirmed consistent with sibling fields
  • Validation checkTF_RET_CHECK(proto.waves_per_eu() >= 0) added in matmul_utils.cc
  • Test precisionAnyOf(0, 1, 2, 4) matcher now used in search space test
  • Vacuous test guardfound_fn tracking added to device test for zero-value case

No new issues found. The waves_per_eu plumbing is complete and well-tested across proto definitions, config structs, autotuner search space, and the xtile compiler. LGTM.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Apr 1, 2026
@nurmukhametov
Copy link
Copy Markdown
Member Author

On upstream this may require https://github.com/openxla/xla/blob/main/xla/service/gpu/autotuning/autotune_cache_key.h#L37 increment too.

I have incremented it here. I wonder if I need also to do something similar due to a backend_proto change?

@nurmukhametov nurmukhametov force-pushed the anurmukh/add-waves-per-eu-triton-gemm-config branch 2 times, most recently from 3dac08c to b471c38 Compare April 7, 2026 13:31
void TritonDotFusionSearchSpace::AddWavesPerEuParameter(
const ConfigWithNotes& config,
std::vector<ConfigWithNotes>& updated_configs) const {
static constexpr int kWavesPerEuValues[] = {0, 1, 2, 4};
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the comment based on both links here. I think it is better to be self-consistent and I believe that code comment will outlive any URL.

VerifyModule(*ll_triton_module);
}

// Apply ROCm-specific waves_per_eu attribute if set.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@nurmukhametov nurmukhametov force-pushed the anurmukh/add-waves-per-eu-triton-gemm-config branch 4 times, most recently from b8cfb81 to 22d19a8 Compare April 7, 2026 16:29
@nurmukhametov nurmukhametov requested a review from i-chaochen April 8, 2026 08:30
Add the ROCm-specific `waves_per_eu` hint to the Triton GEMM config.
This parameter specifies the minimum number of wavefronts per execution
unit on AMD GPUs. The LLVM backend uses this to:
1) limit the number of SGPRs and VGPRs available per wave, which affects
   register allocation;
2) set register pressure thresholds for the instruction scheduler.

The parameter flows through: TritonGemmKey proto -> TritonGemmConfig
struct -> BlockLevelFusionConfig proto -> BlockLevelParameters ->
xtile_compiler, where it is applied as the "amdgpu-waves-per-eu" LLVM
function attribute on the kernel, matching Triton's own AMD backend
behavior. Default value of 0 means no restriction.

The autotuner search space is extended with values {0,1,2,4} for ROCm
targets.
@nurmukhametov nurmukhametov force-pushed the anurmukh/add-waves-per-eu-triton-gemm-config branch from 22d19a8 to 34bb84c Compare April 8, 2026 08:47
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants