Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions aiter/mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,12 @@ def mla_decode_fwd(
and kv_buffer.dtype == dtypes.fp8
and max_seqlen_q == 2
)
or (
get_gfx() == "gfx950"
and nhead * max_seqlen_q % 128 == 0
and q.dtype == dtypes.bf16
and kv_buffer.dtype == dtypes.bf16
)
or (
get_gfx() == "gfx950"
and nhead == 8
Expand Down
6 changes: 6 additions & 0 deletions aiter/ops/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -941,6 +941,12 @@ def get_mla_metadata_info_v1(
and num_head_qo == 128
and kv_dtype == dtypes.fp8
and q_dtype == dtypes.fp8
)
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is trailing whitespace after the closing parenthesis on this line. This will cause black to reformat the file and fail the formatting check in CI; please remove the trailing spaces (or run black on the file).

Suggested change
)
)

Copilot uses AI. Check for mistakes.
or (
get_gfx() == "gfx950"
and (num_head_qo * max_seqlen_qo) % 128 == 0
and kv_dtype == dtypes.bf16
and q_dtype == dtypes.bf16
)
or (
get_gfx() == "gfx950"
Expand Down
3 changes: 2 additions & 1 deletion aiter/ops/flydsl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
"so its version cannot be validated."
) from exc

if installed_flydsl_version != _REQUIRED_FLYDSL_VERSION:
_base_version = installed_flydsl_version.split("+")[0]
if _base_version != _REQUIRED_FLYDSL_VERSION:
raise ImportError(
"Unsupported `flydsl` version: "
f"expected `{_REQUIRED_FLYDSL_VERSION}`, "
Expand Down
6 changes: 4 additions & 2 deletions csrc/kernels/mla/metadata/v1_2_device.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
((arch_id == "gfx950") && (num_heads == 32) && q_is_fp8 && kv_is_fp8 &&
(max_seqlen_qo == 4)) ||
((arch_id == "gfx950") && (num_heads == 64) && q_is_fp8 && kv_is_fp8 &&
(max_seqlen_qo == 1)) ||
(max_seqlen_qo == 1)) ||
((arch_id == "gfx950") && ((num_heads * max_seqlen_qo) % 128 == 0) && !q_is_fp8 && !kv_is_fp8) ||
((arch_id == "gfx942") && (num_heads == 128) && q_is_fp8 && kv_is_fp8);

const bool use_qseqlen_fold =
Expand Down Expand Up @@ -493,7 +494,8 @@ void get_mla_metadata_v1_2_device(const torch::Tensor& seqlens_qo_indptr, // [ba
TORCH_CHECK((num_heads == 16) || (num_heads == 128) ||
((num_heads == 32) && q_is_fp8 && kv_is_fp8) ||
((num_heads == 64) && q_is_fp8 && kv_is_fp8 && (max_seqlen_qo == 1)) ||
((num_heads == 8) && (max_seqlen_qo == 4) && q_is_fp8 && kv_is_fp8),
((num_heads == 8) && (max_seqlen_qo == 4) && q_is_fp8 && kv_is_fp8) ||
((arch_id == "gfx950") && ((num_heads * max_seqlen_qo) % 128 == 0) && !q_is_fp8 && !kv_is_fp8),
__func__,
": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.")
Comment on lines 500 to 501
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TORCH_CHECK error text is now outdated: this function also accepts gfx950 bf16/bf16 cases when (num_heads * max_seqlen_qo) % 128 == 0, but the message still says it "only supports #heads in [16, 64, 128]" and fp8-only special cases. Please update the message to include the new gfx950 bf16/bf16 rule so failures are actionable.

Suggested change
": only supports #heads in [16, 64, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8), or (#head, max_seqlen_qo) = (8, 4) where q and kv are fp8.")
": only supports #heads in [16, 128], or (#head, uni_seqlen_qo) = (16*N, 1) where "
"N is in [2, 8) and q and kv are fp8, or (#head, max_seqlen_qo) = (8, 4) where q and kv "
"are fp8, or on gfx950 with non-fp8 q/kv when (#head * max_seqlen_qo) % 128 == 0.")

