Skip to content

Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)#2729

Open
fangche123 wants to merge 5 commits intomainfrom
chefang_mlaA16W16_nhead64_qseqlen1
Open

Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)#2729
fangche123 wants to merge 5 commits intomainfrom
chefang_mlaA16W16_nhead64_qseqlen1

Conversation

@fangche123
Copy link
Copy Markdown
Contributor

@fangche123 fangche123 commented Apr 14, 2026

  • Add new bf16/bf16 gqa64 qseqlen1 ASM kernels (.co) for non-persistent decode mode, with and without LSE output
  • Register new kernels in mla_asm.csv for gfx950
  • Fix asm_mla.cu: zero-initialize KernelArgs struct, differentiate persistent vs non-persistent arg setup (out_16_nosplit, ptr_RP),
    and add gqa_ratio=64 qseqlen=1 dispatch path
  • Fix mla.py: remove early-return bypass for bf16 nhead=[32,64] when num_kv_splits=1, ensuring reduce kernel always runs for correct fp32-to-bf16 conversion; pass final_lse to stage1 kernel when return_lse=True so the LSE kernel variant is dispatched
  • Update test_mla.py: add nhead=(64,1) test config, add -lse CLI flag for LSE validation, compute golden LSE reference in ref_masked_attention

test comand:
python3 op_tests/test_mla.py -b 256 -c 10000 -n 64,1 -d bf16 -kvd bf16 -lse
python3 op_tests/test_mla.py -d bf16 -kvd bf16 -n 64,1 -c 16384 -b 1 64 128

performance :
image
image

Copilot AI review requested due to automatic review settings April 14, 2026 06:03
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2729 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR adds new bf16 MLA decode ASM kernels for gfx950 targeting GQA ratio=64 with qseqlen=1 in non-persistent decode mode, including an LSE-producing variant, and wires up dispatch + tests to validate correctness (including LSE).

Changes:

  • Added new gfx950 .co kernels for bf16/bf16 decode with gqa_ratio=64, qseqlen=1, with and without final-LSE output, and registered them in mla_asm.csv.
  • Updated asm_mla.cu kernel-arg initialization and dispatch to differentiate persistent vs non-persistent setup and to select the new gqa64/qseqlen1 path.
  • Updated mla.py and test_mla.py to ensure reduction/conversion runs when needed and to optionally request/validate final LSE output.

Reviewed changes

Copilot reviewed 4 out of 6 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
op_tests/test_mla.py Updates torch reference to compute/return LSE and adds CLI/test coverage for validating returned LSE.
hsa/gfx950/mla/mla_asm.csv Registers the new gfx950 gqa64/qseqlen1 bf16 decode kernels (LSE + non-LSE variants).
hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co Adds new non-persistent decode kernel binary for bf16/bf16 gqa64 qseqlen1.
hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co Adds new non-persistent decode kernel binary variant that can output final LSE.
csrc/py_itfs_cu/asm_mla.cu Zero-inits kernel args, updates arg wiring for persistent vs non-persistent, and adds gqa64/qseqlen1 dispatch behavior.
aiter/mla.py Removes an incorrect early-return bypass for certain bf16 cases and passes final_lse into stage1 to dispatch LSE kernel variant.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 42 to +63
def ref_masked_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
scale: float,
dtype,
is_causal=True,
) -> torch.Tensor:
attn_weights = torch.einsum("qhd,khd->hqk", query.float(), key.float()) * scale
if is_causal:
s_q = query.shape[0]
s_k = key.shape[0]
attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype)
temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q)
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weights += attn_bias
lse = attn_weights.logsumexp(dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1)

out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float())
return out.to(dtype)
return out.to(dtype), lse
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref_masked_attention now returns a tuple (out, lse) but the return type annotation is still -> torch.Tensor. This makes the function signature misleading and can break static type checking / IDE inference. Update the annotation to reflect the tuple return (or drop the annotation).

Copilot uses AI. Check for mistakes.
and nhead in [32, 64]
)
):
return logits.view(total_s, nhead, v_head_dim), attn_lse
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the num_kv_splits == 1 early-return path, the function returns attn_lse even when return_lse=True, and it ignores the final_lse buffer that is now passed to mla_decode_stage1_asm_fwd. This makes the return value inconsistent with the main return path (return logits, final_lse) and prevents callers from getting the requested final LSE for these cases. Consider returning final_lse when return_lse is set (and optionally reshaping/squeezing it to the documented layout), otherwise return attn_lse as before.

Suggested change
return logits.view(total_s, nhead, v_head_dim), attn_lse
lse = final_lse if return_lse else attn_lse
return logits.view(total_s, nhead, v_head_dim), lse

Copilot uses AI. Check for mistakes.
@fangche123 fangche123 requested a review from a team April 14, 2026 06:13
@@ -21,6 +21,7 @@ def _fwd_kernel_stage2_asm(
Mid_O,
Mid_lse,
O,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ [ruff] <E741> reported by reviewdog 🐶
Ambiguous variable name: O

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants