Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions op_tests/op_benchmarks/triton/bench_unified_attention.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import itertools
import sys

import argparse
import torch
import triton

Expand Down Expand Up @@ -369,7 +369,7 @@ def parse_int_or_list(value):
return int(value)


def parse_args():
def parse_args(args: list[str] | None = None) -> argparse.Namespace:
parser = get_parser(kernel_name="Unified Attention")

parser.add_argument("-b", type=int, default=0)
Expand Down Expand Up @@ -445,11 +445,11 @@ def parse_args():
help="Sliding window size (default: disabled)",
)

return parser.parse_args()
return parser.parse_args(args=args)


def main():
args = parse_args()
def main(args: list[str] | None = None) -> None:
args = parse_args(args=args)

if args.fp8 and args.fp8_kv:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@
from op_tests.op_benchmarks.triton.bench_rope import main as bench_rope_main
from op_tests.op_benchmarks.triton.bench_mha import main as bench_mha_main
from op_tests.op_benchmarks.triton.bench_mla_decode import main as bench_mla_main
from op_tests.op_benchmarks.triton.bench_unified_attention import (
main as bench_unified_attention_main,
)


def disable_aiter_logs() -> None:
Expand All @@ -74,6 +77,7 @@ def disable_aiter_logs() -> None:
"rope": bench_rope_main,
"mha": bench_mha_main,
"mla": bench_mla_main,
"unified_attention": bench_unified_attention_main,
}

# Shape dicts from model_shapes.json (int, str values)
Expand Down Expand Up @@ -377,13 +381,17 @@ def build_args(self) -> str:
# bshd (batch-seq-head-dim) - fwd
# thd (token-head-dim) - fwd_varlen with equal seq lens
fn = "fwd_varlen" if self._mha_layout == "thd" else "fwd"
sliding_window_left = shape.get("sliding_window_left", -1)
sink = shape.get("sink", None)
args = (
f"-fn {fn} -causal true --dtype bf16 -b {self._batch_size} "
f"-hq {shape['hq']} -hk {shape['hkv']} -sq {self._seq_len} -sk {self._seq_len} "
f"-d {shape['dqk']} -dv {shape['dv']} -metric {self._metric}"
f"-d {shape['dqk']} -dv {shape['dv']} --window-size-left {sliding_window_left} -metric {self._metric}"
)
if fn == "fwd_varlen":
args += " -equal_seqlens"
if sink:
args += " -sink"
return args

def parse_stdout(self, stdout: str) -> float:
Expand Down Expand Up @@ -416,6 +424,8 @@ def build_result_row(self, bench_result: float | str) -> ResultRow:
"dqk": shape["dqk"],
"dv": shape["dv"],
"mha_layout": self._mha_layout,
"sink": shape.get("sink", "false"),
"sliding_window": shape.get("sliding_window_left", None),
self._metric: bench_result,
}

Expand Down Expand Up @@ -465,13 +475,65 @@ def build_result_row(self, bench_result: float | str) -> ResultRow:
}


class UnifiedAttnKernelHandler(KernelHandler):
"""Handler for unified attention benchmarks (bench_unified_attention.py)."""

def get_tp_shapes(self, shapes: list[ShapeDict]) -> list[ShapeDict]:
result = []
for shape in shapes:
s = shape.copy()
self._shard_keys(s, ["hq", "hkv"])
result.append(s)
return result

def build_args(self) -> str:
shape = self._shape
block_size = int(shape.get("block_size", 0))
sliding_window = shape.get("sliding_window", None)
args = (
f"-b {self._batch_size} -hq {shape['hq']} -hk {shape['hkv']} "
f"-d {shape['dqk']} -dv {shape['dv']} -sq {self._seq_len} -sk {self._seq_len} "
f"-block_size {block_size} --metric {self._metric}"
)
if sliding_window is not None:
args += f" -sliding_window {sliding_window}"
return args

def parse_stdout(self, stdout: str) -> float:
lines = [line.split() for line in stdout.strip().splitlines() if line.strip()]
if len(lines) < 3:
raise ValueError(
f"Unexpected unified_attention bench output: expected at least 3 lines, got {len(lines)}"
)
data = lines[2]
if len(data) < 10:
raise ValueError(f"Unexpected unified_attention bench data line: {data!r}")
return float(data[9])

def build_result_row(self, bench_result: float | str) -> ResultRow:
shape = self._shape
return {
"Model": self._model,
"Kernel": self._kernel,
"batch_size": self._batch_size,
"seq_len": self._seq_len,
"hq": shape["hq"],
"hkv": shape["hkv"],
"dqk": shape["dqk"],
"dv": shape["dv"],
"sliding_window": shape.get("sliding_window", None),
self._metric: bench_result,
}


