Skip to content

MI350 mla ps mode suppport nhead128,1 128,2 128,3 128,4 64,4 64,2 32,4 through kernel hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co#2727

Open
minmengdie wants to merge 3 commits intomainfrom
mmd/dev/mla_ps_bf16_64_128
Open

MI350 mla ps mode suppport nhead128,1 128,2 128,3 128,4 64,4 64,2 32,4 through kernel hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co#2727
minmengdie wants to merge 3 commits intomainfrom
mmd/dev/mla_ps_bf16_64_128

Conversation

@minmengdie
Copy link
Copy Markdown
Contributor

@minmengdie minmengdie commented Apr 14, 2026

Motivation

This PR extends MI350 (gfx950) MLA persistent-scheduling (ps) support to additional (nhead, seqlen) combinations by routing more bf16/bf16 cases through a new precompiled gfx950 kernel and updating the related kernel selection/metadata logic.

Technical Details

Add a new gfx950 bf16/bf16 ps kernel entry to hsa/gfx950/mla/mla_asm.csv and include the new .co binary.
Broaden gfx950 “natively supported” gating in Python tests/runtime and in the C++ metadata generator using the (num_heads * max_seqlen) % 128 == 0 rule for bf16/bf16.
Adjust MLA ASM dispatch logic to select the new kernel configuration on gfx950 for eligible bf16/bf16 persistent cases.

Test Plan

python3 op_tests/test_mla_persistent.py -d bf16 -kvd bf16 -n 128,1 128,2 128,3 128,4 64,4 64,2 32,4
python3 op_tests/test_mla_persistent.py -d bf16 -kvd bf16 -n 128,1 128,2 128,3 128,4 64,4 64,2 32,4 -b 1 64 128 -c 16384

Test Result

image

Perf 对比:
image

Submission Checklist

…4 through kernel hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co
@minmengdie minmengdie requested review from a team and Copilot April 14, 2026 02:56
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2727 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends MI350 (gfx950) MLA persistent-scheduling (ps) support to additional (nhead, seqlen) combinations by routing more bf16/bf16 cases through a new precompiled gfx950 kernel and updating the related kernel selection/metadata logic.

Changes:

  • Add a new gfx950 bf16/bf16 ps kernel entry to hsa/gfx950/mla/mla_asm.csv and include the new .co binary.
  • Broaden gfx950 “natively supported” gating in Python tests/runtime and in the C++ metadata generator using the (num_heads * max_seqlen) % 128 == 0 rule for bf16/bf16.
  • Adjust MLA ASM dispatch logic to select the new kernel configuration on gfx950 for eligible bf16/bf16 persistent cases.

Reviewed changes

Copilot reviewed 7 out of 8 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
op_tests/test_mla_persistent.py Expands gfx950 native-support conditions used by the reference test path; tweaks gfx942-only exclusion for bf16 nhead=32.
hsa/gfx950/mla/mla_asm.csv Registers the new gfx950 bf16/bf16 ps kernel in the asm config CSV.
hsa/gfx950/mla/mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co Adds the new precompiled kernel binary for gfx950.
csrc/py_itfs_cu/asm_mla.cu Updates kernel selection logic for gfx950 bf16/bf16 persistent decode and tweaks gfx942-only behavior.
csrc/kernels/mla/metadata/v1_2_device.cuh Broadens native-support + validation conditions for gfx950 bf16/bf16 based on head*seqlen divisibility.
aiter/ops/flydsl/__init__.py Relaxes FlyDSL version check to ignore local version suffixes (+...).
aiter/ops/attention.py Broadens metadata sizing logic for gfx950 bf16/bf16 based on head*seqlen divisibility.
aiter/mla.py Broadens native-support gating for gfx950 bf16/bf16 based on head*seqlen divisibility.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

if (arch_id == "gfx950" && q_type == "bf16" && kv_type == "bf16" && persistent && (gqa_ratio* max_seqlen_q % 128 == 0)){
config_max_seqlen_q = 4;
config_gqa_ratio = 32;
args.s_Q_Bs = gqa_ratio;
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the gfx950 bf16/bf16 persistent fast-path, args.s_Q_Bs is overwritten with gqa_ratio. Everywhere else in this function s_Q_Bs is set to stride_Q (a byte stride), so changing it to a head ratio is very likely to break kernel addressing for this path. If the kernel needs the runtime gqa_ratio, pass it via a dedicated arg (or s_MQA/another existing scalar if that’s what the kernel expects) and keep s_Q_Bs as the Q batch stride.

Suggested change
args.s_Q_Bs = gqa_ratio;

Copilot uses AI. Check for mistakes.
Comment on lines 500 to 501
": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.")
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TORCH_CHECK error text is now outdated: this function also accepts gfx950 bf16/bf16 cases when (num_heads * max_seqlen_qo) % 128 == 0, but the message still says it "only supports #heads in [16, 64, 128]" and fp8-only special cases. Please update the message to include the new gfx950 bf16/bf16 rule so failures are actionable.

Suggested change
": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.")
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8) and q and kv are fp8, or (#head, max_seqlen_qo) = (8, 4) where q and kv "
"are fp8, or on gfx950 with non-fp8 q/kv when (#head * max_seqlen_qo) % 128 == 0.")

Copilot uses AI. Check for mistakes.
and num_head_qo == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is trailing whitespace after the closing parenthesis on this line. This will cause black to reformat the file and fail the formatting check in CI; please remove the trailing spaces (or run black on the file).

Suggested change
)
)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants