Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
249 changes: 150 additions & 99 deletions aiter/ops/triton/_triton_kernels/gather_kv_b_proj.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,55 @@
import triton.language as tl


def _next_pow2(n):
"""Return the smallest power of 2 >= n (Python-side helper, not a JIT function)."""
return 1 << (n - 1).bit_length()


@triton.jit
def _load_unshuffle_segment(
base_ptr,
seg_idx,
QkNopeHeadDim: tl.constexpr,
HeadDim: tl.constexpr,
PaddedHeadDim: tl.constexpr,
KV_CDim: tl.constexpr,
ScaleKGranularity: tl.constexpr,
):
"""Load one [QkNopeHeadDim, ScaleKGranularity] weight segment from a
"""Load one [PaddedHeadDim, ScaleKGranularity] weight segment from a
preshuffled weight matrix via coalesced row-major loads, then unshuffle
in registers.

Each n_blk (16 original N rows) occupies KV_CDim//32 shuffled rows.
A ScaleK segment of 128 K values spans SegKBlocks=4 consecutive rows
within each n_blk. We gather these rows across all n_blks (with row
stride KV_CDim), producing a [NumNBlk*SegKBlocks, KV_CDim] tensor,
then reshape + permute to recover [QkNopeHeadDim, ScaleKGranularity].
in registers. PaddedHeadDim is HeadDim rounded up to the next power of 2.
Out-of-range rows are zero-filled so dot-products stay correct.
"""
NumNBlk: tl.constexpr = QkNopeHeadDim // 16
NumNBlk: tl.constexpr = HeadDim // 16
PaddedNumNBlk: tl.constexpr = PaddedHeadDim // 16
SegKBlocks: tl.constexpr = ScaleKGranularity // 32
NumKBlkTotal: tl.constexpr = KV_CDim // 32
TotalRows: tl.constexpr = NumNBlk * SegKBlocks
PaddedTotalRows: tl.constexpr = PaddedNumNBlk * SegKBlocks

offs_nb = tl.arange(0, NumNBlk)
offs_nb = tl.arange(0, PaddedNumNBlk)
offs_kb = tl.arange(0, SegKBlocks)
row_indices = (
offs_nb[:, None] * NumKBlkTotal + seg_idx * SegKBlocks + offs_kb[None, :]
)
row_indices_flat = tl.reshape(row_indices, (TotalRows,))
row_indices_flat = tl.reshape(row_indices, (PaddedTotalRows,))
mask_flat = tl.reshape(
(offs_nb[:, None] < NumNBlk).broadcast_to(PaddedNumNBlk, SegKBlocks),
(PaddedTotalRows,),
)

offs_col = tl.arange(0, KV_CDim)
raw = tl.load(base_ptr + row_indices_flat[:, None] * KV_CDim + offs_col[None, :])
raw = tl.load(
base_ptr + row_indices_flat[:, None] * KV_CDim + offs_col[None, :],
mask=mask_flat[:, None],
other=0.0,
)

w = tl.reshape(
tl.permute(
tl.reshape(raw, (NumNBlk, SegKBlocks, 2, 16, 16)),
tl.reshape(raw, (PaddedNumNBlk, SegKBlocks, 2, 16, 16)),
(0, 3, 1, 2, 4),
),
(QkNopeHeadDim, ScaleKGranularity),
(PaddedHeadDim, ScaleKGranularity),
)
return w

Expand All @@ -56,22 +66,25 @@ def _triton_gather_kv_b_proj(
kv_indptr, # [batch_size + 1]
kv_indices, # [total_kv]
kv_prefix_sum_context_lens, # [batch_size + 1]
kv_proj_weight, # [tp_k_head_num * 2 * qk_nope_head_dim, kv_c_dim]
kv_proj_scale, # block: [n//128, k//128]; per-row: [tp_heads * 2 * qk_nope_head_dim]
k_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim + kv_pe_dim]
v_prefix, # [total_kv, tp_k_head_num * qk_nope_head_dim]
kv_proj_weight, # [tp_k_head_num * (qk_nope_head_dim + v_head_dim), kv_c_dim]
kv_proj_scale, # block: [n//128, k//128]; per-row: [weight_n] or [weight_n, 1]
k_prefix, # [total_kv, tp_k_head_num, qk_nope_head_dim + kv_pe_dim]
v_prefix, # [total_kv, tp_k_head_num, v_head_dim]
KBlockSize: tl.constexpr,
TpNumHeads: tl.constexpr,
QkNopeHeadDim: tl.constexpr,
VHeadDim: tl.constexpr,
KV_CDim: tl.constexpr,
KV_PeDim: tl.constexpr,
ChunkK: tl.constexpr,
PaddedK: tl.constexpr,
PaddedV: tl.constexpr,
WEIGHT_PRESHUFFLE: tl.constexpr = False,
PER_ROW_SCALE: tl.constexpr = False,
):
stride_k_buffer: tl.constexpr = KBlockSize * (KV_CDim + KV_PeDim)
stride_k_prefix: tl.constexpr = TpNumHeads * (QkNopeHeadDim + KV_PeDim)
stride_v_prefix: tl.constexpr = TpNumHeads * QkNopeHeadDim
stride_v_prefix: tl.constexpr = TpNumHeads * VHeadDim

ScaleKGranularity: tl.constexpr = 128
ScaleNGranularity: tl.constexpr = 128
Expand Down Expand Up @@ -103,112 +116,150 @@ def _triton_gather_kv_b_proj(
else:
k_scalar_scale = tl.load(k_scale)

offs_n = tl.arange(0, QkNopeHeadDim)
offs_n_k = tl.arange(0, PaddedK)
offs_n_v = tl.arange(0, PaddedV)
mask_k = offs_n_k < QkNopeHeadDim
mask_v = offs_n_v < VHeadDim
offs_k = tl.arange(0, ScaleKGranularity)
k_head_base = kv_proj_weight + pid_head * 2 * QkNopeHeadDim * KV_CDim
k_head_base = kv_proj_weight + pid_head * (QkNopeHeadDim + VHeadDim) * KV_CDim
v_head_base = k_head_base + QkNopeHeadDim * KV_CDim

if PER_ROW_SCALE:
k_row0 = pid_head * 2 * QkNopeHeadDim
k_nope_scale_vec = tl.load(kv_proj_scale + k_row0 + offs_n).to(tl.float32)
v_nope_scale_vec = tl.load(kv_proj_scale + k_row0 + QkNopeHeadDim + offs_n).to(
tl.float32
)
k_row0 = pid_head * (QkNopeHeadDim + VHeadDim)
k_nope_scale_vec = tl.load(
kv_proj_scale + k_row0 + offs_n_k, mask=mask_k, other=1.0
).to(tl.float32)
v_nope_scale_vec = tl.load(
kv_proj_scale + k_row0 + QkNopeHeadDim + offs_n_v, mask=mask_v, other=1.0
).to(tl.float32)
else:
k_nope_scale_base_offset = (
kv_proj_scale
+ pid_head
* 2
* QkNopeHeadDim
* KV_CDim
// ScaleKGranularity
// ScaleNGranularity
+ tl.arange(0, QkNopeHeadDim // ScaleNGranularity)
* (KV_CDim // ScaleKGranularity)
)
num_scale_cols: tl.constexpr = KV_CDim // ScaleKGranularity
k_abs_rows = pid_head * (QkNopeHeadDim + VHeadDim) + offs_n_k
k_scale_n_idx = k_abs_rows // ScaleNGranularity
v_abs_rows = pid_head * (QkNopeHeadDim + VHeadDim) + QkNopeHeadDim + offs_n_v
v_scale_n_idx = v_abs_rows // ScaleNGranularity

if WEIGHT_PRESHUFFLE:
# _load_unshuffle_segment returns [PaddedHeadDim, ScaleKGranularity]
# with zero-filled rows beyond HeadDim
k_nope_weight_0 = _load_unshuffle_segment(
k_head_base, 0, QkNopeHeadDim, KV_CDim, ScaleKGranularity
k_head_base, 0, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity
).to(k_type)
k_nope_weight_1 = _load_unshuffle_segment(
k_head_base, 1, QkNopeHeadDim, KV_CDim, ScaleKGranularity
k_head_base, 1, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity
).to(k_type)
k_nope_weight_2 = _load_unshuffle_segment(
k_head_base, 2, QkNopeHeadDim, KV_CDim, ScaleKGranularity
k_head_base, 2, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity
).to(k_type)
k_nope_weight_3 = _load_unshuffle_segment(
k_head_base, 3, QkNopeHeadDim, KV_CDim, ScaleKGranularity
k_head_base, 3, QkNopeHeadDim, PaddedK, KV_CDim, ScaleKGranularity
).to(k_type)

v_nope_weight_0 = _load_unshuffle_segment(
v_head_base, 0, QkNopeHeadDim, KV_CDim, ScaleKGranularity
v_head_base, 0, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity
).to(k_type)
v_nope_weight_1 = _load_unshuffle_segment(
v_head_base, 1, QkNopeHeadDim, KV_CDim, ScaleKGranularity
v_head_base, 1, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity
).to(k_type)
v_nope_weight_2 = _load_unshuffle_segment(
v_head_base, 2, QkNopeHeadDim, KV_CDim, ScaleKGranularity
v_head_base, 2, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity
).to(k_type)
v_nope_weight_3 = _load_unshuffle_segment(
v_head_base, 3, QkNopeHeadDim, KV_CDim, ScaleKGranularity
v_head_base, 3, VHeadDim, PaddedV, KV_CDim, ScaleKGranularity
).to(k_type)
else:
k_nope_weight_base_offset = (
k_head_base + offs_n[:, None] * KV_CDim + offs_k[None, :]
)
k_nope_weight_0 = tl.load(k_nope_weight_base_offset + 0 * ScaleKGranularity).to(
k_type
)
k_nope_weight_1 = tl.load(k_nope_weight_base_offset + 1 * ScaleKGranularity).to(
k_type
)
k_nope_weight_2 = tl.load(k_nope_weight_base_offset + 2 * ScaleKGranularity).to(
k_type
)
k_nope_weight_3 = tl.load(k_nope_weight_base_offset + 3 * ScaleKGranularity).to(
k_type
k_head_base + offs_n_k[:, None] * KV_CDim + offs_k[None, :]
)
k_mask_2d = mask_k[:, None]
k_nope_weight_0 = tl.load(
k_nope_weight_base_offset + 0 * ScaleKGranularity,
mask=k_mask_2d,
other=0.0,
).to(k_type)
k_nope_weight_1 = tl.load(
k_nope_weight_base_offset + 1 * ScaleKGranularity,
mask=k_mask_2d,
other=0.0,
).to(k_type)
k_nope_weight_2 = tl.load(
k_nope_weight_base_offset + 2 * ScaleKGranularity,
mask=k_mask_2d,
other=0.0,
).to(k_type)
k_nope_weight_3 = tl.load(
k_nope_weight_base_offset + 3 * ScaleKGranularity,
mask=k_mask_2d,
other=0.0,
).to(k_type)

v_nope_weight_base_offset = (
v_head_base + offs_n_v[:, None] * KV_CDim + offs_k[None, :]
)
v_mask_2d = mask_v[:, None]
v_nope_weight_0 = tl.load(
k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 0 * ScaleKGranularity
v_nope_weight_base_offset + 0 * ScaleKGranularity,
mask=v_mask_2d,
other=0.0,
).to(k_type)
v_nope_weight_1 = tl.load(
k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 1 * ScaleKGranularity
v_nope_weight_base_offset + 1 * ScaleKGranularity,
mask=v_mask_2d,
other=0.0,
).to(k_type)
v_nope_weight_2 = tl.load(
k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 2 * ScaleKGranularity
v_nope_weight_base_offset + 2 * ScaleKGranularity,
mask=v_mask_2d,
other=0.0,
).to(k_type)
v_nope_weight_3 = tl.load(
k_nope_weight_base_offset + QkNopeHeadDim * KV_CDim + 3 * ScaleKGranularity
v_nope_weight_base_offset + 3 * ScaleKGranularity,
mask=v_mask_2d,
other=0.0,
).to(k_type)

if not PER_ROW_SCALE:
k_nope_scale_0 = tl.load(k_nope_scale_base_offset + 0)
k_nope_scale_1 = tl.load(k_nope_scale_base_offset + 1)
k_nope_scale_2 = tl.load(k_nope_scale_base_offset + 2)
k_nope_scale_3 = tl.load(k_nope_scale_base_offset + 3)
k_nope_scale_0 = tl.load(
kv_proj_scale + k_scale_n_idx * num_scale_cols + 0,
mask=mask_k,
other=0.0,
).to(tl.float32)
k_nope_scale_1 = tl.load(
kv_proj_scale + k_scale_n_idx * num_scale_cols + 1,
mask=mask_k,
other=0.0,
).to(tl.float32)
k_nope_scale_2 = tl.load(
kv_proj_scale + k_scale_n_idx * num_scale_cols + 2,
mask=mask_k,
other=0.0,
).to(tl.float32)
k_nope_scale_3 = tl.load(
kv_proj_scale + k_scale_n_idx * num_scale_cols + 3,
mask=mask_k,
other=0.0,
).to(tl.float32)

v_nope_scale_0 = tl.load(
k_nope_scale_base_offset
+ QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity
+ 0
)
kv_proj_scale + v_scale_n_idx * num_scale_cols + 0,
mask=mask_v,
other=0.0,
).to(tl.float32)
v_nope_scale_1 = tl.load(
k_nope_scale_base_offset
+ QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity
+ 1
)
kv_proj_scale + v_scale_n_idx * num_scale_cols + 1,
mask=mask_v,
other=0.0,
).to(tl.float32)
v_nope_scale_2 = tl.load(
k_nope_scale_base_offset
+ QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity
+ 2
)
kv_proj_scale + v_scale_n_idx * num_scale_cols + 2,
mask=mask_v,
other=0.0,
).to(tl.float32)
v_nope_scale_3 = tl.load(
k_nope_scale_base_offset
+ QkNopeHeadDim * KV_CDim // ScaleNGranularity // ScaleKGranularity
+ 3
)
kv_proj_scale + v_scale_n_idx * num_scale_cols + 3,
mask=mask_v,
other=0.0,
).to(tl.float32)

for chunk_id in range(total_kv_chunk):
kv_block_idx = tl.load(
Expand All @@ -225,8 +276,8 @@ def _triton_gather_kv_b_proj(
+ tl.arange(0, ScaleKGranularity)[None, :]
) # [ChunkK, kv_c_dim]

accum_k = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32)
accum_v = tl.zeros((ChunkK, QkNopeHeadDim), dtype=tl.float32)
accum_k = tl.zeros((ChunkK, PaddedK), dtype=tl.float32)
accum_v = tl.zeros((ChunkK, PaddedV), dtype=tl.float32)

kv_c_data_0 = tl.load(k_buffer + kv_c_data_base_offset + 0 * ScaleKGranularity)
kv_c_data_1 = tl.load(k_buffer + kv_c_data_base_offset + 1 * ScaleKGranularity)
Expand Down Expand Up @@ -266,14 +317,14 @@ def _triton_gather_kv_b_proj(
tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_vec[None, :]
)
else:
accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0
accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0
accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1
accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1
accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2
accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2
accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3
accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3
accum_k += tl.dot(kv_c_data_0, k_nope_weight_0.T) * k_nope_scale_0[None, :]
accum_v += tl.dot(kv_c_data_0, v_nope_weight_0.T) * v_nope_scale_0[None, :]
accum_k += tl.dot(kv_c_data_1, k_nope_weight_1.T) * k_nope_scale_1[None, :]
accum_v += tl.dot(kv_c_data_1, v_nope_weight_1.T) * v_nope_scale_1[None, :]
accum_k += tl.dot(kv_c_data_2, k_nope_weight_2.T) * k_nope_scale_2[None, :]
accum_v += tl.dot(kv_c_data_2, v_nope_weight_2.T) * v_nope_scale_2[None, :]
accum_k += tl.dot(kv_c_data_3, k_nope_weight_3.T) * k_nope_scale_3[None, :]
accum_v += tl.dot(kv_c_data_3, v_nope_weight_3.T) * v_nope_scale_3[None, :]

accum_k *= k_scalar_scale
accum_v *= k_scalar_scale
Expand All @@ -297,16 +348,16 @@ def _triton_gather_kv_b_proj(
+ (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None]
* stride_k_prefix
+ pid_head * (QkNopeHeadDim + KV_PeDim)
+ tl.arange(0, QkNopeHeadDim)[None, :],
+ offs_n_k[None, :],
accum_k,
mask=context_mask[:, None],
mask=context_mask[:, None] & mask_k[None, :],
)
tl.store(
v_prefix
+ (context_start + chunk_id * ChunkK + tl.arange(0, ChunkK))[:, None]
* stride_v_prefix
+ pid_head * QkNopeHeadDim
+ tl.arange(0, QkNopeHeadDim)[None, :],
+ pid_head * VHeadDim
+ offs_n_v[None, :],
accum_v,
mask=context_mask[:, None],
mask=context_mask[:, None] & mask_v[None, :],
)
Loading
Loading