_HANDLER_RULES: list[tuple[Callable[[str], bool], type[KernelHandler]]] = [
(lambda k: "moe" in k, MoeKernelHandler),
(lambda k: "gemm" in k and "moe" not in k, GemmKernelHandler),
(lambda k: k == "rmsnorm", RmsnormKernelHandler),
(lambda k: k == "rope", RopeKernelHandler),
(lambda k: k == "mha", MhaKernelHandler),
(lambda k: k == "mla", MlaKernelHandler),
(lambda k: k == "unified_attention", UnifiedAttnKernelHandler),
]


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@
"dqk": 128,
"dv": 128
}
],
"unified_attention": [
{
"hq": 128,
"hkv": 8,
"dqk": 128,
"dv": 128
}
]
},
"Llama3 70B": {
Expand Down Expand Up @@ -135,6 +143,14 @@
"dqk": 128,
"dv": 128
}
],
"unified_attention": [
{
"hq": 64,
"hkv": 8,
"dqk": 128,
"dv": 128
}
]
},
"Llama3 8B": {
Expand Down Expand Up @@ -204,6 +220,14 @@
"dqk": 128,
"dv": 128
}
],
"unified_attention": [
{
"hq": 32,
"hkv": 8,
"dqk": 128,
"dv": 128
}
]
},
"GPT-OSS 120B": {
Expand Down Expand Up @@ -256,6 +280,15 @@
"sink": true,
"sliding_window_left": 128
}
],
"unified_attention": [
{
"hq": 64,
"hkv": 8,
"dqk": 64,
"dv": 64,
"sliding_window": 128
}
]
},
"DeepSeek-R1": {
Expand Down Expand Up @@ -534,6 +567,14 @@
"dqk": 88,
"dv": 88
}
],
"unified_attention": [
{
"hq": 40,
"hkv": 8,
"dqk": 128,
"dv": 128
}
]
},
"Qwen3-235B-A22B": {
Expand Down Expand Up @@ -589,6 +630,14 @@
"dqk": 128,
"dv": 128
}
],
"unified_attention": [
{
"hq": 64,
"hkv": 4,
"dqk": 128,
"dv": 128
}
]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ For a new kernel to be run by the script, the corresponding benchmark must be im
| `K` | int | Inner dimension K. |
| `TP_dim` | string \| null | `"N"`, `"K"`, or `null`. |

### Batched GEMM kernels (`batched_gemm_a8w8`, `batched_gemm_afp4wfp4`)
### Batched GEMM kernels (`batched_gemm_a8w8`, `batched_gemm_afp4wfp4`, `batched_gemm_a16wfp4`)

| Field | Type | Description |
|----------|--------|-------------|
Expand Down Expand Up @@ -74,6 +74,8 @@ For a new kernel to be run by the script, the corresponding benchmark must be im
| `dqk` | int | Query/key head dimension. |
| `dv` | int | Value head dimension. |
| `comment`| string | Optional label (e.g. `"Prefill"`, `"Text"`, `"Vision"`). |
| `sink` | bool | Optional. When true, enables attention sink in the MHA benchmark. |
| `sliding_window_left` | int | Optional. Left sliding-window size (`--window-size-left`); omit for no window. |

### MLA — Multi-head Latent Attention (`mla`)

Expand All @@ -85,6 +87,17 @@ For a new kernel to be run by the script, the corresponding benchmark must be im
| `dv` | int | Value head dimension. |
| `comment`| string | Optional label (e.g. `"Decode"`). |

### Unified Attention (`unified_attention`)

| Field | Type | Description |
|-------------|--------|-------------|
| `hq` | int | Number of query heads. |
| `hkv` | int | Number of key/value heads. |
| `dqk` | int | Query/key head dimension. |
| `dv` | int | Value head dimension. |
| `block_size`| int | Optional. KV cache block size. |
| `sliding_window` | int \| null | Optional. Sliding-window size for unified attention; omit when not used. |

## Example

```json
Expand All @@ -109,6 +122,9 @@ For a new kernel to be run by the script, the corresponding benchmark must be im
],
"mha": [
{ "hq": 128, "hkv": 8, "dqk": 128, "dv": 128 }
],
"unified_attention": [
{ "hq": 128, "hkv": 8, "dqk": 128, "dv": 128 }
]
}
}
Expand Down
Loading