Skip to content

[ROCm] Add Triton autotuning configs for MI300 and MI350 v0.9.0#765

Open
nurmukhametov wants to merge 1 commit intorocm-jaxlib-v0.9.0from
anurmukh/mi300-mi350-triton-configs-0.9.0
Open

[ROCm] Add Triton autotuning configs for MI300 and MI350 v0.9.0#765
nurmukhametov wants to merge 1 commit intorocm-jaxlib-v0.9.0from
anurmukh/mi300-mi350-triton-configs-0.9.0

Conversation

@nurmukhametov
Copy link
Copy Markdown
Member

@nurmukhametov nurmukhametov commented Mar 27, 2026

Motivation

Improve performance.

Technical Details

Extend the Triton GEMM autotuner with dedicated config sets for AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 40 configs), expanding beyond the generic 6-config ROCm default.

Submission Checklist

@nurmukhametov nurmukhametov added the claude-review Request a Claude AI code review for this PR label Mar 27, 2026
Comment on lines +143 to +177
constexpr absl::string_view kMI300TritonConfigs = R"(
config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 }
config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 128 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 }
config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 }
config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 }
config { block_m: 64 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 }
config { block_m: 128 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 128 block_n: 64 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 128 block_n: 128 block_k: 32 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 }
config { block_m: 128 block_n: 256 block_k: 64 split_k: 2 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 256 block_n: 128 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 }
config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 1 num_warps: 8 num_ctas: 1 }
config { block_m: 256 block_n: 256 block_k: 32 split_k: 4 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 128 block_n: 32 block_k: 32 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 }
config { block_m: 64 block_n: 32 block_k: 32 split_k: 8 num_stages: 3 num_warps: 2 num_ctas: 1 }
config { block_m: 64 block_n: 32 block_k: 32 split_k: 8 num_stages: 5 num_warps: 2 num_ctas: 1 }
config { block_m: 128 block_n: 32 block_k: 32 split_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 }
config { block_m: 32 block_n: 32 block_k: 32 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 }
config { block_m: 64 block_n: 32 block_k: 128 split_k: 2 num_stages: 2 num_warps: 2 num_ctas: 1 }
config { block_m: 256 block_n: 8 block_k: 32 split_k: 4 num_stages: 1 num_warps: 2 num_ctas: 1 }
config { block_m: 128 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 }
config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 }
config { block_m: 32 block_n: 16 block_k: 128 split_k: 2 num_stages: 5 num_warps: 2 num_ctas: 1 }
config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 }
config { block_m: 64 block_n: 8 block_k: 128 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 }
config { block_m: 32 block_n: 16 block_k: 256 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 }
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

optional: The first 6 configs of both kMI300TritonConfigs and kMI350TritonConfigs are identical to the 6 entries in kDefaultRocmTritonConfigs, and the first 13 configs of MI300 and MI350 are identical to each other. This is presumably intentional (superset design), but it does mean autotuning time increases from 6 to 36/43 configs. Just flagging in case any configs were accidentally duplicated or could be pruned.

Also: no unit tests are added for these new config paths. A simple test verifying GetTritonConfigsForPlatform(kMI300).size() == 36 and kMI350 == 43 would be a useful regression guard against config parsing errors.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

since num_ctas is NV feature and we always put as 1, do we really need to keep this on our config? will be feasible to rm this part completely, and just put a default 1 on our initialization part?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_ctas is not NVIDIA-specific. The ROCm pipeline also passes it to Triton here:

rocm_cc.threads_per_warp(), num_ctas}));

That said, I'm not 100% sure whether Triton actually does anything meaningful with num_ctas on the AMD/ROCm backend. It might just be passed through and effectively ignored when set to 1. But regardless, it cannot be omitted from the configs because the proto default is 0, and the validation rejects that here:
TF_RET_CHECK(proto.num_ctas() > 0);

So num_ctas: 1 must be explicitly set in all configs, even if it may be a no-op on ROCm.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this is a NV feature after all https://triton-lang.org/main/python-api/generated/triton.Config.html AMD doesn't have this blocked cluster thing and always set it as 1, so maybe we could no need to put this into config.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean this is a NV feature after all https://triton-lang.org/main/python-api/generated/triton.Config.html AMD doesn't have this blocked cluster thing and always set it as 1, so maybe we could no need to put this into config.

If I remove them not changing anything else then it fails on the CHECK mentioned above:

F0000 00:00:1774953714.758075  267522 triton_configs.cc:348] Check failed: config is OK (INTERNAL: RET_CHECK failure (xla/service/gpu/matmul_utils.cc:811) proto.num_ctas() > 0 )
*** Check failure stack trace: ***
    @     0x55a642804e34  absl::lts_20250814::log_internal::LogMessage::SendToLog()
    @     0x55a642804db6  absl::lts_20250814::log_internal::LogMessage::Flush()
    @     0x55a6362117df  xla::gpu::(anonymous namespace)::LoadTritonConfigs()
    @     0x55a6362113a9  xla::gpu::GetTritonConfigsForPlatform()
    @     0x55a636208d6f  xla::gpu::TritonBackend::GetSupportedConfigsForDot()
    @     0x55a636208a0b  xla::gpu::TritonBackend::GetSupportedConfigs()
    @     0x55a6361f66a2  xla::Autotuner::GetSupportedConfigs()
    ...

At the current state, num_ctas: 1 has to be provided because the default is 0 and there is at least one place in the code that doesn't allow it.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

of course if we directly rm this one it will cause protobuf error....which is why I said in the beginning maybe we could re-set a default 1 to our initialization part only, despite the default is 0 in NV side.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think diverging the default value for the proto only for the ROCm platform is a good idea. We might want to change the default altogether since 0 is meaningless, but it's better to do it right in the upstream branch.

@claude
Copy link
Copy Markdown

claude bot commented Mar 27, 2026

Review Summary

Overall: Clean, data-focused change that adds MI300 (gfx942) and MI350 (gfx950) Triton GEMM autotuning configs. The enum/map/config plumbing is correct and consistent. A few minor suggestions posted inline:

  1. Defensive null guardgemm_fusion_autotuner_rocm.cc dereferences rocm_compute_capability() without a null/ROCm check, unlike the guarded pattern in triton.cc.
  2. Duplicated dispatch logic — MI300/MI350 routing is duplicated across two files; a shared helper would reduce maintenance burden for future GPU targets.
  3. Test coverage — No unit tests for the new config paths; a simple size-check test would guard against config parsing regressions.

No blocking issues found.


Reviewed with Claude

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 27, 2026
@nurmukhametov nurmukhametov force-pushed the anurmukh/mi300-mi350-triton-configs-0.9.0 branch 2 times, most recently from bc56c16 to 36f58c2 Compare March 27, 2026 17:23
@nurmukhametov nurmukhametov added the claude-review Request a Claude AI code review for this PR label Mar 27, 2026
@claude
Copy link
Copy Markdown

claude bot commented Mar 27, 2026

Re-review Summary

Checked the latest revision against previous review feedback. The CHECK(rocm_cc != nullptr) has been added as requested — resolved. The dispatch duplication across triton.cc and gemm_fusion_autotuner_rocm.cc is intentional to align with other backends. No new issues found in this revision.

@github-actions github-actions bot removed the claude-review Request a Claude AI code review for this PR label Mar 27, 2026
@nurmukhametov nurmukhametov force-pushed the anurmukh/mi300-mi350-triton-configs-0.9.0 branch from 36f58c2 to 75a45e1 Compare April 1, 2026 11:17
Extend the Triton GEMM autotuner with dedicated config sets for
AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 40 configs),
expanding beyond the generic 6-config ROCm default.
@nurmukhametov nurmukhametov force-pushed the anurmukh/mi300-mi350-triton-configs-0.9.0 branch from 75a45e1 to 85db421 Compare April 1, 2026 11:32
@nurmukhametov nurmukhametov changed the title [ROCm] Add Triton autotuning configs for MI300 and MI350 [ROCm] Add Triton autotuning configs for MI300 and MI350 v0.9.0 Apr 1, 2026
@nurmukhametov
Copy link
Copy Markdown
Member Author

On MI300, Llama 3.2 1B with these updated configs improved Triton-only GEMM performance from 2193ms to 1796ms, an 18.1% speedup. On MI350, the same model improved from 1295ms to 1121ms, a 13.4% speedup. For Gemma 2 2B and Gemma 3 1B, the updated configs perform within a few percent of exhaustive search on both platforms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants