Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,11 @@ def mla_decode_fwd(
attn_lse = torch.empty(
(total_s, num_kv_splits, nhead, 1), dtype=dtypes.fp32, device=device
)
final_lse = torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
final_lse = (
torch.empty((total_s, nhead), dtype=dtypes.fp32, device=device)
if return_lse
else None
)

aiter.mla_decode_stage1_asm_fwd(
q,
Expand All @@ -252,19 +256,14 @@ def mla_decode_fwd(
logits,
attn_lse,
o,
None,
final_lse,
q_scale,
kv_scale,
)

if num_kv_splits == 1 and (
q.dtype == dtypes.fp8
or (q.dtype == dtypes.bf16 and max_seqlen_q == 4)
or (
q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
and nhead in [32, 64]
)
):
return logits.view(total_s, nhead, v_head_dim), attn_lse
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the num_kv_splits == 1 early-return path, the function returns attn_lse even when return_lse=True, and it ignores the final_lse buffer that is now passed to mla_decode_stage1_asm_fwd. This makes the return value inconsistent with the main return path (return logits, final_lse) and prevents callers from getting the requested final LSE for these cases. Consider returning final_lse when return_lse is set (and optionally reshaping/squeezing it to the documented layout), otherwise return attn_lse as before.

Suggested change
return logits.view(total_s, nhead, v_head_dim), attn_lse
lse = final_lse if return_lse else attn_lse
return logits.view(total_s, nhead, v_head_dim), lse

Copilot uses AI. Check for mistakes.

Expand Down
25 changes: 16 additions & 9 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void mla_decode_stage1_asm_fwd(
int stride_Page = KV->stride(0) * KV->element_size();
uint32_t log2_page = (uint32_t)log2f(page_size);

KernelArgs args;
KernelArgs args = {};
size_t arg_size = sizeof(args);
args.ptr_R = splitData->data_ptr();
args.ptr_LSE = splitLse->data_ptr();
Expand All @@ -149,10 +149,17 @@ void mla_decode_stage1_asm_fwd(
args.s_Q_Bs = stride_Q;
args.s_Bs = stride_Page;
args.s_log2_plen = log2_page;
args.out_16_nosplit = kv_split;
args.ptr_LSEP = nullptr;
if (lse != nullptr)
{
args.ptr_LSEP = lse->data_ptr();
}

if (persistent)
{
args.out_16_nosplit = kv_split;
args.ptr_RP = output->data_ptr();

if (work_meta_data != nullptr)
{
args.ptr_STP = work_meta_data->data_ptr();
Expand All @@ -178,14 +185,10 @@ void mla_decode_stage1_asm_fwd(
}
else
{
args.out_16_nosplit = 0;
args.ptr_RP = nullptr;
args.ptr_STP = num_kv_splits_indptr->data_ptr();
}
args.ptr_RP = output->data_ptr(); //final output
args.ptr_LSEP = nullptr;
if (lse != nullptr)
{
args.ptr_LSEP = lse->data_ptr(); //final lse
}

// std::cout << "mla args" << std::endl;
// std::cout << "ptr_R: " << args.ptr_R << std::endl;
Expand Down Expand Up @@ -325,7 +328,11 @@ void mla_decode_stage1_asm_fwd(
} else if (gqa_ratio == 64){
if (q_type == "bf16" && kv_type == "bf16"){
if(!persistent){
config_max_seqlen_q = 0;
if(max_seqlen_q == 1){
config_max_seqlen_q = 1;
} else {
config_max_seqlen_q = 0;
}
sub_Q = 64;
}
} else if (q_type == "fp8" && kv_type == "fp8"){
Expand Down
Binary file not shown.
Binary file not shown.
2 changes: 2 additions & 0 deletions hsa/gfx950/mla/mla_asm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,7 @@ fp8,fp8,1,1,0,1,1,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal1E,mla_pfl
fp8,fp8,1,1,0,1,0,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E,mla_pfl_qh192_vh128_m32x8_n128x1_causal0.co
bf16,bf16,32,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m32x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,1,0,0,0,_ZN5aiter38mla_a16w16_qh64_qseqlen1_gqaratio64_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_v3.co
bf16,bf16,64,0,1,0,0,1,_ZN5aiter42mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3E,mla_a16w16_qh64_qseqlen1_gqaratio64_lse_v3.co
fp8,fp8,8,1,4,0,0,0,_ZN5aiter35mla_a8w8_qh32_qseqlen4_gqaratio8_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_ps.co
fp8,fp8,8,1,4,0,0,1,_ZN5aiter39mla_a8w8_qh32_qseqlen4_gqaratio8_lse_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_lse_ps.co
40 changes: 27 additions & 13 deletions op_tests/test_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,10 +56,11 @@ def ref_masked_attention(
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
attn_bias.to(query.dtype)
attn_weights += attn_bias
lse = attn_weights.logsumexp(dim=-1)
attn_weights = torch.softmax(attn_weights, dim=-1)

out = torch.einsum("hqk,khd->qhd", attn_weights.float(), value.float())
return out.to(dtype)
return out.to(dtype), lse
Comment on lines 42 to +63
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ref_masked_attention now returns a tuple (out, lse) but the return type annotation is still -> torch.Tensor. This makes the function signature misleading and can break static type checking / IDE inference. Update the annotation to reflect the tuple return (or drop the annotation).

Copilot uses AI. Check for mistakes.


def torch_mha_extend(
Expand All @@ -82,7 +83,7 @@ def torch_mha_extend(
q = qs[i]
k = ks[i]
v = vs[i]
o = ref_masked_attention(q, k, v, sm_scale, dtype)
o, _ = ref_masked_attention(q, k, v, sm_scale, dtype)
os.append(o)
o = torch.concat(os)
return o
Expand All @@ -106,15 +107,18 @@ def torch_mla_extend(
bs = qo_indptr.shape[0] - 1

os = []
lses = []
for i in range(bs):
kvc = kvs[i]
q = qs[i]
k = kvc
v, _ = torch.split(kvc, [kv_lora_rank, qk_rope_head_dim], dim=-1)
o = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal)
o, lse = ref_masked_attention(q, k, v, sm_scale, dtype, is_causal=is_causal)
os.append(o)
lses.append(lse)
o = torch.concat(os)
return o
lse = torch.concat(lses, dim=1).transpose(0, 1)
return o, lse


@benchmark()
Expand All @@ -132,6 +136,7 @@ def test_mla(
varlen,
decode_qlen,
split_per_batch=None,
return_lse=False,
):
ret = {}

Expand Down Expand Up @@ -236,7 +241,7 @@ def test_normal_prefill():
def test_absorb_prefill():
q = torch.randn((total_qo, nhead, qk_head_dim), dtype=torch.bfloat16)

out_ref = torch_mla_extend(
out_ref, _ = torch_mla_extend(
q,
kv_buffer,
qo_indptr,
Expand Down Expand Up @@ -326,7 +331,7 @@ def test_absorb_prefill():
q = torch.randn((total_q, nhead, qk_head_dim), dtype=torch.bfloat16)

# troch implementation
out_ref = torch_mla_extend(
out_ref, lse_ref = torch_mla_extend(
q,
kv_buffer,
qo_indptr,
Expand Down Expand Up @@ -390,19 +395,20 @@ def test_absorb_decode_bf16():
nhead_kv,
sm_scale,
num_kv_splits=split_per_batch,
return_lse=return_lse,
)

# print(f"{out_ref.view(total_q, -1)=}")
# print(f"{out_asm.view(total_q, -1)=}")
# checkAllclose(logits_ref, attn_logits,
# msg=f'attn_logits [golden vs aiter_asm]')
# checkAllclose(lse_ref, attn_lse,
# msg=f'attn_lse [golden vs aiter_asm]')
err = checkAllclose(
out_ref,
out_asm,
msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......",
)
if return_lse and attn_lse is not None:
checkAllclose(
lse_ref,
attn_lse.reshape(total_q, nhead),
msg=f"mla_decode-absorb [lse_ref vs attn_lse]: {us_asm_decode:>8.2f} us......",
)
return err, us_asm_decode

def test_absorb_decode_fp8():
Expand Down Expand Up @@ -573,7 +579,7 @@ def test_absorb_decode_fp8():
"-n",
"--nhead",
type=dtypes.str2tuple,
choices=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2), (128, 4)],
choices=[(16, 1), (16, 2), (16, 4), (64, 1), (128, 1), (128, 2), (128, 4)],
nargs="*",
const=None,
default=[(16, 1), (16, 2), (16, 4), (128, 1), (128, 2)],
Expand All @@ -595,6 +601,13 @@ def test_absorb_decode_fp8():
help="""variable kv seqlens per batch. Default: False.
--varlen # True""",
)
parser.add_argument(
"-lse",
"--return_lse",
action="store_true",
help="""return lse. Default: False.
--lse # True""",
)


args = parser.parse_args()
Expand All @@ -619,6 +632,7 @@ def test_absorb_decode_fp8():
varlen=args.varlen,
decode_qlen=decode_qlen,
split_per_batch=split_per_batch,
return_lse=args.return_lse,
)
df.append(ret)
df = pd.DataFrame(df)
Expand Down
Loading