Skip to content

gather support qk_nope_head_dim != v_head_dim#2739

Open
jiayyu wants to merge 2 commits intomainfrom
support_gather
Open

gather support qk_nope_head_dim != v_head_dim#2739
jiayyu wants to merge 2 commits intomainfrom
support_gather

Conversation

@jiayyu
Copy link
Copy Markdown
Contributor

@jiayyu jiayyu commented Apr 14, 2026

Motivation

Technical Details

Test Plan

Test Result

Submission Checklist

@jiayyu jiayyu requested review from a team and Copilot April 14, 2026 10:04
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2739 --add-label <label>

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the Triton gather_kv_b_proj path to support asymmetric head dimensions (qk_nope_head_dim != v_head_dim) by updating the Python wrapper, Triton kernel, and expanding test coverage (including bf16-weight and asymmetric-dim scenarios).

Changes:

  • Generalize gather_kv_b_proj wrapper and Triton kernel to accept separate QkNopeHeadDim and VHeadDim, including power-of-2 padding for these dims.
  • Update the reference implementation and existing tests to use (qk_nope_head_dim + v_head_dim) weight layout and v_head_dim-sized V outputs.
  • Add new test coverage for bf16-weight (no quantization scaling) and asymmetric head dims (e.g., 192/256).

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 7 comments.

File Description
aiter/ops/triton/gather_kv_b_proj.py Updates the Python wrapper to infer qk_nope_head_dim, pass VHeadDim, and provide padded dims to the kernel.
aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py Updates the Triton kernel to handle separate K/V head dims and padded loads/stores (including preshuffled weight handling).
op_tests/triton_tests/test_gather_kv_b_proj.py Updates reference/test shapes and adds new tests for bf16 weight and asymmetric head dims; introduces a shared test-data helper.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 72 to 74
KV_CDim=weight_k,
KV_PeDim=qk_nope_pe_dim - qk_nope_dim,
KV_PeDim=qk_nope_pe_dim - qk_nope_head_dim,
ChunkK=ChunkK,


def _next_pow2(n):
"""Return the smallest power of 2 >= n (Python-side helper, not a JIT function)."""
Comment on lines +27 to +31
NumNBlk: tl.constexpr = HeadDim // 16
PaddedNumNBlk: tl.constexpr = PaddedHeadDim // 16
SegKBlocks: tl.constexpr = ScaleKGranularity // 32
NumKBlkTotal: tl.constexpr = KV_CDim // 32
TotalRows: tl.constexpr = NumNBlk * SegKBlocks
PaddedTotalRows: tl.constexpr = PaddedNumNBlk * SegKBlocks
@@ -41,9 +46,7 @@ def ref_gather_kv_b_proj(
scale_granularity_k = weight_k // kv_proj_scale.shape[1]
Comment on lines +168 to +187
.cuda()
.to(torch.int32)
)
context_blocks = torch.div(
context_lens + block_size - 1, block_size, rounding_mode="trunc"
)

kv_indptr = torch.zeros((batch_size + 1,), device="cuda", dtype=torch.int32)
kv_indptr[1:] = torch.cumsum(context_blocks, dim=0)

kv_prefix_sum_context_lens = torch.zeros(
(batch_size + 1,), device="cuda", dtype=torch.int32
)
kv_prefix_sum_context_lens[1:] = torch.cumsum(context_lens, dim=0)

kv_indices = torch.zeros(kv_indptr[-1], device="cuda", dtype=torch.int32)
for b in range(batch_size):
ctx_len = int(context_blocks[b].item())
kv_indices[kv_indptr[b] : kv_indptr[b + 1]] = torch.randperm(
num_block, device="cuda"
Comment on lines +142 to +200
def _make_kv_test_data(
batch_size,
block_size,
avg_kv_length,
kv_c_dim,
kv_pe_dim,
k_buffer_type,
device="cuda",
):
"""Create common test data: k_buffer, k_scale, kv_indptr, kv_indices, etc."""
num_block = 2 * avg_kv_length // block_size

k_buffer = torch.randn(
(num_block, block_size, kv_c_dim + kv_pe_dim),
device=device,
dtype=torch.float32,
).to(k_buffer_type)
k_scale = torch.randn(1, device=device, dtype=torch.float32).abs()

var_ratio = 0.2
context_lens = (
torch.randint(
int((1 - var_ratio) * avg_kv_length),
int(((1 + var_ratio)) * avg_kv_length) + 1,
(batch_size,),
)
.cuda()
.to(torch.int32)
)
context_blocks = torch.div(
context_lens + block_size - 1, block_size, rounding_mode="trunc"
)

kv_indptr = torch.zeros((batch_size + 1,), device="cuda", dtype=torch.int32)
kv_indptr[1:] = torch.cumsum(context_blocks, dim=0)

kv_prefix_sum_context_lens = torch.zeros(
(batch_size + 1,), device="cuda", dtype=torch.int32
)
kv_prefix_sum_context_lens[1:] = torch.cumsum(context_lens, dim=0)

kv_indices = torch.zeros(kv_indptr[-1], device="cuda", dtype=torch.int32)
for b in range(batch_size):
ctx_len = int(context_blocks[b].item())
kv_indices[kv_indptr[b] : kv_indptr[b + 1]] = torch.randperm(
num_block, device="cuda"
)[:ctx_len]

return (
k_buffer,
k_scale,
kv_indptr,
kv_indices,
kv_prefix_sum_context_lens,
context_lens,
num_block,
)


Comment on lines +29 to +30

qk_nope_head_dim = weight_n // tp_k_head_num_k - v_head_dim
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants