Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)#2729
Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent)#2729fangche123 wants to merge 5 commits intomainfrom
Conversation
🏷️ CI GuideRuns automatically on every PR:
Extended tests (opt-in via labels):
|
There was a problem hiding this comment.
Pull request overview
This PR adds new bf16 MLA decode ASM kernels for gfx950 targeting GQA ratio=64 with qseqlen=1 in non-persistent decode mode, including an LSE-producing variant, and wires up dispatch + tests to validate correctness (including LSE).
Changes:
- Added new gfx950
.cokernels for bf16/bf16 decode withgqa_ratio=64, qseqlen=1, with and without final-LSE output, and registered them inmla_asm.csv. - Updated
asm_mla.cukernel-arg initialization and dispatch to differentiate persistent vs non-persistent setup and to select the new gqa64/qseqlen1 path. - Updated
mla.pyandtest_mla.pyto ensure reduction/conversion runs when needed and to optionally request/validate final LSE output.
Reviewed changes
Copilot reviewed 4 out of 6 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
op_tests/test_mla.py |
Updates torch reference to compute/return LSE and adds CLI/test coverage for validating returned LSE. |
hsa/gfx950/mla/mla_asm.csv |
Registers the new gfx950 gqa64/qseqlen1 bf16 decode kernels (LSE + non-LSE variants). |
hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co |
Adds new non-persistent decode kernel binary for bf16/bf16 gqa64 qseqlen1. |
hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co |
Adds new non-persistent decode kernel binary variant that can output final LSE. |
csrc/py_itfs_cu/asm_mla.cu |
Zero-inits kernel args, updates arg wiring for persistent vs non-persistent, and adds gqa64/qseqlen1 dispatch behavior. |
aiter/mla.py |
Removes an incorrect early-return bypass for certain bf16 cases and passes final_lse into stage1 to dispatch LSE kernel variant. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| def ref_masked_attention( | ||
| query: torch.Tensor, | ||
| key: torch.Tensor, | ||
| value: torch.Tensor, | ||
| scale: float, | ||
| dtype, | ||
| is_causal=True, | ||
| ) -> torch.Tensor: | ||
| attn_weights = torch.einsum("qhd,khd->hqk", query.float(), key.float()) * scale | ||
| if is_causal: | ||
| s_q = query.shape[0] | ||
| s_k = key.shape[0] | ||
| attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype) | ||
| temp_mask = torch.ones(s_q, s_k, dtype=torch.bool).tril(diagonal=s_k - s_q) | ||
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias.to(query.dtype) | ||
| attn_weights += attn_bias | ||
| lse = attn_weights.logsumexp(dim=-1) | ||
| attn_weights = torch.softmax(attn_weights, dim=-1) | ||
|
|
||
| out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float()) | ||
| return out.to(dtype) | ||
| return out.to(dtype), lse |
There was a problem hiding this comment.
ref_masked_attention now returns a tuple (out, lse) but the return type annotation is still -> torch.Tensor. This makes the function signature misleading and can break static type checking / IDE inference. Update the annotation to reflect the tuple return (or drop the annotation).
| and nhead in [32, 64] | ||
| ) | ||
| ): | ||
| return logits.view(total_s, nhead, v_head_dim), attn_lse |
There was a problem hiding this comment.
In the num_kv_splits == 1 early-return path, the function returns attn_lse even when return_lse=True, and it ignores the final_lse buffer that is now passed to mla_decode_stage1_asm_fwd. This makes the return value inconsistent with the main return path (return logits, final_lse) and prevents callers from getting the requested final LSE for these cases. Consider returning final_lse when return_lse is set (and optionally reshaping/squeezing it to the documented layout), otherwise return attn_lse as before.
| return logits.view(total_s, nhead, v_head_dim), attn_lse | |
| lse = final_lse if return_lse else attn_lse | |
| return logits.view(total_s, nhead, v_head_dim), lse |
and add gqa_ratio=64 qseqlen=1 dispatch path
test comand:
python3 op_tests/test_mla.py -b 256 -c 10000 -n 64,1 -d bf16 -kvd bf16 -lse
python3 op_tests/test_mla.py -d bf16 -kvd bf16 -n 64,1 -c 16384 -b 1 64 128
performance :

