[ROCm] Add Triton autotuning configs for MI300 and MI350 v0.9.0#765
[ROCm] Add Triton autotuning configs for MI300 and MI350 v0.9.0#765nurmukhametov wants to merge 1 commit intorocm-jaxlib-v0.9.0from
Conversation
| constexpr absl::string_view kMI300TritonConfigs = R"( | ||
| config { block_m: 32 block_n: 32 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 32 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 64 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 128 block_k: 64 split_k: 4 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 16 block_n: 16 block_k: 256 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 16 block_n: 128 block_k: 32 split_k: 16 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 256 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 256 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 128 block_k: 64 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 8 block_k: 16 split_k: 1 num_stages: 4 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 32 block_k: 16 split_k: 1 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 64 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 128 block_k: 32 split_k: 1 num_stages: 3 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 256 block_k: 64 split_k: 2 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 128 block_k: 32 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 256 block_k: 32 split_k: 1 num_stages: 1 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 256 block_k: 32 split_k: 4 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 32 block_k: 32 split_k: 8 num_stages: 1 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 32 block_k: 32 split_k: 8 num_stages: 3 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 32 block_k: 32 split_k: 8 num_stages: 5 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 32 block_k: 32 split_k: 32 num_stages: 2 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 32 block_k: 32 split_k: 8 num_stages: 2 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 32 block_k: 128 split_k: 2 num_stages: 2 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 256 block_n: 8 block_k: 32 split_k: 4 num_stages: 1 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 128 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 8 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 4 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 16 block_k: 128 split_k: 2 num_stages: 5 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 16 block_k: 128 split_k: 1 num_stages: 2 num_warps: 4 num_ctas: 1 } | ||
| config { block_m: 64 block_n: 8 block_k: 128 split_k: 1 num_stages: 1 num_warps: 2 num_ctas: 1 } | ||
| config { block_m: 32 block_n: 16 block_k: 256 split_k: 1 num_stages: 2 num_warps: 2 num_ctas: 1 } |
There was a problem hiding this comment.
optional: The first 6 configs of both kMI300TritonConfigs and kMI350TritonConfigs are identical to the 6 entries in kDefaultRocmTritonConfigs, and the first 13 configs of MI300 and MI350 are identical to each other. This is presumably intentional (superset design), but it does mean autotuning time increases from 6 to 36/43 configs. Just flagging in case any configs were accidentally duplicated or could be pruned.
Also: no unit tests are added for these new config paths. A simple test verifying GetTritonConfigsForPlatform(kMI300).size() == 36 and kMI350 == 43 would be a useful regression guard against config parsing errors.
There was a problem hiding this comment.
since num_ctas is NV feature and we always put as 1, do we really need to keep this on our config? will be feasible to rm this part completely, and just put a default 1 on our initialization part?
There was a problem hiding this comment.
num_ctas is not NVIDIA-specific. The ROCm pipeline also passes it to Triton here:
That said, I'm not 100% sure whether Triton actually does anything meaningful with
num_ctas on the AMD/ROCm backend. It might just be passed through and effectively ignored when set to 1. But regardless, it cannot be omitted from the configs because the proto default is 0, and the validation rejects that here: xla/xla/service/gpu/matmul_utils.cc
Line 811 in 7d76601
So
num_ctas: 1 must be explicitly set in all configs, even if it may be a no-op on ROCm.
There was a problem hiding this comment.
I mean this is a NV feature after all https://triton-lang.org/main/python-api/generated/triton.Config.html AMD doesn't have this blocked cluster thing and always set it as 1, so maybe we could no need to put this into config.
There was a problem hiding this comment.
I mean this is a NV feature after all https://triton-lang.org/main/python-api/generated/triton.Config.html AMD doesn't have this blocked cluster thing and always set it as 1, so maybe we could no need to put this into config.
If I remove them not changing anything else then it fails on the CHECK mentioned above:
F0000 00:00:1774953714.758075 267522 triton_configs.cc:348] Check failed: config is OK (INTERNAL: RET_CHECK failure (xla/service/gpu/matmul_utils.cc:811) proto.num_ctas() > 0 )
*** Check failure stack trace: ***
@ 0x55a642804e34 absl::lts_20250814::log_internal::LogMessage::SendToLog()
@ 0x55a642804db6 absl::lts_20250814::log_internal::LogMessage::Flush()
@ 0x55a6362117df xla::gpu::(anonymous namespace)::LoadTritonConfigs()
@ 0x55a6362113a9 xla::gpu::GetTritonConfigsForPlatform()
@ 0x55a636208d6f xla::gpu::TritonBackend::GetSupportedConfigsForDot()
@ 0x55a636208a0b xla::gpu::TritonBackend::GetSupportedConfigs()
@ 0x55a6361f66a2 xla::Autotuner::GetSupportedConfigs()
...At the current state, num_ctas: 1 has to be provided because the default is 0 and there is at least one place in the code that doesn't allow it.
There was a problem hiding this comment.
of course if we directly rm this one it will cause protobuf error....which is why I said in the beginning maybe we could re-set a default 1 to our initialization part only, despite the default is 0 in NV side.
There was a problem hiding this comment.
I don't think diverging the default value for the proto only for the ROCm platform is a good idea. We might want to change the default altogether since 0 is meaningless, but it's better to do it right in the upstream branch.
Review SummaryOverall: Clean, data-focused change that adds MI300 (gfx942) and MI350 (gfx950) Triton GEMM autotuning configs. The enum/map/config plumbing is correct and consistent. A few minor suggestions posted inline:
No blocking issues found. Reviewed with Claude |
bc56c16 to
36f58c2
Compare
Re-review SummaryChecked the latest revision against previous review feedback. The |
36f58c2 to
75a45e1
Compare
Extend the Triton GEMM autotuner with dedicated config sets for AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 40 configs), expanding beyond the generic 6-config ROCm default.
75a45e1 to
85db421
Compare
|
On MI300, Llama 3.2 1B with these updated configs improved Triton-only GEMM performance from 2193ms to 1796ms, an 18.1% speedup. On MI350, the same model improved from 1295ms to 1121ms, a 13.4% speedup. For Gemma 2 2B and Gemma 3 1B, the updated configs perform within a few percent of exhaustive search on both platforms. |
Motivation
Improve performance.
Technical Details
Extend the Triton GEMM autotuner with dedicated config sets for AMD MI300 (gfx942, 33 configs) and MI350 (gfx950, 40 configs), expanding beyond the generic 6-config ROCm default.
Submission Checklist