Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4293bb0
add ua3d gluon kernel for gfx12
k50112113 Mar 25, 2026
0c620b2
update clean
k50112113 Mar 26, 2026
eb4a310
fix print
k50112113 Mar 26, 2026
2c855c7
clean up
k50112113 Mar 26, 2026
653fe9b
skip main loop seq mask
k50112113 Mar 30, 2026
4b368b3
update UT
k50112113 Mar 30, 2026
ebb8bba
unroll2 triple buffer block_idx
k50112113 Mar 31, 2026
640130d
config changes
k50112113 Apr 6, 2026
da97be3
clean
k50112113 Apr 6, 2026
b710d0e
clean
k50112113 Apr 7, 2026
63307bd
clean
k50112113 Apr 8, 2026
7c3504e
resolve merge conflict and clean up
k50112113 Apr 8, 2026
47e63e4
clean
k50112113 Apr 8, 2026
671646e
update UT
k50112113 Apr 8, 2026
b1c2843
ruff
k50112113 Apr 8, 2026
d12b8b7
fix ua interface
k50112113 Apr 9, 2026
ea12693
remove q_descale assertion
k50112113 Apr 9, 2026
ee9fbec
fix bug
k50112113 Apr 10, 2026
162baee
ut
k50112113 Apr 10, 2026
c689a80
move qk_factor to prologue
k50112113 Apr 13, 2026
66fae54
Merge remote-tracking branch 'origin/main' into shaoclee/ua3d-gfx12
k50112113 Apr 13, 2026
f0355b1
add back kv_cache shuffle support for triton unified attention
k50112113 Apr 13, 2026
5128796
new shuffle style
k50112113 Apr 14, 2026
d836bb7
clean
k50112113 Apr 14, 2026
940a580
import fix
k50112113 Apr 14, 2026
0fb45ed
fix
k50112113 Apr 14, 2026
4f6ee8a
UT
k50112113 Apr 14, 2026
d7b4455
add preshuffle kv cahce to UA2D triton, inlucde block_size assertion …
k50112113 Apr 14, 2026
0646d32
skip binary search if all_decode in UA2D triton
k50112113 Apr 14, 2026
9ea6135
add repr
k50112113 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
261 changes: 201 additions & 60 deletions aiter/ops/triton/_triton_kernels/attention/unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import triton.language as tl
import torch
from aiter.ops.triton.utils.types import e4m3_dtype
from aiter.ops.triton.utils._triton.kernel_repr import make_kernel_repr

float8_info = torch.finfo(e4m3_dtype)

Expand Down Expand Up @@ -99,6 +100,8 @@ def kernel_unified_attention_2d(
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
ALL_DECODE: tl.constexpr = False, # bool
SHUFFLED_KV_CACHE: tl.constexpr = False, # bool
K_WIDTH: tl.constexpr = 0, # int
):
kv_head_idx = tl.program_id(0)
q_block_global_idx = tl.program_id(1)
Expand All @@ -107,27 +110,37 @@ def kernel_unified_attention_2d(
RCP_LN2 = 1.4426950408889634
qk_scale = scale * RCP_LN2

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
if ALL_DECODE:
seq_idx = q_block_global_idx
q_block_local_idx: tl.int32 = 0
cur_batch_query_len: tl.int32 = 1
cur_batch_in_all_start_index: tl.int32 = q_block_global_idx
else:
seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

q_block_local_idx = q_block_global_idx - q_block_start_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx

cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)

cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index

if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)
query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

offs_shfl = None
if SHUFFLED_KV_CACHE:
offs_shfl = tl.arange(0, TILE_SIZE * HEAD_SIZE_PADDED)

