-
Notifications
You must be signed in to change notification settings - Fork 281
Add bf16 MLA decode kernel for gqa_ratio=64, qseqlen=1 (non-persistent) #2729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
b32667f
9819d09
32a0542
dd36164
c64f2cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -21,6 +21,7 @@ def _fwd_kernel_stage2_asm( | |||||||
| Mid_O, | ||||||||
| Mid_lse, | ||||||||
| O, | ||||||||
| Final_lse, | ||||||||
| qo_indptr, | ||||||||
| kv_indptr, | ||||||||
| num_kv_splits_indptr, | ||||||||
|
|
@@ -29,7 +30,9 @@ def _fwd_kernel_stage2_asm( | |||||||
| stride_mid_os: tl.int64, | ||||||||
| stride_obs: tl.int64, | ||||||||
| stride_oh: tl.int64, | ||||||||
| stride_lse_bs: tl.int64, | ||||||||
| MAYBE_FINAL_OUT: tl.constexpr, | ||||||||
| HAS_FINAL_LSE: tl.constexpr, | ||||||||
| BATCH_NUM: tl.constexpr, | ||||||||
| BLOCK_DV: tl.constexpr, | ||||||||
| Lv: tl.constexpr, | ||||||||
|
|
@@ -57,7 +60,6 @@ def _fwd_kernel_stage2_asm( | |||||||
| if FINAL_OUT: | ||||||||
| input_ptr = Mid_O.to(tl.pointer_type(O.type.element_ty)) | ||||||||
| out = tl.load( | ||||||||
| # input_ptr + offs_v + stride_mid_ob * Lv, | ||||||||
| input_ptr | ||||||||
| + Lv * (cur_qo * stride_mid_os + cur_head * stride_mid_oh) | ||||||||
| + offs_d, | ||||||||
|
|
@@ -96,6 +98,11 @@ def _fwd_kernel_stage2_asm( | |||||||
| acc / e_sum, | ||||||||
| mask=mask_d, | ||||||||
| ) | ||||||||
| if HAS_FINAL_LSE: | ||||||||
| tl.store( | ||||||||
| Final_lse + cur_qo * stride_lse_bs + cur_head, | ||||||||
| e_max + tl.log(e_sum), | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
||||||||
| @functools.lru_cache() | ||||||||
|
|
@@ -205,6 +212,12 @@ def mla_decode_fwd( | |||||||
| if ( | ||||||||
| nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8 | ||||||||
| ) | ||||||||
| or ( | ||||||||
| nhead == 64 | ||||||||
| and q.dtype == dtypes.bf16 | ||||||||
| and kv_buffer.dtype == dtypes.bf16 | ||||||||
| and max_seqlen_q == 1 | ||||||||
| ) | ||||||||
| else mgc | ||||||||
| ) | ||||||||
|
|
||||||||
|
|
@@ -232,7 +245,11 @@ def mla_decode_fwd( | |||||||
| attn_lse = torch.empty( | ||||||||
| (total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device | ||||||||
| ) | ||||||||
| final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) | ||||||||
| final_lse = ( | ||||||||
| torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) | ||||||||
| if return_lse | ||||||||
| else None | ||||||||
| ) | ||||||||
|
|
||||||||
| aiter.mla_decode_stage1_asm_fwd( | ||||||||
| q, | ||||||||
|
|
@@ -252,19 +269,13 @@ def mla_decode_fwd( | |||||||
| logits, | ||||||||
| attn_lse, | ||||||||
| o, | ||||||||
| None, | ||||||||
| final_lse, | ||||||||
| q_scale, | ||||||||
| kv_scale, | ||||||||
| ) | ||||||||
|
|
||||||||
| if num_kv_splits == 1 and ( | ||||||||
| q.dtype == dtypes.fp8 | ||||||||
| or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) | ||||||||
| or ( | ||||||||
| q.dtype == dtypes.bf16 | ||||||||
| and kv_buffer.dtype == dtypes.bf16 | ||||||||
| and nhead in [32, 64] | ||||||||
| ) | ||||||||
| q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) | ||||||||
| ): | ||||||||
| return logits.view(total_s, nhead, v_head_dim), attn_lse | ||||||||
|
||||||||
| return logits.view(total_s, nhead, v_head_dim), attn_lse | |
| lse = final_lse if return_lse else attn_lse | |
| return logits.view(total_s, nhead, v_head_dim), lse |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -56,10 +56,11 @@ def ref_masked_attention( | |
| attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) | ||
| attn_bias.to(query.dtype) | ||
| attn_weights += attn_bias | ||
| lse = attn_weights.logsumexp(dim=-1) | ||
| attn_weights = torch.softmax(attn_weights, dim=-1) | ||
|
|
||
| out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float()) | ||
| return out.to(dtype) | ||
| return out.to(dtype), lse | ||
|
Comment on lines
42
to
+63
|
||
|
|
||
|
|
||
| def torch_mha_extend( | ||
|
|
@@ -82,7 +83,7 @@ def torch_mha_extend( | |
| q = qs[i] | ||
| k = ks[i] | ||
| v = vs[i] | ||
| o = ref_masked_attention(q, k, v, sm_scale, dtype) | ||
| o, _ = ref_masked_attention(q, k, v, sm_scale, dtype) | ||
| os.append(o) | ||
| o = torch.concat(os) | ||
| return o | ||
|
|
@@ -106,15 +107,18 @@ def torch_mla_extend( | |
| bs = qo_indptr.shape[0] - 1 | ||
|
|
||
| os = [] | ||
| lses = [] | ||
| for i in range(bs): | ||
| kvc = kvs[i] | ||
| q = qs[i] | ||
| k = kvc | ||
| v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1) | ||
| o = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) | ||
| o, lse = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) | ||
| os.append(o) | ||
| lses.append(lse) | ||
| o = torch.concat(os) | ||
| return o | ||
| lse = torch.concat(lses, dim=1).transpose(0, 1) | ||
| return o, lse | ||
|
|
||
|
|
||
| @benchmark() | ||
|
|
@@ -132,6 +136,7 @@ def test_mla( | |
| varlen, | ||
| decode_qlen, | ||
| split_per_batch=None, | ||
| return_lse=False, | ||
| ): | ||
| ret = {} | ||
|
|
||
|
|
@@ -236,7 +241,7 @@ def test_normal_prefill(): | |
| def test_absorb_prefill(): | ||
| q = torch.randn((total_qo, nhead, qk_head_dim), dtype=torch.bfloat16) | ||
|
|
||
| out_ref = torch_mla_extend( | ||
| out_ref, _ = torch_mla_extend( | ||
| q, | ||
| kv_buffer, | ||
| qo_indptr, | ||
|
|
@@ -326,7 +331,7 @@ def test_absorb_prefill(): | |
| q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16) | ||
|
|
||
| # troch implementation | ||
| out_ref = torch_mla_extend( | ||
| out_ref, lse_ref = torch_mla_extend( | ||
| q, | ||
| kv_buffer, | ||
| qo_indptr, | ||
|
|
@@ -390,19 +395,20 @@ def test_absorb_decode_bf16(): | |
| nhead_kv, | ||
| sm_scale, | ||
| num_kv_splits=split_per_batch, | ||
| return_lse=return_lse, | ||
| ) | ||
|
|
||
| # print(f"{out_ref.view(total_q, -1)=}") | ||
| # print(f"{out_asm.view(total_q, -1)=}") | ||
| # checkAllclose(logits_ref, attn_logits, | ||
| # msg=f'attn_logits [golden vs aiter_asm]') | ||
| # checkAllclose(lse_ref, attn_lse, | ||
| # msg=f'attn_lse [golden vs aiter_asm]') | ||
| err = checkAllclose( | ||
| out_ref, | ||
| out_asm, | ||
| msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", | ||
| ) | ||
| if return_lse and attn_lse is not None: | ||
| checkAllclose( | ||
| lse_ref, | ||
| attn_lse.reshape(total_q, nhead), | ||
| msg=f"mla_decode-absorb [lse_ref vs attn_lse]: {us_asm_decode:>8.2f} us......", | ||
| ) | ||
| return err, us_asm_decode | ||
|
|
||
| def test_absorb_decode_fp8(): | ||
|
|
@@ -573,7 +579,7 @@ def test_absorb_decode_fp8(): | |
| "-n", | ||
| "--nhead", | ||
| type=dtypes.str2tuple, | ||
| choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)], | ||
| choices=[(16, 1), (16, 2), (16, 4), (64, 1), (128, 1), (128, 2), (128, 4)], | ||
| nargs="*", | ||
| const=None, | ||
| default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)], | ||
|
|
@@ -595,6 +601,13 @@ def test_absorb_decode_fp8(): | |
| help="""variable kv seqlens per batch. Default: False. | ||
| --varlen # True""", | ||
| ) | ||
| parser.add_argument( | ||
| "-lse", | ||
| "--return_lse", | ||
| action="store_true", | ||
| help="""return lse. Default: False. | ||
| --lse # True""", | ||
| ) | ||
|
|
||
|
|
||
| args = parser.parse_args() | ||
|
|
@@ -619,6 +632,7 @@ def test_absorb_decode_fp8(): | |
| varlen=args.varlen, | ||
| decode_qlen=decode_qlen, | ||
| split_per_batch=split_per_batch, | ||
| return_lse=args.return_lse, | ||
| ) | ||
| df.append(ret) | ||
| df = pd.DataFrame(df) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous variable name:
O