diff --git a/benchmarks/gemm/benchmark_gemm.py b/benchmarks/gemm/benchmark_gemm.py new file mode 100644 index 0000000000..5b83b0b8c1 --- /dev/null +++ b/benchmarks/gemm/benchmark_gemm.py @@ -0,0 +1,1609 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +"""Unified GEMM benchmark for BF16, FP8 Block, MXFP8, and NVFP4 precisions. + +Compares matrix-multiplication throughput across precisions using +Transformer Engine on NVIDIA GPUs. Supports two timing back-ends, +pre-quantized and autocast quantization modes, arbitrary MxKxN matrix +shapes, Nsight Systems profiling integration, and bar-chart output. + +Timing back-ends +---------------- +* **cuda-events** -- CUDA event pairs with a leading-kernel trick to + hide CPU dispatch latency. Measures the full GPU-side duration of + the timed loop (includes quantisation when using autocast mode). +* **profiler** -- ``torch.profiler`` (CUPTI) kernel timestamps. + Only the matched GEMM compute kernels (nvjet, xmma, cutlass, cublas) + are summed, giving a kernel-only measurement. + +Usage examples:: + + # Kernel-only timing via torch.profiler: + python benchmarks/gemm/benchmark_gemm.py --timing profiler --pre-quantize -o kernel.png + + # End-to-end timing via CUDA events: + python benchmarks/gemm/benchmark_gemm.py --timing cuda-events -o e2e.png + + # Custom non-square shapes: + python benchmarks/gemm/benchmark_gemm.py --shapes 88064x2560x10240,88064x10240x2560 + + # Nsight profiling of a single shape: + nsys profile --capture-range=cudaProfilerApi \\ + python benchmarks/gemm/benchmark_gemm.py --profile --profile-shape 4096 + + # Model config mode (derives all 12 GEMM shapes from hyperparameters): + python benchmarks/gemm/benchmark_gemm.py \\ + --hidden_size 4096 --intermediate_size 16384 \\ + --num_attention_heads 32 --num_hidden_layers 24 \\ + --micro_batch_size 31 --sequence_length 512 +""" + +import argparse +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Optional + +import matplotlib.pyplot as plt +import numpy as np +import torch +from torch.profiler import ProfilerActivity, profile + +try: + import transformer_engine.pytorch as te + import transformer_engine_torch as tex + from transformer_engine.common.recipe import ( + Float8BlockScaling, + Format, + MXFP8BlockScaling, + NVFP4BlockScaling, + ) + + TE_AVAILABLE = True +except ImportError: + TE_AVAILABLE = False + + +GEMM_KERNEL_PATTERNS = ("gemm", "nvjet", "xmma", "cutlass") + +PRECISION_COLORS = { + "BF16": "#808080", + "FP8Block": "#006400", + "MXFP8": "#4B0082", + "NVFP4": "#B22222", +} + + +# --------------------------------------------------------------------------- +# Data structures +# --------------------------------------------------------------------------- +@dataclass +class GEMMResult: + """Single GEMM benchmark measurement.""" + + tflops: float + avg_time_ms: float + shape: tuple[int, int, int] + precision: str + + +@dataclass +class ModelConfig: + """Transformer model hyperparameters for GEMM shape derivation.""" + + hidden_size: int + intermediate_size: int + num_attention_heads: int + num_hidden_layers: int + micro_batch_size: int + sequence_length: int + + +# --------------------------------------------------------------------------- +# Hardware helpers +# --------------------------------------------------------------------------- +def is_blackwell_available() -> bool: + """Return True when the current device is Blackwell (SM100+) for NVFP4 support.""" + if not torch.cuda.is_available(): + return False + major, _ = torch.cuda.get_device_capability() + return major >= 10 + + +def compute_gemm_flops(M: int, K: int, N: int) -> int: + """Theoretical FLOP count for C = A @ B: 2 * M * N * K.""" + return 2 * M * N * K + + +# --------------------------------------------------------------------------- +# torch.profiler helpers (kernel-only timing) +# --------------------------------------------------------------------------- +def _is_gemm_kernel(name: str) -> bool: + """Return True when *name* looks like a GEMM compute kernel.""" + low = name.lower() + return any(p in low for p in GEMM_KERNEL_PATTERNS) + + +def _extract_gemm_kernel_time_us( + prof_result: profile, + num_iters: int, + verbose: bool = False, +) -> float: + """Average GEMM-kernel time in microseconds from profiler events.""" + total_us = 0.0 + count = 0 + seen: dict[str, float] = {} + + for evt in prof_result.events(): + if evt.device_type == torch.autograd.DeviceType.CUDA and _is_gemm_kernel(evt.name): + total_us += evt.device_time + count += 1 + seen[evt.name] = seen.get(evt.name, 0.0) + evt.device_time + + if verbose and seen: + print(f" Matched GEMM kernels ({count} invocations):") + for kname, kus in seen.items(): + print(f" {kname}: {kus:.0f} us total") + + if count == 0: + if verbose: + print(" WARNING: No GEMM kernels found. All CUDA events:") + for evt in prof_result.events(): + if evt.device_type == torch.autograd.DeviceType.CUDA: + print(f" {evt.name}: {evt.device_time:.0f} us") + return 0.0 + + return total_us / num_iters + + +# --------------------------------------------------------------------------- +# Timing wrappers +# --------------------------------------------------------------------------- +def _time_with_profiler( + run_fn, + num_iters: int, + flops: int, + verbose: bool = False, +) -> tuple[float, float]: + """Return (tflops, avg_ms) using torch.profiler kernel extraction.""" + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + for _ in range(num_iters): + run_fn() + torch.cuda.synchronize() + + avg_us = _extract_gemm_kernel_time_us(prof, num_iters, verbose=verbose) + avg_s = avg_us / 1e6 + tflops = (flops / avg_s) / 1e12 if avg_s > 0 else 0.0 + return tflops, avg_us / 1000.0 + + +def _time_with_cuda_events( + run_fn, + num_iters: int, + flops: int, + leading_fn=None, +) -> tuple[float, float]: + """Return (tflops, avg_ms) using CUDA events with optional leading kernel.""" + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + + if leading_fn is not None: + leading_fn() + + start.record() + for _ in range(num_iters): + run_fn() + end.record() + torch.cuda.synchronize() + + avg_ms = start.elapsed_time(end) / num_iters + avg_s = avg_ms / 1000.0 + tflops = (flops / avg_s) / 1e12 if avg_s > 0 else 0.0 + return tflops, avg_ms + + +# --------------------------------------------------------------------------- +# BF16 benchmark +# --------------------------------------------------------------------------- +def benchmark_bf16( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> GEMMResult: + """Benchmark BF16 torch.matmul.""" + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + A = torch.randn(M, K, dtype=torch.bfloat16, device=device) + B = torch.randn(K, N, dtype=torch.bfloat16, device=device) + + for _ in range(num_warmup): + torch.matmul(A, B) + torch.cuda.synchronize() + + def _run(): + torch.matmul(A, B) + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + A_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + tflops, avg_ms = _time_with_cuda_events( + _run, num_iters, flops, leading_fn=lambda: torch.matmul(A_lg, B_lg) + ) + del A_lg, B_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="BF16") + + +# --------------------------------------------------------------------------- +# MXFP8 benchmarks +# --------------------------------------------------------------------------- +def benchmark_fp8( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """MXFP8 GEMM via te.Linear autocast.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + recipe = MXFP8BlockScaling(fp8_format=Format.E4M3) + + with te.autocast(enabled=True, recipe=recipe): + for _ in range(num_warmup): + linear(x) + torch.cuda.synchronize() + + def _run(): + linear(x) + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + lin_lg = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + tflops, avg_ms = _time_with_cuda_events( + _run, num_iters, flops, leading_fn=lambda: lin_lg(x_lg) + ) + del lin_lg, x_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="MXFP8") + + +def benchmark_fp8_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """Pre-quantized MXFP8 GEMM via tex.generic_gemm (raw kernel throughput).""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + quantizer = te.MXFP8Quantizer(tex.DType.kFloat8E4M3) + + # tex.generic_gemm uses column-major convention: A=(K,M), B=(K,N), + # D=(N,M) with transa=False, transb=True for a logical C(M,N) GEMM. + A_q = quantizer.quantize(torch.randn(K, M, dtype=torch.bfloat16, device=device)) + B_q = quantizer.quantize(torch.randn(K, N, dtype=torch.bfloat16, device=device)) + D = torch.empty(N, M, dtype=torch.bfloat16, device=device) + ws_size = 32 * 1024 * 1024 + ws = torch.empty(ws_size, dtype=torch.uint8, device=device) + + def _run(): + tex.generic_gemm( + A_q, + False, + B_q, + True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + for _ in range(num_warmup): + _run() + torch.cuda.synchronize() + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + A_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + B_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + D_lg = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + def _lead(): + tex.generic_gemm( + A_lg_q, + False, + B_lg_q, + True, + D_lg, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + tflops, avg_ms = _time_with_cuda_events(_run, num_iters, flops, leading_fn=_lead) + del A_lg_q, B_lg_q, D_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="MXFP8") + except Exception as e: + print(f"Warning: FP8 prequantized benchmark failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Float8 Block-Scaling benchmarks +# --------------------------------------------------------------------------- +def benchmark_fp8_block( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """FP8 GEMM with Float8BlockScaling recipe via te.Linear autocast.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + recipe = Float8BlockScaling(fp8_format=Format.E4M3) + + with te.autocast(enabled=True, recipe=recipe): + for _ in range(num_warmup): + linear(x) + torch.cuda.synchronize() + + def _run(): + linear(x) + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + lin_lg = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + tflops, avg_ms = _time_with_cuda_events( + _run, num_iters, flops, leading_fn=lambda: lin_lg(x_lg) + ) + del lin_lg, x_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="FP8Block") + + +def benchmark_fp8_block_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """Pre-quantized FP8 GEMM with Float8BlockScaling via tex.generic_gemm.""" + if not TE_AVAILABLE: + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + quantizer = te.Float8BlockQuantizer( + tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=True, + ) + + A_q = quantizer.quantize(torch.randn(K, M, dtype=torch.bfloat16, device=device)) + B_q = quantizer.quantize(torch.randn(K, N, dtype=torch.bfloat16, device=device)) + D = torch.empty(N, M, dtype=torch.bfloat16, device=device) + ws_size = 32 * 1024 * 1024 + ws = torch.empty(ws_size, dtype=torch.uint8, device=device) + + def _run(): + tex.generic_gemm( + A_q, + False, + B_q, + True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + for _ in range(num_warmup): + _run() + torch.cuda.synchronize() + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + A_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + B_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + D_lg = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + def _lead(): + tex.generic_gemm( + A_lg_q, + False, + B_lg_q, + True, + D_lg, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + tflops, avg_ms = _time_with_cuda_events(_run, num_iters, flops, leading_fn=_lead) + del A_lg_q, B_lg_q, D_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="FP8Block") + except Exception as e: + print(f"Warning: FP8 Block-Scaling prequantized benchmark failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# NVFP4 benchmarks (Blackwell SM100+ only) +# --------------------------------------------------------------------------- +def benchmark_fp4( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """NVFP4 GEMM via te.Linear autocast (Blackwell only).""" + if not TE_AVAILABLE or not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + linear = te.Linear(K, N, bias=False, params_dtype=torch.bfloat16).to(device) + x = torch.randn(M, K, dtype=torch.bfloat16, device=device) + recipe = NVFP4BlockScaling(fp4_format=Format.E2M1) + + with te.autocast(enabled=True, recipe=recipe): + for _ in range(num_warmup): + linear(x) + torch.cuda.synchronize() + + def _run(): + linear(x) + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + lin_lg = te.Linear(4096, 4096, bias=False, params_dtype=torch.bfloat16).to(device) + x_lg = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + tflops, avg_ms = _time_with_cuda_events( + _run, num_iters, flops, leading_fn=lambda: lin_lg(x_lg) + ) + del lin_lg, x_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="NVFP4") + + +def benchmark_fp4_prequantized( + M: int, + K: int, + N: int, + num_warmup: int = 10, + num_iters: int = 100, + timing: str = "cuda-events", + verbose: bool = False, +) -> Optional[GEMMResult]: + """Pre-quantized NVFP4 GEMM via tex.generic_gemm (Blackwell only).""" + if not TE_AVAILABLE or not is_blackwell_available(): + return None + + device = torch.device("cuda") + flops = compute_gemm_flops(M, K, N) + + try: + quantizer = te.NVFP4Quantizer(tex.DType.kFloat4E2M1) + + # tex.generic_gemm uses column-major convention: A=(K,M), B=(K,N), + # D=(N,M) with transa=False, transb=True for a logical C(M,N) GEMM. + A_q = quantizer.quantize(torch.randn(K, M, dtype=torch.bfloat16, device=device)) + B_q = quantizer.quantize(torch.randn(K, N, dtype=torch.bfloat16, device=device)) + D = torch.empty(N, M, dtype=torch.bfloat16, device=device) + ws_size = 32 * 1024 * 1024 + ws = torch.empty(ws_size, dtype=torch.uint8, device=device) + + def _run(): + tex.generic_gemm( + A_q, + False, + B_q, + True, + D, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + for _ in range(num_warmup): + _run() + torch.cuda.synchronize() + + if timing == "profiler": + tflops, avg_ms = _time_with_profiler(_run, num_iters, flops, verbose=verbose) + else: + A_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + B_lg_q = quantizer.quantize( + torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + ) + D_lg = torch.empty(4096, 4096, dtype=torch.bfloat16, device=device) + + def _lead(): + tex.generic_gemm( + A_lg_q, + False, + B_lg_q, + True, + D_lg, + None, + tex.DType.kBFloat16, + None, + tex.DType.kBFloat16, + False, + None, + False, + ws, + ws_size, + False, + False, + ) + + tflops, avg_ms = _time_with_cuda_events(_run, num_iters, flops, leading_fn=_lead) + del A_lg_q, B_lg_q, D_lg + + return GEMMResult(tflops=tflops, avg_time_ms=avg_ms, shape=(M, K, N), precision="NVFP4") + except Exception as e: + print(f"Warning: FP4 prequantized benchmark failed: {e}") + return None + + +# --------------------------------------------------------------------------- +# Shape helpers +# --------------------------------------------------------------------------- +def get_default_shapes() -> list[tuple[int, int, int]]: + """Default set of square matrix shapes for benchmarking.""" + return [ + (256, 256, 256), + (512, 512, 512), + (768, 768, 768), + (1024, 1024, 1024), + (1536, 1536, 1536), + (2048, 2048, 2048), + (3072, 3072, 3072), + (4096, 4096, 4096), + (6144, 6144, 6144), + (8192, 8192, 8192), + (16384, 16384, 16384), + ] + + +def parse_shapes_arg(shapes_arg: str) -> list[tuple[int, int, int]]: + """Parse ``--shapes`` into a list of (M, K, N) tuples. + + Accepts either square sizes (``1024,2048,4096``) or explicit + triplets (``8192x5120x10240,8192x10240x5120``), or a mix. + + Raises: + ValueError: On malformed input. + """ + items = [s.strip() for s in shapes_arg.split(",") if s.strip()] + if not items: + raise ValueError("Empty --shapes argument.") + + shapes: list[tuple[int, int, int]] = [] + for item in items: + if "x" in item: + parts = [p.strip() for p in item.lower().split("x")] + if len(parts) != 3: + raise ValueError(f"Invalid shape '{item}'. Expected 'MxKxN'.") + shapes.append((int(parts[0]), int(parts[1]), int(parts[2]))) + else: + size = int(item) + shapes.append((size, size, size)) + return shapes + + +def compute_gemm_shapes( + config: ModelConfig, +) -> tuple[ + list[tuple[str, int, int, int]], + list[tuple[str, int, int, int]], + list[tuple[str, int, int, int]], +]: + """Derive Fprop, Dgrad, and Wgrad GEMM shapes from a transformer model config. + + For forward Y = X @ W with shape (M, K, N): + - Dgrad: dX = dY @ Wᵀ → (M, N, K) (K and N swap) + - Wgrad: dW = Xᵀ @ dY → (K, M, N) (M moves to contraction axis) + + Returns: + (fprop_shapes, dgrad_shapes, wgrad_shapes) where each is a list of + (label, M, K, N) tuples. + """ + H = config.hidden_size + I = config.intermediate_size + M = config.micro_batch_size * config.sequence_length + + if H % config.num_attention_heads != 0: + raise ValueError( + f"hidden_size ({H}) must be divisible by " + f"num_attention_heads ({config.num_attention_heads})" + ) + + N_qkv = 3 * H + + fprop_shapes = [ + ("QKV Proj", M, H, N_qkv), + ("Attn Out", M, H, H), + ("MLP Up", M, H, I), + ("MLP Down", M, I, H), + ] + + dgrad_shapes = [ + ("QKV Proj (Dgrad)", M, N_qkv, H), + ("Attn Out (Dgrad)", M, H, H), + ("MLP Up (Dgrad)", M, I, H), + ("MLP Down (Dgrad)", M, H, I), + ] + + wgrad_shapes = [ + ("QKV Proj (Wgrad)", H, M, N_qkv), + ("Attn Out (Wgrad)", H, M, H), + ("MLP Up (Wgrad)", H, M, I), + ("MLP Down (Wgrad)", I, M, H), + ] + + return fprop_shapes, dgrad_shapes, wgrad_shapes + + +# --------------------------------------------------------------------------- +# GPU warmup +# --------------------------------------------------------------------------- +def warmup_gpu(duration_seconds: float = 5.0) -> None: + """Run sustained matmuls to stabilize GPU clocks before benchmarking.""" + print(f"Warming up GPU for {duration_seconds:.1f} seconds...") + device = torch.device("cuda") + A = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + B = torch.randn(4096, 4096, dtype=torch.bfloat16, device=device) + + torch.cuda.synchronize() + t0 = time.time() + while time.time() - t0 < duration_seconds: + for _ in range(10): + torch.matmul(A, B) + torch.cuda.synchronize() + + del A, B + torch.cuda.empty_cache() + print("GPU warmup complete.\n") + + +# --------------------------------------------------------------------------- +# Main orchestrator +# --------------------------------------------------------------------------- +def run_benchmarks( + shapes: list[tuple[int, int, int]], + num_warmup: int = 10, + num_iters: int = 100, + include_fp8: bool = True, + include_fp4: bool = True, + gpu_warmup_seconds: float = 5.0, + pre_quantize: bool = False, + timing: str = "cuda-events", + profile_shape: Optional[int] = None, +) -> dict[str, list[float]]: + """Run GEMM benchmarks for every shape and enabled precision. + + Returns: + Dict mapping precision name to a list of TFLOPS values, one per shape. + """ + results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []} + time_results: dict[str, list[float]] = {"BF16": [], "MXFP8": [], "NVFP4": []} + + has_blackwell = is_blackwell_available() + run_fp8 = include_fp8 and TE_AVAILABLE + run_fp4 = include_fp4 and TE_AVAILABLE and has_blackwell + + gpu_name = torch.cuda.get_device_name(0) + timing_label = ( + "torch.profiler (CUPTI kernel timestamps)" if timing == "profiler" else "CUDA events" + ) + + print(f"\nGEMM Benchmark on {gpu_name}") + print(f"Timing method: {timing_label}") + print(f"Warmup iterations: {num_warmup}, Timed iterations: {num_iters}") + if pre_quantize: + print("Mode: Pre-quantized inputs (raw kernel throughput)") + else: + print("Mode: Autocast (includes quantization overhead)") + if not has_blackwell and include_fp4: + print("Note: NVFP4 requires Blackwell (SM100+), skipping FP4 benchmarks") + + if profile_shape is not None: + shapes = [(profile_shape, profile_shape, profile_shape)] + print(f"\n*** PROFILING MODE: shape {profile_shape}x{profile_shape}x{profile_shape} ***") + print( + "*** Run with: nsys profile --capture-range=cudaProfilerApi python