Skip to content
Open
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
4293bb0
add ua3d gluon kernel for gfx12
k50112113 Mar 25, 2026
0c620b2
update clean
k50112113 Mar 26, 2026
eb4a310
fix print
k50112113 Mar 26, 2026
2c855c7
clean up
k50112113 Mar 26, 2026
653fe9b
skip main loop seq mask
k50112113 Mar 30, 2026
4b368b3
update UT
k50112113 Mar 30, 2026
ebb8bba
unroll2 triple buffer block_idx
k50112113 Mar 31, 2026
640130d
config changes
k50112113 Apr 6, 2026
da97be3
clean
k50112113 Apr 6, 2026
b710d0e
clean
k50112113 Apr 7, 2026
63307bd
clean
k50112113 Apr 8, 2026
7c3504e
resolve merge conflict and clean up
k50112113 Apr 8, 2026
47e63e4
clean
k50112113 Apr 8, 2026
671646e
update UT
k50112113 Apr 8, 2026
b1c2843
ruff
k50112113 Apr 8, 2026
d12b8b7
fix ua interface
k50112113 Apr 9, 2026
ea12693
remove q_descale assertion
k50112113 Apr 9, 2026
ee9fbec
fix bug
k50112113 Apr 10, 2026
162baee
ut
k50112113 Apr 10, 2026
c689a80
move qk_factor to prologue
k50112113 Apr 13, 2026
66fae54
Merge remote-tracking branch 'origin/main' into shaoclee/ua3d-gfx12
k50112113 Apr 13, 2026
f0355b1
add back kv_cache shuffle support for triton unified attention
k50112113 Apr 13, 2026
5128796
new shuffle style
k50112113 Apr 14, 2026
d836bb7
clean
k50112113 Apr 14, 2026
940a580
import fix
k50112113 Apr 14, 2026
0fb45ed
fix
k50112113 Apr 14, 2026
4f6ee8a
UT
k50112113 Apr 14, 2026
d7b4455
add preshuffle kv cahce to UA2D triton, inlucde block_size assertion …
k50112113 Apr 14, 2026
0646d32
skip binary search if all_decode in UA2D triton
k50112113 Apr 14, 2026
9ea6135
add repr
k50112113 Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 94 additions & 30 deletions aiter/ops/triton/_triton_kernels/attention/unified_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,30 +429,44 @@ def kernel_unified_attention_3d(
BLOCK_M: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
ALL_DECODE: tl.constexpr = False, # bool
SHUFFLED_KV_CACHE: tl.constexpr = False, # bool
K_WIDTH: tl.constexpr = 0, # int
):
q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1)
segm_idx = tl.program_id(2)

if SHUFFLED_KV_CACHE:
tl.static_assert(
TILE_SIZE == BLOCK_SIZE,
"TILE_SIZE must be equal to BLOCK_SIZE if SHUFFLED_KV_CACHE is True",
)

# needed to use exp2 (exp2 -> exp conversion)
RCP_LN2 = 1.4426950408889634
qk_scale = scale * RCP_LN2

seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)
if ALL_DECODE:
seq_idx = q_block_global_idx
q_block_local_idx: tl.int32 = 0
cur_batch_query_len: tl.int32 = 1
cur_batch_in_all_start_index: tl.int32 = q_block_global_idx
else:
seq_idx = find_seq_idx(
query_start_len_ptr, q_block_global_idx, num_seqs, BLOCK_Q, True
)

q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx
q_block_start_idx = tl.load(query_start_len_ptr + seq_idx) // BLOCK_Q + seq_idx

q_block_local_idx = q_block_global_idx - q_block_start_idx
q_block_local_idx = q_block_global_idx - q_block_start_idx

cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)
cur_batch_in_all_start_index = tl.load(query_start_len_ptr + seq_idx)
cur_batch_in_all_stop_index = tl.load(query_start_len_ptr + seq_idx + 1)

cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index
cur_batch_query_len = cur_batch_in_all_stop_index - cur_batch_in_all_start_index

if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return
if q_block_local_idx * BLOCK_Q >= cur_batch_query_len:
return

# sequence len for this particular sequence
seq_len = tl.load(seq_lens_ptr + seq_idx)
Expand All @@ -467,6 +481,11 @@ def kernel_unified_attention_3d(
offs_m = tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, HEAD_SIZE_PADDED)
offs_t = tl.arange(0, TILE_SIZE)

offs_shfl = None
if SHUFFLED_KV_CACHE:
offs_shfl = tl.arange(0, TILE_SIZE * HEAD_SIZE_PADDED)

query_pos = q_block_local_idx * BLOCK_Q + offs_m // num_queries_per_kv

query_offset_0 = cur_batch_in_all_start_index + query_pos
Expand Down Expand Up @@ -570,43 +589,86 @@ def kernel_unified_attention_3d(
else:
tile_mask = seq_offset < max_seq_prefix_len

physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)
k_mask = None
v_mask = None
other = None
if SHUFFLED_KV_CACHE:
physical_block_idx_shfl = tl.load(
block_tables_ptr + block_table_offset + j
).to(tl.int64)
k_offset = (
physical_block_idx_shfl * stride_k_cache_0
+ kv_head_idx * stride_k_cache_1
+ offs_shfl
)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_offset = (
physical_block_idx_shfl * stride_v_cache_0
+ kv_head_idx * stride_v_cache_1
+ offs_shfl
)
else:
physical_block_idx = tl.load(
block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE
).to(tl.int64)

v_offset = (
physical_block_idx[:, None] * stride_v_cache_0
+ kv_head_idx * stride_v_cache_2
+ offs_d[None, :] * stride_v_cache_3
+ (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1
)
v_mask = dim_mask[None, :] & tile_mask[:, None]

k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_offset = (
physical_block_idx[None, :] * stride_k_cache_0
+ kv_head_idx * stride_k_cache_2
+ offs_d[:, None] * stride_k_cache_3
+ (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1
)
k_mask = dim_mask[:, None] & tile_mask[None, :]
other = 0.0

# K : (HEAD_SIZE, TILE_SIZE)
K_load = tl.load(
key_cache_ptr + k_offset,
mask=dim_mask[:, None] & tile_mask[None, :],
other=0.0,
mask=k_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

K = K_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
K = (
K.reshape(
HEAD_SIZE_PADDED // K_WIDTH,
TILE_SIZE,
K_WIDTH,
)
.permute(1, 0, 2)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
.trans(1, 0)
)

# V : (TILE_SIZE, HEAD_SIZE)
V_load = tl.load(
value_cache_ptr + v_offset,
mask=dim_mask[None, :] & tile_mask[:, None],
other=0.0,
mask=v_mask,
other=other,
cache_modifier=KV_cache_modifier,
)

V = V_load.to(Q.dtype)
if SHUFFLED_KV_CACHE:
V = (
V.reshape(
TILE_SIZE // K_WIDTH,
HEAD_SIZE_PADDED,
K_WIDTH,
)
.permute(0, 2, 1)
.reshape(TILE_SIZE, HEAD_SIZE_PADDED)
)

seq_mask = seq_offset[None, :] < context_len + query_pos[:, None] + 1

Expand Down Expand Up @@ -791,4 +853,6 @@ def reduce_segments(
+ query_head_idx * output_stride_1
+ tl.arange(0, HEAD_SIZE_PADDED)
)
tl.store(output_ptr + output_offset, acc, mask=dim_mask)
tl.store(
output_ptr + output_offset, acc.to(output_ptr.type.element_ty), mask=dim_mask
)
Loading
Loading