Copilot uses AI. Check for mistakes.
Expand Down
14 changes: 10 additions & 4 deletions csrc/py_itfs_cu/asm_mla.cu
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ void mla_decode_stage1_asm_fwd(
args.scalar = softmax_scale;
args.s_MQA = gqa_ratio * max_seqlen_q;
args.s_kv_split = kv_split;
args.s_Q_Bs = stride_Q;
args.s_Q_Bs = stride_Q;
args.s_Bs = stride_Page;
args.s_log2_plen = log2_page;
args.out_16_nosplit = kv_split;
Expand Down Expand Up @@ -260,12 +260,13 @@ void mla_decode_stage1_asm_fwd(
int prefill = 0; // decode stage
int causal = 0;
int config_max_seqlen_q = max_seqlen_q;
int config_gqa_ratio = gqa_ratio;
int sub_Q = 128; // default value

if(gqa_ratio == 128){
config_max_seqlen_q = 0;
sub_Q = 128;
if (q_type == "bf16" && kv_type == "bf16"){
if (q_type == "bf16" && kv_type == "bf16" && arch_id == "gfx942"){
ps = 0; // not use ps
}
}
Expand Down Expand Up @@ -338,9 +339,14 @@ void mla_decode_stage1_asm_fwd(
}
}

if (arch_id == "gfx950" && q_type == "bf16" && kv_type == "bf16" && persistent && (gqa_ratio* max_seqlen_q % 128 == 0)){
config_max_seqlen_q = 4;
config_gqa_ratio = 32;
args.s_Q_Bs = gqa_ratio;
Copy link

Copilot AI Apr 14, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the gfx950 bf16/bf16 persistent fast-path, args.s_Q_Bs is overwritten with gqa_ratio. Everywhere else in this function s_Q_Bs is set to stride_Q (a byte stride), so changing it to a head ratio is very likely to break kernel addressing for this path. If the kernel needs the runtime gqa_ratio, pass it via a dedicated arg (or s_MQA/another existing scalar if that’s what the kernel expects) and keep s_Q_Bs as the Q batch stride.

Suggested change
args.s_Q_Bs = gqa_ratio;

Copilot uses AI. Check for mistakes.
}
int lse_flag = (lse != nullptr) ? 1 : 0;
std::string kernelName = get_heuristic_kernel_mla(q_type, kv_type, gqa_ratio, ps, prefill, causal, config_max_seqlen_q, arch_id, config_map, lse_flag);
std::string kernelName = get_heuristic_kernel_mla(q_type, kv_type, config_gqa_ratio, ps, prefill, causal, config_max_seqlen_q, arch_id, config_map, lse_flag);

AITER_CHECK(!kernelName.empty(), __func__, ": cannot find suitable kernel");

AiterAsmKernel* impl_ptr = nullptr;
Expand Down
Binary file not shown.
3 changes: 2 additions & 1 deletion hsa/gfx950/mla/mla_asm.csv
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ fp8,fp8,1,1,0,1,0,0,_ZN5aiter40mla_pfl_qh192_vh128_m32x8_n128x1_causal0E,mla_pfl
bf16,bf16,32,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m32x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_32mx1_16nx1_Coex0_Msk1_QH16.co
bf16,bf16,64,0,0,0,0,0,_ZN5aiter39mla_a16w16_qh16_m64x1_n16x1_coex0_mask1E,MLA_A16W16_1TG_4W_64mx1_16nx1_Coex0_Msk1_QH16.co
fp8,fp8,8,1,4,0,0,0,_ZN5aiter35mla_a8w8_qh32_qseqlen4_gqaratio8_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_ps.co
fp8,fp8,8,1,4,0,0,1,_ZN5aiter39mla_a8w8_qh32_qseqlen4_gqaratio8_lse_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_lse_ps.co
fp8,fp8,8,1,4,0,0,1,_ZN5aiter39mla_a8w8_qh32_qseqlen4_gqaratio8_lse_psE,mla_a8w8_qh32_qseqlen4_gqaratio8_lse_ps.co
bf16,bf16,32,1,4,0,0,0,_ZN5aiter38mla_a16w16_qh32_qseqlen4_gqaratio32_psE,mla_a16w16_qh32_qseqlen4_gqaratio32_ps.co
9 changes: 7 additions & 2 deletions op_tests/test_mla_persistent.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def line_for(name: str, pick) -> str:
def check_support(dtype, kv_dtype, nhead):
if dtype == dtypes.fp8 and kv_dtype == dtypes.bf16:
return False
if dtype == dtypes.bf16 and nhead == 32:
if dtype == dtypes.bf16 and nhead == 32 and get_gfx() == "gfx942":
return False
return True

Expand Down Expand Up @@ -481,6 +481,12 @@ def torch_mla_extend_split_kv(
and is_fp8_kvc
and max_seqlen_q == 1
)
or (
get_gfx() == "gfx950"
and (nheads * max_seqlen_q) % 128 == 0
and not is_fp8_q
and not is_fp8_kvc
)
):
# Natively support cases
pass
Expand Down Expand Up @@ -1319,7 +1325,6 @@ def test_absorb_decode_bf16():
out_asm,
msg=f"mla_decode-absorb [golden vs aiter_asm]: {us_asm_decode:>8.2f} us......",
)

if not non_persistent_mode:
partial_out_ref, partial_lse_ref, split_out_ref, split_lse_ref = (
torch_mla_split_kv_and_reduce(
Expand Down
Loading