diff --git a/sage-attention/benchmarks/benchmark.py b/sage-attention/benchmarks/benchmark.py new file mode 100644 index 00000000..e4a3d552 --- /dev/null +++ b/sage-attention/benchmarks/benchmark.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F + +from kernels.benchmark import Benchmark + +# SageAttention is approximate (INT8 quantized QK) so element-wise allclose +# is too strict. Use cosine similarity instead (threshold 0.99). +_orig_allclose = torch.allclose +torch.allclose = lambda a, b, **_kw: ( + F.cosine_similarity(a.flatten().float().unsqueeze(0), + b.flatten().float().unsqueeze(0)).item() > 0.99 +) + + +def _ref(q, k, v, is_causal=False): + return F.scaled_dot_product_attention(q, k, v, is_causal=is_causal) + + +class SageAttentionBenchmark(Benchmark): + seed: int = 42 + + # --- base: B=2, H=32, L=1024, D=128 --- + + def setup_base(self): + B, H, L, D = 2, 32, 1024, 128 + self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.out = torch.empty_like(self.q) + + def benchmark_base(self): + self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND") + + def verify_base(self) -> torch.Tensor: + return _ref(self.q, self.k, self.v) + + # --- causal: B=2, H=32, L=1024, D=128 with causal mask --- + + def setup_causal(self): + B, H, L, D = 2, 32, 1024, 128 + self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.out = torch.empty_like(self.q) + + def benchmark_causal(self): + self.out = self.kernel.sageattn( + self.q, self.k, self.v, tensor_layout="HND", is_causal=True + ) + + def verify_causal(self) -> torch.Tensor: + return _ref(self.q, self.k, self.v, is_causal=True) + + # --- large: B=4, H=32, L=4096, D=128 --- + + def setup_large(self): + B, H, L, D = 4, 32, 4096, 128 + self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.out = torch.empty_like(self.q) + + def benchmark_large(self): + self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND") + + def verify_large(self) -> torch.Tensor: + return _ref(self.q, self.k, self.v) + + # --- d64: B=4, H=32, L=2048, D=64 (smaller head dim) --- + + def setup_d64(self): + B, H, L, D = 4, 32, 2048, 64 + self.q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=self.device) + self.out = torch.empty_like(self.q) + + def benchmark_d64(self): + self.out = self.kernel.sageattn(self.q, self.k, self.v, tensor_layout="HND") + + def verify_d64(self) -> torch.Tensor: + return _ref(self.q, self.k, self.v) diff --git a/sage-attention/scripts/readme_example.py b/sage-attention/scripts/readme_example.py new file mode 100644 index 00000000..1454d555 --- /dev/null +++ b/sage-attention/scripts/readme_example.py @@ -0,0 +1,29 @@ +# /// script +# dependencies = [ +# "numpy", +# "torch", +# "kernels", +# "triton", +# ] +# /// +import torch +from kernels import get_kernel, get_local_kernel +from pathlib import Path + +# Setup +torch.manual_seed(42) +sage_attention = get_local_kernel(Path("build"), "sage_attention") + +print(sage_attention) + +# Try calling sageattn to verify no duplicate registration error +device = "cuda" if torch.cuda.is_available() else "cpu" +if device == "cuda": + B, H, L, D = 1, 8, 256, 64 + q = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device) + k = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device) + v = torch.randn(B, H, L, D, dtype=torch.bfloat16, device=device) + out = sage_attention.sageattn(q, k, v) + print(f"sageattn output shape: {out.shape}") +else: + print("No CUDA device available - but kernel loaded without registration errors") diff --git a/sage-attention/torch-ext/sage_attention/sm100_compile.py b/sage-attention/torch-ext/sage_attention/sm100_compile.py index 4a4aa996..874225fe 100644 --- a/sage-attention/torch-ext/sage_attention/sm100_compile.py +++ b/sage-attention/torch-ext/sage_attention/sm100_compile.py @@ -28,31 +28,6 @@ # Low-level ops with torch.compile support (custom_op + register_fake) # --------------------------------------------------------------------------- -@torch.library.custom_op( - add_op_namespace_prefix("mha_fwd"), mutates_args=(), device_types="cuda" -) -def mha_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - sfq: torch.Tensor, - sfk: torch.Tensor, - sfv: torch.Tensor, - delta_s: torch.Tensor, - unpadded_k: int, - out: Optional[torch.Tensor], - softmax_scale: float, - is_causal: bool, - per_block_mean: bool, - is_bf16: bool, -) -> List[torch.Tensor]: - return ops.mha_fwd( - q, k, v, sfq, sfk, sfv, delta_s, - unpadded_k, out, softmax_scale, is_causal, - per_block_mean, is_bf16, - ) - - @torch.library.register_fake(add_op_namespace_prefix("mha_fwd")) def mha_fwd_fake( q: torch.Tensor, @@ -86,20 +61,6 @@ def mha_fwd_fake( return [fake_out, fake_lse] -@torch.library.custom_op( - add_op_namespace_prefix("scaled_fp4_quant"), - mutates_args=("output", "output_sf"), - device_types="cuda", -) -def scaled_fp4_quant( - input: torch.Tensor, - output: torch.Tensor, - output_sf: torch.Tensor, - tensor_layout: int, -) -> None: - ops.scaled_fp4_quant(input, output, output_sf, tensor_layout) - - @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant")) def scaled_fp4_quant_fake( input: torch.Tensor, @@ -110,20 +71,6 @@ def scaled_fp4_quant_fake( pass -@torch.library.custom_op( - add_op_namespace_prefix("scaled_fp4_quant_permute"), - mutates_args=("output", "output_sf"), - device_types="cuda", -) -def scaled_fp4_quant_permute( - input: torch.Tensor, - output: torch.Tensor, - output_sf: torch.Tensor, - tensor_layout: int, -) -> None: - ops.scaled_fp4_quant_permute(input, output, output_sf, tensor_layout) - - @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_permute")) def scaled_fp4_quant_permute_fake( input: torch.Tensor, @@ -134,20 +81,6 @@ def scaled_fp4_quant_permute_fake( pass -@torch.library.custom_op( - add_op_namespace_prefix("scaled_fp4_quant_trans"), - mutates_args=("output", "output_sf"), - device_types="cuda", -) -def scaled_fp4_quant_trans( - input: torch.Tensor, - output: torch.Tensor, - output_sf: torch.Tensor, - tensor_layout: int, -) -> None: - ops.scaled_fp4_quant_trans(input, output, output_sf, tensor_layout) - - @torch.library.register_fake(add_op_namespace_prefix("scaled_fp4_quant_trans")) def scaled_fp4_quant_trans_fake( input: torch.Tensor, @@ -158,6 +91,13 @@ def scaled_fp4_quant_trans_fake( pass +# Direct references to C++ ops (same pattern as sm89_compile.py) +mha_fwd = ops.mha_fwd +scaled_fp4_quant = ops.scaled_fp4_quant +scaled_fp4_quant_permute = ops.scaled_fp4_quant_permute +scaled_fp4_quant_trans = ops.scaled_fp4_quant_trans + + # --------------------------------------------------------------------------- # Triton kernel for grouped mean subtraction # ---------------------------------------------------------------------------