MI350 mla ps mode suppport nhead128,1 128,2 128,3 128,4 64,4 64,2 32,4 through kernel hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co#2727
Conversation
…4 through kernel hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR extends MI350 (gfx950) MLA persistent-scheduling (ps) support to additional (nhead, seqlen) combinations by routing more bf16/bf16 cases through a new precompiled gfx950 kernel and updating the related kernel selection/metadata logic.
Changes:
- Add a new gfx950 bf16/bf16 ps kernel entry to
hsa/gfx950/mla/mla_asm.csvand include the new.cobinary. - Broaden gfx950 “natively supported” gating in Python tests/runtime and in the C++ metadata generator using the
(num_heads * max_seqlen) % 128 == 0rule for bf16/bf16. - Adjust MLA ASM dispatch logic to select the new kernel configuration on gfx950 for eligible bf16/bf16 persistent cases.
Reviewed changes
Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/test_mla_persistent.py |
Expands gfx950 native-support conditions used by the reference test path; tweaks gfx942-only exclusion for bf16 nhead=32. |
hsa/gfx950/mla/mla_asm.csv |
Registers the new gfx950 bf16/bf16 ps kernel in the asm config CSV. |
hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co |
Adds the new precompiled kernel binary for gfx950. |
csrc/py_itfs_cu/asm_mla.cu |
Updates kernel selection logic for gfx950 bf16/bf16 persistent decode and tweaks gfx942-only behavior. |
csrc/kernels/mla/metadata/v1_2_device.cuh |
Broadens native-support + validation conditions for gfx950 bf16/bf16 based on head*seqlen divisibility. |
aiter/ops/flydsl/__init__.py |
Relaxes FlyDSL version check to ignore local version suffixes (+...). |
aiter/ops/attention.py |
Broadens metadata sizing logic for gfx950 bf16/bf16 based on head*seqlen divisibility. |
aiter/mla.py |
Broadens native-support gating for gfx950 bf16/bf16 based on head*seqlen divisibility. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| if (arch_id == "gfx950" && q_type == "bf16" && kv_type == "bf16" && persistent && (gqa_ratio* max_seqlen_q % 128 == 0)){ | ||
| config_max_seqlen_q = 4; | ||
| config_gqa_ratio = 32; | ||
| args.s_Q_Bs = gqa_ratio; |
There was a problem hiding this comment.
In the gfx950 bf16/bf16 persistent fast-path, args.s_Q_Bs is overwritten with gqa_ratio. Everywhere else in this function s_Q_Bs is set to stride_Q (a byte stride), so changing it to a head ratio is very likely to break kernel addressing for this path. If the kernel needs the runtime gqa_ratio, pass it via a dedicated arg (or s_MQA/another existing scalar if that’s what the kernel expects) and keep s_Q_Bs as the Q batch stride.
| args.s_Q_Bs = gqa_ratio; |
| ": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " | ||
| "N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.") |
There was a problem hiding this comment.
The TORCH_CHECK error text is now outdated: this function also accepts gfx950 bf16/bf16 cases when (num_heads * max_seqlen_qo) % 128 == 0, but the message still says it "only supports #heads in [16, 64, 128]" and fp8-only special cases. Please update the message to include the new gfx950 bf16/bf16 rule so failures are actionable.
| ": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " | |
| "N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.") | |
| ": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where " | |
| "N is in [2, 8) and q and kv are fp8, or (#head, max_seqlen_qo) = (8, 4) where q and kv " | |
| "are fp8, or on gfx950 with non-fp8 q/kv when (#head * max_seqlen_qo) % 128 == 0.") |
aiter/ops/attention.py
Outdated
| and num_head_qo == 128 | ||
| and kv_dtype == dtypes.fp8 | ||
| and q_dtype == dtypes.fp8 | ||
| ) |
There was a problem hiding this comment.
There is trailing whitespace after the closing parenthesis on this line. This will cause black to reformat the file and fail the formatting check in CI; please remove the trailing spaces (or run black on the file).
| ) | |
| ) |
Motivation
This PR extends MI350 (gfx950) MLA persistent-scheduling (ps) support to additional (nhead, seqlen) combinations by routing more bf16/bf16 cases through a new precompiled gfx950 kernel and updating the related kernel selection/metadata logic.
Technical Details
Add a new gfx950 bf16/bf16 ps kernel entry to hsa/gfx950/mla/mla_asm.csv and include the new .co binary.
Broaden gfx950 “natively supported” gating in Python tests/runtime and in the C++ metadata generator using the (num_heads * max_seqlen) % 128 == 0 rule for bf16/bf16.
Adjust MLA ASM dispatch logic to select the new kernel configuration on gfx950 for eligible bf16/bf16 persistent cases.
Test Plan
python3 op_tests/test_mla_persistent.py -d bf16 -kvd bf16 -n 128,1 128,2 128,3 128,4 64,4 64,2 32,4
python3 op_tests/test_mla_persistent.py -d bf16 -kvd bf16 -n 128,1 128,2 128,3 128,4 64,4 64,2 32,4 -b 1 64 128 -c 16384
Test Result
Perf 对比:

Submission Checklist