query_offset_0 = cur_batch_in_all_start_index + query_pos
query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv
query_offset = (
Expand Down Expand Up @@ -252,43 +265,86 @@ def kernel_unified_attention_2d(
else:
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
k_mask = None
v_mask = None
other = None
if SHUFFLED_KV_CACHE:
physical_block_idx_shfl = tl.load(
block_tables_ptr + block_table_offset + j
).to(tl.int64)
k_offset = (
physical_block_idx_shfl * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ offs_shfl
)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_offset = (
physical_block_idx_shfl * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_shfl
)
else:
physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_mask = dim_mask[None, :] & tile_mask[:, None]

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_mask = dim_mask[:, None] & tile_mask[None, :]
other = 0.0

# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
mask=k_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

K = K_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
K = (
K.reshape(
HEAD_SIZE_PADDED // K_WIDTH,
TILE_SIZE,
K_WIDTH,
)
.permute(1, 0, 2)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
.trans(1, 0)
)

# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
mask=v_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

V = V_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
V = (
V.reshape(
TILE_SIZE // K_WIDTH,
HEAD_SIZE_PADDED,
K_WIDTH,
)
.permute(0, 2, 1)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
)

# S : (BLOCK_M, TILE_SIZE)
# qk_scale = scale * RCP_LN2 (log_2 e) so that we can use exp2 later
Expand Down Expand Up @@ -381,7 +437,28 @@ def kernel_unified_attention_2d(
)


@triton.jit

kernel_unified_attention_3d_repr = make_kernel_repr(
"kernel_unified_attention_3d",
[
"num_query_heads",
"num_queries_per_kv",
"BLOCK_SIZE",
"TILE_SIZE",
"HEAD_SIZE",
"NUM_SEGMENTS_PER_SEQ",
"num_warps",
"waves_per_eu",
"num_stages",
"ALL_DECODE",
"SHUFFLED_KV_CACHE",
"IS_Q_FP8",
"IS_KV_FP8",
],
)


@triton.jit(repr=kernel_unified_attention_3d_repr)
def kernel_unified_attention_3d(
segm_output_ptr,
# [num_tokens, num_query_heads, num_segments, head_size]
Expand Down Expand Up @@ -427,8 +504,15 @@ def kernel_unified_attention_3d(
BLOCK_Q: tl.constexpr, # int
num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
num_warps: tl.constexpr, # int
waves_per_eu: tl.constexpr, # int
num_stages: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
ALL_DECODE: tl.constexpr = False, # bool
SHUFFLED_KV_CACHE: tl.constexpr = False, # bool
K_WIDTH: tl.constexpr = 0, # int
IS_Q_FP8: tl.constexpr = False, # bool
IS_KV_FP8: tl.constexpr = False, # bool
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
Expand All @@ -438,21 +522,27 @@ def kernel_unified_attention_3d(
RCP_LN2 = 1.4426950408889634
qk_scale = scale * RCP_LN2

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
if ALL_DECODE:
seq_idx = q_block_global_idx
q_block_local_idx: tl.int32 = 0
cur_batch_query_len: tl.int32 = 1
cur_batch_in_all_start_index: tl.int32 = q_block_global_idx
else:
seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

q_block_local_idx = q_block_global_idx - q_block_start_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx

cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)

cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index

if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
Expand All @@ -467,6 +557,11 @@ def kernel_unified_attention_3d(
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)

offs_shfl = None
if SHUFFLED_KV_CACHE:
offs_shfl = tl.arange(0, TILE_SIZE * HEAD_SIZE_PADDED)

query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
Expand Down Expand Up @@ -570,43 +665,86 @@ def kernel_unified_attention_3d(
else:
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
k_mask = None
v_mask = None
other = None
if SHUFFLED_KV_CACHE:
physical_block_idx_shfl = tl.load(
block_tables_ptr + block_table_offset + j
).to(tl.int64)
k_offset = (
physical_block_idx_shfl * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ offs_shfl
)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_offset = (
physical_block_idx_shfl * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_shfl
)
else:
physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_mask = dim_mask[None, :] & tile_mask[:, None]

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_mask = dim_mask[:, None] & tile_mask[None, :]
other = 0.0

# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
mask=k_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

K = K_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
K = (
K.reshape(
HEAD_SIZE_PADDED // K_WIDTH,
TILE_SIZE,
K_WIDTH,
)
.permute(1, 0, 2)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
.trans(1, 0)
)

# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
mask=v_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

V = V_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
V = (
V.reshape(
TILE_SIZE // K_WIDTH,
HEAD_SIZE_PADDED,
K_WIDTH,
)
.permute(0, 2, 1)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
)

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

Expand Down Expand Up @@ -672,6 +810,7 @@ def kernel_unified_attention_3d(
M = m_j

# acc : (BLOCK_M, HEAD_SIZE_PADDED)
# if TILE_SIZE < 32 and V.dtype.is_fp8():
acc += tl.dot(P.to(V.dtype), V)

if v_descale is not None:
Expand Down Expand Up @@ -791,4 +930,6 @@ def reduce_segments(
+ query_head_idx * output_stride_1
+ tl.arange(0, HEAD_SIZE_PADDED)
)
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
tl.store(
output_ptr + output_offset, acc.to(output_ptr.type.element_ty), mask=dim_mask
)
Loading
Loading