Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions sage-attention/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import torch
import torch.nn.functional as F

from kernels.benchmark import Benchmark

# SageAttention is approximate (INT8 quantized QK) so element-wise allclose
# is too strict. Use cosine similarity instead (threshold 0.99).
_orig_allclose = torch.allclose
torch.allclose = lambda a, b, **_kw: (
F.cosine_similarity(a.flatten().float().unsqueeze(0),
b.flatten().float().unsqueeze(0)).item() > 0.99
)


def _ref(q, k, v, is_causal=False):
return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal)


class SageAttentionBenchmark(Benchmark):
seed: int = 42

# --- base: B=2, H=32, L=1024, D=128 ---

def setup_base(self):
B, H, L, D = 2, 32, 1024, 128
self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.out = torch.empty_like(self.q)

def benchmark_base(self):
self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND")

def verify_base(self) -> torch.Tensor:
return _ref(self.q, self.k, self.v)

# --- causal: B=2, H=32, L=1024, D=128 with causal mask ---

def setup_causal(self):
B, H, L, D = 2, 32, 1024, 128
self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.out = torch.empty_like(self.q)

def benchmark_causal(self):
self.out = self.kernel.sageattn(
self.q, self.k, self.v, tensor_layout="HND", is_causal=True
)

def verify_causal(self) -> torch.Tensor:
return _ref(self.q, self.k, self.v, is_causal=True)

# --- large: B=4, H=32, L=4096, D=128 ---

def setup_large(self):
B, H, L, D = 4, 32, 4096, 128
self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.out = torch.empty_like(self.q)

def benchmark_large(self):
self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND")

def verify_large(self) -> torch.Tensor:
return _ref(self.q, self.k, self.v)

# --- d64: B=4, H=32, L=2048, D=64 (smaller head dim) ---

def setup_d64(self):
B, H, L, D = 4, 32, 2048, 64
self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device)
self.out = torch.empty_like(self.q)

def benchmark_d64(self):
self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND")

def verify_d64(self) -> torch.Tensor:
return _ref(self.q, self.k, self.v)
29 changes: 29 additions & 0 deletions sage-attention/scripts/readme_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# /// script
# dependencies = [
# "numpy",
# "torch",
# "kernels",
# "triton",
# ]
# ///
import torch
from kernels import get_kernel, get_local_kernel
from pathlib import Path

# Setup
torch.manual_seed(42)
sage_attention = get_local_kernel(Path("build"), "sage_attention")

print(sage_attention)

# Try calling sageattn to verify no duplicate registration error
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda":
B, H, L, D = 1, 8, 256, 64
q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device)
out = sage_attention.sageattn(q, k, v)
print(f"sageattn output shape: {out.shape}")
else:
print("No CUDA device available - but kernel loaded without registration errors")
74 changes: 7 additions & 67 deletions sage-attention/torch-ext/sage_attention/sm100_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,31 +28,6 @@
# Low-level ops with torch.compile support (custom_op + register_fake)
# ---------------------------------------------------------------------------

@torch.library.custom_op(
add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda"
)
def mha_fwd(
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
sfq: torch.Tensor,
sfk: torch.Tensor,
sfv: torch.Tensor,
delta_s: torch.Tensor,
unpadded_k: int,
out: Optional[torch.Tensor],
softmax_scale: float,
is_causal: bool,
per_block_mean: bool,
is_bf16: bool,
) -> List[torch.Tensor]:
return ops.mha_fwd(
q, k, v, sfq, sfk, sfv, delta_s,
unpadded_k, out, softmax_scale, is_causal,
per_block_mean, is_bf16,
)


@torch.library.register_fake(add_op_namespace_prefix("mha_fwd"))
def mha_fwd_fake(
q: torch.Tensor,
Expand Down Expand Up @@ -86,20 +61,6 @@ def mha_fwd_fake(
return [fake_out, fake_lse]


@torch.library.custom_op(
add_op_namespace_prefix("scaled_fp4_quant"),
mutates_args=("output", "output_sf"),
device_types="cuda",
)
def scaled_fp4_quant(
input: torch.Tensor,
output: torch.Tensor,
output_sf: torch.Tensor,
tensor_layout: int,
) -> None:
ops.scaled_fp4_quant(input, output, output_sf, tensor_layout)


@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant"))
def scaled_fp4_quant_fake(
input: torch.Tensor,
Expand All @@ -110,20 +71,6 @@ def scaled_fp4_quant_fake(
pass


@torch.library.custom_op(
add_op_namespace_prefix("scaled_fp4_quant_permute"),
mutates_args=("output", "output_sf"),
device_types="cuda",
)
def scaled_fp4_quant_permute(
input: torch.Tensor,
output: torch.Tensor,
output_sf: torch.Tensor,
tensor_layout: int,
) -> None:
ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout)


@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute"))
def scaled_fp4_quant_permute_fake(
input: torch.Tensor,
Expand All @@ -134,20 +81,6 @@ def scaled_fp4_quant_permute_fake(
pass


@torch.library.custom_op(
add_op_namespace_prefix("scaled_fp4_quant_trans"),
mutates_args=("output", "output_sf"),
device_types="cuda",
)
def scaled_fp4_quant_trans(
input: torch.Tensor,
output: torch.Tensor,
output_sf: torch.Tensor,
tensor_layout: int,
) -> None:
ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout)


@torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans"))
def scaled_fp4_quant_trans_fake(
input: torch.Tensor,
Expand All @@ -158,6 +91,13 @@ def scaled_fp4_quant_trans_fake(
pass


# Direct references to C++ ops (same pattern as sm89_compile.py)
mha_fwd = ops.mha_fwd
scaled_fp4_quant = ops.scaled_fp4_quant
scaled_fp4_quant_permute = ops.scaled_fp4_quant_permute
scaled_fp4_quant_trans = ops.scaled_fp4_quant_trans


# ---------------------------------------------------------------------------
# Triton kernel for grouped mean subtraction
# ---------------------------------------------------------------------------
Expand Down
Loading