diff --git a/aiter/mla.py b/aiter/mla.py index cef8181a7a..5acbdcf3ba 100644 --- a/aiter/mla.py +++ b/aiter/mla.py @@ -20,7 +20,8 @@ def _fwd_kernel_stage2_asm( Mid_O, Mid_lse, - O, + O, # noqa: E741 + Final_lse, qo_indptr, kv_indptr, num_kv_splits_indptr, @@ -29,7 +30,9 @@ def _fwd_kernel_stage2_asm( stride_mid_os: tl.int64, stride_obs: tl.int64, stride_oh: tl.int64, + stride_lse_bs: tl.int64, MAYBE_FINAL_OUT: tl.constexpr, + HAS_FINAL_LSE: tl.constexpr, BATCH_NUM: tl.constexpr, BLOCK_DV: tl.constexpr, Lv: tl.constexpr, @@ -57,7 +60,6 @@ def _fwd_kernel_stage2_asm( if FINAL_OUT: input_ptr = Mid_O.to(tl.pointer_type(O.type.element_ty)) out = tl.load( - # input_ptr + offs_v + stride_mid_ob * Lv, input_ptr + Lv * (cur_qo * stride_mid_os + cur_head * stride_mid_oh) + offs_d, @@ -96,6 +98,11 @@ def _fwd_kernel_stage2_asm( acc / e_sum, mask=mask_d, ) + if HAS_FINAL_LSE: + tl.store( + Final_lse + cur_qo * stride_lse_bs + cur_head, + e_max + tl.log(e_sum), + ) @functools.lru_cache() @@ -205,6 +212,12 @@ def mla_decode_fwd( if ( nhead == 128 and q.dtype == dtypes.fp8 and kv_buffer.dtype == dtypes.fp8 ) + or ( + nhead == 64 + and q.dtype == dtypes.bf16 + and kv_buffer.dtype == dtypes.bf16 + and max_seqlen_q == 1 + ) else mgc ) @@ -232,7 +245,11 @@ def mla_decode_fwd( attn_lse = torch.empty( (total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device ) - final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + final_lse = ( + torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device) + if return_lse + else None + ) aiter.mla_decode_stage1_asm_fwd( q, @@ -252,31 +269,34 @@ def mla_decode_fwd( logits, attn_lse, o, - None, + final_lse, q_scale, kv_scale, ) if num_kv_splits == 1 and ( - q.dtype == dtypes.fp8 - or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) - or ( - q.dtype == dtypes.bf16 - and kv_buffer.dtype == dtypes.bf16 - and nhead in [32, 64] - ) + q.dtype == dtypes.fp8 or (q.dtype == dtypes.bf16 and max_seqlen_q == 4) ): - return logits.view(total_s, nhead, v_head_dim), attn_lse + lse = final_lse if return_lse else attn_lse + return logits.view(total_s, nhead, v_head_dim), lse Lv = v_head_dim BLOCK_DV = triton.next_power_of_2(Lv) grid = (bs, nhead) extra_kargs = {"waves_per_eu": 4} + has_final_lse = final_lse is not None + final_lse_buf = ( + final_lse + if has_final_lse + else torch.empty((1,), dtype=dtypes.fp32, device=device) + ) + _fwd_kernel_stage2_asm[grid]( logits, attn_lse, o, + final_lse_buf, qo_indptr, kv_indptr, num_kv_splits_indptr, @@ -285,7 +305,9 @@ def mla_decode_fwd( attn_lse.stride(1), o.stride(0), o.stride(1), + final_lse_buf.stride(0) if has_final_lse else 0, MAYBE_FINAL_OUT=MAYBE_FINAL_OUT, + HAS_FINAL_LSE=has_final_lse, BATCH_NUM=bs, BLOCK_DV=BLOCK_DV, Lv=Lv, @@ -512,11 +534,10 @@ def mla_prefill_fwd( num_kv_splits=None, # for experts only!!! ): device = q.device + num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape assert logit_cap <= 0, f"{logit_cap=} is not support yet" if sm_scale is None: sm_scale = 1.0 / (qk_head_dim**0.5) - - num_page, page_size, nhead_kv, qk_head_dim = kv_buffer.shape bs, nhead, v_head_dim = o.shape num_kv_splits = 1 diff --git a/csrc/py_itfs_cu/asm_mla.cu b/csrc/py_itfs_cu/asm_mla.cu index cdc415fc6c..60abff1c03 100644 --- a/csrc/py_itfs_cu/asm_mla.cu +++ b/csrc/py_itfs_cu/asm_mla.cu @@ -133,7 +133,7 @@ void mla_decode_stage1_asm_fwd( int stride_Page = KV->stride(0) * KV->element_size(); uint32_t log2_page = (uint32_t)log2f(page_size); - KernelArgs args; + KernelArgs args = {}; size_t arg_size = sizeof(args); args.ptr_R = splitData->data_ptr(); args.ptr_LSE = splitLse->data_ptr(); @@ -149,10 +149,17 @@ void mla_decode_stage1_asm_fwd( args.s_Q_Bs = stride_Q; args.s_Bs = stride_Page; args.s_log2_plen = log2_page; - args.out_16_nosplit = kv_split; + args.ptr_LSEP = nullptr; + if (lse != nullptr) + { + args.ptr_LSEP = lse->data_ptr(); + } if (persistent) { + args.out_16_nosplit = kv_split; + args.ptr_RP = output->data_ptr(); + if (work_meta_data != nullptr) { args.ptr_STP = work_meta_data->data_ptr(); @@ -178,14 +185,10 @@ void mla_decode_stage1_asm_fwd( } else { + args.out_16_nosplit = 0; + args.ptr_RP = nullptr; args.ptr_STP = num_kv_splits_indptr->data_ptr(); } - args.ptr_RP = output->data_ptr(); //final output - args.ptr_LSEP = nullptr; - if (lse != nullptr) - { - args.ptr_LSEP = lse->data_ptr(); //final lse - } // std::cout << "mla args" << std::endl; // std::cout << "ptr_R: " << args.ptr_R << std::endl; @@ -325,7 +328,11 @@ void mla_decode_stage1_asm_fwd( } else if (gqa_ratio == 64){ if (q_type == "bf16" && kv_type == "bf16"){ if(!persistent){ - config_max_seqlen_q = 0; + if(max_seqlen_q == 1){ + config_max_seqlen_q = 1; + } else { + config_max_seqlen_q = 0; + } sub_Q = 64; } } else if (q_type == "fp8" && kv_type == "fp8"){ diff --git a/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co b/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co new file mode 100755 index 0000000000..fac99f0220 Binary files /dev/null and b/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co differ diff --git a/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co b/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co new file mode 100755 index 0000000000..1b99687a48 Binary files /dev/null and b/hsa/gfx950/mla/mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co differ diff --git a/hsa/gfx950/mla/mla_asm.csv b/hsa/gfx950/mla/mla_asm.csv index 33559da53c..7a955177fa 100644 --- a/hsa/gfx950/mla/mla_asm.csv +++ b/hsa/gfx950/mla/mla_asm.csv @@ -26,5 +26,7 @@ fp8,fp8,1,1,0,1,1,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal1E,mla_pfl fp8,fp8,1,1,0,1,0,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E,mla_pfl_qh192_vh128_m32x8_n128x1_causal0.co bf16,bf16,32,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m32x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co bf16,bf16,64,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co +bf16,bf16,64,0,1,0,0,0,_ZN5aiter38mla_a16w16_qh64_qseqlen1_gqaratio64_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co +bf16,bf16,64,0,1,0,0,1,_ZN5aiter42mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co fp8,fp8,8,1,4,0,0,0,_ZN5aiter35mla_a8w8_qh32_qseqlen4_gqaratio8_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_ps.co fp8,fp8,8,1,4,0,0,1,_ZN5aiter39mla_a8w8_qh32_qseqlen4_gqaratio8_lse_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_lse_ps.co \ No newline at end of file diff --git a/op_tests/test_mla.py b/op_tests/test_mla.py index e81121b36f..cfc61e648f 100644 --- a/op_tests/test_mla.py +++ b/op_tests/test_mla.py @@ -56,10 +56,11 @@ def ref_masked_attention( attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weights += attn_bias + lse = attn_weights.logsumexp(dim=-1) attn_weights = torch.softmax(attn_weights, dim=-1) out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float()) - return out.to(dtype) + return out.to(dtype), lse def torch_mha_extend( @@ -82,7 +83,7 @@ def torch_mha_extend( q = qs[i] k = ks[i] v = vs[i] - o = ref_masked_attention(q, k, v, sm_scale, dtype) + o, _ = ref_masked_attention(q, k, v, sm_scale, dtype) os.append(o) o = torch.concat(os) return o @@ -106,15 +107,18 @@ def torch_mla_extend( bs = qo_indptr.shape[0] - 1 os = [] + lses = [] for i in range(bs): kvc = kvs[i] q = qs[i] k = kvc v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1) - o = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) + o, lse = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal) os.append(o) + lses.append(lse) o = torch.concat(os) - return o + lse = torch.concat(lses, dim=1).transpose(0, 1) + return o, lse @benchmark() @@ -132,6 +136,7 @@ def test_mla( varlen, decode_qlen, split_per_batch=None, + return_lse=False, ): ret = {} @@ -236,7 +241,7 @@ def test_normal_prefill(): def test_absorb_prefill(): q = torch.randn((total_qo, nhead, qk_head_dim), dtype=torch.bfloat16) - out_ref = torch_mla_extend( + out_ref, _ = torch_mla_extend( q, kv_buffer, qo_indptr, @@ -326,7 +331,7 @@ def test_absorb_prefill(): q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16) # troch implementation - out_ref = torch_mla_extend( + out_ref, lse_ref = torch_mla_extend( q, kv_buffer, qo_indptr, @@ -390,19 +395,20 @@ def test_absorb_decode_bf16(): nhead_kv, sm_scale, num_kv_splits=split_per_batch, + return_lse=return_lse, ) - # print(f"{out_ref.view(total_q, -1)=}") - # print(f"{out_asm.view(total_q, -1)=}") - # checkAllclose(logits_ref, attn_logits, - # msg=f'attn_logits [golden vs aiter_asm]') - # checkAllclose(lse_ref, attn_lse, - # msg=f'attn_lse [golden vs aiter_asm]') err = checkAllclose( out_ref, out_asm, msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......", ) + if return_lse and attn_lse is not None: + checkAllclose( + lse_ref, + attn_lse.reshape(total_q, nhead), + msg=f"mla_decode-absorb [lse_ref vs attn_lse]: {us_asm_decode:>8.2f} us......", + ) return err, us_asm_decode def test_absorb_decode_fp8(): @@ -573,7 +579,7 @@ def test_absorb_decode_fp8(): "-n", "--nhead", type=dtypes.str2tuple, - choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)], + choices=[(16, 1), (16, 2), (16, 4), (64, 1), (128, 1), (128, 2), (128, 4)], nargs="*", const=None, default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)], @@ -595,6 +601,13 @@ def test_absorb_decode_fp8(): help="""variable kv seqlens per batch. Default: False. --varlen # True""", ) +parser.add_argument( + "-lse", + "--return_lse", + action="store_true", + help="""return lse. Default: False. + --lse # True""", +) args = parser.parse_args() @@ -619,6 +632,7 @@ def test_absorb_decode_fp8(): varlen=args.varlen, decode_qlen=decode_qlen, split_per_batch=split_per_batch, + return_lse=args.return_lse, ) df.append(ret) df = pd.DataFrame(df)