-
Notifications
You must be signed in to change notification settings - Fork 281
Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent) #2729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 1 commit
b32667f
9819d09
32a0542
dd36164
c64f2cb
4957b5a
5e07396
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,10 +56,11 @@ def ref_masked_attention( | |
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias.to(query.dtype) | ||
| attn_weights += attn_bias | ||
| lse = attn_weights.logsumexp(dim=-1) | ||
| attn_weights = torch.softmax(attn_weights, dim=-1) | ||
|
|
||
| out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float()) | ||
| return out.to(dtype) | ||
| return out.to(dtype), lse | ||
|
Comment on lines
42
to
+63
|
||
|
|
||
|
|
||
| def torch_mha_extend( | ||
|
|
@@ -82,7 +83,7 @@ def torch_mha_extend( | |
| q = qs[i] | ||
| k = ks[i] | ||
| v = vs[i] | ||
| o = ref_masked_attention(q, k, v, sm_scale, dtype) | ||
| o, _ = ref_masked_attention(q, k, v, sm_scale, dtype) | ||
| os.append(o) | ||
| o = torch.concat(os) | ||
| return o | ||
|
|
@@ -106,15 +107,18 @@ def torch_mla_extend( | |
| bs = qo_indptr.shape[0] - 1 | ||
|
|
||
| os = [] | ||
| lses = [] | ||
| for i in range(bs): | ||
| kvc = kvs[i] | ||
| q = qs[i] | ||
| k = kvc | ||
| v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1) | ||
| o = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) | ||
| o, lse = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) | ||
| os.append(o) | ||
| lses.append(lse) | ||
| o = torch.concat(os) | ||
| return o | ||
| lse = torch.concat(lses, dim=1).transpose(0, 1) | ||
| return o, lse | ||
|
|
||
|
|
||
| @benchmark() | ||
|
|
@@ -132,6 +136,7 @@ def test_mla( | |
| varlen, | ||
| decode_qlen, | ||
| split_per_batch=None, | ||
| return_lse=False, | ||
| ): | ||
| ret = {} | ||
|
|
||
|
|
@@ -236,7 +241,7 @@ def test_normal_prefill(): | |
| def test_absorb_prefill(): | ||
| q = torch.randn((total_qo, nhead, qk_head_dim), dtype=torch.bfloat16) | ||
|
|
||
| out_ref = torch_mla_extend( | ||
| out_ref, _ = torch_mla_extend( | ||
| q, | ||
| kv_buffer, | ||
| qo_indptr, | ||
|
|
@@ -326,7 +331,7 @@ def test_absorb_prefill(): | |
| q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16) | ||
|
|
||
| # troch implementation | ||
| out_ref = torch_mla_extend( | ||
| out_ref, lse_ref = torch_mla_extend( | ||
| q, | ||
| kv_buffer, | ||
| qo_indptr, | ||
|
|
@@ -390,19 +395,20 @@ def test_absorb_decode_bf16(): | |
| nhead_kv, | ||
| sm_scale, | ||
| num_kv_splits=split_per_batch, | ||
| return_lse=return_lse, | ||
| ) | ||
|
|
||
| # print(f"{out_ref.view(total_q, -1)=}") | ||
| # print(f"{out_asm.view(total_q, -1)=}") | ||
| # checkAllclose(logits_ref, attn_logits, | ||
| # msg=f'attn_logits [golden vs aiter_asm]') | ||
| # checkAllclose(lse_ref, attn_lse, | ||
| # msg=f'attn_lse [golden vs aiter_asm]') | ||
| err = checkAllclose( | ||
| out_ref, | ||
| out_asm, | ||
| msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", | ||
| ) | ||
| if return_lse and attn_lse is not None: | ||
| checkAllclose( | ||
| lse_ref, | ||
| attn_lse.reshape(total_q, nhead), | ||
| msg=f"mla_decode-absorb [lse_ref vs attn_lse]: {us_asm_decode:>8.2f} us......", | ||
| ) | ||
| return err, us_asm_decode | ||
|
|
||
| def test_absorb_decode_fp8(): | ||
|
|
@@ -573,7 +579,7 @@ def test_absorb_decode_fp8(): | |
| "-n", | ||
| "--nhead", | ||
| type=dtypes.str2tuple, | ||
| choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)], | ||
| choices=[(16, 1), (16, 2), (16, 4), (64, 1), (128, 1), (128, 2), (128, 4)], | ||
| nargs="*", | ||
| const=None, | ||
| default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)], | ||
|
|
@@ -595,6 +601,13 @@ def test_absorb_decode_fp8(): | |
| help="""variable kv seqlens per batch. Default: False. | ||
| --varlen # True""", | ||
| ) | ||
| parser.add_argument( | ||
| "-lse", | ||
| "--return_lse", | ||
| action="store_true", | ||
| help="""return lse. Default: False. | ||
| --lse # True""", | ||
| ) | ||
|
|
||
|
|
||
| args = parser.parse_args() | ||
|
|
@@ -619,6 +632,7 @@ def test_absorb_decode_fp8(): | |
| varlen=args.varlen, | ||
| decode_qlen=decode_qlen, | ||
| split_per_batch=split_per_batch, | ||
| return_lse=args.return_lse, | ||
| ) | ||
| df.append(ret) | ||
| df = pd.DataFrame(df) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the
num_kv_splits == 1early-return path, the function returnsattn_lseeven whenreturn_lse=True, and it ignores thefinal_lsebuffer that is now passed tomla_decode_stage1_asm_fwd. This makes the return value inconsistent with the main return path (return logits, final_lse) and prevents callers from getting the requested final LSE for these cases. Consider returningfinal_lsewhenreturn_lseis set (and optionally reshaping/squeezing it to the documented layout), otherwise returnattn_lseas before.