diff --git a/scripts/check_kernel_freshness.py b/scripts/check_kernel_freshness.py index 3da2d22c..2be76b46 100644 --- a/scripts/check_kernel_freshness.py +++ b/scripts/check_kernel_freshness.py @@ -36,6 +36,7 @@ "rwkv": "https://github.com/BlinkDL/RWKV-LM", "scattermoe": "https://github.com/shawntan/scattermoe", "sgl-flash-attn3": "https://github.com/sgl-project/sgl-flash-attn", + "sonic-moe": "https://github.com/Dao-AILab/sonic-moe", "tinygrad-rms": "https://github.com/tinygrad/tinygrad", "trimul-gpumode": "https://github.com/davidberard98/gpumode-trimul", "triton-kernels": "https://github.com/triton-lang/triton.git", diff --git a/sonic-moe/README.md b/sonic-moe/README.md new file mode 100644 index 00000000..48eefe78 --- /dev/null +++ b/sonic-moe/README.md @@ -0,0 +1,57 @@ +--- +tags: +- kernels +- moe +- cuda +--- + +# SonicMoE + +Accelerating Mixture-of-Experts with IO and Tile-aware Optimizations. + +**SonicMoE** is a blazing-fast MoE implementation optimized for NVIDIA Hopper and Blackwell GPUs. +It leverages CuTe-DSL and Triton to deliver state-of-the-art performance through IO-aware optimizations. + +- Paper: [arXiv:2512.14080](https://arxiv.org/abs/2512.14080) +- Source: [Dao-AILab/sonic-moe](https://github.com/Dao-AILab/sonic-moe) + +## Requirements + +- NVIDIA Hopper GPUs (H100, H200) or Blackwell GPUs (GB200, B200) +- PyTorch >= 2.7 +- CUDA 12.9+ +- Python 3.12+ + +## Usage + +```python +import torch +from kernels import get_kernel + +sonicmoe = get_kernel("kernels-community/sonic-moe") + +from sonicmoe import MoE, KernelBackendMoE +from sonicmoe.enums import ActivationType + +moe = MoE( + num_experts=128, + num_experts_per_tok=8, + hidden_size=4096, + intermediate_size=1536, + activation_function=ActivationType.SWIGLU, + add_bias=False, + std=0.02, +).to(device="cuda", dtype=torch.bfloat16) + +x = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16) +output, aux_loss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe) +``` + +## Vendored Dependencies + +This kernel vendors [QuACK](https://github.com/Dao-AILab/quack) (quack-kernels) for CuTe-DSL +GEMM infrastructure. The vendored copy is located at `torch-ext/sonicmoe/quack/`. + +## License + +Apache-2.0 (SonicMoE and QuACK are both Apache-2.0 licensed) diff --git a/sonic-moe/build.toml b/sonic-moe/build.toml new file mode 100644 index 00000000..19bf9a03 --- /dev/null +++ b/sonic-moe/build.toml @@ -0,0 +1,14 @@ +[general] +name = "sonicmoe" +license = "Apache-2.0" +backends = ["cuda"] +version = 1 + +[general.hub] +repo-id = "kernels-community/sonic-moe" + +[general.cuda] +minver = "12.8" +python-depends = ["nvidia-cutlass-dsl"] + +[kernel] diff --git a/sonic-moe/flake.lock b/sonic-moe/flake.lock new file mode 100644 index 00000000..9032b93f --- /dev/null +++ b/sonic-moe/flake.lock @@ -0,0 +1,117 @@ +{ + "nodes": { + "flake-compat": { + "locked": { + "lastModified": 1767039857, + "narHash": "sha256-vNpUSpF5Nuw8xvDLj2KCwwksIbjua2LZCqhV1LNRDns=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "5edf11c44bc78a0d334f6334cdaf7d60d732daab", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "kernel-builder": { + "inputs": { + "flake-compat": "flake-compat", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs", + "rust-overlay": "rust-overlay" + }, + "locked": { + "lastModified": 1775239449, + "narHash": "sha256-p5gimOf9ErDZb9OwMbDJuGWqFdnu9JYmqCQkHNvmK18=", + "owner": "huggingface", + "repo": "kernels", + "rev": "205a554bb3421a1aa67e941fe94d4418fa8bf0be", + "type": "github" + }, + "original": { + "owner": "huggingface", + "repo": "kernels", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1774935083, + "narHash": "sha256-Mh6bLcYAcENBAZk3RoMPMFCGGMZmfaGMERE4siZOgP4=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "2f4fd5e1abf9bac8c1d22750c701a7a5e6b524c6", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-unstable-small", + "repo": "nixpkgs", + "type": "github" + } + }, + "root": { + "inputs": { + "kernel-builder": "kernel-builder" + } + }, + "rust-overlay": { + "inputs": { + "nixpkgs": [ + "kernel-builder", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1774926780, + "narHash": "sha256-JMdDYn0F+swYBILlpCeHDbCSyzqkeSGNxZ/Q5J584jM=", + "owner": "oxalica", + "repo": "rust-overlay", + "rev": "962a0934d0e32f42d1b5e49186f9595f9b178d2d", + "type": "github" + }, + "original": { + "owner": "oxalica", + "repo": "rust-overlay", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/sonic-moe/flake.nix b/sonic-moe/flake.nix new file mode 100644 index 00000000..728dbec3 --- /dev/null +++ b/sonic-moe/flake.nix @@ -0,0 +1,17 @@ +{ + description = "Flake for sonic-moe kernels"; + + inputs = { + kernel-builder.url = "github:huggingface/kernels"; + }; + + outputs = + { + self, + kernel-builder, + }: + kernel-builder.lib.genKernelFlakeOutputs { + inherit self; + path = ./.; + }; +} diff --git a/sonic-moe/tests/test_moe.py b/sonic-moe/tests/test_moe.py new file mode 100644 index 00000000..274f85be --- /dev/null +++ b/sonic-moe/tests/test_moe.py @@ -0,0 +1,120 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import pytest +import random + +import numpy as np +import torch +from torch.testing import assert_close + +if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9: + pytest.skip("SonicMoE requires Hopper (SM90) or newer GPU", allow_module_level=True) + +try: + from sonicmoe import KernelBackendMoE, MoE, enable_quack_gemm + from sonicmoe.enums import ActivationType +except ImportError as e: + pytest.skip(f"sonicmoe dependencies not available: {e}", allow_module_level=True) + +_SEED = 42 + + +def set_seed(seed: int) -> None: + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +PROBLEM_SHAPES = [ + (8192, 768, 256, 128, 8), + (8192, 768, 512, 64, 4), + (8192, 4096, 512, 128, 8), + (8192, 4096, 1024, 64, 4), +] + + +@pytest.mark.parametrize("problem_shape", PROBLEM_SHAPES) +@pytest.mark.parametrize("add_bias", [False, True]) +def test_moe_forward_backward(problem_shape, add_bias): + device = torch.device("cuda") + dtype = torch.bfloat16 + + set_seed(_SEED) + + T, H, I, E, K = problem_shape + with torch.device(device): + moe = MoE( + num_experts=E, + num_experts_per_tok=K, + hidden_size=H, + intermediate_size=I, + activation_function=ActivationType.SWIGLU, + add_bias=add_bias, + std=0.02, + ).to(dtype=dtype) + + if add_bias: + torch.nn.init.normal_(moe.c_fc.bias, 0, 0.01) + torch.nn.init.normal_(moe.c_proj.bias, 0, 0.01) + + torch.cuda.empty_cache() + x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True) + x_kernel = x_torch.clone().detach().requires_grad_() + + with torch.autocast(device.type, torch.float32): + y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0] + y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0] + + assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2) + + dy = 0.02 * torch.randn(T, H, device=device, dtype=dtype) + W = list(moe.parameters()) + + with torch.autocast(device.type, torch.float32): + kernel_grads = torch.autograd.grad(y_kernel, [x_kernel] + W, grad_outputs=dy, retain_graph=True) + torch_grads = torch.autograd.grad(y_torch, [x_torch] + W, grad_outputs=dy, retain_graph=True) + + for tg, kg in zip(torch_grads, kernel_grads): + assert_close(kg.float(), tg.float(), atol=2e-2, rtol=2e-2) + + torch.cuda.empty_cache() + + +@pytest.mark.parametrize( + "problem_shape", + [(8192, 4096, 512, 128, 8)], +) +def test_moe_quack_gemm(problem_shape): + device = torch.device("cuda") + dtype = torch.bfloat16 + + set_seed(_SEED) + + T, H, I, E, K = problem_shape + with torch.device(device): + moe = MoE( + num_experts=E, + num_experts_per_tok=K, + hidden_size=H, + intermediate_size=I, + activation_function=ActivationType.SWIGLU, + add_bias=False, + std=0.02, + ).to(dtype=dtype) + + torch.cuda.empty_cache() + x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True) + x_kernel = x_torch.clone().detach().requires_grad_() + + with torch.autocast(device.type, torch.float32): + with enable_quack_gemm(True): + y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0] + + y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0] + + assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2) + + torch.cuda.empty_cache() diff --git a/sonic-moe/torch-ext/sonicmoe/__init__.py b/sonic-moe/torch-ext/sonicmoe/__init__.py new file mode 100644 index 00000000..36c5204d --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/__init__.py @@ -0,0 +1,36 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from functools import lru_cache + +__version__ = "0.1.1" + +from .enums import KernelBackendMoE + +_LAZY_IMPORTS = { + "MoE": ".moe", + "enable_quack_gemm": ".functional", + "moe_general_routing_inputs": ".functional", + "moe_TC_softmax_topk_layer": ".functional", +} + +@lru_cache(maxsize=None) +def _load_attr(name: str): + import importlib + module_path = _LAZY_IMPORTS[name] + mod = importlib.import_module(module_path, __name__) + return getattr(mod, name) + +def __getattr__(name): + if name in _LAZY_IMPORTS: + return _load_attr(name) + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + +__all__ = [ + "KernelBackendMoE", + "MoE", + "enable_quack_gemm", + "moe_general_routing_inputs", + "moe_TC_softmax_topk_layer", +] diff --git a/sonic-moe/torch-ext/sonicmoe/_ops_compat.py b/sonic-moe/torch-ext/sonicmoe/_ops_compat.py new file mode 100644 index 00000000..f4d00b10 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/_ops_compat.py @@ -0,0 +1,10 @@ +"""Compatibility helpers for op namespacing in source and built layouts.""" + +try: + from ._ops import add_op_namespace_prefix as _generated_add_op_namespace_prefix +except ImportError: + def _generated_add_op_namespace_prefix(name: str) -> str: + return name if "::" in name else f"sonicmoe::{name}" + +def add_op_namespace_prefix(name: str) -> str: + return _generated_add_op_namespace_prefix(name) diff --git a/sonic-moe/torch-ext/sonicmoe/enums.py b/sonic-moe/torch-ext/sonicmoe/enums.py new file mode 100644 index 00000000..c7d71322 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/enums.py @@ -0,0 +1,30 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from enum import Enum + + +LIBRARY_NAME = "sonicmoe" +TENSORMAP = "tensormap" + + +class KernelBackendMoE(Enum): + scattermoe = "scattermoe" + torch = "torch" + sonicmoe = "sonicmoe" + + +class ActivationType(Enum): + SWIGLU = "swiglu" + GEGLU = "geglu" + REGLU = "reglu" + + RELU_SQ = "relu_sq" + RELU = "relu" + GELU = "gelu" + SILU = "silu" + + +def is_glu(activation_type: ActivationType): + return activation_type in [ActivationType.SWIGLU, ActivationType.REGLU, ActivationType.GEGLU] diff --git a/sonic-moe/torch-ext/sonicmoe/functional/__init__.py b/sonic-moe/torch-ext/sonicmoe/functional/__init__.py new file mode 100644 index 00000000..14e3a3d0 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/__init__.py @@ -0,0 +1,554 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import os + +import torch +import torch.nn.functional as F +from ..quack.gemm_interface import gemm + +from ..enums import ActivationType, is_glu +from ..quack_utils import gemm_dgated, gemm_gated +from .backward import ( + _down_projection_backward_act, + _down_projection_backward_weight, + _softmax_topk_bwd, + _token_broadcast_backward, + _up_projection_backward_act, + _up_projection_backward_weight, +) +from .forward import _down_projection_forward, _router_forward, _softmax_topk_fwd, _up_projection_forward +from .triton_kernels import TC_topk_router_metadata_triton, general_routing_router_metadata_triton +from .utils import enable_quack_gemm, is_using_quack_gemm + + +class TC_Softmax_Topk_Router_Function(torch.autograd.Function): + @staticmethod + def forward(ctx, router_logits: torch.Tensor, E: int, K: int) -> tuple[torch.Tensor, torch.Tensor]: + T = router_logits.size(0) + + # change this to router_logits.dtype (bfloat16) increase another 5 tflops at fwd at the cost of numerical accuracy + topk_router_score = torch.empty(T, K, dtype=torch.float32, device=router_logits.device) + topk_router_indices = torch.empty(T, K, dtype=torch.int32, device=router_logits.device) + + _softmax_topk_fwd(router_logits, topk_router_score, topk_router_indices, E, K) + + ctx.save_for_backward(topk_router_score, topk_router_indices) + ctx.E = E + ctx.dtype = router_logits.dtype + + return topk_router_score, topk_router_indices + + @staticmethod + def backward(ctx, dtopk_score: torch.Tensor, _: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + T, K = dtopk_score.size() + + topk_router_score, topk_router_indices = ctx.saved_tensors + dlogits = torch.zeros(T, ctx.E, dtype=ctx.dtype, device=topk_router_score.device) + + _softmax_topk_bwd(dlogits, None, dtopk_score, topk_router_score, topk_router_indices, K) + + return dlogits, None, None + + +class _UpProjection(torch.autograd.Function): + @staticmethod + def forward( + ctx, + x: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + expert_frequency_offset: torch.Tensor, + total_expert_freq: int, + K: int, + stream_id: int, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, + num_activated_expert_per_token_offset: torch.Tensor, + is_varlen_K: bool, + activation_type: ActivationType, + is_inference_mode_enabled: bool, + ) -> torch.Tensor: + T, H = x.shape + I, H, E = w1.shape + is_glu_activation = is_glu(activation_type) + if is_glu_activation: + I //= 2 + TK = total_expert_freq + + if is_using_quack_gemm(): + assert not torch.compiler.is_compiling() + assert is_glu_activation, "QuACK GEMM does not support non GLU activation yet" + z, y1 = gemm_gated( + x, + w1.permute(2, 1, 0), + activation="swiglu", + cu_seqlens_m=expert_frequency_offset, + A_idx=x_gather_idx, + dynamic_scheduler=False, + ) + else: + z = torch.empty(TK, (2 * I if is_glu_activation else I), dtype=x.dtype, device=x.device) + y1 = torch.empty(TK, I, dtype=x.dtype, device=x.device) + _up_projection_forward( + x=x, + w1=w1, + z=z, + y1=y1, + b1=b1, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + stream_id=stream_id, + activation_type=activation_type.value, + is_glu_activation=is_glu_activation, + is_inference_mode_enabled=is_inference_mode_enabled, + ) + + ctx.T = T + ctx.TK = TK + ctx.E = E + ctx.K = K + ctx.H = H + ctx.I = I + ctx.is_varlen_K = is_varlen_K + ctx.is_glu_activation = is_glu_activation + ctx.stream_id = stream_id + + ctx.save_for_backward( + x, + w1, + b1, + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + ) + + ctx.mark_non_differentiable(y1) + ctx.set_materialize_grads(False) + + return y1, z + + @staticmethod + def backward(ctx, _: None, dz: torch.Tensor): + is_compiling = torch.compiler.is_compiling() + + if not is_compiling: + assert _ is None + + T = ctx.T + TK = ctx.TK + E = ctx.E + K = ctx.K + H = ctx.H + is_glu_activation = ctx.is_glu_activation + is_varlen_K = ctx.is_varlen_K + stream_id = ctx.stream_id + + ( + x, + w1, + b1, + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + ) = ctx.saved_tensors + + dw1 = torch.empty_like(w1) + db1 = None if b1 is None else torch.empty_like(b1) + + if is_using_quack_gemm(): + assert not is_compiling + + gemm( + x.T, + dz, + out=dw1.permute(2, 1, 0), + cu_seqlens_k=expert_frequency_offset, + A_idx=x_gather_idx, + batch_idx_permute=None, + dynamic_scheduler=False, + ) + dx_expanded = gemm(dz, w1.permute(2, 0, 1), cu_seqlens_m=expert_frequency_offset, dynamic_scheduler=False) + else: + dx_expanded = torch.empty(TK, H, dtype=dz.dtype, device=dz.device) + + _up_projection_backward_act( + w1=w1, + dx_expanded=dx_expanded, + dz=dz, + db1=db1, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + s_scatter_idx=s_scatter_idx, + is_glu_activation=is_glu_activation, + stream_id=stream_id, + ) + + _up_projection_backward_weight( + x=x, + dw1=dw1, + dz=dz, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + is_glu_activation=is_glu_activation, + stream_id=stream_id, + ) + + dx_reduced = torch.empty(T, H, dtype=dz.dtype, device=dz.device) + + _token_broadcast_backward( + dx_reduced=dx_reduced, + dx_expanded=dx_expanded, + s_reverse_scatter_idx=s_reverse_scatter_idx, + num_activated_expert_per_token_offset=num_activated_expert_per_token_offset, + varlen_K_max=(E if is_varlen_K else K), + H=H, + is_varlen_K=is_varlen_K, + ) + + return dx_reduced, dw1, db1, *[None] * 12 + + +class _DownProjection(torch.autograd.Function): + @staticmethod + def forward( + ctx, + y1: torch.Tensor, + z: torch.Tensor, + w2: torch.Tensor, + b2: torch.Tensor | None, + topk_scores: torch.Tensor, + expert_frequency_offset: torch.Tensor, + T: int, + K: int, + stream_id: int, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, + num_activated_expert_per_token_offset: torch.Tensor, + is_varlen_K: bool, + activation_type: ActivationType, + ) -> torch.Tensor: + TK = y1.size(0) + H, I, E = w2.shape + + if is_using_quack_gemm(): + assert not torch.compiler.is_compiling() + + assert b2 is None + y2 = gemm(y1, w2.permute(2, 1, 0), cu_seqlens_m=expert_frequency_offset) + else: + y2 = torch.empty(TK, H, dtype=y1.dtype, device=y1.device) + _down_projection_forward( + w2=w2, + y1=y1, + y2=y2, + b2=b2, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + stream_id=stream_id, + ) + + o = torch.empty(T, H, device=z.device, dtype=z.dtype) + topk_scores = topk_scores.flatten() + + _router_forward( + y2=y2, + o=o, + topk_scores=topk_scores, + s_reverse_scatter_idx=s_reverse_scatter_idx, + num_activated_expert_per_token_offset=num_activated_expert_per_token_offset, + varlen_K_max=(E if is_varlen_K else K), + H=H, + is_varlen_K=is_varlen_K, + ) + + ctx.T = T + ctx.K = K + ctx.is_varlen_K = is_varlen_K + ctx.activation_type = activation_type + ctx.stream_id = stream_id + + ctx.save_for_backward( + z, + w2, + b2, + topk_scores, + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + ) + + return o + + @staticmethod + def backward(ctx, dout: torch.Tensor): + T = ctx.T + K = ctx.K + stream_id = ctx.stream_id + is_varlen_K = ctx.is_varlen_K + activation_type = ctx.activation_type + + ( + z, + w2, + b2, + topk_scores, + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + ) = ctx.saved_tensors + + dw2 = torch.empty_like(w2) + db2 = None if b2 is None else torch.empty_like(b2) + dz = torch.empty_like(z) + + if is_using_quack_gemm(): + assert not torch.compiler.is_compiling() + assert is_glu(activation_type), "QuACK GEMM does not support non GLU activation yet" + + s = topk_scores[s_scatter_idx] + _, y1s, ds = gemm_dgated( + dout, + w2.permute(2, 0, 1), + PreAct=z, + activation="swiglu", + dx_out=dz, + colvec_scale=s, + colvec_reduce=True, + cu_seqlens_m=expert_frequency_offset, + A_idx=x_gather_idx, + dynamic_scheduler=False, + ) + gemm( + dout.T, + y1s, + out=dw2.permute(2, 0, 1), + cu_seqlens_k=expert_frequency_offset, + A_idx=x_gather_idx, + batch_idx_permute=None, + dynamic_scheduler=False, + ) + + ds = ds[s_reverse_scatter_idx] + else: + ds = torch.empty_like(topk_scores) + + I = w2.size(1) + TK = x_gather_idx.size(0) + + y1s = torch.empty(TK, I, dtype=z.dtype, device=z.device) + is_glu_activation = is_glu(activation_type) + + _down_projection_backward_act( + dout=dout, + z=z, + w2=w2, + dz=dz, + ds=ds, + b2=b2, + db2=db2, + y1s=y1s, + topk_scores=topk_scores, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + s_scatter_idx=s_scatter_idx, + is_glu_activation=is_glu_activation, + activation_type=activation_type.value, + stream_id=stream_id, + ) + + _down_projection_backward_weight( + dout=dout, + y1s=y1s, + dw2=dw2, + expert_frequency_offset=expert_frequency_offset, + expert_schedule_order=None, + x_gather_idx=x_gather_idx, + stream_id=stream_id, + ) + + # TC top-K routing + if not is_varlen_K: + ds = ds.view(T, K) + + return None, dz, dw2, db2, ds, *[None] * 10 + + +def moe_TC_softmax_topk_layer( + x: torch.Tensor, + router_w: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + w2: torch.Tensor, + b2: torch.Tensor | None, + K: int, + stream_id: int, + activation_type: ActivationType | str = ActivationType.SWIGLU, + is_inference_mode_enabled: bool = False, +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + assert ((b1 is None) and (b2 is None)) or ( + (b1 is not None) and (b2 is not None) + ), "b1 and b2 has to be None or not None at the same time!" + E = router_w.size(0) + router_logits = F.linear(x, router_w) + topk_scores, topk_indices = TC_Softmax_Topk_Router_Function.apply(router_logits, E, K) + + T, K = topk_indices.size() + TK = T * K + device = topk_indices.device + + s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) + s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) + expert_frequency = torch.empty(E, dtype=torch.int32, device=device) + expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device) + x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device) + + TC_topk_router_metadata_triton( + topk_indices, E, expert_frequency, expert_frequency_offset, x_gather_idx, s_scatter_idx, s_reverse_scatter_idx + ) + + T = x.size(0) + + if type(activation_type) == str: + activation_type = ActivationType(activation_type) + + y1, z = _UpProjection.apply( + x, + w1, + b1, + expert_frequency_offset, + T * K, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + None, + False, # is_varlen_K + activation_type, + is_inference_mode_enabled, + ) + + o = _DownProjection.apply( + y1, + z, + w2, + b2, + topk_scores, + expert_frequency_offset, + T, + K, + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + None, + False, # is_varlen_K + activation_type, + ) + + return o, router_logits, expert_frequency + + +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +# Weight format requirements: +# - w1_weight: Shape (2*I, H, E), stride order (2, 0, 1), must be interleaved [gate_row0, up_row0, gate_row1, up_row1, ...] +# - w2_weight: Shape (H, I, E), stride order (2, 0, 1) + + +# We assume token_indices is already SORTED ascendingly !!! +# and len(token_indices) = len(expert_indices) = len(router_scores) +# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! +def moe_general_routing_inputs( + x: torch.Tensor, + router_scores: torch.Tensor, + token_indices: torch.Tensor, + expert_indices: torch.Tensor, + w1: torch.Tensor, + b1: torch.Tensor | None, + w2: torch.Tensor, + b2: torch.Tensor | None, + E: int, + stream_id: int, + activation_type: ActivationType, + is_inference_mode_enabled: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + assert ((b1 is None) and (b2 is None)) or ( + (b1 is not None) and (b2 is not None) + ), "b1 and b2 has to be None or not None at the same time!" + + T = x.size(0) + TK = router_scores.size(0) + E = w2.size(-1) + device = router_scores.device + + s_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) + s_reverse_scatter_idx = torch.empty(TK, dtype=torch.int32, device=device) + expert_frequency = torch.empty(E, dtype=torch.int32, device=device) + expert_frequency_offset = torch.empty(E + 1, dtype=torch.int32, device=device) + x_gather_idx = torch.empty(TK, dtype=torch.int32, device=device) + num_activated_expert_per_token_offset = torch.empty(T + 1, dtype=torch.int32, device=device) + + general_routing_router_metadata_triton( + token_indices, + expert_indices, + T, + E, + expert_frequency, + expert_frequency_offset, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + ) + + y1, z = _UpProjection.apply( + x, + w1, + b1, + expert_frequency_offset, + TK, + None, # K, not needed + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + True, # is_varlen_K + activation_type, + is_inference_mode_enabled, + ) + + o = _DownProjection.apply( + y1, + z, + w2, + b2, + router_scores, + expert_frequency_offset, + T, + None, # K, not needed + stream_id, + x_gather_idx, + s_scatter_idx, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + True, # is_varlen_K + activation_type, + ) + + return o, expert_frequency diff --git a/sonic-moe/torch-ext/sonicmoe/functional/backward.py b/sonic-moe/torch-ext/sonicmoe/functional/backward.py new file mode 100644 index 00000000..3ecda490 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/backward.py @@ -0,0 +1,682 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from typing import Optional + +import cuda.bindings.driver as cuda +import cutlass.cute as cute +import torch +import triton +import triton.language as tl + +from .._ops_compat import add_op_namespace_prefix +from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType +from ..utils import ceil_divide, convert_torch_tensor_to_cute_tensor, get_powers_of_2 +from .moe_config import ( + HopperWgmma_MoE_Down_proj_ActGrad_Bwd, + HopperWgmma_MoE_Down_proj_WeightGrad_Bwd, + HopperWgmma_MoE_Up_proj_ActGrad_Bwd, + HopperWgmma_MoE_Up_proj_WeightGrad_Bwd, +) +from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton + + +def _get_autotune_configs_for_db2_and_ds() -> list[triton.Config]: + configs = [] + for BLOCK_TK in get_powers_of_2(4, 32): + configs.append(triton.Config({"BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4)) + return configs + + +@triton.autotune( + configs=_get_autotune_configs_for_db2_and_ds(), + key=["H", "E"], +) +@triton.jit +def db2_and_ds_kernel( + dout_ptr, # (T, H) + s_ptr, # (TK,) + new_ds_partial_ptr, # (TK, n_h_blocks) + old_ds_partial_ptr, # (TK, OLD_DS_PARTIAL_N) + b2_ptr, # (E, H), + db2_ptr, # (E, H), + x_gather_idx_ptr, # (TK,), maps grouped -> token index + s_scatter_idx_ptr, # (TK,), maps grouped -> scatter index + expert_offset_ptr, # (E+1,), offsets in grouped layout + H: tl.constexpr, + E: tl.constexpr, + OLD_DS_PARTIAL_N: tl.constexpr, + BLOCK_H: tl.constexpr, # Block size for H dimension + BLOCK_TK: tl.constexpr, # Block size for token dimension + BLOCK_OLD_DS_PARTIAL_N: tl.constexpr, +): + Eidx = tl.program_id(0) # expert id + Hidx = tl.program_id(1) # h-block id + NUM_H_BLOCKS: tl.constexpr = tl.num_programs(1) + + # Hidden dimension indices for this block + h_offsets = Hidx * BLOCK_H + tl.arange(0, BLOCK_H) + h_mask = h_offsets < H + + E_count_start = tl.load(expert_offset_ptr + Eidx) + E_count_end = tl.load(expert_offset_ptr + Eidx + 1) + n_tokens = E_count_end - E_count_start + + b2 = tl.load(b2_ptr + Eidx * H + h_offsets, mask=h_mask, other=0.0).to(tl.float32) + + db2_acc = tl.zeros([BLOCK_H], dtype=tl.float32) + + # Process tokens in blocks of BLOCK_TK + for block_start in tl.range(0, n_tokens, BLOCK_TK): + # Token offsets within this block + tk_offsets = block_start + tl.arange(0, BLOCK_TK) + tk_mask = tk_offsets < n_tokens + tk_grouped = E_count_start + tk_offsets + + # Gather token indices: [BLOCK_TK] + token_indices = tl.load(x_gather_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32) + + # Get scatter indices: [BLOCK_TK] + scatter_indices = tl.load(s_scatter_idx_ptr + tk_grouped, mask=tk_mask, other=0).to(tl.uint32) + + s = tl.load(s_ptr + scatter_indices, mask=tk_mask, other=0.0).to(tl.float32) + + # Gather dout: [BLOCK_TK, BLOCK_H] + dout_offsets = token_indices[:, None] * H + h_offsets[None, :] + dout_mask = tk_mask[:, None] & h_mask[None, :] + dout = tl.load(dout_ptr + dout_offsets, mask=dout_mask, other=0.0).to(tl.float32) + + # Accumulate db2: sum over tokens of (dout * s) + db2_acc += tl.sum(dout * s[:, None], axis=0) # Sum over BLOCK_TK dimension + + # Compute ds: dot(dout, b2) for this H-block + ds_partial = tl.sum(dout * b2[None, :], axis=1) # [BLOCK_TK] + + # On first H-block, add old_ds_partial.sum(dim=1) + if Hidx == 0: + n_offsets = tl.arange(0, BLOCK_OLD_DS_PARTIAL_N) + old_ds_partial_offsets = scatter_indices[:, None] * OLD_DS_PARTIAL_N + n_offsets[None, :] + old_ds_partial_mask = tk_mask[:, None] & (n_offsets[None, :] < OLD_DS_PARTIAL_N) + old_ds_partial_vals = tl.load( + old_ds_partial_ptr + old_ds_partial_offsets, mask=old_ds_partial_mask, other=0.0 + ).to(tl.float32) + ds_partial += tl.sum(old_ds_partial_vals, axis=1) + + tl.store(new_ds_partial_ptr + scatter_indices * NUM_H_BLOCKS + Hidx, ds_partial, mask=tk_mask) + + tl.store(db2_ptr + Eidx * H + h_offsets, db2_acc, mask=h_mask) + + +def _get_autotune_configs_for_db1() -> list[triton.Config]: + configs = [] + for BLOCK_TK in get_powers_of_2(4, 128): + for BLOCK_I in get_powers_of_2(64, 4096): + if 4096 <= BLOCK_I * BLOCK_TK <= 16384: + configs.append(triton.Config({"BLOCK_I": BLOCK_I, "BLOCK_TK": BLOCK_TK}, num_warps=8, num_stages=4)) + return configs + + +def _prune_triton_autotune_config(configs, nargs, **kw): + pruned_configs = [] + for c in configs: + if c.kwargs["BLOCK_I"] <= triton.next_power_of_2(nargs["I"]): + pruned_configs.append(c) + return pruned_configs + + +@triton.autotune( + configs=_get_autotune_configs_for_db1(), + key=["I", "E"], + prune_configs_by={"early_config_prune": _prune_triton_autotune_config}, +) +@triton.jit +def db1_kernel( + dz_ptr, # (T, H) + db1_ptr, # (E, H), + expert_offset_ptr, # (E+1,), offsets in grouped layout + I: tl.constexpr, + E: tl.constexpr, + BLOCK_I: tl.constexpr, # Block size for H dimension + BLOCK_TK: tl.constexpr, # Block size for token dimension +): + Eidx = tl.program_id(0) # expert id + + E_count_start = tl.load(expert_offset_ptr + Eidx).to(tl.int64) + E_count_end = tl.load(expert_offset_ptr + Eidx + 1).to(tl.int64) + n_tokens = E_count_end - E_count_start + + NUM_I_BLOCKS: tl.constexpr = triton.cdiv(I, BLOCK_I) + for Iidx in tl.static_range(0, NUM_I_BLOCKS, 1): + i_offsets = Iidx * BLOCK_I + tl.arange(0, BLOCK_I) + i_mask = i_offsets < I + + db1_acc = tl.zeros([BLOCK_I], dtype=tl.float32) + + # Process tokens in blocks of BLOCK_TK + for block_start in tl.range(0, n_tokens, BLOCK_TK): + # Token offsets within this block + tk_offsets = block_start + tl.arange(0, BLOCK_TK) + tk_mask = tk_offsets < n_tokens + tk_grouped = E_count_start + tk_offsets + + dz_offsets = tk_grouped[:, None] * I + i_offsets[None, :] + dz_mask = tk_mask[:, None] & i_mask[None, :] + dz = tl.load(dz_ptr + dz_offsets, mask=dz_mask, other=0.0).to(tl.float32) + + db1_acc += tl.sum(dz, axis=0) # Sum over BLOCK_TK dimension + + db1_offsets = Eidx.to(tl.int64) * I + i_offsets + tl.store(db1_ptr + db1_offsets, db1_acc, mask=i_mask) + + +@triton.jit +def _colsum_smallN_kernel( + y_ptr, # *mut T, shape [M] + x_ptr, # *const T, shape [M, N] + stride_xm: tl.constexpr, + stride_xn: tl.constexpr, # strides of X + stride_y: tl.constexpr, # stride of Y (usually 1) + N: tl.constexpr, # sizes + BLOCK_N: tl.constexpr, # tile size along N +): + row = tl.program_id(0) + + # assume BLOCK_N >= N + offs = tl.arange(0, BLOCK_N) + mask = offs < N + # Load a tile from the row; cast to fp32 for the reduction + x = tl.load(x_ptr + row * stride_xm + offs * stride_xn, mask=mask, other=0).to(tl.float32) + # Reduce this tile to a scalar and add + acc = tl.sum(x, axis=0) + + # Store the row-sum (cast back to y dtype) + tl.store(y_ptr + row * stride_y, acc) + + +@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_act"), mutates_args={"dx_expanded", "db1"}) +def _up_projection_backward_act( + w1: torch.Tensor, + dx_expanded: torch.Tensor, + dz: torch.Tensor, + db1: torch.Tensor | None, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor | None, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + is_glu_activation: bool, + stream_id: int, +) -> None: + I, H, E = w1.size() + if is_glu_activation: + I //= 2 + + # db1 computation + if db1 is not None: + db1_kernel[(E,)](dz, db1, expert_frequency_offset, (2 * I if is_glu_activation else I), E) + + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id) + mDz = convert_torch_tensor_to_cute_tensor(dz, (0, 1), 1, 16, 8, stream=stream_id) + mDx_expanded = convert_torch_tensor_to_cute_tensor(dx_expanded, (0, 1), 1, 16, 8, stream=stream_id) + mW1_trans = convert_torch_tensor_to_cute_tensor(w1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + current_stream = cuda.CUstream(stream_id) + + compile_dx_key = ("dx", E, H, I, is_glu_activation, dx_expanded.dtype) + if compile_dx_key not in _up_projection_backward_act.compile_cache: + dx_module = HopperWgmma_MoE_Up_proj_ActGrad_Bwd(E, H, I, is_glu_activation) + tensormaps = [dx_module.module.generate_tensormap(None, None, None) for _ in range(2)] + _up_projection_backward_act.compile_cache[compile_dx_key] = cute.compile( + dx_module, + mDz, + mW1_trans, + mDx_expanded, + mE_offset, + mX_gather, + mS_scatter, + tensormaps, + mE_permute_order, + current_stream, + ) + _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] = tensormaps + + dx_tensormaps = _up_projection_backward_act.compile_cache[f"dx-{TENSORMAP}"] + _up_projection_backward_act.compile_cache[compile_dx_key]( + mDz, + mW1_trans, + mDx_expanded, + mE_offset, + mX_gather, + mS_scatter, + dx_tensormaps, + mE_permute_order, + current_stream, + ) + + +_up_projection_backward_act.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_up_projection_backward_weight"), mutates_args={"dw1"}) +def _up_projection_backward_weight( + x: torch.Tensor, + dw1: torch.Tensor, + dz: torch.Tensor, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor | None, + x_gather_idx: torch.Tensor, + is_glu_activation: bool, + stream_id: int, +) -> None: + I, H, E = dw1.size() + if is_glu_activation: + I //= 2 + + x = x.detach() + + mDz_trans = convert_torch_tensor_to_cute_tensor(dz.T, (1, 0), 0, 16, 8, stream=stream_id) + mDw1_trans = convert_torch_tensor_to_cute_tensor(dw1.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) + + mX_trans = convert_torch_tensor_to_cute_tensor(x.T, (1, 0), 0, 16, 8, stream=stream_id) + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + current_stream = cuda.CUstream(stream_id) + + compile_dw1_key = ("dw1", E, H, I, is_glu_activation, x.dtype) + if compile_dw1_key not in _up_projection_backward_weight.compile_cache: + dw1_module = HopperWgmma_MoE_Up_proj_WeightGrad_Bwd(E, H, I, is_glu_activation) + tensormaps = [dw1_module.module.generate_tensormap(None, None, None) for _ in range(1)] + _up_projection_backward_weight.compile_cache[compile_dw1_key] = cute.compile( + dw1_module, + mX_trans, + mDz_trans, + mDw1_trans, + mE_offset, + mX_gather, + tensormaps, + mE_permute_order, + current_stream, + ) + _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] = tensormaps + + dw1_tensormaps = _up_projection_backward_weight.compile_cache[f"dw1-{TENSORMAP}"] + _up_projection_backward_weight.compile_cache[compile_dw1_key]( + mX_trans, + mDz_trans, + mDw1_trans, + mE_offset, + mX_gather, + dw1_tensormaps, + mE_permute_order, + current_stream, + ) + + +_up_projection_backward_weight.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_act"), mutates_args={"dz", "ds", "db2", "y1s"}) +def _down_projection_backward_act( + dout: torch.Tensor, + z: torch.Tensor, + w2: torch.Tensor, + dz: torch.Tensor, + ds: torch.Tensor, + b2: torch.Tensor | None, + db2: torch.Tensor | None, + y1s: torch.Tensor, + topk_scores: torch.Tensor, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor | None, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + is_glu_activation: bool, + activation_type: str, + stream_id: int, +) -> None: + H, I, E = w2.size() + TK = x_gather_idx.size(0) + + dout = dout.detach() + w2 = w2.detach() + topk_scores = topk_scores.detach() + + mDout = convert_torch_tensor_to_cute_tensor(dout, (0, 1), 1, 16, 8, stream=stream_id) + mW2_trans = convert_torch_tensor_to_cute_tensor(w2.permute(1, 0, 2), (2, 1, 0), 0, 16, 8, stream=stream_id) + mS = convert_torch_tensor_to_cute_tensor(topk_scores, (0,), 0, 4, 1, stream=stream_id) + if is_glu_activation: + mDz_kernel_input = convert_torch_tensor_to_cute_tensor( + dz.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id + ) + mZ_kernel_input = convert_torch_tensor_to_cute_tensor( + z.view(torch.float32), (0, 1), 1, 16, 8, stream=stream_id + ) + else: + mDz_kernel_input = convert_torch_tensor_to_cute_tensor(dz.detach(), (0, 1), 1, 16, 8, stream=stream_id) + mZ_kernel_input = convert_torch_tensor_to_cute_tensor(z.detach(), (0, 1), 1, 16, 8, stream=stream_id) + + mY1S = convert_torch_tensor_to_cute_tensor(y1s, (0, 1), 1, 16, 8, stream=stream_id) + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + mS_scatter = convert_torch_tensor_to_cute_tensor(s_scatter_idx, (0,), 0, 4, 1, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + current_stream = cuda.CUstream(stream_id) + ds_partial = None + + compile_dz_key = ("dz", E, H, I, z.dtype, activation_type) + if compile_dz_key not in _down_projection_backward_act.compile_cache: + # I don't know why but this sync appears to fix a mysterious initialization bug?? + torch.cuda.synchronize() + dz_module = HopperWgmma_MoE_Down_proj_ActGrad_Bwd(E, H, I, ActivationType(activation_type)) + tensormaps = [dz_module.module.generate_tensormap(None, None, None) for _ in range(3)] + + ds_partial_N = max(ceil_divide(I, dz_module.module.tile_shape_mnk[1]), 1) + ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device) + mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id) + + _down_projection_backward_act.compile_cache["ds_partial_N"] = ds_partial_N + _down_projection_backward_act.compile_cache[compile_dz_key] = cute.compile( + dz_module, + mDout, + mW2_trans, + mZ_kernel_input, + mDz_kernel_input, + mY1S, + mS, + mDS_partial, + mE_offset, + mX_gather, + mS_scatter, + tensormaps, + mE_permute_order, + current_stream, + ) + _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] = tensormaps + + if ds_partial is None: + ds_partial_N = _down_projection_backward_act.compile_cache["ds_partial_N"] + ds_partial = torch.empty(TK, ds_partial_N, dtype=torch.float32, device=topk_scores.device) + mDS_partial = convert_torch_tensor_to_cute_tensor(ds_partial, (0, 1), 1, 4, 1, stream=stream_id) + + dz_tensormaps = _down_projection_backward_act.compile_cache[f"dz-{TENSORMAP}"] + _down_projection_backward_act.compile_cache[compile_dz_key]( + mDout, + mW2_trans, + mZ_kernel_input, + mDz_kernel_input, + mY1S, + mS, + mDS_partial, + mE_offset, + mX_gather, + mS_scatter, + dz_tensormaps, + mE_permute_order, + current_stream, + ) + + if db2 is None: + # we don't need to update ds + if ds_partial.size(1) == 1: + ds.copy_(ds_partial.view(-1).to(dtype=ds.dtype)) + elif ds_partial.size(1) <= 32: + ds.copy_(ds_partial.sum(dim=-1, dtype=ds.dtype)) + else: + M, N = ds_partial.size() + + _colsum_smallN_kernel[M,]( + y_ptr=ds, + x_ptr=ds_partial, + stride_xm=ds_partial.stride(0), + stride_xn=ds_partial.stride(1), + stride_y=1, + N=N, + BLOCK_N=triton.next_power_of_2(N), + ) + else: + # db2 and ds update + BLOCK_H = min(triton.next_power_of_2(H), 2048) + NUM_H_BLOCKS = triton.cdiv(H, BLOCK_H) + + new_ds_partial = torch.empty(TK, NUM_H_BLOCKS, device=ds.device, dtype=torch.float32) + + db2_and_ds_kernel[(E, NUM_H_BLOCKS)]( + dout, + topk_scores, + new_ds_partial, + ds_partial, + b2, + db2, + x_gather_idx, + s_scatter_idx, + expert_frequency_offset, + H, + E, + ds_partial_N, + BLOCK_H=BLOCK_H, + BLOCK_OLD_DS_PARTIAL_N=triton.next_power_of_2(ds_partial_N), + ) + + if NUM_H_BLOCKS == 1: + ds.copy_(new_ds_partial.view(-1).to(dtype=ds.dtype)) + else: + ds.copy_(new_ds_partial.sum(dim=-1, dtype=ds.dtype)) + + +_down_projection_backward_act.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_down_projection_backward_weight"), mutates_args={"dw2"}) +def _down_projection_backward_weight( + dout: torch.Tensor, + y1s: torch.Tensor, + dw2: torch.Tensor, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor | None, + x_gather_idx: torch.Tensor, + stream_id: int, +) -> None: + H, I, E = dw2.size() + + mDout_trans = convert_torch_tensor_to_cute_tensor(dout.T, (1, 0), 0, 16, 8, stream=stream_id) + mDw2 = convert_torch_tensor_to_cute_tensor(dw2, (2, 0, 1), 1, 16, 8, stream=stream_id) + mY1S_trans = convert_torch_tensor_to_cute_tensor(y1s.T, (1, 0), 0, 16, 8, stream=stream_id) + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + current_stream = cuda.CUstream(stream_id) + + compile_dw2_key = ("dw2", E, H, I, dw2.dtype) + if compile_dw2_key not in _down_projection_backward_weight.compile_cache: + dw2_module = HopperWgmma_MoE_Down_proj_WeightGrad_Bwd(E, H, I) + tensormaps = [dw2_module.module.generate_tensormap(None, None, None) for _ in range(1)] + _down_projection_backward_weight.compile_cache[compile_dw2_key] = cute.compile( + dw2_module, + mDout_trans, + mY1S_trans, + mDw2, + mE_offset, + mX_gather, + tensormaps, + mE_permute_order, + current_stream, + ) + _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] = tensormaps + + dw2_tensormaps = _down_projection_backward_weight.compile_cache[f"dw2-{TENSORMAP}"] + _down_projection_backward_weight.compile_cache[compile_dw2_key]( + mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, dw2_tensormaps, mE_permute_order, current_stream + ) + + +_down_projection_backward_weight.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_token_broadcast_backward"), mutates_args={"dx_reduced"}) +def _token_broadcast_backward( + dx_reduced: torch.Tensor, + dx_expanded: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, + num_activated_expert_per_token_offset: Optional[torch.Tensor], + varlen_K_max: int, + H: int, + is_varlen_K: bool, +) -> None: + if num_activated_expert_per_token_offset is None: + assert not is_varlen_K, "`num_activated_expert_per_token_offset` as None requires fixed top-K routing" + token_gather_and_sum_varlen_K_triton( + dx_expanded, + None, + dx_reduced, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + dx_reduced.size(0), + varlen_K_max, + H, + is_varlen_K, + ) + + +@triton.jit +def _softmax_bwd_scatter_small_kernel( + dlogits_ptr, + dlogits_full_ptr, + score_ptr, + dscore_ptr, + idx_ptr, + stride_dm: tl.constexpr, + stride_dn: tl.constexpr, + stride_sm: tl.constexpr, + stride_sn: tl.constexpr, + stride_gm: tl.constexpr, + stride_gk: tl.constexpr, + stride_im: tl.constexpr, + stride_ik: tl.constexpr, + K: tl.constexpr, + BLOCK_K: tl.constexpr, + dlogits_is_none: tl.constexpr, +): + row = tl.program_id(axis=0) + + # tl.assume(K <= BLOCK_K) + k_offs = tl.arange(0, BLOCK_K) + k_mask = k_offs < K + + idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32) + s_sel = tl.load(score_ptr + row * stride_sm + k_offs * stride_sn, mask=k_mask, other=0).to(tl.float32) + g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32) + + # dot = sum_j g_j * y_j over selected columns + dot = tl.sum(g_sel * s_sel, axis=0) + + # scatter-only: dx[idx] += y_sel * (g_sel - dot) + add_vals = s_sel * (g_sel - dot) + + indices = row * stride_dm + idx * stride_dn + if not dlogits_is_none: + add_vals += tl.load(dlogits_ptr + indices, mask=k_mask) + tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask) + + +@torch.library.custom_op(add_op_namespace_prefix("_softmax_topk_bwd"), mutates_args={"dlogits_full"}) +def _softmax_topk_bwd( + dlogits_full: torch.Tensor, + dlogits: Optional[torch.Tensor], + dtopk_score: torch.Tensor, + topk_router_score: torch.Tensor, + topk_router_indices: torch.Tensor, + K: int, +) -> None: + T = dtopk_score.shape[0] + + _softmax_bwd_scatter_small_kernel[T,]( + dlogits, + dlogits_full, + topk_router_score, + dtopk_score, + topk_router_indices, + dlogits_full.stride(0), + dlogits_full.stride(1), + topk_router_score.stride(0), + topk_router_score.stride(1), + dtopk_score.stride(0), + dtopk_score.stride(1), + topk_router_indices.stride(0), + topk_router_indices.stride(1), + K, + triton.next_power_of_2(K), + (dlogits is None), + ) + + +@triton.jit +def _topk_bwd_scatter_small_kernel( + dlogits_full_ptr, + dscore_ptr, + idx_ptr, + stride_dm: tl.constexpr, + stride_dn: tl.constexpr, + stride_gm: tl.constexpr, + stride_gk: tl.constexpr, + stride_im: tl.constexpr, + stride_ik: tl.constexpr, + K: tl.constexpr, + BLOCK_K: tl.constexpr, +): + row = tl.program_id(axis=0) + + # tl.assume(K <= BLOCK_K) + k_offs = tl.arange(0, BLOCK_K) + k_mask = k_offs < K + + idx = tl.load(idx_ptr + row * stride_im + k_offs * stride_ik, mask=k_mask, other=0).to(tl.int32) + g_sel = tl.load(dscore_ptr + row * stride_gm + k_offs * stride_gk, mask=k_mask, other=0).to(tl.float32) + + # scatter-only: dx[idx] += y_sel * (g_sel - dot) + add_vals = g_sel + + indices = row * stride_dm + idx * stride_dn + tl.store(dlogits_full_ptr + indices, add_vals, mask=k_mask) + + +@torch.library.custom_op(add_op_namespace_prefix("_topk_bwd"), mutates_args={"dlogits_full"}) +def _topk_bwd( + dlogits_full: torch.Tensor, + dtopk_values: torch.Tensor, + topk_indices: torch.Tensor, + K: int, +) -> None: + T = dtopk_values.shape[0] + + _topk_bwd_scatter_small_kernel[T,]( + dlogits_full, + dtopk_values, + topk_indices, + dlogits_full.stride(0), + dlogits_full.stride(1), + dtopk_values.stride(0), + dtopk_values.stride(1), + topk_indices.stride(0), + topk_indices.stride(1), + K, + triton.next_power_of_2(K), + ) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/forward.py b/sonic-moe/torch-ext/sonicmoe/functional/forward.py new file mode 100644 index 00000000..f9f837f0 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/forward.py @@ -0,0 +1,238 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import cuda.bindings.driver as cuda +import cutlass.cute as cute +import torch +import triton +import triton.language as tl +from cutlass.cute.runtime import from_dlpack +from ..quack.cute_dsl_utils import torch2cute_dtype_map + +from ..enums import LIBRARY_NAME, TENSORMAP, ActivationType +from .._ops_compat import add_op_namespace_prefix +from ..utils import convert_torch_tensor_to_cute_tensor +from .moe_config import HopperWgmma_MoE_Down_proj_Fwd, HopperWgmma_MoE_Up_proj_Fwd +from .reduction_over_k_gather import token_gather_and_sum_varlen_K_triton +from .topk_softmax import TopK_Softmax + + +@torch.library.custom_op(add_op_namespace_prefix("_topk_fwd"), mutates_args={"values", "indices"}) +def _topk_fwd( + x: torch.Tensor, k: int, values: torch.Tensor, indices: torch.Tensor, require_softmax_fusion: bool = True +) -> None: + """Top-k forward pass. + Args: + x: Input tensor of shape (M, N) + k: Number of top elements to return + Returns: + Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k)) + """ + N = x.size(1) + + input_dtype = torch2cute_dtype_map[x.dtype] + output_dtype = torch2cute_dtype_map[values.dtype] + convert_from_dlpack = lambda tensor: ( + from_dlpack(tensor.detach(), assumed_align=16).mark_compact_shape_dynamic(mode=0, stride_order=(0, 1)) + ) + + x_tensor, values_tensor, indices_tensor = [convert_from_dlpack(tensor) for tensor in (x, values, indices)] + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + compile_key = (input_dtype, output_dtype, N, k, require_softmax_fusion) + if compile_key not in _topk_fwd.compile_cache: + topk_op = TopK_Softmax(input_dtype, output_dtype, N, k, require_softmax_fusion) + _topk_fwd.compile_cache[compile_key] = cute.compile( + topk_op, x_tensor, values_tensor, indices_tensor, current_stream + ) + _topk_fwd.compile_cache[compile_key](x_tensor, values_tensor, indices_tensor, current_stream) + + +_topk_fwd.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_up_projection_forward"), mutates_args={"z", "y1"}) +def _up_projection_forward( + x: torch.Tensor, + w1: torch.Tensor, + z: torch.Tensor, + y1: torch.Tensor, + b1: torch.Tensor | None, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor, + x_gather_idx: torch.Tensor, + stream_id: int, + activation_type: str, + is_glu_activation: bool, + is_inference_mode_enabled: bool = False, +) -> None: + I, H, E = w1.size() + if is_glu_activation: + I //= 2 + + mX = convert_torch_tensor_to_cute_tensor(x.detach(), (0, 1), 1, 16, 8, stream=stream_id) + mW1 = convert_torch_tensor_to_cute_tensor(w1.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id) + mZ = convert_torch_tensor_to_cute_tensor(z, (0, 1), 1, 16, 8, stream=stream_id) + mY1 = convert_torch_tensor_to_cute_tensor(y1, (0, 1), 1, 16, 8, stream=stream_id) + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + + if b1 is None: + mB1 = None + else: + mB1 = convert_torch_tensor_to_cute_tensor(b1.detach(), (0, 1), 1, 16, 8, stream=stream_id) + + current_stream = cuda.CUstream(stream_id) + + compile_w1_key = (E, H, I, (b1 is None), x.dtype, activation_type, is_inference_mode_enabled) + if compile_w1_key not in _up_projection_forward.compile_cache: + w1_module = HopperWgmma_MoE_Up_proj_Fwd( + E, H, I, activation_type=ActivationType(activation_type), inference_mode=is_inference_mode_enabled + ) + tensormaps = [w1_module.module.generate_tensormap(None, None, None) for _ in range(2)] + _up_projection_forward.compile_cache[compile_w1_key] = cute.compile( + w1_module, + mX, + mW1, + mZ, + mY1, + mB1, + mE_offset, + mX_gather, + tensormaps[0], + tensormaps[1], + mE_permute_order, + current_stream, + ) + _up_projection_forward.compile_cache[TENSORMAP] = tensormaps + + w1_tensormaps = _up_projection_forward.compile_cache[TENSORMAP] + _up_projection_forward.compile_cache[compile_w1_key]( + mX, + mW1, + mZ, + mY1, + mB1, + mE_offset, + mX_gather, + w1_tensormaps[0], + w1_tensormaps[1], + mE_permute_order, + current_stream, + ) + + +_up_projection_forward.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_down_projection_forward"), mutates_args={"y2"}) +def _down_projection_forward( + w2: torch.Tensor, + y1: torch.Tensor, + y2: torch.Tensor, + b2: torch.Tensor | None, + expert_frequency_offset: torch.Tensor, + expert_schedule_order: torch.Tensor, + x_gather_idx: torch.Tensor, + stream_id: int, +) -> None: + H, I, E = w2.size() + + mW2 = convert_torch_tensor_to_cute_tensor(w2.detach(), (2, 0, 1), 1, 16, 8, stream=stream_id) + mY1 = convert_torch_tensor_to_cute_tensor(y1.detach(), (0, 1), 1, 16, 8, stream=stream_id) + mY2 = convert_torch_tensor_to_cute_tensor(y2, (0, 1), 1, 16, 8, stream=stream_id) + mE_offset = convert_torch_tensor_to_cute_tensor(expert_frequency_offset, (0,), 0, 4, 1, stream=stream_id) + mX_gather = convert_torch_tensor_to_cute_tensor(x_gather_idx, (0,), 0, 4, 1, stream=stream_id) + + if expert_schedule_order is None: + mE_permute_order = None + else: + mE_permute_order = convert_torch_tensor_to_cute_tensor(expert_schedule_order, (0,), 0, 4, 1, stream=stream_id) + + if b2 is None: + mB2 = None + else: + mB2 = convert_torch_tensor_to_cute_tensor(b2.detach(), (0, 1), 1, 16, 8, stream=stream_id) + + current_stream = cuda.CUstream(stream_id) + + compile_w2_key = (E, H, I, (b2 is None), w2.dtype) + if compile_w2_key not in _down_projection_forward.compile_cache: + w2_module = HopperWgmma_MoE_Down_proj_Fwd(E, H, I) + tensormaps = [w2_module.module.generate_tensormap(None, None, None) for _ in range(1)] + _down_projection_forward.compile_cache[compile_w2_key] = cute.compile( + w2_module, mY1, mW2, mY2, mB2, mE_offset, mX_gather, tensormaps[0], mE_permute_order, current_stream + ) + _down_projection_forward.compile_cache[TENSORMAP] = tensormaps + + w2_tensormaps = _down_projection_forward.compile_cache[TENSORMAP] + _down_projection_forward.compile_cache[compile_w2_key]( + mY1, mW2, mY2, mB2, mE_offset, mX_gather, w2_tensormaps[0], mE_permute_order, current_stream + ) + + +_down_projection_forward.compile_cache = {} + + +@torch.library.custom_op(add_op_namespace_prefix("_router_forward"), mutates_args={"o"}) +def _router_forward( + y2: torch.Tensor, + o: torch.Tensor, + topk_scores: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, + num_activated_expert_per_token_offset: torch.Tensor, + varlen_K_max: int, + H: int, + is_varlen_K: bool, +) -> None: + token_gather_and_sum_varlen_K_triton( + y2, + topk_scores, + o, + s_reverse_scatter_idx, + num_activated_expert_per_token_offset, + o.size(0), + varlen_K_max, + H, + is_varlen_K, + ) + + +@triton.jit +def _softmax_fwd_small_kernel( + logits_ptr, stride_lm: tl.constexpr, stride_ln: tl.constexpr, K: tl.constexpr, BLOCK_K: tl.constexpr +): + row = tl.program_id(axis=0) + + # tl.assume(K <= BLOCK_K) + k_offs = tl.arange(0, BLOCK_K) + k_mask = k_offs < K + + # load full row (all columns) in one go (N is small) + x = tl.load(logits_ptr + row * stride_lm + k_offs * stride_ln, mask=k_mask, other=-float("inf")).to(tl.float32) + x = x - tl.max(x, axis=0) + ex = tl.exp(x) + y = ex / tl.sum(ex, axis=0) + + tl.store(logits_ptr + row * stride_lm + k_offs * stride_ln, y, mask=k_mask) + + +@torch.library.custom_op( + add_op_namespace_prefix("_softmax_topk_fwd"), mutates_args={"topk_router_score", "topk_router_indices"} +) +def _softmax_topk_fwd( + router_logits: torch.Tensor, topk_router_score: torch.Tensor, topk_router_indices: torch.Tensor, E: int, K: int +) -> None: + # T = router_logits.shape[0] + if E <= 4096 and K <= 16 and E % 8 == 0: + # fast topk-softmax fusion that covers most common MoE configs + _topk_fwd(router_logits, K, topk_router_score, topk_router_indices, require_softmax_fusion=True) + else: + topk_results = router_logits.topk(K, dim=-1) + topk_router_score.copy_(topk_results.values.softmax(dim=-1, dtype=torch.float32).to(topk_router_score.dtype)) + topk_router_indices.copy_(topk_results.indices.to(topk_router_indices.dtype)) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/grouped_gemm.py b/sonic-moe/torch-ext/sonicmoe/functional/grouped_gemm.py new file mode 100644 index 00000000..13e6d8e7 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/grouped_gemm.py @@ -0,0 +1,3069 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: + +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. + +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. + +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. + +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import enum +import math +import operator +from functools import partial +from typing import Callable, Optional, Tuple, Type, Union + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +import cutlass.utils as utils +import cutlass.utils.hopper_helpers as sm90_utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass._mlir.dialects import llvm, vector +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import T, dsl_user_op +from ..quack.copy_utils import sm90_get_smem_load_op +from ..quack.cute_dsl_utils import ParamsBase +from ..quack.layout_utils import make_acc_tensor_mn_view + +# return PipelineStateWAdvance instead of PipelineState +from ..quack.pipeline import PipelineTmaCpAsync, make_pipeline_state +from ..quack.sm90_utils import partition_for_epilogue +from ..quack.tensormap_manager import TensorMapManagerSm90 +from ..quack.tile_scheduler import RasterOrderOption, TileSchedulerArguments, VarlenMTileSchedulerArguments + +from .tile_scheduler import SonicMoETileScheduler, SonicMoEVarlenMTileScheduler + + +class NamedBarrierGemm(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + EpilogueLoad = enum.auto() + MmaWG0 = enum.auto() + MmaWG1 = enum.auto() + EpiWG0 = enum.auto() + EpiWG1 = enum.auto() + Prolog = enum.auto() + + +class HopperWgmma_MoE_kernel: + def __init__( + self, + E: int, + acc_dtype: Type[cutlass.Numeric], + tile_shape_mnk: Tuple[int, int, int], + cluster_shape_mnk: Tuple[int, int, int], + pingpong: bool = False, + is_persistent: bool = True, + compute_dz_and_partial_ds_and_y1s: bool = False, + compute_weight_gradient: bool = False, + compute_relu: bool = False, + compute_silu: bool = False, + compute_gelu: bool = False, + compute_relu_sq: bool = False, + compute_swiglu: bool = False, + compute_reglu: bool = False, + compute_geglu: bool = False, + is_normal_act: bool = False, + is_glu: bool = False, + is_A_gather: bool = False, + is_scatter_idx_prefetched: bool = False, + epi_tile_size: int = 32, + initial_d_epi_stage: int = 4, + index_dtype: Type[cutlass.Numeric] = cutlass.Int32, + prefetch_idx_store_to_smem: int = 2048, + inference_mode: bool = False, + L2_group_size: int = 8, + raster_order: RasterOrderOption = RasterOrderOption.Heuristic, + ): + self.epi_tile_size = epi_tile_size + self.initial_d_epi_stage = initial_d_epi_stage + + self.is_A_gather = is_A_gather + self.is_scatter_idx_prefetched = is_scatter_idx_prefetched + + self.compute_swiglu = compute_swiglu + self.compute_geglu = compute_geglu + self.compute_reglu = compute_reglu + + self.compute_relu = compute_relu + self.compute_silu = compute_silu + self.compute_gelu = compute_gelu + self.compute_relu_sq = compute_relu_sq + + self.is_glu = is_glu or (compute_swiglu or compute_geglu or compute_reglu) + self.is_normal_act = is_normal_act or (compute_gelu or compute_relu_sq or compute_relu or compute_silu) + + self.compute_dz_and_partial_ds_and_y1s = compute_dz_and_partial_ds_and_y1s + self.compute_weight_gradient = compute_weight_gradient + + self.need_adhoc_epilogue_store = self.is_glu or self.is_normal_act or compute_dz_and_partial_ds_and_y1s + self.need_epilogue_load = compute_dz_and_partial_ds_and_y1s + + self.L2_group_size = L2_group_size + self.raster_order = raster_order + + self.E = E + self.acc_dtype = acc_dtype + assert self.acc_dtype == cutlass.Float32 + self.pingpong = pingpong + self.is_persistent = is_persistent + if self.pingpong: + assert self.is_persistent, "Pingpong gemm requires persistent scheduler" + + self.cluster_shape_mnk = cluster_shape_mnk + self.tile_shape_mnk = tuple(tile_shape_mnk) + tile_M, tile_N = tile_shape_mnk[0], tile_shape_mnk[1] + # check the cta tile shape + if not self.pingpong: + if tile_M not in [64, 128, 192, 256, 320]: + raise ValueError("CTA tile shape M must be 64/128/192/256/320") + if tile_M in [192, 320]: # special case + tile_N_max = 256 if tile_M == 192 else 160 + if not (tile_N % 32 == 0 and tile_N <= tile_N_max): + raise ValueError( + f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}" + ) + else: + if not ((tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512)): + raise ValueError( + "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512" + ) + else: + if tile_M not in [64, 128, 192]: + raise ValueError("CTA tile shape M must be 64/128/192 if pingpong") + tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128) + if not (tile_N % 16 == 0 and tile_N <= tile_N_max): + raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}") + if not self.tile_shape_mnk[2] % 16 == 0: + raise ValueError("CTA tile shape K must be divisible by 16") + + self.tile_M, self.tile_N, self.tile_K = tile_shape_mnk + + if not self.pingpong: + if tile_M == 320: # tile_M / 64 is not even so we have to split along N + atom_layout_m, atom_layout_n = 1, 2 + elif tile_M == 192: + if tile_N <= 128: + atom_layout_m, atom_layout_n = 3, 1 + else: + atom_layout_m, atom_layout_n = 1, 2 + else: + atom_layout_m = tile_shape_mnk[0] // 64 if tile_shape_mnk[0] < 256 else 2 + atom_layout_n = 1 + assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2] + else: + atom_layout_m, atom_layout_n = 1, 1 + self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1) + + if is_A_gather: + assert self.cluster_shape_mnk[1] == 1 + self.num_mcast_ctas_a = None + self.is_a_mcast = False + else: + self.num_mcast_ctas_a = self.cluster_shape_mnk[1] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.num_mcast_ctas_b = self.cluster_shape_mnk[0] + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.occupancy = 1 + self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) + if self.pingpong: + assert self.mma_warp_groups == 2 + self.num_threads_per_warp_group = 128 + self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group + self.smem_capacity = utils.get_smem_capacity_in_bytes("sm_90") + self.num_mma_threads = (self.mma_warp_groups if not self.pingpong else 1) * self.num_threads_per_warp_group + self.num_epi_threads = (self.mma_warp_groups if not self.pingpong else 1) * self.num_threads_per_warp_group + self.tma_warp_id = self.mma_warp_groups * 4 + self.universal_copy_bits = 128 + # assumed BF16 now + self.num_load_A_threads = ( + min(self.tile_M * self.tile_K // 8, self.threads_per_cta - self.tma_warp_id * cute.arch.WARP_SIZE) + if is_A_gather + else 0 + ) + if self.compute_weight_gradient and self.is_A_gather: + if tile_M == 192: # contiguous dimension + self.num_load_A_threads = 3 * 32 + assert tile_M in [64, 128, 192, 256] + + self.num_epi_load_threads = 0 + if self.need_epilogue_load: + # 3 warps to load A, 1 warp to load C, (and 1 warp to load S) + self.num_load_A_threads = 4 * cute.arch.WARP_SIZE + self.num_epi_load_threads = self.num_epi_threads + + regs_per_thread = math.prod(self.tile_shape_mnk[:2]) // self.num_mma_threads + heavy_register_pressure = regs_per_thread >= 208 + + if not is_A_gather: + if self.mma_warp_groups == 3: + self.num_regs_load, self.num_regs_mma = 32, 160 + else: + heavy_register_pressure = regs_per_thread >= 208 + self.num_regs_load, self.num_regs_mma = (40, 232) if not heavy_register_pressure else (24, 240) + else: + if self.mma_warp_groups == 3: + self.num_regs_load, self.num_regs_mma = 56, 152 + else: + self.num_regs_load, self.num_regs_mma = (56, 224) + + self.ab_stage = None + self.c_epi_stage = None + self.d_epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.d_epi_smem_layout_staged = None + self.d_epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + self.tensormap_update_mode = cutlass.utils.TensorMapUpdateMode.SMEM + self.bytes_per_tensormap = 128 + self.tensor_memory_management_bytes = 12 + + self.inference_mode = inference_mode + + if is_A_gather: + if self.need_adhoc_epilogue_store: + if self.inference_mode: + self.num_tensormaps = self.mma_warp_groups if self.pingpong else 1 + else: + self.num_tensormaps = 2 * self.mma_warp_groups if self.pingpong else 2 + else: + self.num_tensormaps = 1 * self.mma_warp_groups if self.pingpong else 1 + else: + if self.need_adhoc_epilogue_store: + if self.inference_mode: + self.num_tensormaps = 2 * self.mma_warp_groups + 1 if self.pingpong else 3 + else: + self.num_tensormaps = self.mma_warp_groups + 1 if self.pingpong else 2 + else: + self.num_tensormaps = 1 * self.mma_warp_groups + 1 if self.pingpong else 2 + + if self.need_epilogue_load: + self.num_tensormaps += 2 * self.mma_warp_groups if self.pingpong else 1 + + if self.compute_weight_gradient: + if self.is_A_gather: + self.num_tensormaps = 1 + self.prefetch_token_idx_size = prefetch_idx_store_to_smem + self.index_dtype = index_dtype + + assert ( + self.prefetch_token_idx_size % self.tile_K == 0 + and self.prefetch_token_idx_size >= self.tile_K + and self.prefetch_token_idx_size % self.num_load_A_threads == 0 + ) + else: + self.num_tensormaps = 2 + self.prefetch_token_idx_size = 0 + self.index_dtype = None + else: + self.prefetch_token_idx_size = 0 + self.index_dtype = None + + self.tensormap_bytes_total = self.num_tensormaps * self.bytes_per_tensormap + + def _setup_attributes(self): + self.cta_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + + self.d_epi_tile = self._sm90_compute_tile_shape_or_override( + self.tile_shape_mnk, + self.atom_layout_mnk, + self.d_dtype, + ) + self.c_epi_tile = self.d_epi_tile + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + self.y_epi_tile = self.d_epi_tile + elif const_expr(self.is_glu): + self.y_epi_tile = (self.d_epi_tile[0], self.d_epi_tile[1] // 2) + elif const_expr(self.is_normal_act): + self.y_epi_tile = self.d_epi_tile + else: + self.y_epi_tile = None + + if const_expr(self.use_bias): + self.bias_epi_tile = self.d_epi_tile + self.initial_d_epi_stage -= 1 # for safety + else: + self.bias_epi_tile = None + + # Compute stage before compute smem layout + self.ab_stage, self.c_epi_stage, self.d_epi_stage, self.y_epi_stage = self._compute_stages( + self.tile_shape_mnk, + self.initial_d_epi_stage, + # epi_smem will reuse smem ab if not persistent. + self.d_epi_tile, + self.c_epi_tile, + self.y_epi_tile, + self.a_dtype, + self.b_dtype, + self.d_dtype, + self.c_dtype, + self.y_dtype, + self.smem_capacity, + self.occupancy, + # epi_smem will reuse smem ab if not persistent. + overlap_sD_sA=not self.is_persistent, + ) + + if const_expr((not self.inference_mode) and self.need_adhoc_epilogue_store): + assert self.d_epi_stage == self.y_epi_stage + + self.sched_stage = 2 if self.pingpong else 1 + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.c_epi_smem_layout_staged, + self.bias_epi_smem_layout_staged, + self.d_epi_smem_layout_staged, + self.y_epi_smem_layout_staged, + self.s_epi_smem_layout_staged, + self.prefetch_AIdx_smem_layout_staged, + ) = self._make_smem_layouts( + self.tile_shape_mnk, + self.c_epi_tile, + self.bias_epi_tile, + self.d_epi_tile, + self.y_epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.prefetch_token_idx_size, + self.ab_stage, + self.c_dtype, + self.c_layout, + self.bias_dtype, + self.bias_layout, + self.d_dtype, + self.d_layout, + self.y_dtype, + self.y_layout, + self.s_dtype, + self.c_epi_stage, + self.d_epi_stage, + self.y_epi_stage, + ) + + @dsl_user_op + def tanh(self, a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def fma(self, a: float | Float32, b: float | Float32, c: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [ + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + Float32(c).ir_value(loc=loc, ip=ip), + ], + "fma.rn.f32 $0, $1, $2, $3;", + "=f,f,f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def silu(self, a: float | Float32, *, loc=None, ip=None) -> Float32: + """ + silu(a) = a * sigmoid(a) = a * (1 + tanh(a / 2)) / 2 = (0.5 * a) * tanh(0.5 * a) + (0.5 * a) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * a, MUFU.TANH, and FFMA. + """ + # return a / (1.0 + cute.arch.exp2(-a * math.log2(math.e))) + a_half = 0.5 * a + # return a_half * self.tanh(a_half) + a_half + return self.fma(a_half, self.tanh(a_half), a_half) + + @dsl_user_op + def relu(self, a: float | Float32, *, loc=None, ip=None) -> Float32: + return cute.arch.fmax(a, 0.0) + + @dsl_user_op + def relu_sq(self, a: float | Float32, *, loc=None, ip=None) -> Float32: + return a * cute.arch.fmax(a, 0.0) + + @dsl_user_op + def gelu(self, a: Float32, *, loc=None, ip=None) -> Float32: + # gelu(x) ≈ 0.5*x*(1 + tanh(√(2/π)*(x + 0.044715*x^3))) + c0 = const_expr(math.sqrt(2 / math.pi)) # √(2/π) + c1 = 0.044715 + a2 = a * a + # inner = √(2/π) * (x + 0.044715*x^3) + inner = c0 * self.fma(c1, a2 * a, a) + return 0.5 * a * self.fma(1.0, self.tanh(inner), 1.0) + + @dsl_user_op + def elem_pointer(self, x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + @dsl_user_op + def min_i32(self, a: int | Int32, b: int | Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), # return type + [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], + "min.s32 $0, $1, $2;", + "=r,r,r", # output, input constraints + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @cute.jit + def prefetch_gather_idx_for_A_when_vary_M( + self, mAIdx: cute.Tensor, M_offset: int, M_boundary: int, copy_elems_per_thr_load: int # m n k l + ) -> cute.Tensor: + assert const_expr(not self.compute_weight_gradient) + M, K = self.tile_M, self.tile_K + + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE + + stride_1_tile, other_tile = K, M + + threads_per_stride_1_dim = const_expr(stride_1_tile // copy_elems_per_thr_load) + num_other_dim_per_load = const_expr(self.num_load_A_threads // threads_per_stride_1_dim) + + num_other_dim_per_thread = const_expr(other_tile // num_other_dim_per_load) + tmAIdx = cute.make_rmem_tensor((num_other_dim_per_load,), dtype=mAIdx.element_type) + + for i in cutlass.range_constexpr(num_other_dim_per_thread): + other_dim_offset = const_expr(i * num_other_dim_per_load) + tidx // threads_per_stride_1_dim + + if other_dim_offset < M_boundary: + M_i = M_offset + other_dim_offset + tmAIdx[i] = mAIdx[M_i] + + return tmAIdx + + @cute.jit + def prefetch_scatter_idx_for_D_when_vary_M( + self, + mD: cute.Tensor, # unused, kept for symmetry + mDIdx: cute.Tensor, + D_r2g_thr_copy, + tcDgcD_flat_partition: cute.Tensor, + epi_tile_layout: cute.Layout, + epi_tile_num: int, + copy_elems_per_thr_load: int, # unused here, but fine to keep + tile_coord_mnkl: Tuple[int, int, None, int], # (block_M, block_N, _, batch) + MIdx_cur_group: int, + MIdx_next_group: int, + ) -> cute.Tensor: + # Same base M offset as store_D_scatter + block_M, block_N = tile_coord_mnkl[0], tile_coord_mnkl[1] + M_offset = block_M * const_expr(self.tile_M) + MIdx_cur_group + + tDcD0 = D_r2g_thr_copy.partition_D(tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(0)]) + num_load_per_thread = const_expr(cute.size(tDcD0, mode=[1])) + + tmDIdx = cute.make_rmem_tensor((epi_tile_num * num_load_per_thread,), dtype=mDIdx.element_type) + tmDIdx = cute.make_rmem_tensor((epi_tile_num * num_load_per_thread,), dtype=mDIdx.element_type) + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + tDcD_slice = D_r2g_thr_copy.partition_D( + tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(epi_idx)] + ) + + for i in cutlass.range_constexpr(num_load_per_thread): + # Same coordinate source as in store_D_scatter + MIdx_in_tile, _ = tDcD_slice[0, i, 0] + MIdx = M_offset + MIdx_in_tile + + if MIdx < MIdx_next_group: + tmDIdx[epi_idx * num_load_per_thread + i] = mDIdx[MIdx] + + return tmDIdx + + @cute.jit + def prefetch_gather_idx_for_A_when_vary_K( + self, + mAIdx: cute.Tensor, + sAIdx: cute.Tensor, + token_group_size: int, + K_offset: int, + ) -> cute.Tensor: + assert const_expr(self.compute_weight_gradient and self.is_A_gather) + + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE + + # !!! cannot be removed for correctness !!! + cute.arch.barrier(barrier_id=NamedBarrierGemm.Prolog, number_of_threads=self.num_load_A_threads) + + for i in cutlass.range_constexpr(cute.ceil_div(self.prefetch_token_idx_size, self.num_load_A_threads)): + offset = const_expr(i * self.num_load_A_threads) + tidx + kidx = K_offset + offset + + if kidx < token_group_size: + sAIdx[offset] = mAIdx[kidx] + + # !!! cannot be removed for correctness !!! + cute.arch.barrier(barrier_id=NamedBarrierGemm.Prolog, number_of_threads=self.num_load_A_threads) + + @cute.jit + def load_A_gather( + self, + mA: cute.Tensor, + tmAIdx: Optional[cute.Tensor], + sAIdx_prefetch: cute.Tensor, + M_offset: cutlass.Int32, + tAsA: cute.Tensor, + tApA: cute.Tensor, + A_g2s_thr_copy, + K_offset: cutlass.Int32, + token_group_size: cutlass.Int32, + copy_elems_per_thr_load: cutlass.Int32, + ): + M, K = self.tile_M, self.tile_K + + tidx, _, _ = cute.arch.thread_idx() + tidx = tidx - self.tma_warp_id * cute.arch.WARP_SIZE + + if const_expr(self.compute_weight_gradient): + stride_1_tile, other_tile = M, K + else: + stride_1_tile, other_tile = K, M + + threads_per_stride_1_dim = const_expr(stride_1_tile // copy_elems_per_thr_load) + num_other_dim_per_load = const_expr(self.num_load_A_threads // threads_per_stride_1_dim) + + K_offset_mod_smem_load = K_offset % const_expr(self.prefetch_token_idx_size) + for i in cutlass.range_constexpr(cute.ceil_div(other_tile, num_other_dim_per_load)): + stride_1_dim_offset = (tidx % threads_per_stride_1_dim) * copy_elems_per_thr_load + other_dim_offset = const_expr(i * num_other_dim_per_load) + tidx // threads_per_stride_1_dim + + if const_expr(self.compute_weight_gradient): + MIdx = M_offset + stride_1_dim_offset + KIdx_global = K_offset + other_dim_offset + + if KIdx_global < token_group_size and MIdx < mA.shape[0]: + KIdx = sAIdx_prefetch[K_offset_mod_smem_load + other_dim_offset] + # KIdx = mAIdx_mk[K_offset + other_dim_offset] + tPrAptr = self.elem_pointer(mA, (MIdx, KIdx)).align( + self.universal_copy_bits // copy_elems_per_thr_load + ) + mA_cur_copy = cute.make_tensor(tPrAptr, ((copy_elems_per_thr_load, 1), 1)) + + cute.copy(A_g2s_thr_copy, mA_cur_copy, tAsA[None, None, i]) + + else: + MIdx = tmAIdx[i] + KIdx = K_offset + stride_1_dim_offset + + tPrAptr = self.elem_pointer(mA, (MIdx, KIdx)).align( + self.universal_copy_bits // copy_elems_per_thr_load + ) + mA_cur_copy = cute.make_tensor(tPrAptr, ((copy_elems_per_thr_load, 1), 1)) + cute.copy(A_g2s_thr_copy, mA_cur_copy, tAsA[None, i, None], pred=tApA[None, i, None]) + + @cute.jit + def store_D_scatter( + self, + mD: cute.Tensor, # m, n, k, l + mDIdx: cute.Tensor, + tmDIdx: cute.Tensor, # assume to have same size as mD + tDrD: cute.Tensor, + tDcD_slice: cute.Tensor, # ((8, 1), 16, 1) + D_r2g_thr_copy, + epi_idx: cutlass.Int32, + copy_elems_per_thr_load: cutlass.Int32, + tile_coord_mnkl: Tuple[int, int, None, int], # m n k l + MIdx_cur_group: int, + MIdx_next_group: int, + ): + block_M, block_N = tile_coord_mnkl[0], tile_coord_mnkl[1] + + M_offset = block_M * const_expr(self.tile_M) + MIdx_cur_group + N_offset = block_N * const_expr(self.tile_N) + + num_load_per_thread = const_expr(cute.size(tDcD_slice, mode=[1])) + for i in cutlass.range_constexpr(num_load_per_thread): + MIdx_in_epi_tile, NIdx_in_epi_tile = tDcD_slice[0, i, 0] + + MIdx = M_offset + MIdx_in_epi_tile + NIdx = N_offset + NIdx_in_epi_tile + + if MIdx < MIdx_next_group and NIdx < mD.shape[1]: + if const_expr(self.is_scatter_idx_prefetched): + SIdx = tmDIdx[i + epi_idx * num_load_per_thread] + else: + SIdx = mDIdx[MIdx] # equivalent + tPDptr = self.elem_pointer(mD, (SIdx, NIdx)).align(self.universal_copy_bits // copy_elems_per_thr_load) + + mD_cur_copy = cute.make_tensor(tPDptr, ((copy_elems_per_thr_load, 1), 1)) + + cute.copy( + D_r2g_thr_copy, + tDrD[None, i, None], + mD_cur_copy, + ) + + @cute.jit + def fetch_scattered_S( + self, + tidx: int, + mS: cute.Tensor, + mS_scatter_idx: cute.Tensor, + sS_staged: cute.Tensor, + tile_coord_mnkl: Tuple[int, int, None, int], # m n k l + MIdx_cur_group: int, + MIdx_next_group: int, + ): + block_M = tile_coord_mnkl[0] + M = self.tile_M + + M_s = block_M * M + MIdx_cur_group + + for i in cutlass.range_constexpr(cute.ceil_div(M, self.num_epi_threads)): + sS_offset = const_expr(i * self.num_epi_threads) + tidx + M_i = M_s + sS_offset + + if M_i < MIdx_next_group and sS_offset < M: + sIdx = mS_scatter_idx[M_i] + sS_staged[sS_offset] = self.s_dtype(mS[sIdx]) + + @dsl_user_op + def prmt(self, a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [ + Int32(a).ir_value(loc=loc, ip=ip), + Int32(b).ir_value(loc=loc, ip=ip), + Int32(c).ir_value(loc=loc, ip=ip), + ], + "prmt.b32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + @dsl_user_op + def pack2x16_as_f32( + self, + a: Union[cutlass.BFloat16, cutlass.Float16], + b: Union[cutlass.BFloat16, cutlass.Float16], + *, + loc=None, + ip=None, + ) -> cutlass.Float32: + vec_src_type = T.bf16() if a.dtype == cutlass.BFloat16 else T.f16() + + vec_f16x2 = vector.from_elements(T.vector(2, vec_src_type), (a.ir_value(), b.ir_value()), loc=loc, ip=ip) + vec_f32x1 = vector.bitcast(T.vector(1, T.f32()), vec_f16x2) + return cutlass.Float32(vector.extract(vec_f32x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip)) + + @dsl_user_op + def unpack2x16_as_2xf32( + self, a: Float32, dtype: cutlass.Numeric, *, loc=None, ip=None + ) -> Tuple[cutlass.Float32, cutlass.Float32]: + + vec_dst_type = T.bf16() if dtype == cutlass.BFloat16 else T.f16() + + vec_f32x1 = vector.from_elements(T.vector(1, T.f32()), (a.ir_value(),), loc=loc, ip=ip) + vec_f16x2 = vector.bitcast(T.vector(2, vec_dst_type), vec_f32x1) + res0 = Float32(vector.extract(vec_f16x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip)) + res1 = Float32(vector.extract(vec_f16x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip)) + return res0, res1 + + @cute.jit + def permute_gated_Cregs_b16(self, t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = const_expr(mask << 8 | clamp) + + for i in cutlass.range_constexpr(cute.size(t_u32.shape) // 2): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = self.prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = self.prmt(upper0, lower0, selector_lower) + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mC: Optional[cute.Tensor], + mBias: Optional[cute.Tensor], + mD: cute.Tensor, + mY: Optional[cute.Tensor], + mS: Optional[cute.Tensor], + mDS_partial: Optional[cute.Tensor], + mMoffset: cute.Tensor, + mAIdx: Optional[cute.Tensor], + mDIdx: Optional[cute.Tensor], + mS_scatter_idx: Optional[cute.Tensor], + mA_tensormap: Optional[cute.Tensor], + mB_tensormap: Optional[cute.Tensor], + mC_tensormap: Optional[cute.Tensor], + mD_tensormap: Optional[cute.Tensor], + mY_tensormap: Optional[cute.Tensor], + mTileCount_semaphore: Optional[cute.Pointer], + mBatchIdx_schedule_order: Optional[cute.Tensor], + max_active_clusters: Int32, + stream: cuda.CUstream, + ): + # setup static attributes before smem/grid/tma computation + self.a_dtype = mA.element_type + self.b_dtype = mB.element_type + self.c_dtype = mC.element_type if mC is not None else None + self.d_dtype = mD.element_type + self.s_dtype = cutlass.Float32 + + self.a_layout = utils.LayoutEnum.from_tensor(mA) + self.b_layout = utils.LayoutEnum.from_tensor(mB) + self.c_layout = cutlass.utils.LayoutEnum.from_tensor(mC) if mC is not None else None + self.d_layout = utils.LayoutEnum.from_tensor(mD) + + self.use_bias = const_expr(mBias is not None) + if const_expr(self.use_bias): + assert not self.compute_weight_gradient, "Bias addition is not supported when computing weight gradients" + self.bias_dtype = mBias.element_type + self.bias_layout = utils.LayoutEnum.from_tensor(mBias) + else: + self.bias_dtype = None + self.bias_layout = None + + if const_expr(self.need_adhoc_epilogue_store): + self.y_dtype = mY.element_type + self.y_layout = utils.LayoutEnum.from_tensor(mY) + else: + self.y_layout = self.y_dtype = None + + if const_expr(mC is not None): + assert self.acc_dtype == cutlass.Float32 + assert self.need_epilogue_load, "Set need_epilogue_load = True or set mC = None" + + if const_expr(mS is not None): + assert self.compute_dz_and_partial_ds_and_y1s, "Set compute_dz_and_partial_ds = True or set mS = None" + assert mDS_partial is not None + assert mY is not None + + if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}") + if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError("a_dtype should be float16 or float8") + + if const_expr(mBatchIdx_schedule_order is not None): + assert ( + mTileCount_semaphore is None + ), "we only define a static scheduling order for static persistent tile scheduler" + + self.tensormap_management_bytes = ( + self.tensormap_bytes_total + if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.SMEM) + else 0 + ) + self.tensor_memory_management_bytes + + self._setup_attributes() + + tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.tile_shape_mnk[1] // self.atom_layout_mnk[1]), + ) + if const_expr(self.atom_layout_mnk[1] > 1): + # If N dimension is split among 2 WGs, we need to permute the N dimension so + # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32) + # containing accumulators that are next to each other in the N dimension. + # Without permutation WG0 would write to epi smem of size (64, 16) and + # WG1 would write to a separate epi smem of size (64, 16) that's far away. + atom_n = self.atom_layout_mnk[1] + permutation_n = cute.make_ordered_layout( + (8, self.tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1) + ) + tiled_mma = cute.make_tiled_mma( + cute.make_mma_atom(tiled_mma.op), + self.atom_layout_mnk, + permutation_mnk=(None, permutation_n, None), + ) + + if const_expr(self.is_A_gather): + A_tiled_copy = self._make_tiled_copy_2D( + mA, + self.tile_M, + self.tile_K, + self.a_layout == cutlass.utils.LayoutEnum.ROW_MAJOR, + self.num_load_A_threads, + self.universal_copy_bits, + is_g2s=True, + ) + tma_atom_a = tma_tensor_a = None + else: + A_tiled_copy = None + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + mA, + self.a_smem_layout_staged, + (self.tile_shape_mnk[0], self.tile_shape_mnk[2]), + self.cluster_shape_mnk[1], + ) + + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + mB, + self.b_smem_layout_staged, + (self.tile_shape_mnk[1], self.tile_shape_mnk[2]), + self.cluster_shape_mnk[0], + ) + + if const_expr(self.need_epilogue_load): + tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors( + mC, self.c_epi_smem_layout_staged, self.c_epi_tile, store_or_load="load" + ) + else: + tma_atom_c, tma_tensor_c = None, None + + atom_bias = None + if const_expr(self.use_bias): + atom_bias = cute.make_copy_atom( + cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.ALWAYS), + mBias.element_type, + num_bits_per_copy=self.universal_copy_bits, + ) + + thread_per_row = self.tile_shape_mnk[1] // (self.universal_copy_bits // mBias.element_type.width) + thread_layout = cute.make_ordered_layout((1, thread_per_row), order=(1, 0)) + value_layout = cute.make_layout((1, self.universal_copy_bits // mBias.element_type.width)) + atom_bias = cute.make_tiled_copy_tv(atom_bias, thread_layout, value_layout) + + if const_expr(self.d_epi_smem_layout_staged is not None): + tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( + mD, self.d_epi_smem_layout_staged, self.d_epi_tile, store_or_load="store" + ) + else: + tma_atom_d, tma_tensor_d = None, None + + if const_expr(mDIdx is not None): + copy_elems = self.universal_copy_bits // mD.element_type.width + assert self.num_epi_threads % (self.d_epi_tile[1] // copy_elems) == 0 + + D_tiled_copy = self._make_tiled_copy_2D( + mD, + self.d_epi_tile[0], + self.d_epi_tile[1], + self.d_layout.is_n_major_c(), + self.num_epi_threads, + self.universal_copy_bits, + is_g2s=False, + ) + else: + D_tiled_copy = None + + if const_expr(self.need_adhoc_epilogue_store): + tma_atom_y, tma_tensor_y = self._make_tma_epi_atoms_and_tensors( + mY, self.y_epi_smem_layout_staged, self.y_epi_tile, store_or_load="store" + ) + else: + tma_atom_y, tma_tensor_y = None, None + + if const_expr(self.compute_weight_gradient): + assert const_expr( + not self.compute_dz_and_partial_ds_and_y1s + ), "weight grad computation conflicts with activation grad computation" + + problem_shape_ntile_mnl = cute.ceil_div(mD.shape[:2], self.tile_shape_mnk[:2]) + (mD.shape[2],) + TileScheduler = SonicMoETileScheduler + tile_sched_args = TileSchedulerArguments( + problem_shape_ntile_mnl=problem_shape_ntile_mnl, + raster_order=self.raster_order, + group_size=self.L2_group_size, + cluster_shape_mnk=self.cluster_shape_mnk, + is_persistent=self.is_persistent, + tile_count_semaphore=mTileCount_semaphore, + batch_idx_permute=mBatchIdx_schedule_order, + ) + else: + problem_shape_ntile_mnl = ( + None, + cute.ceil_div(mD.shape[1], self.tile_shape_mnk[1]), + mMoffset.shape[0] - 1, + ) + TileScheduler = SonicMoEVarlenMTileScheduler + tile_sched_args = VarlenMTileSchedulerArguments( + problem_shape_ntile_mnl=problem_shape_ntile_mnl, + total_m=mD.shape[0], + cu_seqlens_m=mMoffset, + raster_order=self.raster_order, + group_size=self.L2_group_size, + tile_shape_mn=self.tile_shape_mnk[:2], + cluster_shape_mnk=self.cluster_shape_mnk, + is_persistent=self.is_persistent, + tile_count_semaphore=mTileCount_semaphore, + ) + + tile_sched_params = TileScheduler.to_underlying_arguments(tile_sched_args) + grid = TileScheduler.get_grid_shape(tile_sched_params, max_active_clusters) + + c_epi_smem_size = cute.cosize(self.c_epi_smem_layout_staged) if const_expr(self.need_epilogue_load) else 0 + bias_epi_smem_size = cute.cosize(self.bias_epi_smem_layout_staged) if const_expr(self.use_bias) else 0 + d_epi_smem_size = ( + cute.cosize(self.d_epi_smem_layout_staged) + if const_expr(self.is_persistent and (self.d_epi_stage > 0)) + else 0 + ) + y_epi_smem_size = ( + cute.cosize(self.y_epi_smem_layout_staged) + if const_expr(self.need_adhoc_epilogue_store) and self.is_persistent + else 0 + ) + s_epi_smem_size = ( + cute.cosize(self.s_epi_smem_layout_staged) if const_expr(self.compute_dz_and_partial_ds_and_y1s) else 0 + ) + + @cute.struct + class SharedStorage: + mainloop_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] + tensormap_buffer: cute.struct.Align[cute.struct.MemRange[cutlass.Int64, self.num_tensormaps], 64] + sD: cute.struct.Align[ + cute.struct.MemRange[self.d_dtype, d_epi_smem_size], + self.buffer_align_bytes, + ] + sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] + tile_count: cute.struct.MemRange[cutlass.Int32, self.sched_stage] + if const_expr(self.need_epilogue_load): + sC: cute.struct.Align[ + cute.struct.MemRange[self.c_dtype, c_epi_smem_size], + self.buffer_align_bytes, + ] + epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.c_epi_stage * 2] + + if const_expr(self.use_bias): + sBias: cute.struct.Align[ + cute.struct.MemRange[self.bias_dtype, bias_epi_smem_size], + self.buffer_align_bytes, + ] + + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + sS: cute.struct.Align[ + cute.struct.MemRange[self.s_dtype, s_epi_smem_size], + self.buffer_align_bytes, + ] + + if const_expr(self.need_adhoc_epilogue_store): + sY: cute.struct.Align[ + cute.struct.MemRange[self.y_dtype, y_epi_smem_size], + self.buffer_align_bytes, + ] + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)], + self.buffer_align_bytes, + ] + if const_expr(self.compute_weight_gradient and self.is_A_gather): + sAIdx_prefetch: cute.struct.Align[ + cute.struct.MemRange[self.index_dtype, self.prefetch_token_idx_size], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + allocated_smem_size = self.shared_storage.size_in_bytes() + self.tensormap_management_bytes + # Launch the kernel synchronously + self.kernel( + A_tiled_copy, + mA, + tma_atom_a, + tma_tensor_a, + mB, + tma_atom_b, + tma_tensor_b, + tma_atom_c, + tma_tensor_c, + mC, + mBias, + atom_bias, + D_tiled_copy, + tma_atom_d, + tma_tensor_d, + mD, + tma_atom_y, + tma_tensor_y, + mY, + mS, + mDS_partial, + mMoffset, + mAIdx, + mDIdx, + mS_scatter_idx, + mA_tensormap, + mB_tensormap, + mC_tensormap, + mD_tensormap, + mY_tensormap, + tiled_mma, + self.cta_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.prefetch_AIdx_smem_layout_staged, + self.c_epi_smem_layout_staged, + self.bias_epi_smem_layout_staged, + self.d_epi_smem_layout_staged, + self.y_epi_smem_layout_staged, + self.s_epi_smem_layout_staged, + tile_sched_params, + TileScheduler, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + smem=allocated_smem_size, + stream=stream, + min_blocks_per_mp=1, + ) + + @cute.jit + def update_tma_desc_ptr( + self, + mTensor: cute.Tensor, + tma_atom: cute.CopyAtom, + tensormap_manager: TensorMapManagerSm90, + tensormap_ptr: cute.Pointer, + token_start: Int32, + token_group_size: Int32, + is_tma_warp: bool, + tensormap_smem_ptr: Optional[cute.Pointer] = None, + address_space: cute.AddressSpace = cute.AddressSpace.generic, + ) -> cute.Pointer: + if const_expr(self.compute_weight_gradient): + tensor_shape = (mTensor.shape[0], token_group_size) + start_ptr = (mTensor.iterator + token_start * mTensor.stride[1]).toint() + else: + tensor_shape = (token_group_size, mTensor.shape[1]) + start_ptr = (mTensor.iterator + token_start * mTensor.stride[0]).toint() + + tensor_gmem_ptr = cute.make_ptr( + mTensor.element_type, + start_ptr, + cute.AddressSpace.gmem, + assumed_align=16, + ) + real_tensor = cute.make_tensor( + tensor_gmem_ptr, + cute.make_layout(tensor_shape, stride=mTensor.stride), + ) + if const_expr(self.tensormap_update_mode == cutlass.utils.TensorMapUpdateMode.GMEM): + tensormap_manager.update_tensormap( + (real_tensor,), + (tma_atom,), + tensormap_gmem_ptr=(tensormap_ptr,), + is_manager_warp=is_tma_warp, + tensormap_smem_ptr=None, + ) + else: + assert tensormap_smem_ptr is not None + tensormap_manager.update_tensormap( + (real_tensor,), + (tma_atom,), + tensormap_gmem_ptr=(tensormap_ptr,), + is_manager_warp=is_tma_warp, + tensormap_smem_ptr=(tensormap_smem_ptr,), + ) + + tensormap_manager.fence_tensormap_update(tensormap_ptr) + + @cute.jit + def align_tensormap_smem_ptr(self, base_ptr: cute.Pointer): + return cute.make_ptr( + cutlass.Int64, + base_ptr.toint(), + cute.AddressSpace.smem, + assumed_align=128, + ) + + @cute.jit + def allocate_new_tensormap_smem_ptr(self, tensormap_smem_ptr: cute.Pointer): + return self.align_tensormap_smem_ptr(tensormap_smem_ptr + self.bytes_per_tensormap // 8) + + @cute.jit + def swiglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: + half_g = 0.5 * g + tanh_half_g = self.tanh(half_g) + + sig_g = self.fma(0.5, tanh_half_g, 0.5) + sig_n_g = 1 - sig_g + + silu_g = self.fma(half_g, tanh_half_g, half_g) + + dg = dy1 * (u * self.fma(silu_g, sig_n_g, sig_g)) + du = dy1 * silu_g + + swiglu_output = silu_g * u + return dg, du, swiglu_output + + @cute.jit + def reglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: + relu_g = cute.arch.fmax(0.0, g) + + relu_prime_g = 1.0 + if g < Float32(0.0): + relu_prime_g = 0.0 # derivative of ReLU + + dg = dy1 * u * relu_prime_g + du = dy1 * relu_g + + reglu_output = u * relu_g + return dg, du, reglu_output + + @cute.jit + def geglu_derivative(self, g: Float32, u: Float32, dy1: Float32) -> Tuple[Float32, Float32, Float32]: + # gelu(g) = 0.5 * g * (1 + tanh(sqrt(2/pi) * (g + 0.044715*g^3))) + sqrt_2_over_pi = const_expr(math.sqrt(2 / math.pi)) + c = 0.044715 + + g2 = g * g + g3 = g2 * g + + # t = sqrt(2/pi) * (g + c*g^3) + t = sqrt_2_over_pi * self.fma(c, g3, g) + + tanh_t = self.tanh(t) + one_plus_th = 1.0 + tanh_t # 1 + tanh(t) + gelu_g = 0.5 * g * one_plus_th # gelu(g) + + # d th / d g = (1 - tanh(t)^2) * sqrt(2/pi) * (1 + 3c g^2) + sech2 = self.fma(-tanh_t, tanh_t, 1.0) + dt_dg = sech2 * sqrt_2_over_pi * self.fma(3.0 * c, g2, 1.0) + + # d gelu / d g = 0.5*(1 + tanh(t)) + 0.5*g*dt_dg + gelu_prime = 0.5 * self.fma(g, dt_dg, one_plus_th) + + # Chain rule for y = gelu(g) * u + dg = dy1 * u * gelu_prime + du = dy1 * gelu_g + + geglu_output = u * gelu_g + return dg, du, geglu_output + + @cute.jit + def silu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: + half_x = 0.5 * x + tanh_half_x = self.tanh(half_x) + + sig_x = self.fma(0.5, tanh_half_x, 0.5) + sig_n_x = 1 - sig_x + + silu_x = self.fma(half_x, tanh_half_x, half_x) + dx = dy1 * self.fma(silu_x, sig_n_x, sig_x) + + return dx, silu_x + + @cute.jit + def relu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: + relu_x = cute.arch.fmax(0.0, x) + + relu_prime_x = 1.0 + if x < Float32(0.0): + relu_prime_x = 0.0 # derivative of ReLU + + dx = dy1 * relu_prime_x + return dx, relu_x + + @cute.jit + def relu_sq_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: + relu_x = cute.arch.fmax(x, 0.0) + relu_sq_output = relu_x * x + dx = dy1 * (2.0 * relu_x) + return dx, relu_sq_output + + @cute.jit + def gelu_derivative(self, x: Float32, dy1: Float32) -> Tuple[Float32, Float32]: + # gelu(g) = 0.5 * g * (1 + tanh(sqrt(2/pi) * (g + 0.044715*g^3))) + sqrt_2_over_pi = const_expr(math.sqrt(2 / math.pi)) + c = 0.044715 + + x2 = x * x + x3 = x2 * x + + # t = sqrt(2/pi) * (g + c*g^3) + t = sqrt_2_over_pi * self.fma(c, x3, x) + + tanh_t = self.tanh(t) + one_plus_tanh_t = 1.0 + tanh_t # 1 + tanh(t) + gelu_x = 0.5 * x * one_plus_tanh_t + + # d th / d g = (1 - tanh(t)^2) * sqrt(2/pi) * (1 + 3c g^2) + sech2 = self.fma(-tanh_t, tanh_t, 1.0) + dt_dg = sech2 * sqrt_2_over_pi * self.fma(3.0 * c, x2, 1.0) + + # d gelu / d g = 0.5*(1 + tanh(t)) + 0.5*g*dt_dg + gelu_prime = 0.5 * self.fma(x, dt_dg, one_plus_tanh_t) + + # Chain rule for y = gelu(g) * u + dx = dy1 * gelu_prime + + return dx, gelu_x + + @cute.jit + def compute_activation(self, tRS_rD, tRS_rY): + if const_expr(self.is_glu): + # tRS_sY: (((2, 4), 1), 1, 1, (1, 4)) + # (((2, 4), 1), 1, 1) + if const_expr(self.compute_swiglu): + act_func = self.silu + elif const_expr(self.compute_reglu): + act_func = self.relu + elif const_expr(self.compute_geglu): + act_func = self.gelu + else: + raise NotImplementedError() + + for i in cutlass.range_constexpr(cute.size(tRS_rD) // 2): + tRS_rY[i] = (act_func(tRS_rD[const_expr(2 * i)]) * tRS_rD[const_expr(2 * i + 1)]).to(self.y_dtype) + + self.permute_gated_Cregs_b16(tRS_rY) + + elif const_expr(self.is_normal_act): + assert cute.size(tRS_rD) == cute.size(tRS_rY) + if const_expr(self.compute_relu_sq): + act_func = self.relu_sq + elif const_expr(self.compute_relu): + act_func = self.relu + elif const_expr(self.compute_silu): + act_func = self.silu + elif const_expr(self.compute_gelu): + act_func = self.gelu + else: + raise NotImplementedError() + + for i in cutlass.range_constexpr(cute.size(tRS_rD)): + tRS_rY[i] = act_func(tRS_rD[i]).to(self.y_dtype) + + else: + raise NotImplementedError() + + @cute.jit + def compute_backward_activation(self, tRS_rAcc, sS, tRS_rcD, tRS_rC, tRS_rD, tRS_rD_out, tRS_rY, epi_idx: Int32): + if const_expr(self.is_glu): + # if we compute glu activation, + # we will assume the incoming C dtype as FP32, and we will output final result in FP32 (decompress to BF16 in caller side) + + if const_expr(self.compute_swiglu): + bwd_act_func = self.swiglu_derivative + elif const_expr(self.compute_reglu): + bwd_act_func = self.reglu_derivative + elif const_expr(self.compute_geglu): + bwd_act_func = self.geglu_derivative + else: + raise NotImplementedError() + + for i in cutlass.range_constexpr(cute.size(tRS_rD)): + g, u = self.unpack2x16_as_2xf32(tRS_rC[i], self.a_dtype) + dy = tRS_rD[i] + dg, du, fwd_output = bwd_act_func(g, u, dy) + tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + i)] = dy * fwd_output + s = sS[tRS_rcD[i]] + tRS_rD_out[i] = self.pack2x16_as_f32(self.a_dtype(dg * s), self.a_dtype(du * s)) + tRS_rY[i] = self.y_dtype(fwd_output * s) + + elif const_expr(self.is_normal_act): + if const_expr(self.compute_relu_sq): + bwd_act_func = self.relu_sq_derivative + elif const_expr(self.compute_relu): + bwd_act_func = self.relu_derivative + elif const_expr(self.compute_gelu): + bwd_act_func = self.gelu_derivative + elif const_expr(self.compute_silu): + bwd_act_func = self.silu_derivative + else: + raise NotImplementedError() + + for i in cutlass.range_constexpr(cute.size(tRS_rD)): + z = tRS_rC[i] + dy = tRS_rD[i] + dz, fwd_output = bwd_act_func(z, dy) + tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + i)] = dy * fwd_output + s = sS[tRS_rcD[i]] + tRS_rD_out[i] = self.a_dtype(dz * s) + tRS_rY[i] = self.y_dtype(fwd_output * s) + + else: + raise NotImplementedError() + + # GPU device kernel + @cute.kernel + def kernel( + self, + A_tiled_copy: Optional[cute.TiledCopy], + mA_mkl: cute.Tensor, + tma_atom_a: Optional[cute.CopyAtom], + mA_mkl_tma: Optional[cute.Tensor], + mB_nkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl_tma: cute.Tensor, + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl_tma: Optional[cute.Tensor], + mC_mnl: cute.Tensor, + mBias_nl: Optional[cute.Tensor], + cpasync_atom_bias: Optional[cute.CopyAtom], + D_tiled_copy: Optional[cute.TiledCopy], + tma_atom_d: Optional[cute.CopyAtom], + mD_mnl_tma: cute.Tensor, + mD_mnl: cute.Tensor, + tma_atom_y: Optional[cute.CopyAtom], + mY_mnl_tma: Optional[cute.Tensor], + mY_mnl: Optional[cute.Tensor], + mS_ml: Optional[cute.Tensor], + mDS_partial: Optional[cute.Tensor], + mTokenoffset: cute.Tensor, + mAIdx_mkl: cute.Tensor, + mDIdx_mnl: Optional[cute.Tensor], + mS_scatter_idx: Optional[cute.Tensor], + mA_tensormap: Optional[cute.Tensor], + mB_tensormap: Optional[cute.Tensor], + mC_tensormap: Optional[cute.Tensor], + mD_tensormap: cute.Tensor, + mY_tensormap: Optional[cute.Tensor], + tiled_mma: cute.TiledMma, + cta_layout_mnk: cute.Layout, + a_smem_layout_staged: cute.ComposedLayout, + b_smem_layout_staged: cute.ComposedLayout, + prefetch_AIdx_smem_layout_staged: Optional[cute.Layout], + c_epi_smem_layout_staged: Optional[cute.ComposedLayout], + bias_epi_smem_layout_staged: Optional[cute.Layout], + d_epi_smem_layout_staged: cute.ComposedLayout, + y_epi_smem_layout_staged: Optional[cute.ComposedLayout], + s_epi_smem_layout_staged: Optional[cute.Layout], + tile_sched_params: ParamsBase, + TileScheduler: cutlass.Constexpr[Callable], + ): + tidx, _, _ = cute.arch.thread_idx() + # Assume: M: 2048, N: 512, K: 1024, L: 4 + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + if warp_idx == self.tma_warp_id: + if const_expr(not self.is_A_gather): + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + if const_expr(not self.inference_mode): + cpasync.prefetch_descriptor(tma_atom_d) + if const_expr(self.need_adhoc_epilogue_store): + cpasync.prefetch_descriptor(tma_atom_y) + if const_expr(tma_atom_c is not None): + cpasync.prefetch_descriptor(tma_atom_c) + + A_thr_copy_elems = self.universal_copy_bits // mA_mkl.element_type.width + + a_smem_layout = cute.slice_(a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(b_smem_layout_staged, (None, None, 0)) + if const_expr(self.is_A_gather): + tma_copy_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout) + else: + tma_copy_bytes = cute.size_in_bytes(self.a_dtype, a_smem_layout) + cute.size_in_bytes( + self.b_dtype, b_smem_layout + ) + + smem = cutlass.utils.SmemAllocator() + shared_storage = smem.allocate(self.shared_storage) + + # Threads/warps participating in this pipeline + if const_expr(self.is_A_gather): + mainloop_pipeline_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, 1 + self.num_load_A_threads + ) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_b + pipeline_class = PipelineTmaCpAsync + else: + # Threads/warps participating in this pipeline + mainloop_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + # Each warp will constribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + pipeline_class = pipeline.PipelineTmaAsync + + consumer_arrive_cnt = mcast_size * (self.num_mma_threads // cute.arch.WARP_SIZE) + mainloop_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) + cta_layout_vmnk = cute.make_layout((1, *cta_layout_mnk.shape)) + mainloop_pipeline = pipeline_class.create( + barrier_storage=shared_storage.mainloop_pipeline_array_ptr.data_ptr(), + num_stages=self.ab_stage, + producer_group=mainloop_pipeline_producer_group, + consumer_group=mainloop_pipeline_consumer_group, + tx_count=tma_copy_bytes, + cta_layout_vmnk=cta_layout_vmnk, + ) + + if const_expr(self.need_epilogue_load): + # Threads/warps participating in this pipeline + epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + # Each warp will contribute 1 to the arrive count + consumer_arrive_cnt = self.num_epi_threads // cute.arch.WARP_SIZE + epi_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) + c_smem_layout = cute.slice_(c_epi_smem_layout_staged, (None, None, 0)) + tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout) + epi_pipeline = pipeline.PipelineTmaAsync.create( + barrier_storage=shared_storage.epi_pipeline_array_ptr.data_ptr(), + num_stages=self.c_epi_stage, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + tx_count=tma_copy_c_bytes, + ) + else: + epi_pipeline = None + + sA = shared_storage.sA.get_tensor(a_smem_layout_staged.outer, swizzle=a_smem_layout_staged.inner) + sB = shared_storage.sB.get_tensor(b_smem_layout_staged.outer, swizzle=b_smem_layout_staged.inner) + + if const_expr(not self.is_persistent): + sD_ptr = cute.recast_ptr(sA.iterator, d_epi_smem_layout_staged.inner, dtype=self.d_dtype) + sD = cute.make_tensor(sD_ptr, d_epi_smem_layout_staged.outer) + + if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): + next_ptr = sD_ptr + else: + next_ptr = sD_ptr + cute.cosize(d_epi_smem_layout_staged) + + if const_expr(self.need_adhoc_epilogue_store): + sY_ptr = cute.recast_ptr(next_ptr, y_epi_smem_layout_staged.inner, dtype=self.y_dtype) + sY = cute.make_tensor(sY_ptr, y_epi_smem_layout_staged.outer) + next_ptr = sY_ptr + cute.cosize(y_epi_smem_layout_staged) + else: + sY = None + + else: + if const_expr(self.need_adhoc_epilogue_store): + sY = shared_storage.sY.get_tensor( + y_epi_smem_layout_staged.outer, swizzle=y_epi_smem_layout_staged.inner + ) + if const_expr(self.inference_mode): + sD = cute.make_tensor( + cute.recast_ptr(sY.iterator, d_epi_smem_layout_staged.inner, dtype=self.d_dtype), + d_epi_smem_layout_staged.outer, + ) + else: + sD = shared_storage.sD.get_tensor( + d_epi_smem_layout_staged.outer, swizzle=d_epi_smem_layout_staged.inner + ) + else: + sY = None + sD = shared_storage.sD.get_tensor( + d_epi_smem_layout_staged.outer, swizzle=d_epi_smem_layout_staged.inner + ) + + if const_expr(self.compute_weight_gradient and self.is_A_gather): + sAIdx_prefetch = shared_storage.sAIdx_prefetch.get_tensor(prefetch_AIdx_smem_layout_staged) + else: + sAIdx_prefetch = None + + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + sS = shared_storage.sS.get_tensor(s_epi_smem_layout_staged, dtype=self.s_dtype) + else: + sS = None + + if const_expr(self.need_epilogue_load): + sC = shared_storage.sC.get_tensor(c_epi_smem_layout_staged.outer, swizzle=c_epi_smem_layout_staged.inner) + else: + sC = None + + if const_expr(self.use_bias): + sBias = shared_storage.sBias.get_tensor(bias_epi_smem_layout_staged) + else: + sBias = None + + sched_pipeline = None + tile_count = None + if const_expr(tile_sched_params.tile_count_semaphore is not None): + sched_pipeline = self.make_sched_pipeline( + cta_layout_mnk, + sched_pipeline_mbar_ptr=shared_storage.sched_pipeline_array_ptr.data_ptr(), + ) + tile_count = shared_storage.tile_count.get_tensor((self.sched_stage,)) + + a_tensormap_smem_ptr = b_tensormap_smem_ptr = c_tensormap_smem_ptr = d_tensormap_smem_ptr = ( + y_tensormap_smem_ptr + ) = None + if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): + tensormap_smem_ptr = shared_storage.tensormap_buffer.data_ptr() + tensormap_smem_ptr = self.align_tensormap_smem_ptr(tensormap_smem_ptr) + + if const_expr(self.compute_weight_gradient): + if const_expr(not self.is_A_gather): + tensormap_smem_ptr = a_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + tensormap_smem_ptr = b_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr(tensormap_smem_ptr) + else: + if const_expr(self.pingpong): + if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + tensormap_smem_ptr = d0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + tensormap_smem_ptr = d1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + d_tensormap_smem_ptr = d0_tensormap_smem_ptr if warp_idx // 4 == 0 else d1_tensormap_smem_ptr + + if const_expr(self.need_adhoc_epilogue_store): + tensormap_smem_ptr = y0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + tensormap_smem_ptr = y1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + y_tensormap_smem_ptr = y0_tensormap_smem_ptr if warp_idx // 4 == 0 else y1_tensormap_smem_ptr + + if const_expr(self.need_epilogue_load): + tensormap_smem_ptr = c0_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + tensormap_smem_ptr = c1_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + c_tensormap_smem_ptr = c0_tensormap_smem_ptr if warp_idx // 4 == 0 else c1_tensormap_smem_ptr + + else: + if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + tensormap_smem_ptr = d_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + if const_expr(self.need_adhoc_epilogue_store): + tensormap_smem_ptr = y_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + + if const_expr(self.need_epilogue_load): + tensormap_smem_ptr = c_tensormap_smem_ptr = self.allocate_new_tensormap_smem_ptr( + tensormap_smem_ptr + ) + + grid_dim = cute.arch.grid_dim() + bid = cute.arch.block_idx() + tensormap_workspace_idx = bid[2] * grid_dim[1] * grid_dim[0] + bid[1] * grid_dim[0] + bid[0] + tensormap_manager = TensorMapManagerSm90(self.tensormap_update_mode, self.bytes_per_tensormap) + + if const_expr(self.compute_weight_gradient): + if const_expr(not self.is_A_gather and (mA_tensormap is not None)): + a_tensormap_ptr = tensormap_manager.get_tensormap_ptr( + mA_tensormap[tensormap_workspace_idx, None].iterator + ) + else: + a_tensormap_ptr = None + + if const_expr(mB_tensormap is not None): + b_tensormap_ptr = tensormap_manager.get_tensormap_ptr( + mB_tensormap[tensormap_workspace_idx, None].iterator + ) + else: + b_tensormap_ptr = None + + if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): + if const_expr(not self.is_A_gather): + tensormap_a_init_ptr = a_tensormap_smem_ptr + tensormap_b_init_ptr = b_tensormap_smem_ptr + else: + if const_expr(not self.is_A_gather): + tensormap_a_init_ptr = b_tensormap_ptr + tensormap_b_init_ptr = b_tensormap_ptr + + else: + if const_expr(self.pingpong): + tensormap_workspace_idx = tensormap_workspace_idx * 2 + warp_idx // 4 + + if const_expr( + (mD_tensormap is not None) and (not (self.inference_mode and self.need_adhoc_epilogue_store)) + ): + d_tensormap_ptr = tensormap_manager.get_tensormap_ptr( + mD_tensormap[tensormap_workspace_idx, None].iterator + ) + else: + d_tensormap_ptr = None + + if const_expr(self.need_adhoc_epilogue_store): + assert mY_tensormap is not None + y_tensormap_ptr = tensormap_manager.get_tensormap_ptr( + mY_tensormap[tensormap_workspace_idx, None].iterator + ) + else: + y_tensormap_ptr = None + + if const_expr(self.need_epilogue_load): + assert mC_tensormap is not None + c_tensormap_ptr = tensormap_manager.get_tensormap_ptr( + mC_tensormap[tensormap_workspace_idx, None].iterator + ) + else: + c_tensormap_ptr = None + + if cutlass.const_expr(self.tensormap_update_mode == utils.TensorMapUpdateMode.SMEM): + tensormap_d_init_ptr = d_tensormap_smem_ptr + tensormap_y_init_ptr = y_tensormap_smem_ptr + tensormap_c_init_ptr = c_tensormap_smem_ptr + else: + tensormap_d_init_ptr = d_tensormap_ptr + tensormap_y_init_ptr = y_tensormap_ptr + tensormap_c_init_ptr = c_tensormap_ptr + + TileSchedulerCls = partial(TileScheduler.create, tile_sched_params, tile_count, sched_pipeline) + + k_tile_cnt = cute.ceil_div(cute.size(mA_mkl.shape[1]), self.tile_shape_mnk[2]) + c_tile_cnt = ( + cute.size(cute.ceil_div(self.tile_shape_mnk[:2], self.c_epi_tile)) + if const_expr(self.need_epilogue_load) + else Int32(0) + ) + + if warp_idx >= self.tma_warp_id: + cute.arch.setmaxregister_decrease(self.num_regs_load) + cute.arch.setmaxregister_decrease(self.num_regs_load) + + prolog_loading_warp_ids = ( + [const_expr(self.tma_warp_id + i) for i in range(self.num_load_A_threads // cute.arch.WARP_SIZE)] + if const_expr(self.is_A_gather) + else [const_expr(self.tma_warp_id)] + ) + + if warp_idx in prolog_loading_warp_ids: + is_tma_warp = cutlass.Boolean(warp_idx == self.tma_warp_id) + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + cluster_coord_mnk = cta_layout_mnk.get_flat_coord(cta_rank_in_cluster) + + a_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=1) + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = cute.make_layout_image_mask(cta_layout_mnk, cluster_coord_mnk, mode=0) + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + + mainloop_producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.ab_stage) + is_scheduler_warp = warp_idx == self.tma_warp_id + if const_expr(cute.size(cta_layout_mnk) > 1): + is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 + + tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.initial_work_tile_info() + + last_batch_idx = cutlass.Int32(-1) + token_group_size = cutlass.Int32(0) + + mcA_mkl = cute.make_identity_tensor((mA_mkl.shape[0], mA_mkl.shape[1])) + + TIdx_cur_group = TIdx_next_group = cutlass.Int32(0) + mAIdx_mk = cute.domain_offset((0,), mAIdx_mkl) + + gA_mk = None + A_g2s_thr_copy = None + if const_expr(self.is_A_gather): + A_g2s_thr_copy = A_tiled_copy.get_slice(tidx - self.tma_warp_id * cute.arch.WARP_SIZE) + gA_mk = cute.local_tile(mA_mkl, (self.tile_M, self.tile_K), (0, None)) + tAgA = A_g2s_thr_copy.partition_S(gA_mk) + + if const_expr(self.compute_weight_gradient): + if const_expr(not self.is_A_gather): + tensormap_manager.init_tensormap_from_atom( + tma_atom_a, + tensormap_a_init_ptr, + is_manager_warp=is_tma_warp, + ) + tensormap_manager.init_tensormap_from_atom( + tma_atom_b, + tensormap_b_init_ptr, + is_manager_warp=is_tma_warp, + ) + tensormap_manager.fence_tensormap_initialization() + + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + # (bM, bK, RestK) + if batch_idx != last_batch_idx: + TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( + mTokenoffset[batch_idx] + ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) + token_group_size = TIdx_next_group - TIdx_cur_group + + if const_expr(self.is_A_gather): + if const_expr(self.compute_weight_gradient): + mcA_mkl = cute.make_identity_tensor((mA_mkl.shape[0], token_group_size)) + else: + mcA_mkl = cute.make_identity_tensor((token_group_size, mA_mkl.shape[1])) + + mAIdx_mk = cute.domain_offset((TIdx_cur_group,), mAIdx_mkl) + + if const_expr(self.compute_weight_gradient): + if const_expr(not self.is_A_gather): + assert a_tensormap_ptr is not None + self.update_tma_desc_ptr( + mA_mkl, + tma_atom_a, + tensormap_manager, + a_tensormap_ptr, + TIdx_cur_group, + token_group_size, + is_tma_warp, + tensormap_smem_ptr=a_tensormap_smem_ptr, + ) + if const_expr(b_tensormap_ptr is not None): + self.update_tma_desc_ptr( + mB_nkl, + tma_atom_b, + tensormap_manager, + b_tensormap_ptr, + TIdx_cur_group, + token_group_size, + is_tma_warp, + tensormap_smem_ptr=b_tensormap_smem_ptr, + # cute.AddressSpace.generic + ) + k_tile_cnt = cute.ceil_div(token_group_size, self.tile_shape_mnk[2]) + + last_batch_idx = batch_idx + + if const_expr(self.is_A_gather): + cA = cute.local_tile(mcA_mkl, (self.tile_M, self.tile_K), (tile_coord_mnkl[0], None)) + + tAsA = A_g2s_thr_copy.partition_D(sA) + tAcA = A_g2s_thr_copy.partition_D(cA) + + tApA = cute.make_rmem_tensor( + cute.make_layout( + ( + tAgA.shape[0][1], + cute.size(tAgA, mode=[1]), + cute.size(tAgA, mode=[2]), + ), + stride=(cute.size(tAgA, mode=[1]), 1, 0), + ), + cutlass.Boolean, + ) + + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for m in cutlass.range_constexpr(tApA.shape[1]): + if const_expr(self.compute_weight_gradient): + tApA[rest_v, m, 0] = cute.elem_less(tAcA[(0, rest_v), m, 0, 0][0], mA_mkl.shape[0]) + else: + tApA[rest_v, m, 0] = cute.elem_less( + tAcA[(0, rest_v), m, 0, 0][0], token_group_size + ) + else: + if const_expr(self.compute_weight_gradient): + # update TMA map instead + mA_mk = cute.domain_offset((0, 0), mA_mkl_tma) + else: + mA_mk = cute.domain_offset((TIdx_cur_group, 0), mA_mkl_tma) + + gA_mk_cur = cute.local_tile(mA_mk, (self.tile_M, self.tile_K), (tile_coord_mnkl[0], None)) + + a_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (0, None, 0)).shape) + a_cta_crd = cluster_coord_mnk[1] + + tAsA, tAgA_mkl = cpasync.tma_partition( + tma_atom_a, + a_cta_crd, + a_cta_layout, + cute.group_modes(sA, 0, 2), + cute.group_modes(gA_mk_cur, 0, 2), + ) + + if const_expr(self.compute_weight_gradient): + gB_nk = cute.local_tile(mB_nkl_tma, (self.tile_N, self.tile_K), (tile_coord_mnkl[1], None)) + else: + gB_nk = cute.local_tile(mB_nkl_tma, self.tile_shape_mnk, tile_coord_mnkl, proj=(None, 1, 1)) + + b_cta_layout = cute.make_layout(cute.slice_(cta_layout_mnk, (None, 0, 0)).shape) + b_cta_crd = cluster_coord_mnk[0] + tBsB, tBgB_nkl = cpasync.tma_partition( + tma_atom_b, + b_cta_crd, + b_cta_layout, + cute.group_modes(sB, 0, 2), + cute.group_modes(gB_nk, 0, 2), + ) + + peek_ab_empty_status = cutlass.Boolean(True) + if 0 < k_tile_cnt: + peek_ab_empty_status = mainloop_pipeline.producer_try_acquire(mainloop_producer_state) + + if const_expr(self.is_A_gather): + M_offset = cute.arch.make_warp_uniform(tile_coord_mnkl[0] * const_expr(self.tile_M)) + if const_expr(self.compute_weight_gradient): + tmAIdx = None + M_boundary = mA_mkl.shape[0] + else: + M_boundary = cute.arch.make_warp_uniform( + self.min_i32(const_expr(self.tile_M), token_group_size - M_offset) + ) + tmAIdx = self.prefetch_gather_idx_for_A_when_vary_M( + mAIdx_mk, M_offset, M_boundary, A_thr_copy_elems + ) + + if const_expr(self.compute_weight_gradient): + if const_expr(self.is_A_gather): + a_tma_desc_ptr = None + else: + a_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( + a_tensormap_ptr, cute.AddressSpace.generic + ) + + b_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( + b_tensormap_ptr, cute.AddressSpace.generic + ) + else: + a_tma_desc_ptr = None + b_tma_desc_ptr = None + + for k_tile in cutlass.range(k_tile_cnt, unroll=1): + if const_expr(self.is_A_gather): + mainloop_pipeline.producer_acquire( + mainloop_producer_state, peek_ab_empty_status, is_tma_warp=is_tma_warp + ) + else: + mainloop_pipeline.producer_acquire(mainloop_producer_state, peek_ab_empty_status) + + if is_tma_warp: + cute.copy( + tma_atom_b, + tBgB_nkl[None, k_tile], + tBsB[None, mainloop_producer_state.index], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), + mcast_mask=b_mcast_mask, + tma_desc_ptr=b_tma_desc_ptr, + ) + + K_offset = k_tile * const_expr(self.tile_K) + + if const_expr(self.compute_weight_gradient and self.is_A_gather): + if K_offset % const_expr(self.prefetch_token_idx_size) == 0: + self.prefetch_gather_idx_for_A_when_vary_K( + mAIdx_mk, sAIdx_prefetch, token_group_size, K_offset + ) + + if const_expr(self.is_A_gather): + self.load_A_gather( + mA_mkl, + tmAIdx, + sAIdx_prefetch, + M_offset, + tAsA[None, None, None, mainloop_producer_state.index], + tApA, + A_g2s_thr_copy, + K_offset, + token_group_size, + A_thr_copy_elems, + ) + else: + cute.copy( + tma_atom_a, + tAgA_mkl[None, k_tile], + tAsA[None, mainloop_producer_state.index], + tma_bar_ptr=mainloop_pipeline.producer_get_barrier(mainloop_producer_state), + mcast_mask=a_mcast_mask, + tma_desc_ptr=a_tma_desc_ptr, + ) + + if const_expr(not self.is_A_gather): + # Mainloop pipeline's producer commit is a NOP + mainloop_pipeline.producer_commit(mainloop_producer_state) + else: + mainloop_pipeline.producer_cpasync_commit(mainloop_producer_state) + mainloop_producer_state.advance() + + peek_ab_empty_status = cutlass.Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_empty_status = mainloop_pipeline.producer_try_acquire(mainloop_producer_state) + + tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.get_current_work() + + if const_expr(self.pingpong): + # Need to write the tile_idx to smem for the next WG in the pingpong mode + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + # End of persistent scheduler loop + mainloop_pipeline.producer_tail(mainloop_producer_state) + if is_scheduler_warp: + tile_scheduler.producer_tail() + + if warp_idx < self.tma_warp_id: + cute.arch.setmaxregister_increase(self.num_regs_mma) + cute.arch.setmaxregister_increase(self.num_regs_mma) + is_tma_warp = cutlass.Boolean( + (not self.pingpong and warp_idx == 0) or (self.pingpong and (warp_idx == 0 or warp_idx == 4)) + ) + if const_expr(not self.compute_weight_gradient): + if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + tensormap_manager.init_tensormap_from_atom( + tma_atom_d, + tensormap_d_init_ptr, + is_manager_warp=is_tma_warp, + ) + if const_expr(self.need_adhoc_epilogue_store): + tensormap_manager.init_tensormap_from_atom( + tma_atom_y, + tensormap_y_init_ptr, + is_manager_warp=is_tma_warp, + ) + if const_expr(self.need_epilogue_load): + tensormap_manager.init_tensormap_from_atom( + tma_atom_c, + tensormap_c_init_ptr, + is_manager_warp=is_tma_warp, + ) + + tidx, _, _ = cute.arch.thread_idx() + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + if const_expr(self.pingpong): + tidx = tidx % self.num_threads_per_warp_group + warp_group_thread_layout = cute.make_layout( + self.mma_warp_groups if not self.pingpong else 1, + stride=self.num_threads_per_warp_group, + ) + thr_mma = tiled_mma.get_slice(warp_group_thread_layout(warp_group_idx if not self.pingpong else 0)) + + tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA)) + tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB)) + + acc_shape = tiled_mma.partition_shape_C(cute.select(self.tile_shape_mnk, mode=[0, 1])) + acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + acc = cute.make_rmem_tensor(acc_shape, self.acc_dtype) + + if const_expr(self.pingpong): + if warp_group_idx == 0: + # WG0 needs a start signal at the very beginning + self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") + self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") + + mainloop_consumer_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + epi_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.c_epi_stage) + epi_producer_state = make_pipeline_state(pipeline.PipelineUserType.Producer, self.c_epi_stage) + + if const_expr(not self.compute_weight_gradient): + tensormap_manager.fence_tensormap_initialization() + + tile_scheduler = TileSchedulerCls() + if const_expr(self.pingpong): + if warp_idx >= 4: + # Advance 2nd Math WG pipeline states to the end of 1st Math WG + if const_expr(self.compute_weight_gradient): + wg0_batch_idx = tile_scheduler.initial_work_tile_info().tile_idx[-1] + wg0_token_group_size = cute.arch.make_warp_uniform( + mTokenoffset[wg0_batch_idx + 1] + ) - cute.arch.make_warp_uniform(mTokenoffset[wg0_batch_idx]) + tile_scheduler.advance_to_next_work() + mainloop_consumer_read_state.advance_iters( + cute.ceil_div(wg0_token_group_size, self.tile_shape_mnk[2]) + ) + else: + tile_scheduler.advance_to_next_work() + mainloop_consumer_read_state.advance_iters(k_tile_cnt) + + # mainloop_consumer_read_state.advance_iters(k_tile_cnt) + if const_expr(self.need_epilogue_load): + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + + work_tile = tile_scheduler.initial_work_tile_info() + last_batch_idx = cutlass.Int32(-1) + token_group_size = cutlass.Int32(0) + + TIdx_cur_group = TIdx_next_group = cutlass.Int32(0) + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + is_group_changed = batch_idx != last_batch_idx + if is_group_changed: + # construct tensor D based on real address, shape and stride information + TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( + mTokenoffset[batch_idx] + ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) + token_group_size = cute.arch.make_warp_uniform(TIdx_next_group - TIdx_cur_group) + if const_expr(self.compute_weight_gradient): + k_tile_cnt = cute.arch.make_warp_uniform( + cute.ceil_div(token_group_size, self.tile_shape_mnk[2]) + ) + else: + if const_expr((not self.inference_mode) or (not self.need_adhoc_epilogue_store)): + assert d_tensormap_smem_ptr is not None and d_tensormap_ptr is not None + self.update_tma_desc_ptr( + mD_mnl, + tma_atom_d, + tensormap_manager, + d_tensormap_ptr, + TIdx_cur_group, + token_group_size, + is_tma_warp, + tensormap_smem_ptr=d_tensormap_smem_ptr, + # cute.AddressSpace.generic + ) + if const_expr(self.need_adhoc_epilogue_store): + assert y_tensormap_smem_ptr is not None and y_tensormap_ptr is not None + self.update_tma_desc_ptr( + mY_mnl, + tma_atom_y, + tensormap_manager, + y_tensormap_ptr, + TIdx_cur_group, + token_group_size, + is_tma_warp, + tensormap_smem_ptr=y_tensormap_smem_ptr, + # cute.AddressSpace.generic + ) + if const_expr(self.need_epilogue_load): + assert c_tensormap_smem_ptr is not None and c_tensormap_ptr is not None + self.update_tma_desc_ptr( + mC_mnl, + tma_atom_c, + tensormap_manager, + c_tensormap_ptr, + TIdx_cur_group, + token_group_size, + is_tma_warp, + tensormap_smem_ptr=c_tensormap_smem_ptr, + # cute.AddressSpace.generic + ) + last_batch_idx = batch_idx + + k_pipe_mmas = 1 + mainloop_consumer_release_state = mainloop_consumer_read_state.clone() + num_prologue_mma = min(k_pipe_mmas, k_tile_cnt) + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, stage="mma") + + peek_ab_full_status = cutlass.Boolean(True) + + if const_expr(self.compute_weight_gradient): + if k_tile_cnt == 0: + acc.fill(0.0) + + if k_tile_cnt > 0: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) + + tiled_mma.set(warpgroup.Field.ACCUMULATE, False) + num_k_blocks = cute.size(tCrA, mode=[2]) + + for k_tile in cutlass.range(num_prologue_mma): + # Wait for A/B buffer to be ready + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status) + warpgroup.fence() + for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_blk_coord = (None, None, k_blk_idx, mainloop_consumer_read_state.index) + cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) + tiled_mma.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + mainloop_consumer_read_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile + 1 < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) + + for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1): + # Wait for TMA copies to complete + mainloop_pipeline.consumer_wait(mainloop_consumer_read_state, peek_ab_full_status) + # WGMMA + warpgroup.fence() + for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_blk_coord = (None, None, k_blk_idx, mainloop_consumer_read_state.index) + cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) + warpgroup.commit_group() + # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete + warpgroup.wait_group(k_pipe_mmas) + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_read_state.advance() + mainloop_consumer_release_state.advance() + peek_ab_full_status = cutlass.Boolean(1) + if k_tile + 1 < k_tile_cnt: + peek_ab_full_status = mainloop_pipeline.consumer_try_wait(mainloop_consumer_read_state) + if const_expr(self.pingpong): + # Cue for next WG's MMA to start + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + warpgroup.wait_group(0) + for k_tile in cutlass.range(num_prologue_mma, unroll=1): + mainloop_pipeline.consumer_release(mainloop_consumer_release_state) + mainloop_consumer_release_state.advance() + + if const_expr(self.pingpong): + if const_expr(self.compute_weight_gradient): + other_batch_idx = tile_scheduler.prefetch_next_work().tile_idx[-1] + other_token_group_size = cute.arch.make_warp_uniform( + mTokenoffset[other_batch_idx + 1] + ) - cute.arch.make_warp_uniform(mTokenoffset[other_batch_idx]) + mainloop_consumer_read_state.advance_iters( + cute.ceil_div(other_token_group_size, self.tile_shape_mnk[2]) + ) + else: + mainloop_consumer_read_state.advance_iters(k_tile_cnt) + + # Update starting mainloop pipeline state for the next tile + + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, "epi") + + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), num_threads=self.num_epi_threads + ) + + # Wait for all warp groups in the thread block to finish, because smem for tensor + # A in the mainloop is reused in the epilogue if not persistent. + if const_expr(not self.is_persistent): + epilogue_barrier.arrive_and_wait() + + copy_atom_D_r2s = sm90_utils.sm90_get_smem_store_op( + self.d_layout, + elem_ty_d=self.d_dtype, + elem_ty_acc=self.acc_dtype, + ) + copy_atom_D = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(self.d_layout.is_m_major_c(), 4), + self.d_dtype, + ) + tiled_copy_D_atom = cute.make_tiled_copy_C_atom(copy_atom_D, tiled_mma) + tiled_copy_D_r2s = cute.make_tiled_copy_S(copy_atom_D_r2s, tiled_copy_D_atom) + # (R2S, R2S_M, R2S_N, PIPE_D) + tRS_sD = tiled_copy_D_r2s.get_slice(tidx).partition_D(sD) + tRS_rD_layout = cute.make_layout(tiled_copy_D_r2s.get_slice(tidx).partition_S(sD).shape[:3]) + tRS_rD = cute.make_rmem_tensor(tRS_rD_layout, self.acc_dtype) + + if const_expr(self.need_epilogue_load): + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp( + self.c_layout.is_m_major_c(), + num_matrices=(4 if self.c_epi_tile[1] % 16 == 0 else 2), + ), + cutlass.Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + copy_atom_C_s2r = sm90_get_smem_load_op(self.c_layout, self.c_dtype) + tiled_copy_C_s2r = cute.make_tiled_copy_S(copy_atom_C_s2r, tiled_copy_C_atom) + thr_copy_C_s2r = tiled_copy_C_s2r.get_slice(tidx) + tSR_sC = thr_copy_C_s2r.partition_S(sC) + tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, self.c_dtype) + tRS_rC = cute.make_rmem_tensor(tRS_rD_layout, self.c_dtype) + tSR_rC = thr_copy_C_s2r.retile(tRS_rC) + else: + thr_copy_C_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None + + if const_expr(self.need_adhoc_epilogue_store): + copy_atom_Y_r2s = sm90_utils.sm90_get_smem_store_op( + self.y_layout, + elem_ty_d=self.y_dtype, + elem_ty_acc=self.acc_dtype, + ) + copy_atom_Y = cute.make_copy_atom( + warp.StMatrix8x8x16bOp(self.y_layout.is_m_major_c(), 4), + self.y_dtype, + ) + tiled_copy_Y_atom = cute.make_tiled_copy_C_atom(copy_atom_Y, tiled_mma) + tiled_copy_Y_r2s = cute.make_tiled_copy_S(copy_atom_Y_r2s, tiled_copy_Y_atom) + tRS_sY = tiled_copy_Y_r2s.get_slice(tidx).partition_D(sY) + + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_D_r2s.retile(acc) + # tRS_rAcc: tensor> o ((8,8),3,1):((1,8),64,0)> + + # (bM, bN) + batch_idx = tile_coord_mnkl[3] + if const_expr(self.compute_weight_gradient): + gD_mn = cute.local_tile( + mD_mnl_tma[None, None, batch_idx], (self.tile_M, self.tile_N), tile_coord_mnkl[:2] + ) + else: + gD_mn = cute.local_tile(mD_mnl_tma, (self.tile_M, self.tile_N), tile_coord_mnkl[:2]) + + copy_elems_D = self.universal_copy_bits // mD_mnl.element_type.width + tdgd_for_tma_partition = cute.zipped_divide(gD_mn, self.d_epi_tile) + + if const_expr(self.need_adhoc_epilogue_store): + y_tile_size = (self.tile_M, self.tile_N) + if const_expr(self.is_glu and not self.compute_dz_and_partial_ds_and_y1s): + y_tile_size = (self.tile_M, self.tile_N // 2) + + gY_mn = cute.local_tile(mY_mnl_tma, y_tile_size, tile_coord_mnkl[:2]) + + tygy_for_tma_partition = cute.zipped_divide(gY_mn, self.y_epi_tile) + # bSG_sD: tensor, S<2,4,3>> o ((2048,1),(1,4)):((1,0),(0,2048))> + # bSG_gD: tensor<(?{div=128},?{div=192},?) o (((32,64),1),(3,4)):(((1@0,1@1),0),(64@1,32@0))> + if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): + bSG_sD = bSG_gD = None + else: + bSG_sD, bSG_gD = cpasync.tma_partition( + tma_atom_d, + 0, + cute.make_layout(1), + cute.group_modes(sD, 0, 2), + tdgd_for_tma_partition, + ) + + if const_expr(self.need_adhoc_epilogue_store): + bSG_sY, bSG_gY = cpasync.tma_partition( + tma_atom_y, + 0, + cute.make_layout(1), + cute.group_modes(sY, 0, 2), + tygy_for_tma_partition, + ) + assert const_expr(cute.size(tdgd_for_tma_partition, mode=[1])) == const_expr( + cute.size(tygy_for_tma_partition, mode=[1]) + ) + + if const_expr(self.use_bias): + expert_idx = tile_coord_mnkl[-1] + expert_elem_load = const_expr(self.universal_copy_bits // mBias_nl.element_type.width) + gBias = cute.local_tile(mBias_nl, (1, self.tile_shape_mnk[1]), (expert_idx, tile_coord_mnkl[1])) + cBias = cute.local_tile( + cute.make_identity_tensor((1, mBias_nl.shape[1])), + (1, self.tile_shape_mnk[1]), + (0, tile_coord_mnkl[1]), + ) + + thr_copy_bias = cpasync_atom_bias.get_slice(tidx) + tBiasgBias = thr_copy_bias.partition_S(gBias) + tBiassBias = thr_copy_bias.partition_D(sBias) + tBiascBias = thr_copy_bias.partition_S(cBias) + + thread_per_row = const_expr(self.tile_shape_mnk[1] // expert_elem_load) + if tidx < thread_per_row: + if tBiascBias[0][1] < mBias_nl.shape[1]: + cute.copy(thr_copy_bias, tBiasgBias, tBiassBias) + else: + tBiassBias.fill(0.0) + + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + # cannot be removed for correctness! + epilogue_barrier.arrive_and_wait() + + partition_for_epi_fn = partial( + partition_for_epilogue, + epi_tile=self.d_epi_tile, + tiled_copy=tiled_copy_D_r2s, + tidx=tidx, + reference_src=True, + ) + sBias_retiled = partition_for_epi_fn( + cute.make_tensor(sBias.iterator, cute.make_layout((self.tile_M, self.tile_N), stride=(0, 1))) + ) + + epi_tile_num = const_expr(cute.size(tdgd_for_tma_partition, mode=[1])) + + epi_tile_shape = tdgd_for_tma_partition.shape[1] + num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num + epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) + + if const_expr(mDIdx_mnl is not None): + mcD = cute.make_identity_tensor((self.tile_M, self.tile_N)) + + tcDgcD_flat_partition = cute.flat_divide(mcD, self.d_epi_tile) + D_r2g_thr_copy = D_tiled_copy.get_slice(tidx) + + TIdx_cur_group, TIdx_next_group = mTokenoffset[batch_idx], mTokenoffset[batch_idx + 1] + if const_expr(self.is_scatter_idx_prefetched): + tmDIdx = self.prefetch_scatter_idx_for_D_when_vary_M( + mD_mnl, + mDIdx_mnl, + D_r2g_thr_copy, + tcDgcD_flat_partition, + epi_tile_layout, + epi_tile_num, + const_expr(self.universal_copy_bits // mD_mnl.element_type.width), + tile_coord_mnkl, + TIdx_cur_group, + TIdx_next_group, + ) + else: + tmDIdx = None + else: + mcD = tcDgcD_flat_partition = D_r2g_thr_copy = tmDIdx = None + + if const_expr(not self.compute_weight_gradient): + if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): + d_tma_desc_ptr = None + else: + d_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( + d_tensormap_ptr, + cute.AddressSpace.generic, + ) + if const_expr(self.need_adhoc_epilogue_store): + y_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( + y_tensormap_ptr, + cute.AddressSpace.generic, + ) + if const_expr(self.need_epilogue_load): + c_tma_desc_ptr = tensormap_manager.get_tensormap_ptr( + c_tensormap_ptr, + cute.AddressSpace.generic, + ) + else: + d_tma_desc_ptr = y_tma_desc_ptr = c_tma_desc_ptr = None + + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + TIdx_cur_group, TIdx_next_group = cute.arch.make_warp_uniform( + mTokenoffset[batch_idx] + ), cute.arch.make_warp_uniform(mTokenoffset[batch_idx + 1]) + self.fetch_scattered_S( + tidx, + mS_ml, + mS_scatter_idx, + sS, + tile_coord_mnkl, + TIdx_cur_group, + TIdx_next_group, + ) + epilogue_barrier.arrive_and_wait() + + cD = cute.make_identity_tensor((self.tile_M, self.tile_N)) + tDcD = tiled_mma.get_slice(tidx).partition_C(cD) + tRS_rcD_retiled = tiled_copy_D_r2s.retile(tDcD) + tRS_rcD = cute.make_rmem_tensor_like(tRS_rD, dtype=mS_scatter_idx.element_type) + + if const_expr(self.need_epilogue_load): + # mC_mn = cute.domain_offset((mTokenoffset[batch_idx], 0), mC_mnl_tma) + gC = cute.local_tile(mC_mnl_tma, (self.tile_M, self.tile_N), tile_coord_mnkl[:2]) + tCgC_for_tma_partition = cute.zipped_divide(gC, self.c_epi_tile) + bGS_sC, bGS_gC = cpasync.tma_partition( + tma_atom_c, + 0, + cute.make_layout(1), + cute.group_modes(sC, 0, 2), + tCgC_for_tma_partition, + ) + + for epi_idx in cutlass.range(min(epi_tile_num, self.c_epi_stage), unroll=1): + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + # Get the global memory coordinate for the current epi tile + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + cute.copy( + tma_atom_c, + bGS_gC[None, gmem_coord], + bGS_sC[None, epi_producer_state.index], + tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state), + tma_desc_ptr=c_tma_desc_ptr, + ) + # Epi pipeline's producer commit is a NOP + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # Copy from acc to D registers + # tRS_sD: (((2, 4), 1), 1, 2, (1, 4)) + # tRS_rD = cute.make_fragment_like(tRS_sD[None, None, None, 0], self.acc_dtype) # (((2, 4), 1), 1, 2) + + # tRS_rD: tensor> o (((2,4),1),1,2):(((1,2),0),0,8)> + for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): # cute.size(tRS_rD): 16 + tRS_rD[epi_v] = tRS_rAcc[const_expr(epi_idx * cute.size(tRS_rD) + epi_v)] + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + tRS_rcD[epi_v] = tRS_rcD_retiled[const_expr(epi_idx * cute.size(tRS_rD) + epi_v)][0] + + if const_expr(self.need_epilogue_load): + epi_pipeline.consumer_wait(epi_read_state) + cute.copy(thr_copy_C_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) + # Fence to make sure shared memory read is visible to TMA load + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + cute.arch.sync_warp() + with cute.arch.elect_one(): + epi_pipeline.consumer_release(epi_read_state) + epi_read_state.advance() + if const_expr(epi_idx + self.c_epi_stage < epi_tile_num): + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + # Get the global memory coordinate for the current epi tile + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx + self.c_epi_stage) + cute.copy( + tma_atom_c, + bGS_gC[None, gmem_coord], + bGS_sC[None, epi_producer_state.index], + tma_bar_ptr=epi_pipeline.producer_get_barrier(epi_producer_state), + tma_desc_ptr=c_tma_desc_ptr, + ) + # Epi pipeline's producer commit is a NOP + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + if const_expr(self.use_bias): + sBias_retiled_and_grouped = cute.group_modes(sBias_retiled, 3, cute.rank(sBias_retiled)) + sBias_retiled_and_grouped_epi = sBias_retiled_and_grouped[ + None, None, None, epi_tile_layout.get_hier_coord(epi_idx) + ] + rBias_retiled_epi_r = cute.make_rmem_tensor( + sBias_retiled_and_grouped_epi.layout, dtype=mBias_nl.element_type + ) + cute.autovec_copy( + cute.filter_zeros(sBias_retiled_and_grouped_epi), cute.filter_zeros(rBias_retiled_epi_r) + ) + + for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): + tRS_rD[epi_v] = tRS_rD[epi_v] + self.acc_dtype(rBias_retiled_epi_r[epi_v]) + + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + tRS_rD_out = cute.make_rmem_tensor_like( + tRS_rD, (cutlass.Float32 if const_expr(self.is_glu) else self.d_dtype) + ) + tRS_rY = cute.make_rmem_tensor_like(tRS_sY[None, None, None, 0], self.y_dtype) + self.compute_backward_activation( + tRS_rAcc, sS, tRS_rcD, tRS_rC, tRS_rD, tRS_rD_out, tRS_rY, epi_idx + ) + + elif const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + tRS_rD_out = cute.make_rmem_tensor_like(tRS_rD, self.d_dtype) + tRS_rD_out.store(tRS_rD.load().to(self.d_dtype)) + + if const_expr((self.is_glu or self.is_normal_act) and not self.compute_dz_and_partial_ds_and_y1s): + tRS_rY = cute.make_rmem_tensor_like(tRS_sY[None, None, None, 0], self.y_dtype) + self.compute_activation(tRS_rD, tRS_rY) + + # Copy from D registers to shared memory + if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): + epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sY, mode=[3]) + else: + epi_buffer = (num_prev_subtiles + epi_idx) % cute.size(tRS_sD, mode=[3]) + + if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + cute.copy(tiled_copy_D_r2s, tRS_rD_out, tRS_sD[(None, None, None, epi_buffer)]) + if const_expr(self.need_adhoc_epilogue_store): + cute.copy(tiled_copy_Y_r2s, tRS_rY, tRS_sY[(None, None, None, epi_buffer)]) + + if const_expr(mDIdx_mnl is not None): + epilogue_barrier.arrive_and_wait() + tDsD = D_r2g_thr_copy.partition_S(sD[None, None, epi_buffer]) + tDrD = cute.make_rmem_tensor_like(tDsD) + tDrD = cute.make_rmem_tensor_like(tDsD) + cute.autovec_copy(tDsD, tDrD) + + tDcD_slice = D_r2g_thr_copy.partition_D( + tcDgcD_flat_partition[None, None, *epi_tile_layout.get_hier_coord(epi_idx)] + ) + self.store_D_scatter( + mD_mnl, + mDIdx_mnl, + tmDIdx, + tDrD, + tDcD_slice, + D_r2g_thr_copy, + epi_idx, + copy_elems_D, + tile_coord_mnkl, + TIdx_cur_group, + TIdx_next_group, + ) + epilogue_barrier.arrive_and_wait() + + else: + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy(cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta) + epilogue_barrier.arrive_and_wait() + # Get the global memory coordinate for the current epi tile. + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from shared memory to global memory + if is_tma_warp: + if const_expr(not (self.inference_mode and self.need_adhoc_epilogue_store)): + cute.copy( + tma_atom_d, + bSG_sD[None, epi_buffer], + bSG_gD[None, gmem_coord], + tma_desc_ptr=d_tma_desc_ptr, + ) + if const_expr(self.need_adhoc_epilogue_store): + cute.copy( + tma_atom_y, + bSG_sY[None, epi_buffer], + bSG_gY[None, gmem_coord], + tma_desc_ptr=y_tma_desc_ptr, + ) + cute.arch.cp_async_bulk_commit_group() + if const_expr(self.inference_mode and self.need_adhoc_epilogue_store): + cute.arch.cp_async_bulk_wait_group(const_expr(self.y_epi_stage - 1), read=True) + else: + cute.arch.cp_async_bulk_wait_group(const_expr(self.d_epi_stage - 1), read=True) + + epilogue_barrier.arrive_and_wait() + + if const_expr(self.compute_dz_and_partial_ds_and_y1s): + y1 = make_acc_tensor_mn_view(acc) + cD = cute.make_identity_tensor((self.tile_M, self.tile_N)) + tDcD = tiled_mma.get_slice(tidx).partition_C(cD) + tDcD_mn = make_acc_tensor_mn_view(tDcD) + + tile_M_offset = cute.arch.make_warp_uniform(TIdx_cur_group + tile_coord_mnkl[0] * self.tile_M) + + mDS_partial_M, mDS_partial_N = mDS_partial.shape + mDS_partial_flatten_view = cute.make_tensor(mDS_partial.iterator, (mDS_partial_M * mDS_partial_N,)) + for r in cutlass.range_constexpr(cute.size(y1, mode=[0])): + col_sum = cutlass.Float32(0.0) + + M_tile_idx = tDcD_mn[r, 0][0] + for c in cutlass.range_constexpr(cute.size(y1, mode=[1])): + col_sum = col_sum + y1[r, c] + + col_sum = cute.arch.warp_reduction(col_sum, operator.add, threads_in_group=4) + + M_idx_raw = tile_M_offset + M_tile_idx + if tidx % 4 == 0 and M_idx_raw < TIdx_next_group: + M_idx = mS_scatter_idx[M_idx_raw] + N_idx = tile_coord_mnkl[1] + mDS_partial_flatten_view[M_idx * mDS_partial_N + N_idx] = col_sum.to( + mDS_partial.element_type + ) + + if const_expr(self.pingpong): + # With pingpong, 2 WGs write two different output tiles to the same smem, + # so we have to make sure the smem content is done reading before signalling + # the next WG's epilogue. + if const_expr(self.need_epilogue_load): + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + if warp_idx == 0 or warp_idx == 4: + cute.arch.cp_async_bulk_wait_group(0, read=True) + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + + tile_scheduler.advance_to_next_work(advance_count=1 if not self.pingpong else self.mma_warp_groups) + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + if const_expr(not self.pingpong): + if warp_idx == 0: + cute.arch.cp_async_bulk_wait_group(0, read=True) + + def generate_tensormap(self, m, n, l): + if not self.is_persistent: + total_m = m * l + block_size_m = self.tile_M * self.cluster_shape_mnk[0] + block_size_n = self.tile_N * self.cluster_shape_mnk[1] + total_clusters_m_max = (total_m + l * (block_size_m - 1)) // block_size_m + total_clusters_max = total_clusters_m_max * ((n + block_size_n - 1) // block_size_n) + total_ctas = total_clusters_max * self.cluster_shape_mnk[0] * self.cluster_shape_mnk[1] + else: + total_ctas = cutlass.utils.HardwareInfo().get_device_multiprocessor_count() + if self.pingpong: + total_ctas *= 2 + # 128 bytes per tensormap + tensormaps_torch = torch.empty(total_ctas, 128 // 8, dtype=torch.int64, device="cuda") + tensormaps_tensor = from_dlpack(tensormaps_torch, assumed_align=128).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1) + ) + return tensormaps_tensor + + def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: str): + assert stage in ["mma", "epi"] + barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 + cute.arch.barrier( + barrier_id=int(barrier) + warp_group_idx, + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: str): + assert stage in ["mma", "epi"] + barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 + cute.arch.barrier_arrive( + barrier_id=int(barrier) + warp_group_idx, + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def make_sched_pipeline(self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer): + # Threads/warps participating in this pipeline + sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(cluster_layout_mnk) + # Each warp that are not the scheduler warp will contribute 1 to the arrive count + consumer_arrive_cnt = ( + (self.mma_warp_groups if not self.pingpong else 1) * 4 + + max(self.num_load_A_threads // cute.arch.WARP_SIZE, 1) + ) * cluster_size - 1 + sched_pipeline_consumer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, consumer_arrive_cnt) + return pipeline.PipelineAsync.create( + barrier_storage=sched_pipeline_mbar_ptr, + num_stages=self.sched_stage, + producer_group=sched_pipeline_producer_group, + consumer_group=sched_pipeline_consumer_group, + # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. + consumer_mask=None if const_expr(cluster_size == 1) else 0, + ) + + def _compute_stages( + self, + tile_shape_mnk: Tuple[int, int, int], + initial_d_epi_stage: int, + d_epi_tile: Optional[Tuple[int, int]], + c_epi_tile: Optional[Tuple[int, int]], + y_epi_tile: Optional[Tuple[int, int]], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + c_dtype: Optional[Type[cutlass.Numeric]], + y_dtype: Optional[Type[cutlass.Numeric]], + smem_capacity: int, + occupancy: int, + overlap_sD_sA: bool, + ) -> Tuple[int, int, int]: + d_epi_stage = initial_d_epi_stage if const_expr(not self.need_epilogue_load) else initial_d_epi_stage // 2 + y_epi_stage = d_epi_stage + + if self.inference_mode and self.need_adhoc_epilogue_store: + d_epi_stage = 0 + + if overlap_sD_sA: + epi_bytes = 0 + else: + d_bytes_per_stage = cute.size(d_epi_tile) * d_dtype.width // 8 + epi_bytes = d_bytes_per_stage * d_epi_stage + + if y_dtype is not None or const_expr(self.need_adhoc_epilogue_store): + y_bytes_per_stage = cute.size(y_epi_tile) * y_dtype.width // 8 + epi_bytes += y_bytes_per_stage * y_epi_stage + else: + y_bytes_per_stage = 0 + + c_epi_stage = 0 if (c_dtype is None or const_expr(not self.need_epilogue_load)) else d_epi_stage + if c_dtype is not None and const_expr(self.need_epilogue_load): + c_bytes_per_stage = cute.size(c_epi_tile) * c_dtype.width // 8 * c_epi_stage + epi_bytes += c_bytes_per_stage * c_epi_stage + d_epi_stage = c_epi_stage + else: + c_bytes_per_stage = 0 + + a_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8 + mbar_helpers_bytes = 1024 + + remaining_bytes = ( + (smem_capacity - occupancy * 1024) // occupancy + - mbar_helpers_bytes + - epi_bytes + - self.prefetch_token_idx_size * 4 + - (self.tile_shape_mnk[1] * (self.bias_dtype.width // 8) if self.use_bias else 0) + - 1024 # aligned self.tensormap_management_bytes + ) + ab_stage = remaining_bytes // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if not overlap_sD_sA: + if self.inference_mode and self.need_adhoc_epilogue_store: + epi_stage_delta = (remaining_bytes - ab_bytes_per_stage * ab_stage) // ( + y_bytes_per_stage + c_bytes_per_stage + ) + y_epi_stage += epi_stage_delta + else: + epi_stage_delta = (remaining_bytes - ab_bytes_per_stage * ab_stage) // ( + d_bytes_per_stage + y_bytes_per_stage + c_bytes_per_stage + ) + d_epi_stage += epi_stage_delta + y_epi_stage += epi_stage_delta + + if c_epi_stage > 0: + c_epi_stage += epi_stage_delta + + if not self.need_adhoc_epilogue_store: + y_epi_stage = 0 + + return ab_stage, c_epi_stage, d_epi_stage, y_epi_stage + + def _sm90_compute_tile_shape_or_override( + self, + tile_shape_mnk: Tuple[int, int, int], + atom_layout_mnk: Tuple[int, int, int], + element_type: Type[cutlass.Numeric], + epi_tile_override: Tuple[int, int] | None = None, + ) -> Tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param tile_shape_mnk: CTA tile shape (M,N,K) + :type tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1: + tile_m = math.gcd(128, cute.size(tile_shape_mnk, mode=[0])) + tile_n = math.gcd(self.epi_tile_size, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + elif tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1: + tile_m = math.gcd(192, cute.size(tile_shape_mnk, mode=[0])) + tile_n = math.gcd(self.epi_tile_size, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + else: + # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set + # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the + # M dimension first, then move to the N dimension. But the accumulator in registers + # iterate along the N dimension first, then move to the M dimension. + # We could change the epilogue to accommodate this, + # but it's easier to just set epi_tile_m = 64. + n_perf = 64 if element_type.width == 8 else min(self.epi_tile_size, tile_shape_mnk[1]) + tile_m = math.gcd(64, cute.size(tile_shape_mnk, mode=[0])) + tile_n = math.gcd(n_perf, cute.size(tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + tile_shape_mnk: Tuple[int, int, int], + c_epi_tile: Tuple[int, int], + bias_epi_tile: Tuple[int, int], + d_epi_tile: Tuple[int, int], + y_epi_tile: Optional[Tuple[int, int]], + a_dtype: Type[cutlass.Numeric], + a_layout: utils.LayoutEnum, + b_dtype: Type[cutlass.Numeric], + b_layout: utils.LayoutEnum, + prefetch_idx_size: Optional[int], + ab_stage: int, + c_dtype: Optional[Type[cutlass.Numeric]], + c_layout: Optional[cutlass.utils.LayoutEnum], + bias_dtype: Optional[Type[cutlass.Numeric]], + bias_layout: Optional[cutlass.utils.LayoutEnum], + d_dtype: Type[cutlass.Numeric], + d_layout: utils.LayoutEnum, + y_dtype: Optional[Type[cutlass.Numeric]], + y_layout: Optional[utils.LayoutEnum], + s_dtype: Optional[Type[cutlass.Numeric]], + c_epi_stage: int, + d_epi_stage: int, + y_epi_stage: int, + ) -> Tuple[ + cute.ComposedLayout, + cute.ComposedLayout, + Optional[cute.ComposedLayout], + cute.ComposedLayout, + Optional[cute.ComposedLayout], + ]: + a_smem_shape = cute.slice_(tile_shape_mnk, (None, 0, None)) + + a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K + b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K + + a_major_mode_size = tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + a_layout, + a_dtype, + a_major_mode_size, + ), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(tile_shape_mnk, (0, None, None)) + + b_major_mode_size = tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + b_layout, + b_dtype, + b_major_mode_size, + ), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + d_smem_shape = d_epi_tile + d_major_mode_size = d_epi_tile[1] if d_layout.is_n_major_c() else d_epi_tile[0] + d_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + d_layout, + d_dtype, + d_major_mode_size, + ), + d_dtype, + ) + if d_epi_stage > 0: + d_epi_smem_layout_staged = cute.tile_to_shape( + d_smem_layout_atom, + cute.append(d_smem_shape, d_epi_stage), + order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2), + ) + else: + # calculating the layout + d_epi_smem_layout_staged = cute.tile_to_shape( + d_smem_layout_atom, + cute.append(d_smem_shape, 1), + order=(1, 0, 2) if d_layout.is_m_major_c() else (0, 1, 2), + ) + + if y_epi_tile is not None: + y_smem_shape = y_epi_tile + # we force `y` to have same major mode as `z`. Otherwise the epilogue write is tricky + y_major_mode_size = y_epi_tile[1] if y_layout.is_n_major_c() else y_epi_tile[0] + y_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom( + y_layout, + y_dtype, + y_major_mode_size, + ), + y_dtype, + ) + y_epi_smem_layout_staged = cute.tile_to_shape( + y_smem_layout_atom, + cute.append(y_smem_shape, y_epi_stage), + order=(1, 0, 2) if y_layout.is_m_major_c() else (0, 1, 2), + ) + else: + y_epi_smem_layout_staged = None + + if c_dtype is not None: + assert c_layout is not None + c_smem_shape = c_epi_tile + c_major_mode_size = c_epi_tile[1] if c_layout.is_n_major_c() else c_epi_tile[0] + c_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom(c_layout, c_dtype, c_major_mode_size), + c_dtype, + ) + c_epi_smem_layout_staged = cute.tile_to_shape( + c_smem_layout_atom, + cute.append(c_smem_shape, c_epi_stage), + order=(1, 0, 2) if c_layout.is_m_major_c() else (0, 1, 2), + ) + else: + c_epi_smem_layout_staged = None + + if bias_dtype is not None and bias_layout is not None and bias_epi_tile is not None: + bias_epi_smem_layout_staged = cute.make_layout((1, tile_shape_mnk[1])) + else: + bias_epi_smem_layout_staged = None + + if s_dtype is not None: + s_epi_smem_layout_staged = cute.make_layout((tile_shape_mnk[0],)) + else: + s_epi_smem_layout_staged = None + + if prefetch_idx_size > 0: + prefetched_token_idx_smem_layout = cute.make_layout((prefetch_idx_size,)) + else: + prefetched_token_idx_smem_layout = None + + return ( + a_smem_layout_staged, + b_smem_layout_staged, + c_epi_smem_layout_staged, + bias_epi_smem_layout_staged, + d_epi_smem_layout_staged, + y_epi_smem_layout_staged, + s_epi_smem_layout_staged, + prefetched_token_idx_smem_layout, + ) + + @staticmethod + def _make_tma_epi_atoms_and_tensors( + tensor_d: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: Tuple[int, int], + store_or_load: str, + ) -> Tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for storing D or loading C. + + :param tensor_d: Output tensor D + :type tensor_d: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + assert store_or_load in ["load", "store"] + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile) + op = cpasync.CopyBulkTensorTileG2SOp() if store_or_load == "load" else cpasync.CopyBulkTensorTileS2GOp() + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom(op, tensor_d, epi_smem_layout, d_cta_v_layout) + return tma_atom_d, tma_tensor_d + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout_staged: cute.ComposedLayout, + smem_tile: Tuple[int, int], + mcast_dim: int, + ) -> Tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout_staged: Shared memory layout for the tensor + :type smem_layout_staged: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = cpasync.CopyBulkTensorTileG2SOp() if mcast_dim == 1 else cpasync.CopyBulkTensorTileG2SMulticastOp() + + smem_layout = cute.slice_(smem_layout_staged, (None, None, 0)) + tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + def _make_tiled_copy_2D( + self, + tensor: cute.Tensor, + tile_shape_0: cute.Int32, + tile_shape_1: cute.Int32, + is_row_major: bool, + threads_for_copy: Union[cutlass.Int32, int], + universal_copy_bits: cutlass.Int32, + is_g2s: Optional[bool] = True, + ) -> cute.TiledCopy: + copy_atom = cute.make_copy_atom( + ( + cute.nvgpu.cpasync.CopyG2SOp(cache_mode=cute.nvgpu.cpasync.LoadCacheMode.GLOBAL) + if const_expr(is_g2s) + else cute.nvgpu.CopyUniversalOp() + ), + tensor.element_type, + num_bits_per_copy=universal_copy_bits, + ) + copy_elems = universal_copy_bits // tensor.element_type.width + shape_dim_1 = cute.size(tile_shape_1) // copy_elems + # thread layout for copy + thread_layout = cute.make_layout((threads_for_copy // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1)) + if not is_row_major: + shape_dim_0 = cute.size(tile_shape_0) // copy_elems + thread_layout = cute.make_layout((shape_dim_0, threads_for_copy // shape_dim_0), stride=(1, shape_dim_0)) + # Value layout for copy + value_layout = cute.make_layout((1, copy_elems)) if is_row_major else cute.make_layout((copy_elems, 1)) + return cute.make_tiled_copy_tv(copy_atom, thread_layout, value_layout) + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + out_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + # tested a_dtype + if a_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + is_valid = False + # tested b_dtype + if b_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + }: + is_valid = False + # tested acc_dtype + if acc_dtype not in {cutlass.Float32, cutlass.Float16}: + is_valid = False + # tested d_dtype + if out_dtype not in { + cutlass.Float32, + cutlass.Float16, + cutlass.BFloat16, + }: + is_valid = False + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) + if a_dtype.width != b_dtype.width: + is_valid = False + # for Float8 types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"): + is_valid = False + return is_valid diff --git a/sonic-moe/torch-ext/sonicmoe/functional/moe_config.py b/sonic-moe/torch-ext/sonicmoe/functional/moe_config.py new file mode 100644 index 00000000..42c23502 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/moe_config.py @@ -0,0 +1,581 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import math +from dataclasses import dataclass + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +import torch +from cutlass import const_expr +from ..quack.tile_scheduler import RasterOrderOption + +from ..enums import ActivationType, is_glu +from .grouped_gemm import HopperWgmma_MoE_kernel + + +LIBRARY_NAME = "cutedsl_kernels" + + +def ceil_div(a: int, b: int): + return int(math.ceil(a / b)) + + +@dataclass +class HopperGEMMConfig: + tile_shape_mnk: cutlass.Constexpr[cute.Shape] = (128, 256, 64) + cluster_shape_mnk: cutlass.Constexpr[cute.Shape] = (2, 1) + epi_tile_size: cutlass.Constexpr[int] = 32 + ## assume we always use persistent kernel + # is_persistent: cutlass.Constexpr[bool] = True + is_pingpong: cutlass.Constexpr[bool] = False + raster_order: RasterOrderOption = RasterOrderOption.Heuristic + L2_group_size: int = 8 + initial_d_epi_stage: cutlass.Constexpr[int] = 4 + + +class HopperWgmma_MoE_Up_proj_Fwd: + def __init__(self, E: int, H: int, I: int, activation_type: ActivationType, inference_mode=False): + super().__init__() + is_glu_activation = is_glu(activation_type) + if is_glu_activation: + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + else: + assert ( + H % 64 == 0 and H >= 512 and I % 128 == 0 + ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" + # TODO: this assertion does not mean that the MoE impl prohibits such config. + # Instead, we just do not search for the best configs manually yet for small-shaped MoE + if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation): + up_config = HopperGEMMConfig( + tile_shape_mnk=(128, 256, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=(32 if not inference_mode else 64), + is_pingpong=False, + initial_d_epi_stage=2, + raster_order=RasterOrderOption.AlongM, + ) + elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation): + up_config = HopperGEMMConfig( + tile_shape_mnk=(192, 128, 64), + cluster_shape_mnk=(1, 1), + epi_tile_size=(32 if not inference_mode else 64), + is_pingpong=True, + initial_d_epi_stage=8, + raster_order=RasterOrderOption.AlongM, + ) + else: + raise NotImplementedError() + + compute_swiglu = False + compute_geglu = False + compute_reglu = False + + compute_relu_sq = False + compute_silu = False + compute_relu = False + compute_gelu = False + + if activation_type == ActivationType.SWIGLU: + compute_swiglu = True + elif activation_type == ActivationType.GEGLU: + compute_geglu = True + elif activation_type == ActivationType.REGLU: + compute_reglu = True + + elif activation_type == ActivationType.RELU_SQ: + compute_relu_sq = True + elif activation_type == ActivationType.RELU: + compute_relu = True + elif activation_type == ActivationType.SILU: + compute_silu = True + elif activation_type == ActivationType.GELU: + compute_gelu = True + + else: + raise NotImplementedError(f"Activation function {activation_type} not supported yet!") + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + up_config.tile_shape_mnk, + (*up_config.cluster_shape_mnk, 1), + pingpong=up_config.is_pingpong, + is_persistent=True, + compute_swiglu=compute_swiglu, + compute_reglu=compute_reglu, + compute_geglu=compute_geglu, + compute_relu_sq=compute_relu_sq, + compute_relu=compute_relu, + compute_silu=compute_silu, + compute_gelu=compute_gelu, + is_A_gather=True, + epi_tile_size=up_config.epi_tile_size, + initial_d_epi_stage=up_config.initial_d_epi_stage, + inference_mode=inference_mode, + ) + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + up_config.cluster_shape_mnk[0] * up_config.cluster_shape_mnk[1] + ) + self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + @cute.jit + def __call__( + self, mX, mW1, mZ, mY1, mB1, mE_offset, mX_gather, mD_tensormap, mY1_tensormap, mE_permute_order, stream + ): + return self.module( + mX, + mW1, + None, + mB1, + mZ, + mY1, + None, + None, + mE_offset, + mX_gather, + None, + None, + None, + None, + None, + mD_tensormap, + mY1_tensormap, + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) + + +class HopperWgmma_MoE_Down_proj_Fwd: + def __init__(self, E: int, H: int, I: int): + super().__init__() + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + if I >= 1024: + down_config = HopperGEMMConfig( + tile_shape_mnk=(128, 256, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=32, + is_pingpong=False, + initial_d_epi_stage=4, + raster_order=RasterOrderOption.AlongN, + ) + elif I >= 256: + down_config = HopperGEMMConfig( + tile_shape_mnk=(128, 192, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=(96 if H % 96 == 0 else 64), + is_pingpong=True, + initial_d_epi_stage=5, + raster_order=RasterOrderOption.AlongN, + ) + elif I >= 64: + down_config = HopperGEMMConfig( + tile_shape_mnk=(128, 192, 64), + cluster_shape_mnk=(1, 2), + epi_tile_size=64, + is_pingpong=True, + initial_d_epi_stage=8, + raster_order=RasterOrderOption.AlongN, + ) + else: + raise NotImplementedError() + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + down_config.tile_shape_mnk, + (*down_config.cluster_shape_mnk, 1), + pingpong=down_config.is_pingpong, + is_persistent=True, + compute_swiglu=False, + is_A_gather=False, + epi_tile_size=down_config.epi_tile_size, + initial_d_epi_stage=down_config.initial_d_epi_stage, + ) + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + down_config.cluster_shape_mnk[0] * down_config.cluster_shape_mnk[1] + ) + + @cute.jit + def __call__(self, mY1, mW2, mY2, mB2, mE_offset, mX_gather, mD_tensormap, mE_permute_order, stream): + # we are not really using mX_gather in the Grouped GEMM, + # but CuTe-DSL compiler disallows dynamic flow so we still need to pass this argument + return self.module( + mY1, + mW2, + None, + mB2, + mY2, + None, + None, + None, + mE_offset, + mX_gather, + None, + None, + None, + None, + None, + mD_tensormap, + None, + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) + + +class HopperWgmma_MoE_Down_proj_ActGrad_Bwd: + def __init__(self, E: int, H: int, I: int, activation_type: ActivationType): + super().__init__() + is_glu_activation = is_glu(activation_type) + if is_glu_activation: + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + else: + assert ( + H % 64 == 0 and H >= 512 and I % 128 == 0 + ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" + + # heavy register pressure due to pingpong + heavy epilogue + # effectively no alternatives to this config + dz_partial_ds_config = HopperGEMMConfig( + tile_shape_mnk=(128, 128, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=32, + initial_d_epi_stage=4, + is_pingpong=True, + raster_order=RasterOrderOption.Heuristic, + ) + + compute_swiglu = False + compute_geglu = False + compute_reglu = False + + compute_relu_sq = False + compute_silu = False + compute_relu = False + compute_gelu = False + + if activation_type == ActivationType.SWIGLU: + compute_swiglu = True + elif activation_type == ActivationType.GEGLU: + compute_geglu = True + elif activation_type == ActivationType.REGLU: + compute_reglu = True + + elif activation_type == ActivationType.RELU_SQ: + compute_relu_sq = True + elif activation_type == ActivationType.RELU: + compute_relu = True + elif activation_type == ActivationType.SILU: + compute_silu = True + elif activation_type == ActivationType.GELU: + compute_gelu = True + + else: + raise NotImplementedError(f"Activation function {activation_type} not supported yet!") + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + dz_partial_ds_config.tile_shape_mnk, + (*dz_partial_ds_config.cluster_shape_mnk, 1), + pingpong=dz_partial_ds_config.is_pingpong, + is_persistent=True, + compute_swiglu=compute_swiglu, + compute_reglu=compute_reglu, + compute_geglu=compute_geglu, + compute_relu_sq=compute_relu_sq, + compute_relu=compute_relu, + compute_silu=compute_silu, + compute_gelu=compute_gelu, + compute_dz_and_partial_ds_and_y1s=True, + is_A_gather=True, + epi_tile_size=dz_partial_ds_config.epi_tile_size, + initial_d_epi_stage=dz_partial_ds_config.initial_d_epi_stage, + ) + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + dz_partial_ds_config.cluster_shape_mnk[0] * dz_partial_ds_config.cluster_shape_mnk[1] + ) + + @cute.jit + def __call__( + self, + mDout, + mW2_trans, + mZ_FP32_if_GLU_else_BF16, + mDz_FP32_if_GLU_else_BF16, + mY1S, + mS, + mDS_partial, + mE_offset, + mX_gather, + mS_scatter, + tensormaps, + mE_permute_order, + stream, + ): + return self.module( + mDout, + mW2_trans, + mZ_FP32_if_GLU_else_BF16, + None, + mDz_FP32_if_GLU_else_BF16, + mY1S, + mS, + mDS_partial, + mE_offset, + mX_gather, + None, + mS_scatter, + None, + None, + tensormaps[0], + tensormaps[1], + tensormaps[2], + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) + + +class HopperWgmma_MoE_Down_proj_WeightGrad_Bwd: + def __init__(self, E: int, H: int, I: int): + super().__init__() + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + + if I >= 128: + dw2_config = HopperGEMMConfig( + tile_shape_mnk=(128, 256, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=16, + is_pingpong=False, + initial_d_epi_stage=6, + raster_order=RasterOrderOption.AlongN, + ) + elif I == 64: + dw2_config = HopperGEMMConfig( + tile_shape_mnk=(64, 192, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=32, + is_pingpong=True, + initial_d_epi_stage=6, + raster_order=RasterOrderOption.AlongN, + ) + else: + raise NotImplementedError() + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + dw2_config.tile_shape_mnk, + (*dw2_config.cluster_shape_mnk, 1), + pingpong=dw2_config.is_pingpong, + is_persistent=True, + compute_swiglu=False, + compute_weight_gradient=True, + compute_dz_and_partial_ds_and_y1s=False, + is_A_gather=True, + epi_tile_size=dw2_config.epi_tile_size, + initial_d_epi_stage=dw2_config.initial_d_epi_stage, + ) + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + dw2_config.cluster_shape_mnk[0] * dw2_config.cluster_shape_mnk[1] + ) + + @cute.jit + def __call__(self, mDout_trans, mY1S_trans, mDw2, mE_offset, mX_gather, tensormaps, mE_permute_order, stream): + return self.module( + mDout_trans, + mY1S_trans, + None, + None, + mDw2, + None, + None, + None, + mE_offset, + mX_gather, + None, + None, + None, + tensormaps[0], + None, + None, + None, + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) + + +class HopperWgmma_MoE_Up_proj_ActGrad_Bwd: + def __init__(self, E: int, H: int, I: int, is_glu_activation: bool): + super().__init__() + if is_glu_activation: + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + else: + assert ( + H % 64 == 0 and H >= 512 and I % 128 == 0 + ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" + + if (I >= 512 and is_glu_activation) or (I >= 1024 and not is_glu_activation): + dx_config = HopperGEMMConfig( + tile_shape_mnk=(128, 256, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=32, + is_pingpong=False, + initial_d_epi_stage=4, + raster_order=RasterOrderOption.AlongN, + ) + elif (I >= 64 and is_glu_activation) or (I >= 128 and not is_glu_activation): + dx_config = HopperGEMMConfig( + tile_shape_mnk=(128, 192, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=64, + is_pingpong=True, + initial_d_epi_stage=8, + raster_order=RasterOrderOption.AlongN, + ) + else: + raise NotImplementedError() + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + dx_config.tile_shape_mnk, + (*dx_config.cluster_shape_mnk, 1), + pingpong=dx_config.is_pingpong, + is_persistent=True, + compute_swiglu=False, + compute_dz_and_partial_ds_and_y1s=False, + is_A_gather=False, + epi_tile_size=dx_config.epi_tile_size, + ) + + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + dx_config.cluster_shape_mnk[0] * dx_config.cluster_shape_mnk[1] + ) + self.current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + @cute.jit + def __call__( + self, mDz, mW1_trans, mDx_expanded, mE_offset, mX_gather, mS_scatter, tensormaps, mE_permute_order, stream + ): + return self.module( + mDz, + mW1_trans, + None, + None, + mDx_expanded, + None, + None, + None, + mE_offset, + mX_gather, + None, + mS_scatter, + None, + None, + None, + tensormaps[0], + tensormaps[1], + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) + + +class HopperWgmma_MoE_Up_proj_WeightGrad_Bwd: + def __init__(self, E: int, H: int, I: int, is_glu_activation: bool): + super().__init__() + if is_glu_activation: + assert ( + H % 64 == 0 and H >= 512 and I % 64 == 0 + ), f"{LIBRARY_NAME} only supports GLU MoE with H % 64 == 0 (H >= 512) and I % 64 == 0" + else: + assert ( + H % 64 == 0 and H >= 512 and I % 128 == 0 + ), f"{LIBRARY_NAME} only supports non-GLU MoE with H % 64 == 0 (H >= 512) and I % 128 == 0" + + if (I >= 128 and is_glu_activation) or (I >= 256 and not is_glu_activation): + dw1_config = HopperGEMMConfig( + tile_shape_mnk=(128, 256, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=16, + is_pingpong=False, + initial_d_epi_stage=6, + raster_order=RasterOrderOption.Heuristic, + ) + elif (I == 64 and is_glu_activation) or (I == 128 and not is_glu_activation): + dw1_config = HopperGEMMConfig( + tile_shape_mnk=(256, 128, 64), + cluster_shape_mnk=(2, 1), + epi_tile_size=16, + is_pingpong=False, + initial_d_epi_stage=6, + raster_order=RasterOrderOption.AlongN, + ) + else: + raise NotImplementedError() + + self.module = HopperWgmma_MoE_kernel( + E, + cutlass.Float32, + dw1_config.tile_shape_mnk, + (*dw1_config.cluster_shape_mnk, 1), + pingpong=dw1_config.is_pingpong, + is_persistent=True, + compute_swiglu=False, + compute_weight_gradient=True, + compute_dz_and_partial_ds_and_y1s=False, + is_A_gather=True, + epi_tile_size=dw1_config.epi_tile_size, + ) + + self.max_active_clusters = cutlass.utils.HardwareInfo().get_max_active_clusters( + dw1_config.cluster_shape_mnk[0] * dw1_config.cluster_shape_mnk[1] + ) + + @cute.jit + def __call__(self, mX_trans, mDz_trans, mDw1_trans, mE_offset, mX_gather, tensormaps, mE_permute_order, stream): + return self.module( + mX_trans, + mDz_trans, + None, + None, + mDw1_trans, + None, + None, + None, + mE_offset, + mX_gather, + None, + None, + None, + tensormaps[0], + None, + None, + None, + None, + mE_permute_order, + const_expr(self.max_active_clusters), + stream, + ) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/reduction_over_k_gather.py b/sonic-moe/torch-ext/sonicmoe/functional/reduction_over_k_gather.py new file mode 100644 index 00000000..0c726964 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/reduction_over_k_gather.py @@ -0,0 +1,164 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from typing import Optional + +import torch +import triton +import triton.language as tl + +from ..utils import get_powers_of_2 + + +### This triton impl is equivalent as the cute-dsl impl shown above, +# and also achieves similar memory bandwidth on H100 for large K and H. +# However, for small K and H, this impl is better by autotuning so we use it as the default. +def _get_triton_autotune_configs() -> list[triton.Config]: + configs = [] + for BLOCK_H in get_powers_of_2(256, 4096): + for BLOCK_K in get_powers_of_2(1, 128): + for num_warps in [4, 8]: + if BLOCK_K * BLOCK_H <= 32768: + configs.append( + triton.Config({"BLOCK_H": BLOCK_H, "BLOCK_K": BLOCK_K}, num_warps=num_warps, num_stages=4) + ) + return configs + + +def _prune_triton_autotune_config(configs, nargs, **kw): + pruned_configs = [] + for c in configs: + BLOCK_H = c.kwargs["BLOCK_H"] + BLOCK_K = c.kwargs["BLOCK_K"] + H = kw["H"] + MAX_K = kw["MAX_K"] + if ( + BLOCK_H <= triton.next_power_of_2(H) + and BLOCK_K <= triton.next_power_of_2(MAX_K) + and min(H * MAX_K, 1024) <= (BLOCK_H * BLOCK_K) + ): + pruned_configs.append(c) + + if len(pruned_configs) == 0: + return configs + else: + return pruned_configs + + +@triton.autotune( + configs=_get_triton_autotune_configs(), + key=["H", "MAX_K", "w_is_None", "is_varlen_K"], + prune_configs_by={"early_config_prune": _prune_triton_autotune_config}, +) +@triton.jit +def token_gather_sum_kernel( + x_ptr, # (Mtotal, H) + w_ptr, # (Mtotal,) + M_perm_ptr, # (Mtotal,) int32 + M_offset_ptr, # (T+1,) int32 + out_ptr, # (T, H) + T, + H: tl.constexpr, + MAX_K: tl.constexpr, + # strides + stride_xM: tl.constexpr, + stride_xH: tl.constexpr, + stride_outT: tl.constexpr, + stride_outH: tl.constexpr, + # tile sizes + BLOCK_H: tl.constexpr, + BLOCK_K: tl.constexpr, + w_is_None: tl.constexpr, + is_varlen_K: tl.constexpr, +): + # 1D tiling over T only + pid_t = tl.program_id(axis=0) + t_idx = pid_t.to(tl.uint32) + + # Load segment starts and ends for this token + if is_varlen_K: + Ms = tl.load(M_offset_ptr + t_idx).to(tl.uint32) + Me = tl.load(M_offset_ptr + t_idx + 1).to(tl.uint32) + K_this_token = Me - Ms # actual K for this token + else: + Ms = MAX_K * t_idx + K_this_token: tl.constexpr = MAX_K + + # Outer loop over H tiles + for h_tile in tl.static_range(triton.cdiv(H, BLOCK_H)): + h_idx = (h_tile * BLOCK_H + tl.arange(0, BLOCK_H)).to(tl.uint32) # [BLOCK_H] + m_h = h_idx < H + + # Initialize accumulator for this H tile + acc = tl.zeros([BLOCK_H], dtype=tl.float32) # [BLOCK_H] + + # Inner loop over K tiles + for k_tile in tl.range(tl.cdiv(K_this_token, BLOCK_K)): + k_offset = k_tile * BLOCK_K + + k_idx = (k_offset + tl.arange(0, BLOCK_K)).to(tl.uint32) # [BLOCK_K] + + # Mask for valid K indices + m_k = k_idx < K_this_token # [BLOCK_K] + + # Absolute positions into M_perm and w + m_abs = Ms + k_idx # [BLOCK_K] + + # Gather permuted indices + perm_idx = tl.load(M_perm_ptr + m_abs, mask=m_k, other=0).to(tl.uint32) # [BLOCK_K] + + # Load x values: [BLOCK_K, BLOCK_H] + x_ptrs = x_ptr + perm_idx[:, None] * stride_xM + h_idx[None, :] * stride_xH + x_mask = m_k[:, None] & m_h[None, :] + x_vals = tl.load(x_ptrs, mask=x_mask, other=0.0).to(tl.float32) + + # Reduce along K dimension and add to accumulator + if w_is_None: + acc += tl.sum(x_vals, axis=0) # [BLOCK_H] + else: + w_vals = tl.load(w_ptr + m_abs, mask=m_k, other=0.0).to(tl.float32) # [BLOCK_K] + acc += tl.sum(x_vals * w_vals[:, None], axis=0) # [BLOCK_H] + + # Store final result for this H tile (only once!) + out_ptrs = out_ptr + t_idx * stride_outT + h_idx * stride_outH + tl.store(out_ptrs, acc, mask=m_h) + + +def token_gather_and_sum_varlen_K_triton( + x: torch.Tensor, # (Mtotal, H) + w: Optional[torch.Tensor], # (Mtotal,) + out: torch.Tensor, # (T, H) + M_perm: torch.Tensor, # (Mtotal,) int32 + M_offset: torch.Tensor, # (T+1,) int32, variable K per token + T: int, + MAX_K: int, # maximum K across all tokens + H: int, + is_varlen_K: bool, +): + """ + 1D parallelization over T, with iterative accumulation over K tiles and H tiles. + Supports variable K per token. + + out[i, :] = sum_{j=0..K[i]-1} x[M_perm[M_offset[i] + j], :] * w[M_offset[i] + j] + + where K[i] = M_offset[i+1] - M_offset[i] can vary per token. + """ + + # 1D grid over T only + token_gather_sum_kernel[(T,)]( + x, + w, + M_perm, + M_offset, + out, + T=T, + H=H, + MAX_K=MAX_K, + stride_xM=x.stride(0), + stride_xH=x.stride(1), + stride_outT=out.stride(0), + stride_outH=out.stride(1), + w_is_None=(w is None), + is_varlen_K=is_varlen_K, + ) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/tile_scheduler.py b/sonic-moe/torch-ext/sonicmoe/functional/tile_scheduler.py new file mode 100644 index 00000000..f9d9dd81 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/tile_scheduler.py @@ -0,0 +1,91 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from __future__ import annotations + +import cutlass +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from ..quack.pipeline import PipelineStateWAdvance +from ..quack.tile_scheduler import TileScheduler, VarlenMTileScheduler + + +class SonicMoETileScheduler(TileScheduler): + @staticmethod + @cute.jit + def create( + params: TileScheduler.Params, + tile_count: cute.Tensor | None = None, + scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None, + is_scheduler_warp: bool | Boolean = False, + *, + loc=None, + ip=None, + ) -> SonicMoETileScheduler: + """is_scheduler_warp should only be true for one warp in the whole cluster""" + stages = 0 + if const_expr(not params.is_persistent): + cidx, cidy, _ = cute.arch.cluster_idx() + cdimx, _, _ = cute.arch.cluster_dim() + cluster_id = cidx + cidy * cdimx + current_work_linear_idx = Int32(cluster_id) + else: + _, _, bidz = cute.arch.block_idx() + current_work_linear_idx = Int32(bidz) + if const_expr(params.tile_count_semaphore is not None): + assert tile_count is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(tile_count)) + return SonicMoETileScheduler( + current_work_linear_idx, + Int32(0), # num_tiles_executed + tile_count, + scheduler_pipeline, + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + params, + loc=loc, + ip=ip, + ) + + def prefetch_next_work(self, *, advance_count: int = 1, loc=None, ip=None): + old_current_work_linear_idx = self._current_work_linear_idx + if const_expr(self.params.is_persistent): + num_persistent_clusters = cute.arch.grid_dim()[2] + self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters) + future_tile_coord_mnkl = self.get_current_work() + self._current_work_linear_idx = old_current_work_linear_idx + return future_tile_coord_mnkl + + +class SonicMoEVarlenMTileScheduler(VarlenMTileScheduler, SonicMoETileScheduler): + @staticmethod + @cute.jit + def create( + params: VarlenMTileScheduler.Params, + tile_count: cute.Tensor | None = None, + scheduler_pipeline: cutlass.pipeline.PipelineAsync | None = None, + is_scheduler_warp: bool | Boolean = False, + *, + loc=None, + ip=None, + ) -> SonicMoEVarlenMTileScheduler: + stages = 0 + _, _, bidz = cute.arch.block_idx() + current_work_linear_idx = Int32(bidz) + if const_expr(params.tile_count_semaphore is not None): + assert tile_count is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(tile_count)) + return SonicMoEVarlenMTileScheduler( + current_work_linear_idx, + Int32(0), # num_tiles_executed + Int32(0), # current_batch_idx + Int32(0), # num_work_idx_before_cur_batch + tile_count, + scheduler_pipeline, + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + params, + loc=loc, + ip=ip, + ) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/topk_softmax.py b/sonic-moe/torch-ext/sonicmoe/functional/topk_softmax.py new file mode 100644 index 00000000..6ed5a79a --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/topk_softmax.py @@ -0,0 +1,195 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +# this impl is adapted from QuACK's topk https://github.com/Dao-AILab/quack/blob/main/quack/topk.py +import math +from typing import Type + +import cuda.bindings.driver as cuda +import cutlass +import cutlass.cute as cute +from ..quack import utils +from cutlass import const_expr +from ..quack.sort.bitonic_sort import bitonic_topk +from triton import next_power_of_2 + +from ..utils import domain_offset_i64 + + +class TopK_Softmax: + def __init__( + self, + input_dtype: Type[cutlass.Numeric], + output_dtype: Type[cutlass.Numeric], + N: int, + k: int, + require_softmax_fusion: bool = True, + ): + self.input_dtype = input_dtype + self.output_dtype = output_dtype + self.N = N + self.input_vecsize = 128 // input_dtype.width + self.output_vecsize = 128 // output_dtype.width + self.k = k + self.next_power_of_2_N = next_power_of_2(N) + self.next_power_of_2_K = next_power_of_2(k) + assert k <= 128 and k <= N + assert N <= 4096 and N % 8 == 0 + assert input_dtype.width <= output_dtype.width, "input bitwidth must <= output bitwidth" + + self.require_softmax_fusion = require_softmax_fusion + + def _calculate_threads_per_row(self): + # we want num_elems_per_thread >= self.k + # and each thread can handle at most 64 elements + N = self.next_power_of_2_N + num_threads_per_row = max(min(N // self.k, 32, N // 64), 1) + return num_threads_per_row + + def _get_tv_layout(self, vecsize): + N = self.next_power_of_2_N + num_threads = 128 if N <= 16384 else 256 + threads_per_row = self._calculate_threads_per_row() + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tv_layout = cute.make_layout( + ((threads_per_row, cols_per_block), (vecsize, num_blocks_N)), + stride=( + (vecsize * cols_per_block, 1), + (cols_per_block, cols_per_block * vecsize * threads_per_row), + ), + ) + return tiler_mn, tv_layout + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + stream: cuda.CUstream, + ): + assert mX.element_type == self.input_dtype + assert mValues.element_type == self.output_dtype + assert mIndices.element_type == cutlass.Int32 + input_tiler_mn, input_tv_layout = self._get_tv_layout(self.input_vecsize) + output_tiler_mn, output_tv_layout = self._get_tv_layout(self.output_vecsize) + + num_threads = cute.size(input_tv_layout, mode=[0]) + self.kernel(mX, mValues, mIndices, input_tv_layout, input_tiler_mn, output_tv_layout, output_tiler_mn).launch( + grid=[cute.ceil_div(mX.shape[0], input_tiler_mn[0]), 1, 1], + block=[num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + input_tv_layout: cute.Layout, + input_tiler_mn: cute.Shape, + output_tv_layout: cute.Layout, + output_tiler_mn: cute.Shape, + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + # slice for CTAs + # We use domain_offset_i64 to deal with tensors larger than 2^31 elements + mX = domain_offset_i64((bidx * input_tiler_mn[0], 0), mX) + gX = cute.local_tile(mX, input_tiler_mn, (0, 0)) + cX = cute.local_tile(idX, input_tiler_mn, (bidx, 0)) + + # declare the atoms which will be used later for memory copy + copy_atom_load_X = cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), gX.element_type, num_bits_per_copy=128) + thr_copy_X = cute.make_tiled_copy(copy_atom_load_X, input_tv_layout, input_tiler_mn).get_slice(tidx) + tXgX = thr_copy_X.partition_S(gX) + tXcX = thr_copy_X.partition_S(cX)[(0, None), None, None] + + # allocate fragments for gmem->rmem + tXrX = cute.make_rmem_tensor_like(tXgX) + + is_even_N = const_expr(shape[1] == input_tiler_mn[1]) + tXpX = ( + utils.predicate_k(thr_copy_X.partition_S(cX), limit=shape[1]) + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)) + else None + ) + if tXcX[0][0] < shape[0]: + cute.copy(copy_atom_load_X, tXgX, tXrX, pred=tXpX) + tXrX_f32 = cute.make_rmem_tensor(tXrX.shape, cutlass.Float32) + tXrX_f32.store(tXrX.load().to(cutlass.Float32)) + + # Encode the indices into the bottom bits of values. + log_N = int(math.log2(self.next_power_of_2_N)) + idx_mask = const_expr((1 << log_N) - 1) + input_vecsize = cutlass.const_expr(input_tv_layout.shape[1][0]) + tXrX_u32 = cute.recast_tensor(tXrX_f32, cutlass.Uint32) + # Encode indices into the last log_N bits of tXrX_u32 + for i in cutlass.range(cute.size(tXrX_u32), unroll_full=True): + # tXcX only keeps track of the indices for every @vecsize elements + col_idx = cutlass.Uint32(tXcX[i // input_vecsize][1] + i % input_vecsize) + # If positive, invert the bits of the index, so that if there's a tie, + # indices coming from a earlier column will win. + encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx + # Mask to keep only the last log_N bits of the encoded index + encoded_idx = encoded_idx & idx_mask + # Clear the last log_N bits and set them to our encoded index + tXrX_u32[i] = (tXrX_u32[i] & ~idx_mask) | encoded_idx + + # Fill OOB values with -inf for top-k + if const_expr((not is_even_N) or (self.N != self.next_power_of_2_N)): + utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + threads_per_row = input_tv_layout.shape[0][0] + topk_vals = bitonic_topk(tXrX_f32, self.next_power_of_2_K, warp_width=threads_per_row) + + # Extract indices and clean values + topk_vals_u32 = cute.recast_tensor(topk_vals, cutlass.Uint32) + topk_indices = cute.make_rmem_tensor(self.k, cutlass.Int32) + for i in cutlass.range_constexpr(self.k): + # Extract the encoded index from the last log_N bits + encoded_idx = topk_vals_u32[i] & idx_mask + # Check if original value was positive by looking at the cleaned value + topk_vals_u32[i] = topk_vals_u32[i] & ~idx_mask # Clear last log_N bits + # If positive, we need to invert the bits back to get original index + col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + topk_indices[i] = cutlass.Int32(col_idx & idx_mask) + + if const_expr(self.require_softmax_fusion): + topk_vals_max = -cutlass.Float32.inf + for i in cutlass.range_constexpr(self.k): + topk_vals_max = cute.arch.fmax(topk_vals[i], topk_vals_max) + + topk_exp_sum = cutlass.Int32(0.0) + for i in cutlass.range_constexpr(self.k): + topk_vals[i] = cute.math.exp(topk_vals[i] - topk_vals_max) + topk_exp_sum = topk_exp_sum + topk_vals[i] + + for i in cutlass.range_constexpr(self.k): + topk_vals[i] = topk_vals[i] / topk_exp_sum + + # Convert cleaned values to output type + topk_vals_out = cute.make_rmem_tensor_like(topk_indices, mValues.element_type) + for i in cutlass.range_constexpr(self.k): + topk_vals_out[i] = topk_vals[i].to(mValues.element_type) + + row = tXcX[0][0] + # Only the 1st thread in this row writes the top-k values and indices + output_vecsize = cutlass.const_expr(output_tv_layout.shape[1][0]) + if row < shape[0] and tXcX[0][1] == 0: + # Vectorized write + elems_per_store = const_expr(math.gcd(output_vecsize, self.k)) + mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + for i in cutlass.range_constexpr(cute.size(topk_vals_out_store.shape, [1])): + cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/README.md b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/README.md new file mode 100644 index 00000000..d8f179a1 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/README.md @@ -0,0 +1 @@ +The `bitmatrix.py` contains 3 functions adapted from the ![triton official example](https://github.com/triton-lang/triton/blob/main/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py). We make some minor modifications to `_bitmatrix_metadata_compute_stage1` and `_bitmatrix_metadata_compute_stage2`. \ No newline at end of file diff --git a/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/TRITON_LICENSE b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/TRITON_LICENSE new file mode 100644 index 00000000..0f3852f0 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/TRITON_LICENSE @@ -0,0 +1,23 @@ +/* +* Copyright 2018-2020 Philippe Tillet +* Copyright 2020-2022 OpenAI +* +* Permission is hereby granted, free of charge, to any person obtaining +* a copy of this software and associated documentation files +* (the "Software"), to deal in the Software without restriction, +* including without limitation the rights to use, copy, modify, merge, +* publish, distribute, sublicense, and/or sell copies of the Software, +* and to permit persons to whom the Software is furnished to do so, +* subject to the following conditions: +* +* The above copyright notice and this permission notice shall be +* included in all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +* MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +* IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY +* CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, +* TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE +* SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +*/ \ No newline at end of file diff --git a/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/__init__.py b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/__init__.py new file mode 100644 index 00000000..e2bc2802 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/__init__.py @@ -0,0 +1,351 @@ +import math + +import torch +import triton +import triton.language as tl + +from ..._ops_compat import add_op_namespace_prefix +from .bitmatrix import _bitmatrix_metadata_compute_stage1, _bitmatrix_metadata_compute_stage2, _keyed_add + + +@triton.jit +def _compute_col_partial_sum_kernel( + topk_indices_ptr, + partial_sum_ptr, + T, + E: tl.constexpr, + n_tiles, + TOKENS_PER_TILE: tl.constexpr, + K_POW2: tl.constexpr, # next_power_of_2(K), + K: tl.constexpr, # actual number of experts per token + E_POW2: tl.constexpr, # next_power_of_2(E) +): + # One CTA per tile. Tile `t` covers tokens [t * TOKENS_PER_TILE, (t+1) * TOKENS_PER_TILE). + # Produces partial_sum[e, tile_id] = number of entries in this tile routed to expert e. + # Layout: partial_sum is [E, n_tiles] (row-major), so partial_sum[e, t] = partial_sum_ptr + e * n_tiles + t. + # Caller transposes to [n_tiles, E] before passing to stage1/stage2. + tile_id = tl.program_id(0) + + # Zero this tile's column in partial_sum[*, tile_id]. + # Chunked by E_POW2 to keep vector width a power of 2. + for e_start in tl.static_range(0, E, E_POW2): + e_offs = e_start + tl.arange(0, E_POW2) + tl.store( + partial_sum_ptr + e_offs * n_tiles + tile_id, + tl.zeros([E_POW2], tl.int32), + mask=e_offs < E, + ) + + # Load expert ids for this tile: shape [TOKENS_PER_TILE, K_POW2]. + # Tokens beyond T and k-slots beyond K are masked out (other=-1). + tok_offs = tile_id * TOKENS_PER_TILE + tl.arange(0, TOKENS_PER_TILE) + k_offs = tl.arange(0, K_POW2) + tok_mask = tok_offs < T + + load_mask = tok_mask[:, None] & (k_offs[None, :] < K) + safe_k = tl.minimum(k_offs, K - 1) # avoid OOB when k_offs >= K + expert_ids = tl.load( + topk_indices_ptr + tok_offs[:, None] * K + safe_k[None, :], + mask=load_mask, + other=-1, + ) + + # Flatten to [TOKENS_PER_TILE * K_POW2] and histogram into partial_sum. + # safe_experts remaps masked (-1) entries to expert 0 (harmless: flat_mask=False). + flat_experts = tl.reshape(expert_ids, [TOKENS_PER_TILE * K_POW2]) + flat_mask = tl.reshape(load_mask, [TOKENS_PER_TILE * K_POW2]) + safe_experts = tl.where(flat_mask, flat_experts, 0) + + tl.atomic_add( + partial_sum_ptr + safe_experts * n_tiles + tile_id, + tl.full([TOKENS_PER_TILE * K_POW2], 1, dtype=tl.int32), + mask=flat_mask, + ) + + +@torch.library.custom_op( + add_op_namespace_prefix("triton_kernels__TC_topk_router_metadata"), + mutates_args={ + "expert_frequency", + "expert_frequency_offset", + "x_gather_idx", + "s_scatter_idx", + "s_reverse_scatter_idx", + }, +) +def TC_topk_router_metadata_triton( + topk_router_indices: torch.Tensor, + E: int, + expert_frequency: torch.Tensor, + expert_frequency_offset: torch.Tensor, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, +) -> None: + T, K = topk_router_indices.size() + TK = T * K + device = topk_router_indices.device + E_POW2 = triton.next_power_of_2(E) + K_POW2 = triton.next_power_of_2(K) + TOKENS_PER_BLOCK = 1024 // K_POW2 + n_tiles = triton.cdiv(T, TOKENS_PER_BLOCK) + + # ── Kernel 1: tiled histogram ───────────────────────────────────────────── + # col_partial_sum_trans[E, n_tiles]: raw per-expert-per-tile counts. + # Stored transposed so each CTA writes to its own column (tile_id), avoiding + # cross-CTA write conflicts. Transposed back to [n_tiles, E] for stage1/stage2. + col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device) + _compute_col_partial_sum_kernel[(n_tiles,)]( + topk_router_indices, + col_partial_sum_trans, + T, + E, + n_tiles, + TOKENS_PER_TILE=TOKENS_PER_BLOCK, + K_POW2=K_POW2, + K=K, + E_POW2=E_POW2, + ) + + expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32)) + col_partial_sum = col_partial_sum_trans.T # [n_tiles, E] + + # ── Kernel 2: stage1 ───────────────────────────────────────────────────── + # - For each expert e (pid < E): convert col_partial_sum[*, e] from raw + # counts to exclusive prefix sums over tiles in-place. + # - For pid == E: write exclusive cumsum of expert_freq_offset into + # expert_freq_off[0:E] (= col_offs, a view into expert_freq_off). + + _bitmatrix_metadata_compute_stage1[(E + 2,)]( + expert_frequency, + expert_frequency_offset, + E, + col_partial_sum, + n_tiles, + TK, + BLOCK_M=128, + BLOCK_N=E_POW2, + ) + + # ── Kernel 3: stage2 ───────────────────────────────────────────────────── + # For each tile: sort entries by expert, compute output positions, scatter. + _bitmatrix_metadata_compute_stage2[(n_tiles,)]( + s_scatter_idx, + s_reverse_scatter_idx, + x_gather_idx, + topk_router_indices, + T, + col_partial_sum, + n_tiles, + expert_frequency_offset[:E], + K_POW2=K_POW2, + TOKENS_PER_BLOCK=TOKENS_PER_BLOCK, + K=K, + ) + + +# ── general_routing_router_metadata_triton --- Kernel 1: tiled histogram over flat selected_E ──────────────────────────── +@triton.jit +def _general_compute_col_partial_sum_kernel( + selected_E_ptr, + partial_sum_ptr, # [E, n_tiles], column-major per tile + TK, + E: tl.constexpr, + n_tiles, + BLOCK_SIZE: tl.constexpr, + E_POW2: tl.constexpr, +): + tile_id = tl.program_id(0) + + # Zero this tile's column in partial_sum[*, tile_id]. + for e_start in tl.static_range(0, E, E_POW2): + e_offs = e_start + tl.arange(0, E_POW2) + tl.store( + partial_sum_ptr + e_offs * n_tiles + tile_id, + tl.zeros([E_POW2], tl.int32), + mask=e_offs < E, + ) + + # Load expert ids for this tile (flat indexing into selected_E). + offs = tile_id * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offs < TK + expert_ids = tl.load(selected_E_ptr + offs, mask=mask, other=-1) + + safe_experts = tl.where(mask, expert_ids, 0) + tl.atomic_add( + partial_sum_ptr + safe_experts * n_tiles + tile_id, + tl.full([BLOCK_SIZE], 1, dtype=tl.int32), + mask=mask, + ) + + +# ── general_routing_router_metadata_triton --- Kernel 3: sort entries by expert within each tile, scatter ──────────────── +@triton.jit +def _general_metadata_compute_stage2( + s_scatter_idx_ptr, + s_reverse_scatter_idx_ptr, + x_gather_idx_ptr, + selected_E_ptr, + sorted_selected_T_ptr, + TK, + partial_sum_ptr, # [n_tiles, E] with strides (1, n_tiles) + n_tiles, + expert_offs_ptr, + BLOCK_SIZE: tl.constexpr, +): + tl.static_assert(BLOCK_SIZE <= 32768) + + pid_m = tl.program_id(0) + offs_local = tl.arange(0, BLOCK_SIZE) + offs_global = pid_m * BLOCK_SIZE + offs_local + mask = offs_global < TK + + # Load expert id for each entry in this tile. + expert = tl.load(selected_E_ptr + offs_global, mask=mask, other=-1).to(tl.uint32) + + # Pack (expert, local_offset) into uint32 and sort by expert. + # Upper 16 bits = expert id, lower 16 bits = pre-sort local offset. + kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0) + expert = kv_pairs >> 16 + mask = expert != 0xFFFF + + # Segmented scan for within-expert rank. + scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001 + inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add) + within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF + + # Output position = expert_offs[e] + partial_sum[tile, e] + within_expert_rank. + s_reverse_scatter_val = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask) + s_reverse_scatter_val += tl.load(expert_offs_ptr + expert, mask=mask) + s_reverse_scatter_val += within_expert_rank + + # Recover pre-sort entry index and look up the token index. + presort_offs = kv_pairs & 0xFFFF + entry_idx = pid_m * BLOCK_SIZE + presort_offs + token_idx = tl.load(sorted_selected_T_ptr + entry_idx, mask=mask) + + tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_val, mask=mask) + tl.store(s_scatter_idx_ptr + s_reverse_scatter_val, entry_idx, mask=mask) + tl.store(x_gather_idx_ptr + s_reverse_scatter_val, token_idx, mask=mask) + + +# ── general_routing_router_metadata_triton --- Kernel 4: parallel binary search for token offset ───────────────────────── +# Since sorted_selected_T is sorted ascending, num_activated_expert_per_token_offset[t] +# is exactly searchsorted_left(sorted_selected_T, t): the index of the first entry +# with token index >= t. We compute this via parallel binary search over T+1 queries, +# replacing the PyTorch bincount + cumsum path. +@triton.jit +def _token_offset_searchsorted_kernel( + sorted_T_ptr, # [TK] int32, sorted ascending + offset_ptr, # [T+1] int32, output + T, # number of tokens + TK, # length of sorted_T + BLOCK_SIZE: tl.constexpr, + N_ITERS: tl.constexpr, # ceil(log2(TK + 1)), controls binary search depth +): + pid = tl.program_id(0) + t_offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = t_offs <= T # T+1 total values: offset[0], ..., offset[T] + + t_vals = t_offs.to(tl.int32) + + # Binary search: find smallest i such that sorted_T[i] >= t_vals + lo = tl.zeros([BLOCK_SIZE], dtype=tl.int32) + hi = tl.full([BLOCK_SIZE], TK, dtype=tl.int32) + + for _ in tl.static_range(0, N_ITERS): + mid = (lo + hi) >> 1 + # When mid >= TK, treat the value as +inf (>= any t), so hi = mid. + safe_mid = tl.where(mid < TK, mid, 0) + val = tl.load(sorted_T_ptr + safe_mid, mask=mask & (TK > 0), other=T) + go_right = (val < t_vals) & (mid < TK) + lo = tl.where(go_right, mid + 1, lo) + hi = tl.where(go_right, hi, mid) + + tl.store(offset_ptr + t_offs, lo, mask=mask) + + +@torch.library.custom_op( + add_op_namespace_prefix("triton_kernels__general_routing_router_metadata"), + mutates_args={ + "expert_frequency", + "expert_frequency_offset", + "x_gather_idx", + "s_scatter_idx", + "s_reverse_scatter_idx", + "num_activated_expert_per_token_offset", + }, +) +def general_routing_router_metadata_triton( + sorted_selected_T: torch.Tensor, + selected_E: torch.Tensor, + T: int, + E: int, + expert_frequency: torch.Tensor, + expert_frequency_offset: torch.Tensor, + x_gather_idx: torch.Tensor, + s_scatter_idx: torch.Tensor, + s_reverse_scatter_idx: torch.Tensor, + num_activated_expert_per_token_offset: torch.Tensor, +) -> None: + TK = selected_E.size(0) + device = selected_E.device + E_POW2 = triton.next_power_of_2(E) + BLOCK_SIZE = 1024 + n_tiles = triton.cdiv(TK, BLOCK_SIZE) + + # ── Kernel 1: tiled histogram ───────────────────────────────────────── + col_partial_sum_trans = torch.empty(E, n_tiles, dtype=torch.int32, device=device) + _general_compute_col_partial_sum_kernel[(n_tiles,)]( + selected_E, + col_partial_sum_trans, + TK, + E, + n_tiles, + BLOCK_SIZE=BLOCK_SIZE, + E_POW2=E_POW2, + ) + + expert_frequency.copy_(col_partial_sum_trans.sum(dim=1, dtype=torch.int32)) + col_partial_sum = col_partial_sum_trans.T # [n_tiles, E], strides (1, n_tiles) + + # ── Kernel 2: stage1 ───────────────────────────────────────────────── + _bitmatrix_metadata_compute_stage1[(E + 2,)]( + expert_frequency, + expert_frequency_offset, + E, + col_partial_sum, + n_tiles, + TK, + BLOCK_M=128, + BLOCK_N=E_POW2, + ) + + # ── Kernel 3: stage2 ───────────────────────────────────────────────── + _general_metadata_compute_stage2[(n_tiles,)]( + s_scatter_idx, + s_reverse_scatter_idx, + x_gather_idx, + selected_E, + sorted_selected_T, + TK, + col_partial_sum, + n_tiles, + expert_frequency_offset[:E], + BLOCK_SIZE=BLOCK_SIZE, + ) + + # ── Kernel 4: num_activated_expert_per_token_offset via searchsorted ── + # sorted_selected_T is sorted ascending, so offset[t] = searchsorted_left(sorted_T, t). + # Parallel binary search: each thread handles one token index, O(log TK) work. + N_ITERS = max(1, math.ceil(math.log2(TK + 1))) + TOKEN_BLOCK = 1024 + n_token_blocks = triton.cdiv(T + 1, TOKEN_BLOCK) + _token_offset_searchsorted_kernel[(n_token_blocks,)]( + sorted_selected_T, + num_activated_expert_per_token_offset, + T, + TK, + BLOCK_SIZE=TOKEN_BLOCK, + N_ITERS=N_ITERS, + ) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/bitmatrix.py b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/bitmatrix.py new file mode 100644 index 00000000..1b010806 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/triton_kernels/bitmatrix.py @@ -0,0 +1,147 @@ +import triton +import triton.language as tl + + +# https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L33 +@triton.jit +def _keyed_add(x, y): + # we keep the key in the upper 16 bits of a uint32: + key_mask: tl.constexpr = 0xFFFF0000 + + kx = x & key_mask + ky = y & key_mask + z = tl.where(kx == ky, x + y - kx, y) + return z + + +# Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44 +@triton.jit +def _bitmatrix_metadata_compute_stage1( + expert_freq_ptr, + expert_freq_offs_ptr, + E: tl.constexpr, + partial_sum_ptr, + n_tiles, + TK, + BLOCK_M: tl.constexpr, # chunk size for iterating over tiles per expert + BLOCK_N: tl.constexpr, # chunk size for iterating over experts in cumsum +): + # Assume grid size == E + 1 + + pid = tl.program_id(0) + if pid < E: + # convert partial_sum[e, *] from raw counts to exclusive prefix + # sums over tiles. After this kernel, partial_sum[e, t] = + # number of entries for expert e in tiles 0..t-1. + + # This is read by stage2 to locate each entry's position within expert e's contiguous output segment. + expert_partial_sum_ptr = partial_sum_ptr + pid * n_tiles + curr_sum = 0 + for start in range(0, n_tiles, BLOCK_M): + offs = start + tl.arange(0, BLOCK_M) + tile_counts = tl.load(expert_partial_sum_ptr + offs, mask=offs < n_tiles, other=0) + excl_cumsum = tl.cumsum(tile_counts, 0) - tile_counts + curr_sum + curr_sum += tl.sum(tile_counts, 0) + tl.store(expert_partial_sum_ptr + offs, excl_cumsum, mask=offs < n_tiles) + elif pid == E: + # Exclusive prefix sum of per-expert total counts → expert_offs[e]. + # expert_freq_offset[e] = total entries routed to expert e (from A.sum(dim=1)). + # expert_offs[e] = sum of expert_freq_offset[0..e-1] = global start of expert e. + curr_sum = 0 + for start in tl.static_range(0, E, BLOCK_N): + offs = start + tl.arange(0, BLOCK_N) + expert_freq = tl.load(expert_freq_ptr + offs, mask=offs < E, other=0) + excl_cumsum = tl.cumsum(expert_freq, 0) - expert_freq + curr_sum + curr_sum += tl.sum(expert_freq, 0) + tl.store(expert_freq_offs_ptr + offs, excl_cumsum, mask=offs < E) + elif pid == E + 1: + # expert_freq_off[E] = TK (total number of entries) + tl.store(expert_freq_offs_ptr + E, TK) + + +# Adapted from https://github.com/triton-lang/triton/blob/434aecbe933af6a8d49595d4197bfc3df7618748/python/triton_kernels/triton_kernels/tensor_details/bitmatrix.py#L44 +@triton.jit +def _bitmatrix_metadata_compute_stage2( + s_scatter_idx_ptr, + s_reverse_scatter_idx_ptr, + x_gather_idx_ptr, + topk_indices_ptr, + T, + partial_sum_ptr, + n_tiles, + expert_offs_ptr, + K_POW2: tl.constexpr, # padded K, == BLOCK_SIZE / BLOCK + K: tl.constexpr, # actual experts per token + TOKENS_PER_BLOCK: tl.constexpr, # tokens per tile +): + # One CTA per tile, same tiling as _compute_col_partial_sum_kernel. + # For each entry (token t, k-slot k) in this tile: + # s_reverse_scatter_idx[entry_idx] = output position in expert-sorted order + # s_scatter_idx[output_pos] = entry_idx (inverse permutation) + # x_gather_idx[output_pos] = token index (= entry_idx // K) + # + # Output position = expert_offs[e] (global start of expert e) + # + partial_sum[tile, e] (entries for e in earlier tiles, after stage1) + # + within_expert_rank (position within this tile's group for e) + BLOCK_SIZE: tl.constexpr = TOKENS_PER_BLOCK * K_POW2 + IS_POW2_K: tl.constexpr = K == K_POW2 # fast path: no padding waste + tl.static_assert(BLOCK_SIZE <= 32768) + + pid_m = tl.program_id(0) + offs_local = tl.arange(0, BLOCK_SIZE) # position within this tile's flat [BLOCK*K_POW2] space + offs_global = pid_m * BLOCK_SIZE + offs_local + mask = offs_global < T * K_POW2 + + # Load expert id for each slot. IS_POW2_K fast path reads topk_indices as a + # flat 1D array (no padding gaps). Non-pow2 path reads 2D with k_slot masking. + if IS_POW2_K: + expert = tl.load(topk_indices_ptr + offs_global, mask=mask, other=-1).to(tl.uint32) + else: + token_i_local = offs_local // K_POW2 + k_slot = offs_local % K_POW2 + token_i_global = pid_m * TOKENS_PER_BLOCK + token_i_local + load_mask = mask & (k_slot < K) + safe_k = tl.minimum(k_slot, K - 1) + expert = tl.load( + topk_indices_ptr + token_i_global * K + safe_k, + mask=load_mask, + other=-1, + ).to(tl.uint32) + + # Pack (expert, presort_offs) into a uint32 kv pair and sort by expert. + # Upper 16 bits = expert id (sort key), lower 16 bits = pre-sort local offset. + # Invalid slots have expert=0xffff (from other=-1 cast to uint32 >> 16). + kv_pairs = tl.sort(((expert << 16) | offs_local).to(tl.uint32), 0) + expert = kv_pairs >> 16 + mask = expert != 0xFFFF # exclude padding/OOB slots + + # Segmented scan to compute within-expert rank (0-based exclusive count). + # scan_input packs expert id in upper 16 bits and count=1 in lower 16 bits. + # _keyed_add resets the count at each expert boundary. + scan_input = (kv_pairs & 0xFFFF0000) | 0x00000001 + inclusive_run_lengths = tl.associative_scan(scan_input, 0, _keyed_add) + within_expert_rank = (inclusive_run_lengths - 1) & 0xFFFF # exclusive = inclusive - 1 + + # Output position for this entry in the expert-sorted output array. + # partial_sum layout after stage1: [n_tiles, E], stride (1, n_tiles). + # So partial_sum[pid_m, expert] = partial_sum_ptr + pid_m*1 + expert*n_tiles. + s_reverse_scatter_idx = tl.load(partial_sum_ptr + pid_m + expert * n_tiles, mask=mask) + s_reverse_scatter_idx += tl.load(expert_offs_ptr + expert, mask=mask) + s_reverse_scatter_idx += within_expert_rank + + if IS_POW2_K: + # presort_offs == offs_local before sort; entry_idx is the flat index into + # topk_router_indices.view(-1), i.e. token * K + k_slot. + presort_offs = kv_pairs & 0xFFFF + entry_idx = pid_m * BLOCK_SIZE + presort_offs + tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask) + tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask) + tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, entry_idx // K_POW2, mask=mask) + else: + # presort_offs is in K_POW2-padded space; convert to unpadded entry_idx. + presort_offs = kv_pairs & 0xFFFF + token_i_global_s = pid_m * TOKENS_PER_BLOCK + presort_offs // K_POW2 + entry_idx = token_i_global_s * K + presort_offs % K_POW2 + tl.store(s_reverse_scatter_idx_ptr + entry_idx, s_reverse_scatter_idx, mask=mask) + tl.store(s_scatter_idx_ptr + s_reverse_scatter_idx, entry_idx, mask=mask) + tl.store(x_gather_idx_ptr + s_reverse_scatter_idx, token_i_global_s, mask=mask) diff --git a/sonic-moe/torch-ext/sonicmoe/functional/utils.py b/sonic-moe/torch-ext/sonicmoe/functional/utils.py new file mode 100644 index 00000000..94c15c44 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/functional/utils.py @@ -0,0 +1,25 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import os +from contextlib import contextmanager + + +_IS_USING_QUACK_GEMM = os.getenv("USE_QUACK_GEMM", "0") == "1" + + +@contextmanager +def enable_quack_gemm(enable: bool = True): + global _IS_USING_QUACK_GEMM + + previous_value = _IS_USING_QUACK_GEMM + _IS_USING_QUACK_GEMM = enable + + yield + + _IS_USING_QUACK_GEMM = previous_value + + +def is_using_quack_gemm() -> bool: + return _IS_USING_QUACK_GEMM diff --git a/sonic-moe/torch-ext/sonicmoe/include/utils.h b/sonic-moe/torch-ext/sonicmoe/include/utils.h new file mode 100644 index 00000000..b0594df1 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/include/utils.h @@ -0,0 +1,45 @@ +// ******************************************************************************** +// Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +// ******************************************************************************** + +// basic CUDA launch utils copied from https://github.com/open-lm-engine/accelerated-model-architectures/ + +#include +#include +#include + +#define CHECK_CUDA_TENSOR(x) TORCH_CHECK(x.device().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS_TENSOR(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_CONTIGUOUS_CUDA_TENSOR(x) \ + CHECK_CUDA_DEVICE(x); \ + CHECK_CONTIGUOUS(x) + +#define DISPATCH_CASE(ENUM_TYPE, SCALAR_NAME, ...) AT_PRIVATE_CASE_TYPE_USING_HINT(ENUM_TYPE, SCALAR_NAME, __VA_ARGS__) + +#define DISPATCH_FLOAT_KERNEL(TYPE, NAME, SCALAR_NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, \ + NAME, \ + DISPATCH_CASE(at::ScalarType::Half, SCALAR_NAME, __VA_ARGS__) \ + DISPATCH_CASE(at::ScalarType::BFloat16, SCALAR_NAME, __VA_ARGS__) \ + DISPATCH_CASE(at::ScalarType::Float, SCALAR_NAME, __VA_ARGS__)) +#define DISPATCH_INT_KERNEL(TYPE, NAME, SCALAR_NAME, ...) \ + AT_DISPATCH_SWITCH(TYPE, \ + NAME, \ + DISPATCH_CASE(at::ScalarType::Int, SCALAR_NAME, __VA_ARGS__) \ + DISPATCH_CASE(at::ScalarType::UInt32, SCALAR_NAME, __VA_ARGS__) \ + DISPATCH_CASE(at::ScalarType::Long, SCALAR_NAME, __VA_ARGS__)) + +template +inline __device__ T *load_128_bits(const T *array, const uint64_t &index) { + const int4 *vector_array = reinterpret_cast(array); + int4 vector_element = vector_array[index]; + T *output = reinterpret_cast(&vector_element); + return output; +} + +template +inline __device__ void store_128_bits(const T *source, T *destination, const uint64_t &index) { + int4 *destination_vector_array = reinterpret_cast(destination); + const int4 source_vector = reinterpret_cast(&source[0])[0]; + destination_vector_array[index] = source_vector; +} diff --git a/sonic-moe/torch-ext/sonicmoe/jit.py b/sonic-moe/torch-ext/sonicmoe/jit.py new file mode 100644 index 00000000..d4b931a7 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/jit.py @@ -0,0 +1,159 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import inspect +import os +from shutil import rmtree +from typing import Callable +from uuid import uuid4 + +import torch +from torch.utils.cpp_extension import load as load_cpp_extension + + +_CPP_MODULE_PREFIX = "sonicmoe" +_GLOBAL_RANK = int(os.getenv("RANK", 0)) +_WORLD_SIZE = int(os.getenv("WORLD_SIZE", 1)) + +_ALL_COMPILED_MODULES = {} + + +@torch.compiler.disable +def _get_cpp_function(function_name: str, module_name: str, source_files: list[str], build_directory: str) -> Callable: + module_name = f"{_CPP_MODULE_PREFIX}_{module_name}" + + extra_cflags = ["-O3", "-Wall", "-shared", "-fPIC", "-fdiagnostics-color"] + extra_cuda_cflags = ["-O3", "-lineinfo"] + extra_include_paths = [ + os.path.dirname(__file__), # sonicmoe/include + os.path.dirname(os.path.dirname(__file__)) + "/cutlass/include", # cutlass + os.path.dirname(os.path.dirname(__file__)) + "/cutlass/tools/util/include", # cutlass + ] + + module = _ALL_COMPILED_MODULES.get(module_name, None) + + if module is None: + if torch.distributed.is_initialized(): + os.makedirs(build_directory, exist_ok=True) + + if _GLOBAL_RANK == 0: + module = load_cpp_extension( + module_name, + sources=source_files, + with_cuda=True, + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + build_directory=build_directory, + verbose=True, + ) + + torch.distributed.barrier() + + if _GLOBAL_RANK != 0: + module = load_cpp_extension( + module_name, + sources=source_files, + with_cuda=True, + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + build_directory=build_directory, + verbose=False, + ) + else: + if _WORLD_SIZE > 1: + build_directory = os.path.join(build_directory, str(uuid4())) + + os.makedirs(build_directory, exist_ok=True) + + module = load_cpp_extension( + module_name, + sources=source_files, + with_cuda=True, + extra_cflags=extra_cflags, + extra_cuda_cflags=extra_cuda_cflags, + extra_include_paths=extra_include_paths, + build_directory=build_directory, + verbose=True, + ) + + if _WORLD_SIZE > 1: + rmtree(build_directory, ignore_errors=True) + + _ALL_COMPILED_MODULES[module_name] = module + + return getattr(module, function_name) + + +def cpp_jit( + function_name: str | None = None, + extra_source_files: list[str] = [], + build_directory: str | None = None, + depth: int = 0, +) -> Callable: + """wrapper to compile C++/CUDA source code at runtime. + + Args: + function_name (str | None, optional): name of the function to expose from the C++ file, the python function + name should match the funcion name in the C++ file if this is not specified. Defaults to None. + extra_source_files (list[str], optional): any extra files to use for compilation, by default it scans the + directory of the python stub file. Defaults to []. + build_directory (str | None, optional): directory in which to place the build artifacts. Defaults to None. + depth (int, optional): number of times dirname is called to get the build path. Defaults to 2. + + Returns: + Callable: returns the wrapped function that can be used to call the C++ functions from python + """ + cpp_function = None + args_spec = None + + source_files = [] + source_files.extend(extra_source_files) + + calling_filename = inspect.stack()[1].filename + calling_directory = os.path.dirname(calling_filename) + + for dirname, _, filenames in os.walk(calling_directory): + filenames = [os.path.join(dirname, f) for f in filenames] + filenames = filter(lambda f: os.path.splitext(f)[1] in [".cu", ".cpp"], filenames) + source_files.extend(filenames) + + if build_directory is None: + module_name = calling_directory + for _ in range(depth): + module_name = os.path.dirname(module_name) + module_name = os.path.basename(module_name) + + build_directory = os.path.join(os.path.dirname(os.path.dirname(__file__)), "build", module_name) + + def _run(*args, **kwargs): + nonlocal cpp_function + + if cpp_function is None: + cpp_function = _get_cpp_function( + function_name=_run.__name__, + module_name=module_name, + source_files=source_files, + build_directory=build_directory, + ) + + full_args = [] + full_args.extend(args) + for variable_name in args_spec.args[len(args) :]: + full_args.append(kwargs[variable_name]) + + return cpp_function(*full_args) + + def _wrapper(function: Callable) -> Callable: + nonlocal args_spec + args_spec = inspect.getfullargspec(function) + + _run.__doc__ = function.__doc__ + _run.__name__ = function.__name__ if function_name is None else function_name + _run.__signature__ = inspect.signature(function) + + return _run + + return _wrapper diff --git a/sonic-moe/torch-ext/sonicmoe/moe.py b/sonic-moe/torch-ext/sonicmoe/moe.py new file mode 100644 index 00000000..b6b3da11 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/moe.py @@ -0,0 +1,368 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from typing import Callable + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .enums import ActivationType, KernelBackendMoE, is_glu +from .functional import moe_TC_softmax_topk_layer + + +try: + from xma.modules.moe import scattered_experts + + _IS_XMA_AVAILABLE = True +except ImportError: + _IS_XMA_AVAILABLE = False + + +def _swiglu(x: torch.Tensor) -> torch.Tensor: + u = x[..., 1::2] + g = x[..., ::2] + return u * F.silu(g) + + +def _geglu(x: torch.Tensor) -> torch.Tensor: + u = x[..., 1::2] + g = x[..., ::2] + return (F.gelu(g.to(dtype=torch.float32)) * u).to(dtype=g.dtype) + + +def _gelu(x: torch.Tensor) -> torch.Tensor: + return F.gelu(x.to(dtype=torch.float32)).to(dtype=x.dtype) + + +def _reglu(x: torch.Tensor) -> torch.Tensor: + u = x[..., 1::2] + g = x[..., ::2] + return (F.relu(g) * u).to(dtype=g.dtype) + + +def _relu(x: torch.Tensor) -> torch.Tensor: + return F.relu(x) + + +def _relu_sq(x: torch.Tensor) -> torch.Tensor: + return F.relu(x) ** 2 + + +def _silu(x: torch.Tensor) -> torch.Tensor: + return F.silu(x) + + +class Experts(nn.Module): + def __init__( + self, num_experts: int, in_features: int, out_features: int, add_bias: bool = True, std: float | None = None + ) -> None: + super().__init__() + + self.weight = nn.Parameter(torch.empty(num_experts, out_features, in_features)) + + self.bias = None + if add_bias: + self.bias = nn.Parameter(torch.empty(num_experts, out_features)) + + self.std = std + + self.num_experts = num_experts + self.in_features = in_features + self.out_features = out_features + + self.reset_parameters() + + def up_projection_scattermoe_forward( + self, + input: torch.Tensor, + num_experts_per_token: int | None = None, + sorted_expert_idxs: torch.Tensor | None = None, + sorted_scattered_idxs: torch.Tensor | None = None, + expert_offsets: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.bias is None + + if not _IS_XMA_AVAILABLE: + raise ImportError( + "install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures" + ) + + input = scattered_experts( + inputs=input, + expert_weights=self.weight.permute(0, 2, 1), + k=num_experts_per_token, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=None, + grouped_in=False, + grouped_out=True, + ) + + return input + + def down_projection_scattermoe_forward( + self, + input: torch.Tensor, + num_experts_per_token: int | None = None, + sorted_expert_idxs: torch.Tensor | None = None, + sorted_scattered_idxs: torch.Tensor | None = None, + expert_offsets: torch.Tensor | None = None, + gates: torch.Tensor | None = None, + ) -> torch.Tensor: + assert self.bias is None + + if not _IS_XMA_AVAILABLE: + raise ImportError( + "install accelerated-model-architectures from https://github.com/open-lm-engine/accelerated-model-architectures" + ) + + input = scattered_experts( + inputs=input, + expert_weights=self.weight.permute(0, 2, 1), + k=num_experts_per_token, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=gates, + grouped_in=True, + grouped_out=False, + ) + + return input + + def torch_forward( + self, input: torch.Tensor, expert_frequency: torch.Tensor | None, return_list: bool = False + ) -> list[torch.Tensor] | torch.Tensor: + if isinstance(input, torch.Tensor): + input = input.split(expert_frequency.tolist(), dim=0) + else: + assert expert_frequency is None + + input = [ + F.linear(input[i], self.weight[i], None if self.bias is None else self.bias[i]) + for i in range(self.num_experts) + ] + + if not return_list: + input = torch.cat(input, dim=0) + + return input + + def extra_repr(self): + return "num_experts={}, in_features={}, out_features={}".format( + self.num_experts, self.in_features, self.out_features + ) + + @torch.no_grad() + def reset_parameters(self) -> None: + nn.init.normal_(self.weight, mean=0, std=self.std) + if hasattr(self, "bias") and self.bias is not None: + self.bias.zero_() + + +class MoE(nn.Module): + def __init__( + self, + num_experts: int, + num_experts_per_tok: int, + hidden_size: int, + intermediate_size: int, + activation_function: ActivationType, + add_bias: bool, + std: float, + ) -> None: + super().__init__() + + self.num_experts = num_experts + self.top_k = num_experts_per_tok + + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + + self.router = nn.Linear(in_features=self.hidden_size, out_features=num_experts, bias=False) + + self.activation_function = activation_function + + self.c_fc = Experts( + num_experts=num_experts, + in_features=self.hidden_size, + out_features=2 * self.intermediate_size if is_glu(activation_function) else self.intermediate_size, + add_bias=add_bias, + std=std, + ) + + self.c_proj = Experts( + num_experts=num_experts, + in_features=self.intermediate_size, + out_features=self.hidden_size, + add_bias=add_bias, + std=std, + ) + + self.stream_id = torch.cuda.current_stream().cuda_stream + + def forward( + self, + hidden_states: torch.Tensor, + kernel_backend_moe: KernelBackendMoE = KernelBackendMoE.sonicmoe, + is_inference_mode: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: + original_shape = hidden_states.shape + + # hidden_states -> (batch_size, query_length, hidden_size) + hidden_states = hidden_states.view(-1, self.hidden_size) + + if kernel_backend_moe == KernelBackendMoE.sonicmoe and self.num_experts <= 32768: + hidden_states, router_logits, expert_frequency = moe_TC_softmax_topk_layer( + hidden_states, + self.router.weight, + self.c_fc.weight.permute(1, 2, 0), + self.c_fc.bias, + self.c_proj.weight.permute(1, 2, 0), + self.c_proj.bias, + self.top_k, + self.stream_id, + self.activation_function, + is_inference_mode or not self.training, + ) + else: + # hidden_states -> (total_q, hidden_size) + router_logits, router_weights, selected_experts = self._compute_routing_weights(hidden_states) + + # router_logits -> (total_q, num_experts) + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) + + hidden_states, expert_frequency = self._compute_experts( + hidden_states, + router_weights, + selected_experts, + kernel_backend_moe=kernel_backend_moe, + ) + + hidden_states = hidden_states.view(original_shape) + + # hidden_states -> (batch_size, query_length, hidden_size) + + if is_inference_mode: + aux_loss = None + else: + aux_loss = self._compute_switch_loss( + logits=router_logits, + probs=F.softmax(router_logits, dim=-1, dtype=torch.float32), + expert_frequency=expert_frequency, + ) + + return hidden_states, aux_loss + + # copied from https://github.com/open-lm-engine/lm-engine/blob/1447883df709727839bbbb367ce727fa56962a6a/lm_engine/hf_models/modeling_utils/mlp_blocks/moe.py#L432-L455 + # NOTE we don't do all_reduce here for expert frequency for simplicity across data parallel workers + def _compute_switch_loss( + self, logits: torch.Tensor, probs: torch.Tensor, expert_frequency: torch.Tensor + ) -> torch.Tensor: + logits = logits.view(-1, logits.size(-1)) + probs = probs.view(-1, probs.size(-1)) + + num_experts = logits.size(1) + acc_probs = probs.sum(0) + + expert_frequency = expert_frequency.float() + + aux_loss = num_experts * (F.normalize(acc_probs, p=1, dim=0) * F.normalize(expert_frequency, p=1, dim=0)).sum() + + return aux_loss + + def _compute_routing_weights(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]: + # hidden_states -> (total_q, hidden_size) + router_logits = self.router(hidden_states) + # router_logits -> (total_q, num_experts) + + router_weights, selected_experts = self._get_topk(router_logits) + + # router_weights -> (total_q, top_k) + # selected_experts -> (total_q, top_k) + + router_weights = F.softmax(router_weights.float(), dim=-1) + router_weights = router_weights.type_as(hidden_states) + + return router_logits, router_weights, selected_experts + + def _compute_experts( + self, + hidden_states: torch.Tensor, + router_weights: torch.Tensor, + selected_experts: torch.Tensor, + kernel_backend_moe: KernelBackendMoE, + ) -> tuple[torch.Tensor, torch.Tensor]: + selected_experts = selected_experts.flatten() + + with torch.no_grad(): + sorted_expert_idxs, sorted_scattered_idxs = selected_experts.sort() + + expert_frequency = selected_experts.bincount(minlength=self.num_experts).to(torch.int32) + expert_offsets = expert_frequency.cumsum(-1).to(torch.int32) + + act_func = { + ActivationType.SWIGLU: _swiglu, + ActivationType.GEGLU: _geglu, + ActivationType.REGLU: _reglu, + ActivationType.GELU: _gelu, + ActivationType.RELU: _relu, + ActivationType.SILU: _silu, + ActivationType.RELU_SQ: _relu_sq, + }[self.activation_function] + + T = hidden_states.size(0) + + if kernel_backend_moe == KernelBackendMoE.scattermoe: + hidden_states = self.c_fc.up_projection_scattermoe_forward( + input=hidden_states, + num_experts_per_token=self.top_k, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + ) + hidden_states = act_func(hidden_states) + hidden_states = self.c_proj.down_projection_scattermoe_forward( + input=hidden_states, + num_experts_per_token=1, + sorted_expert_idxs=sorted_expert_idxs, + sorted_scattered_idxs=sorted_scattered_idxs, + expert_offsets=expert_offsets, + gates=router_weights, + ) + elif kernel_backend_moe == KernelBackendMoE.torch: + # sort and group input tokens according to expert assignment + fan_in_index = sorted_scattered_idxs // self.top_k + + # gather the gate values for grouped input tokens + router_weights = router_weights.flatten() + batch_gates = router_weights[sorted_scattered_idxs] + + hidden_states = hidden_states[fan_in_index] + + hidden_states = self.c_fc.torch_forward( + input=hidden_states, expert_frequency=expert_frequency, return_list=True + ) + + hidden_states = [act_func(i) for i in hidden_states] + hidden_states = self.c_proj.torch_forward(input=hidden_states, expert_frequency=None, return_list=False) + + hidden_states = hidden_states * batch_gates.unsqueeze(-1) + zeros = torch.zeros((T, self.hidden_size), dtype=torch.float32, device=hidden_states.device) + hidden_states = zeros.index_add(0, fan_in_index, hidden_states) + else: + raise ValueError(f"unexpected kernel_backend_moe ({kernel_backend_moe})") + + return hidden_states, expert_frequency + + def _get_topk(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if self.top_k == 1: + x, indices = x.max(dim=-1, keepdim=True) + else: + x, indices = x.topk(self.top_k, dim=-1) + + return x, indices diff --git a/sonic-moe/torch-ext/sonicmoe/quack/__init__.py b/sonic-moe/torch-ext/sonicmoe/quack/__init__.py new file mode 100644 index 00000000..b614e9f5 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/__init__.py @@ -0,0 +1,8 @@ +__version__ = "0.2.5" + +import os + +if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: + from . import cute_dsl_ptxas + + cute_dsl_ptxas.patch() diff --git a/sonic-moe/torch-ext/sonicmoe/quack/_ops_compat.py b/sonic-moe/torch-ext/sonicmoe/quack/_ops_compat.py new file mode 100644 index 00000000..9ce465bd --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/_ops_compat.py @@ -0,0 +1,4 @@ +from .._ops_compat import add_op_namespace_prefix + +def add_quack_op_namespace_prefix(name: str) -> str: + return add_op_namespace_prefix(f"quack__{name}") diff --git a/sonic-moe/torch-ext/sonicmoe/quack/activation.py b/sonic-moe/torch-ext/sonicmoe/quack/activation.py new file mode 100644 index 00000000..96552e2e --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/activation.py @@ -0,0 +1,524 @@ +# Copyright (c) 2025, Tri Dao. + +import math +from typing import Tuple + +import cutlass.cute as cute +from cutlass import Float32, Boolean, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + +from . import utils as utils + + +F32_or_F32x2 = Float32 | Tuple[Float32, Float32] + + +@dsl_user_op +def tanh(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "tanh.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def sigmoid(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + # return 0.5 + 0.5 * cute.math.tanh(0.5 * x, fastmath=True) + return 0.5 + 0.5 * tanh(0.5 * x) + else: + x_half = utils.mul_packed_f32x2((0.5, 0.5), x) + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return utils.fma_packed_f32x2(tanh_x_half, (0.5, 0.5), (0.5, 0.5)) + + +@dsl_user_op +def dsigmoid_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + # return dout * out * (1.0 - out) + return dout * (out - out * out) + + +@dsl_user_op +def relu(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) + else: + return cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0)) + + +@dsl_user_op +@cute.jit +def drelu( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + return dout if x_pos else Float32(0.0), cute.arch.fmax(x, Float32(0.0)) + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + dx = (dout[0] if x0_pos else Float32(0.0), dout[1] if x1_pos else Float32(0.0)) + return dx, relu(x) + + +@dsl_user_op +def relu_sq(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * x + else: + relu_x = (cute.arch.fmax(x[0], Float32(0.0)), cute.arch.fmax(x[1], Float32(0.0))) + return utils.mul_packed_f32x2(relu_x, x) + + +@dsl_user_op +@cute.jit +def drelu_sq( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + ReLU squared backward pass: computes gradient w.r.t. x and recomputes forward + Given: relu_sq_out = max(x, 0) * x, and dout = grad w.r.t. relu_sq_out + Returns: (dx, relu_sq_out) where: + - dx = dout * 2 * x if x > 0, else 0 + - relu_sq_out = max(x, 0) * x + """ + if const_expr(not isinstance(x, tuple)): + relu_x = relu(x) + relu_sq_out = relu_x * x + # Derivative: d/dx[max(x,0) * x] = 2*x if x > 0, else 0 + dx = 2.0 * (dout * relu_x) + return dx, relu_sq_out + else: + relu_x = relu(x) + relu_sq_out = utils.mul_packed_f32x2(relu_x, x) + dx = utils.mul_packed_f32x2((2.0, 2.0), utils.mul_packed_f32x2(dout, relu_x)) + return dx, relu_sq_out + + +@dsl_user_op +def gelu_tanh_approx(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ + gelu(x) = 0.5 * x * (1 + tanh(sqrt(2/pi) * (x + 0.044715 * x^3))) + = 0.5 * x * (1 + tanh(x * (0.797885 + 0.0356774 * x * x))) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # ~0.0356774 + if const_expr(not isinstance(x, tuple)): + return 0.5 * ( + x + # Currently cute.math.tanh(x, fastmath=True) generates very slow code + # * (1 + cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)), fastmath=True)) + * (1.0 + tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * (x * x)))) + ) + else: + x_sq = utils.mul_packed_f32x2(x, x) + x_sq_scaled = utils.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = utils.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + x_tanh_z = utils.fma_packed_f32x2(tanh_z, x, x) + return utils.mul_packed_f32x2((0.5, 0.5), x_tanh_z) + + +@dsl_user_op +def dgelu_tanh_approx( + x: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2]: + """ + GELU tanh approximation backward pass: computes gradient w.r.t. x and recomputes forward + Given: gelu_out = 0.5 * x * (1 + tanh(x * (c1 + c2 * x^2))), and dout = grad w.r.t. gelu_out + Returns: (dx, gelu_out) + + Derivative uses the chain rule: + d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + where z = x * (c1 + c2 * x^2), dz/dx = c1 + 3 * c2 * x^2 + and sech^2(z) = 1 - tanh^2(z) + """ + sqrt_2_over_pi = math.sqrt(2 / math.pi) # c1 ~0.797885 + sqrt_2_over_pi_coeff = 0.044715 * sqrt_2_over_pi # c2 ~0.0356774 + sqrt_2_over_pi_coeff_3 = 3.0 * sqrt_2_over_pi_coeff # c3 ~0.01070322 + + if const_expr(not isinstance(x, tuple)): + # Compute z = x * (c1 + c2 * x^2) + x_sq = x * x + # tanh_z = cute.math.tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq), fastmath=True) + tanh_z = tanh(x * (sqrt_2_over_pi + sqrt_2_over_pi_coeff * x_sq)) + half_tanh_z_plus_one = 0.5 + 0.5 * tanh_z + gelu_out = x * half_tanh_z_plus_one + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = 1 - tanh_z * tanh_z + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = sqrt_2_over_pi + sqrt_2_over_pi_coeff_3 * x_sq + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + dgelu = half_tanh_z_plus_one + x * (0.5 * (sech2_z * dz_dx)) + + dx = dout * dgelu + return dx, gelu_out + else: + # Compute z = x * (c1 + c2 * x^2) + x_sq = utils.mul_packed_f32x2(x, x) + x_sq_scaled = utils.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff, sqrt_2_over_pi_coeff), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + z = utils.mul_packed_f32x2(x, x_sq_scaled) + tanh_z = (tanh(z[0]), tanh(z[1])) + half_tanh_z_plus_one = utils.fma_packed_f32x2(tanh_z, (0.5, 0.5), (0.5, 0.5)) + gelu_out = utils.mul_packed_f32x2(x, half_tanh_z_plus_one) + + # Compute gradient + # sech^2(z) = 1 - tanh^2(z) + sech2_z = utils.fma_packed_f32x2(tanh_z, (-tanh_z[0], -tanh_z[1]), (1.0, 1.0)) + # dz/dx = c1 + 3 * c2 * x^2 + dz_dx = utils.fma_packed_f32x2( + x_sq, (sqrt_2_over_pi_coeff_3, sqrt_2_over_pi_coeff_3), (sqrt_2_over_pi, sqrt_2_over_pi) + ) + # d/dx[gelu(x)] = 0.5 * (1 + tanh(z)) + 0.5 * x * sech^2(z) * dz/dx + sech2_dz_dx = utils.mul_packed_f32x2(sech2_z, dz_dx) + x_sech2_dz_dx = utils.mul_packed_f32x2(x, sech2_dz_dx) + dgelu = utils.fma_packed_f32x2(x_sech2_dz_dx, (0.5, 0.5), half_tanh_z_plus_one) + + dx = utils.mul_packed_f32x2(dout, dgelu) + return dx, gelu_out + + +@dsl_user_op +@cute.jit +def softplus(x: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + use_linear = Boolean(x > 20.0) + return ( + cute.math.log(Float32(cute.math.exp(x, fastmath=True)) + 1.0, fastmath=True) + if not use_linear + else x + ) + else: + log2_e = math.log2(math.e) + x_log2e = utils.mul_packed_f32x2(x, (log2_e, log2_e)) + x_exp = (cute.math.exp(x_log2e[0], fastmath=True), cute.math.exp(x_log2e[1], fastmath=True)) + x_exp_p1 = utils.add_packed_f32x2(x_exp, (1.0, 1.0)) + log_x_exp_p1 = ( + cute.math.log2(x_exp_p1[0], fastmath=True), + cute.math.log2(x_exp_p1[1], fastmath=True), + ) + ln2 = math.log(2.0) + softplus_x = utils.mul_packed_f32x2(log_x_exp_p1, (ln2, ln2)) + use_linear_0 = Boolean(x[0] > 20.0) + use_linear_1 = Boolean(x[1] > 20.0) + return ( + softplus_x[0] if not use_linear_0 else x[0], + softplus_x[1] if not use_linear_1 else x[1], + ) + + +@dsl_user_op +@cute.jit +def dsoftplus_from_output(out: Float32, dout: Float32, *, loc=None, ip=None) -> Float32: + use_linear = Boolean(out > 20.0) + # dx = dout * (1.0 - cute.math.exp(-out, fastmath=True)) if not use_linear else dout + dx = dout - dout * cute.math.exp(-out, fastmath=True) + return dx if not use_linear else dout + + +@dsl_user_op +def silu(x: F32_or_F32x2, *, already_halved: bool = False, loc=None, ip=None) -> F32_or_F32x2: + """ + silu(x) = x * sigmoid(x) = x * (1 + tanh(x / 2)) / 2 = (0.5 * x) * tanh(0.5 * x) + (0.5 * x) + This compiles down to 3 SASS instructions: FMUL to get 0.5 * x, MUFU.TANH, and FFMA. + """ + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x if const_expr(not already_halved) else x + # return x_half * cute.math.tanh(x_half, fastmath=True) + x_half + return x_half * tanh(x_half) + x_half + else: + x_half = utils.mul_packed_f32x2((0.5, 0.5), x) if const_expr(not already_halved) else x + tanh_x_half = (tanh(x_half[0]), tanh(x_half[1])) + return utils.fma_packed_f32x2(x_half, tanh_x_half, x_half) + + +@dsl_user_op +def swiglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + if const_expr(not isinstance(x, tuple)): + return silu(x) * y + else: + return utils.mul_packed_f32x2(silu(x), y) + + +@dsl_user_op +def dswiglu( + x: F32_or_F32x2, + y: F32_or_F32x2, + dout: F32_or_F32x2, + *, + already_halved: bool = False, + loc=None, + ip=None, +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + SwiGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: swiglu_out = silu(x) * y, and dout = grad w.r.t. swiglu_out + Returns: (dx, dy, swiglu_out) where dx = dout * y * d_silu(x), dy = dout * silu(x) + + d_silu(x) = sigmoid(x) * (1 + x * (1 - sigmoid(x))) + + This has been optimized to use fewer instructions (i.e. we expand things out + to use FFMA instead of FADD and FMUL). + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(0.5 * x)) + # FMUL, MUFU.TANH, then FFMA + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = x * sigmoid_x # FMUL + else: + tanh_x = tanh(x) # MUFU.TANH + sigmoid_x = 0.5 * tanh_x + 0.5 # FFMA + silu_x = x * tanh_x + x # FFMA + silu_x_dout = silu_x * dout # FMUL + # d_silu(x) * dout + # = sigmoid_x * (1 + x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + sigmoid_x * x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x * (1 - sigmoid_x)) * dout + # = (sigmoid_x + silu_x - silu_x * sigmoid_x) * dout + # = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + d_silu_x_dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x_dout # FFMA, FFMA + dx = d_silu_x_dout * y # FMUL + dy = silu_x_dout + swiglu_out = silu_x * y # FMUL + # Overall it's 1 MUFU.TANH, 5 FMUL, 3 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(x) and silu(x) + if const_expr(not already_halved): + sigmoid_x = sigmoid(x) + silu_x = utils.mul_packed_f32x2(x, sigmoid_x) + else: + tanh_x = (tanh(x[0]), tanh(x[1])) + sigmoid_x = utils.fma_packed_f32x2(tanh_x, (0.5, 0.5), (0.5, 0.5)) + silu_x = utils.fma_packed_f32x2(x, tanh_x, x) + silu_x_dout = utils.mul_packed_f32x2(silu_x, dout) + # d_silu(x) * dout = (sigmoid_x - silu_x * sigmoid_x) * dout + silu_x * dout + sigmoid_x_minus_silu_x_sigmoid_x = utils.fma_packed_f32x2( + sigmoid_x, (-silu_x[0], -silu_x[1]), sigmoid_x + ) + d_silu_x_dout = utils.fma_packed_f32x2(sigmoid_x_minus_silu_x_sigmoid_x, dout, silu_x_dout) + dx = utils.mul_packed_f32x2(d_silu_x_dout, y) + dy = silu_x_dout + swiglu_out = utils.mul_packed_f32x2(silu_x, y) + return dx, dy, swiglu_out + + +@dsl_user_op +def swiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> F32_or_F32x2: + """The swiglu variant used in gpt-oss, which has a scaling factor on x and bias of 1 to y. + https://github.com/openai/gpt-oss/blob/7be9334950053a888e24887a57dac797a17d6e00/gpt_oss/torch/model.py#L249 + x * sigmoid(alpha * x) * (y + 1) + Compile down to FMUL, FMUL, TANH, FFMA, FFMA + """ + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + if const_expr(not isinstance(x, tuple)): + x_half = 0.5 * x + # silu_x = x_half * cute.math.tanh(alpha * x_half, fastmath=True) + x_half + silu_x = x_half * tanh(alpha * x_half) + x_half + return silu_x * y + silu_x + else: + x_half = utils.mul_packed_f32x2((0.5, 0.5), x) + alpha_x_half = utils.mul_packed_f32x2((alpha, alpha), x_half) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + silu_x = utils.fma_packed_f32x2(x_half, tanh_alpha_x_half, x_half) + return utils.fma_packed_f32x2(silu_x, y, silu_x) + + +@dsl_user_op +def dswiglu_oai( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, alpha: float = 1.702, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + Swiglu OAI backward pass: computes gradients w.r.t. x and y + Given: swiglu_oai_out = x * sigmoid(alpha * x) * (y + 1), and dout = grad w.r.t. swiglu_oai_out + Returns: (dx, dy, swiglu_oai_out) + + Derivative of x * sigmoid(alpha * x) w.r.t. x: + d/dx[x * sigmoid(alpha * x)] = sigmoid(alpha * x) + alpha * x * sigmoid(alpha * x) * (1 - sigmoid(alpha * x)) + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(alpha * x) using tanh: sigmoid(z) = 0.5 * (1 + tanh(z/2)) + alpha_x_half = (0.5 * alpha) * x # FMUL + # MUFU.TANH, then FFMA + # sigmoid_alpha_x = 0.5 + 0.5 * cute.math.tanh(alpha_x_half, fastmath=True) + sigmoid_alpha_x = 0.5 + 0.5 * tanh(alpha_x_half) + silu_x = x * sigmoid_alpha_x # FMUL + silu_x_dout = silu_x * dout # FMUL + # FFMA, FFMA, FMUL + d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + dx = d_silu_x_dout * y + d_silu_x_dout # FFMA, instead of multiply by y + 1 + dy = silu_x_dout + swiglu_out = silu_x * y + silu_x # FFMA, instead of multiply by y + 1 + # Overall it's 1 MUFU.TANH, 4 FMUL, 5 FFMA + return dx, dy, swiglu_out + else: + # Compute sigmoid(alpha * x) + alpha_x_half = utils.mul_packed_f32x2(((0.5 * alpha), (0.5 * alpha)), x) + tanh_alpha_x_half = (tanh(alpha_x_half[0]), tanh(alpha_x_half[1])) + sigmoid_alpha_x = utils.fma_packed_f32x2(tanh_alpha_x_half, (0.5, 0.5), (0.5, 0.5)) + silu_x = utils.mul_packed_f32x2(x, sigmoid_alpha_x) + silu_x_dout = utils.mul_packed_f32x2(silu_x, dout) + # d_silu_x_dout = (sigmoid_alpha_x + alpha * (silu_x - silu_x * sigmoid_alpha_x)) * dout + silu_x_minus_product = utils.fma_packed_f32x2( + silu_x, (-sigmoid_alpha_x[0], -sigmoid_alpha_x[1]), silu_x + ) + sigmoid_plus_alpha_diff = utils.fma_packed_f32x2( + (alpha, alpha), silu_x_minus_product, sigmoid_alpha_x + ) + d_silu_x_dout = utils.mul_packed_f32x2(sigmoid_plus_alpha_diff, dout) + dx = utils.fma_packed_f32x2(d_silu_x_dout, y, d_silu_x_dout) + dy = silu_x_dout + swiglu_out = utils.fma_packed_f32x2(silu_x, y, silu_x) + return dx, dy, swiglu_out + + +@dsl_user_op +def glu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GLU: Gated Linear Unit + glu(x, y) = sigmoid(x) * y + Using tanh to compute sigmoid: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + """ + if const_expr(not isinstance(x, tuple)): + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + return sigmoid_x * y # FMUL + else: + sigmoid_x = sigmoid(x) + return utils.mul_packed_f32x2(sigmoid_x, y) + + +@dsl_user_op +def dglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: glu_out = sigmoid(x) * y, and dout = grad w.r.t. glu_out + Returns: (dx, dy, glu_out) where: + - dx = dout * y * sigmoid(x) * (1 - sigmoid(x)) + - dy = dout * sigmoid(x) + - glu_out = sigmoid(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Compute sigmoid(x) using tanh: sigmoid(x) = 0.5 * (1 + tanh(x/2)) + sigmoid_x = sigmoid(x) # FMUL, MUFU.TANH, then FFMA + sigmoid_x_dout = sigmoid_x * dout # FMUL + glu_out = sigmoid_x * y # FMUL + # dx = y * sigmoid(x) * (1 - sigmoid(x)) * dout + # = y * (1 - sigmoid(x)) * sigmoid_x_dout + # = (y - y * sigmoid(x)) * sigmoid_x_dout + # = (y - glu_out) * sigmoid_x_dout + dx = (y - glu_out) * sigmoid_x_dout # FADD, FMUL + dy = sigmoid_x_dout + # Total: 1 MUFU.TANH, 4 FMUL, 1 FADD, 1 FFMA + return dx, dy, glu_out + else: + sigmoid_x = sigmoid(x) + sigmoid_x_dout = utils.mul_packed_f32x2(sigmoid_x, dout) + glu_out = utils.mul_packed_f32x2(sigmoid_x, y) + # dx = (y - glu_out) * sigmoid_x_dout + y_minus_glu_out = utils.sub_packed_f32x2(y, glu_out) + dx = utils.mul_packed_f32x2(y_minus_glu_out, sigmoid_x_dout) + dy = sigmoid_x_dout + return dx, dy, glu_out + + +@dsl_user_op +def reglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """ReGLU: ReLU Gated Linear Unit + reglu(x, y) = relu(x) * y = max(x, 0) * y + """ + if const_expr(not isinstance(x, tuple)): + return cute.arch.fmax(x, Float32(0.0)) * y + else: + relu_x = relu(x) + return utils.mul_packed_f32x2(relu_x, y) + + +@dsl_user_op +@cute.jit +def dreglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + ReGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: reglu_out = relu(x) * y, and dout = grad w.r.t. reglu_out + Returns: (dx, dy, reglu_out) where: + - dx = dout * y if x > 0, else 0 + - dy = dout * relu(x) + - reglu_out = relu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + x_pos = Boolean(x > 0) + relu_x = cute.arch.fmax(x, Float32(0.0)) + dx = (dout * y) if x_pos else Float32(0.0) + dy = dout * relu_x + reglu_out = relu_x * y + return dx, dy, reglu_out + else: + x0_pos = Boolean(x[0] > 0) + x1_pos = Boolean(x[1] > 0) + relu_x = relu(x) + dout_y = utils.mul_packed_f32x2(dout, y) + dx = ((dout_y[0] if x0_pos else Float32(0.0)), (dout_y[1] if x1_pos else Float32(0.0))) + dy = utils.mul_packed_f32x2(dout, relu_x) + reglu_out = utils.mul_packed_f32x2(relu_x, y) + return dx, dy, reglu_out + + +@dsl_user_op +def geglu(x: F32_or_F32x2, y: F32_or_F32x2, *, loc=None, ip=None) -> F32_or_F32x2: + """GeGLU: GELU Gated Linear Unit + geglu(x, y) = gelu(x) * y + Uses the tanh approximation of GELU + """ + if const_expr(not isinstance(x, tuple)): + return gelu_tanh_approx(x) * y + else: + return utils.mul_packed_f32x2(gelu_tanh_approx(x), y) + + +@dsl_user_op +def dgeglu( + x: F32_or_F32x2, y: F32_or_F32x2, dout: F32_or_F32x2, *, loc=None, ip=None +) -> Tuple[F32_or_F32x2, F32_or_F32x2, F32_or_F32x2]: + """ + GeGLU backward pass: computes gradients w.r.t. x (gate) and y (up projection) + Given: geglu_out = gelu(x) * y, and dout = grad w.r.t. geglu_out + Returns: (dx, dy, geglu_out) where: + - dx = dout * y * d_gelu(x) + - dy = dout * gelu(x) + - geglu_out = gelu(x) * y + """ + if const_expr(not isinstance(x, tuple)): + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = dgelu_x_dout * y + dy = gelu_x * dout + geglu_out = gelu_x * y + return dx, dy, geglu_out + else: + # Reuse dgelu_tanh_approx to compute d_gelu(x) * dout and gelu(x) + dgelu_x_dout, gelu_x = dgelu_tanh_approx(x, dout) + # Compute gradients for geglu + dx = utils.mul_packed_f32x2(dgelu_x_dout, y) + dy = utils.mul_packed_f32x2(gelu_x, dout) + geglu_out = utils.mul_packed_f32x2(gelu_x, y) + return dx, dy, geglu_out diff --git a/sonic-moe/torch-ext/sonicmoe/quack/autotuner.py b/sonic-moe/torch-ext/sonicmoe/quack/autotuner.py new file mode 100644 index 00000000..de1f63a1 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/autotuner.py @@ -0,0 +1,369 @@ +# Adapted from https://github.com/triton-lang/triton/blob/main/python/triton/runtime/autotuner.py +# Copyright (C) 2025, Tri Dao. +from __future__ import annotations + +import builtins +import os +import time +import inspect +import base64 +import hashlib +import json +from pathlib import Path +from functools import cached_property, partial +from typing import Dict, Tuple, List, Optional, Any + +import torch +from torch import Tensor + +import triton + +from . import __version__ + + +PACKAGE_NAME = "quack" +VERSION = __version__ + + +def get_home_dir(): + return os.getenv(f"{PACKAGE_NAME.upper()}_HOME", Path.home()) + + +def default_cache_dir(): + return os.path.join(get_home_dir(), f".{PACKAGE_NAME}", "cache") + + +class FileCacheManager(triton.runtime.cache.FileCacheManager): + def __init__(self, key): + super().__init__(key) + self.cache_dir = ( + os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_DIR", "").strip() or default_cache_dir() + ) + if self.cache_dir: + self.cache_dir = os.path.join(self.cache_dir, self.key) + self.lock_path = os.path.join(self.cache_dir, "lock") + os.makedirs(self.cache_dir, exist_ok=True) + else: + raise RuntimeError("Could not create or locate cache dir") + + +def _base32(key): + # Assume key is a hex string. + return base64.b32encode(bytes.fromhex(key)).decode("utf-8").rstrip("=") + + +class Autotuner: + def __init__( + self, + fn, + key, + configs, + restore_value=None, + prune_configs_by: Optional[Dict] = None, + do_bench=None, + cache_results=False, + ): + """ + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs. + """ + if not configs: + self.configs = [AutotuneConfig()] + else: + self.configs = configs + signature = inspect.signature(fn) + self.keys = key + self.cache: Dict[Tuple, AutotuneConfig] = {} + self.arg_names = list(signature.parameters.keys()) + self.cache_results = ( + cache_results or os.getenv(f"{PACKAGE_NAME.upper()}_CACHE_AUTOTUNING", None) == "1" + ) + + self.restore_value = [] + if restore_value is not None: + self.restore_value = list(restore_value) + + if len(self.restore_value) > 0: + + def _pre_hook(kwargs): + self.restore_copies = {name: kwargs[name].clone() for name in self.restore_value} + + self.pre_hook = _pre_hook + else: + self.pre_hook = None + + if len(self.restore_value) > 0: + + def _post_hook(kwargs, exception): + for name in self.restore_value: + kwargs[name].copy_(self.restore_copies[name]) + self.restore_copies = {} + + self.post_hook = _post_hook + else: + self.post_hook = None + + self.perf_model = None + self.configs_top_k = 1.0 + self.early_config_prune = None + if prune_configs_by: + self.perf_model = prune_configs_by.get("perf_model", self.perf_model) + self.configs_top_k = prune_configs_by.get("top_k", self.configs_top_k) + self.early_config_prune = prune_configs_by.get( + "early_config_prune", self.early_config_prune + ) + + self.fn = fn + self._do_bench = do_bench + + @cached_property + def do_bench(self): + if self._do_bench is None: + return partial(triton.testing.do_bench, warmup=5, rep=25) + return self._do_bench + + def _bench(self, *args, config, **meta): + verbose = os.environ.get(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1" + if verbose: + print(f"Autotuning kernel {self.fn.__name__} with config {config}") + + # check for conflicts, i.e. meta-parameters both provided + # as kwargs and by the autotuner + conflicts = meta.keys() & config.kwargs.keys() + if conflicts: + raise ValueError( + f"Conflicting meta-parameters: {', '.join(conflicts)}." + " Make sure that you don't re-define auto-tuned symbols." + ) + # augment meta-parameters with tunable ones + current = dict(meta, **config.all_kwargs()) + full_nargs = {**self.nargs, **current} + + def kernel_call(): + if self.pre_hook is not None: + self.pre_hook(full_nargs) + try: + self.fn.__call__( + *args, + **current, + ) + except Exception as e: + try: + if self.post_hook is not None: + self.post_hook(full_nargs, exception=e) + finally: + # Throw exception raised by `self.fn.run` + raise + + if self.post_hook is not None: + self.post_hook(full_nargs, exception=None) + + try: + return self.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8)) + except Exception as e: + if verbose: + print(f"Autotuning failed with {e}") + return [float("inf"), float("inf"), float("inf")] + + @torch.compiler.disable + def check_disk_cache(self, tuning_key, configs, bench_fn): + if not tuning_key: + bench_fn() + return + + fn = self.fn + config_str_list = [str(c) for c in configs] + assert len(config_str_list) == len(set(config_str_list)), "Config strings must be unique" + cache_key = [VERSION, str(tuning_key)] + config_str_list + cache_key = hashlib.sha256("-".join(cache_key).encode("utf-8")).hexdigest() + cache = FileCacheManager(_base32(cache_key)) + file_name = f"{fn.__name__[:150]}.autotune.json" + path = cache.get_file(file_name) + # There's an environment variable to force cache update + if path and not os.environ.get(f"{PACKAGE_NAME.upper()}_FORCE_CACHE_UPDATE", False): + str2config = {s: c for s, c in zip(config_str_list, configs)} + with open(path, "r") as cached_configs: + timings = json.load(cached_configs)["configs_timings"] + timings = {str2config[config]: timing for config, timing in timings} + self.cache[tuning_key] = builtins.min(timings, key=timings.get) + self.configs_timings = timings + self.bench_time = 0 + return + + bench_fn() + cache.put( + json.dumps( + { + "key": tuning_key, + "configs_timings": [ + (str(config), timings) for config, timings in self.configs_timings.items() + ], + } + ), + file_name, + binary=False, + ) + + def __call__(self, *args, **kwargs): + self.nargs = dict(zip(self.arg_names, args)) + used_cached_result = True + if len(self.configs) > 1: + all_args = {**self.nargs, **kwargs} + _args = {k: v for (k, v) in all_args.items() if k in self.arg_names} + # Need "str" to make it json-serializable + key = [str(_args[key]) for key in self.keys if key in _args] + for _, arg in _args.items(): + if isinstance(arg, Tensor): + key.append(str(arg.shape)) + # If stride != 0, 1, we just cache it as 2 + key.append(str([s if s in {0, 1} else 2 for s in arg.stride()])) + key.append(str(arg.dtype)) + key = tuple(key) + if key not in self.cache: + used_cached_result = False + pruned_configs = self.prune_configs(kwargs) + + @torch.compiler.disable # Don't want any tracing here + def benchmark(): + bench_start = time.time() + timings = { + config: self._bench(*args, config=config, **kwargs) + for config in pruned_configs + } + bench_end = time.time() + if os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1": + for config, time_ in timings.items(): + print(f"[{config}] -> {time_[0]:.3f}ms") + self.bench_time = bench_end - bench_start + self.cache[key] = builtins.min(timings, key=timings.get) + self.configs_timings = timings + + if self.cache_results: + self.check_disk_cache(key, pruned_configs, benchmark) + else: + benchmark() + + config = self.cache[key] + else: + config = self.configs[0] + self.best_config = config + if ( + os.getenv(f"{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING", None) == "1" + and not used_cached_result + ): + print( + f"{PACKAGE_NAME} autotuning for function {self.fn.__name__} finished after " + f"{self.bench_time:.2f}s; best config selected: {self.best_config};" + ) + ret = self.fn.__call__( + *args, + **kwargs, + **config.all_kwargs(), + ) + self.nargs = None + return ret + + def prune_configs(self, kwargs: Dict) -> List[Any]: + pruned_configs = self.configs + if self.early_config_prune: + pruned_configs = self.early_config_prune(self.configs, self.nargs, **kwargs) + if self.perf_model: + top_k = self.configs_top_k + if isinstance(top_k, float) and top_k <= 1.0: + top_k = int(len(self.configs) * top_k) + elif not isinstance(top_k, int): + # Slice index must be an integer + raise TypeError( + "Error while pruning configs, top_k must be either 1) a float <= 1.0 or 2) an int" + ) + + if len(pruned_configs) > top_k: + est_timing = { + config: self.perf_model( + **self.nargs, + **kwargs, + **config.all_kwargs(), + ) + for config in pruned_configs + } + pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k] + return pruned_configs + + +class AutotuneConfig: + """ + An object that represents a possible kernel configuration for the auto-tuner to try. + + :ivar kwargs: a dictionary of meta-parameters to pass to the kernel as keyword arguments. + :type kwargs: dict[Str, Any] + """ + + def __init__(self, **kwargs): + self.kwargs = kwargs + + def __setstate__(self, state): + self.kwargs = state.get("kwargs", {}) + + def all_kwargs(self): + return self.kwargs + + def __str__(self): + res = [] + for k, v in self.kwargs.items(): + res.append(f"{k}: {v}") + return ", ".join(res) + + def __hash__(self): + return hash(tuple(*self.all_kwargs().items())) + + def __eq__(self, other): + self_tuple = tuple(*self.all_kwargs().items()) + other_tuple = tuple(*other.all_kwargs().items()) + return self_tuple == other_tuple + + +def autotune( + configs, key=None, prune_configs_by=None, restore_value=None, do_bench=None, cache_results=True +): + f""" + Decorator for auto-tuning a function function. + + .. highlight:: python + + If the environment variable :code:`{PACKAGE_NAME.upper()}_PRINT_AUTOTUNING` is set to + :code:`"1"`, we will print a message to stdout after autotuning each + kernel, including the time spent autotuning and the best configuration. + + :param configs: a list of :code:`AutotuneConfig` objects + :type configs: list[AutotuneConfig] + :param key: a list of argument names whose change in value will trigger the evaluation of all provided configs. + :type key: list[str] + :param prune_configs_by: a dict of functions that are used to prune configs, fields: + 'perf_model': performance model used to predicate running time with different configs, returns running time + 'top_k': number of configs to bench + 'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs. + :param restore_value: a list of argument names whose value will be restored after evaluating any configs. + :type restore_value: list[str] + :param do_bench: a benchmark function to measure the time of each run. + :type do_bench: lambda fn, quantiles + :param cache_results: whether to cache autotune timings to disk. Defaults to False. + "type cache_results: bool + """ + + if key is None: + key = [] + + def decorator(fn): + return Autotuner( + fn, + key, + configs, + restore_value=restore_value, + prune_configs_by=prune_configs_by, + do_bench=do_bench, + cache_results=cache_results, + ) + + return decorator diff --git a/sonic-moe/torch-ext/sonicmoe/quack/broadcast_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/broadcast_utils.py new file mode 100644 index 00000000..2bfe3f8f --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/broadcast_utils.py @@ -0,0 +1,29 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Callable + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr + +from .layout_utils import make_acc_tensor_mn_view + + +@cute.jit +def vec_op(tCrC: cute.Tensor, tCrVec: cute.Tensor, op: Callable, is_colvec: bool) -> None: + if const_expr(tCrC.element_type != Float32): # Convert to f32 + tCrC_f32 = cute.make_fragment(tCrC.shape, Float32) + tCrC_f32.store(tCrC.load().to(Float32)) + else: + tCrC_f32 = tCrC + # this happens to work for frgA layout too, not just acc layout + tCrC_f32_mn = make_acc_tensor_mn_view(tCrC_f32) + if const_expr(is_colvec): + assert cute.size(tCrC_f32_mn, mode=[0]) == cute.size(tCrVec) + for r in cutlass.range(cute.size(tCrC_f32_mn, mode=[0]), unroll_full=True): + tCrC_f32_mn[r, None].store(op(tCrC_f32_mn[r, None].load(), tCrVec[r])) + else: + assert cute.size(tCrC_f32_mn, mode=[1]) == cute.size(tCrVec) + for c in cutlass.range(cute.size(tCrC_f32_mn, mode=[1]), unroll_full=True): + tCrC_f32_mn[None, c].store(op(tCrC_f32_mn[None, c].load(), tCrVec[c])) + if const_expr(tCrC.element_type != Float32): # Convert back to original dtype + tCrC.store(tCrC_f32.load().to(tCrC.element_type)) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/compile_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/compile_utils.py new file mode 100644 index 00000000..43755946 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/compile_utils.py @@ -0,0 +1,19 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Optional + +import cutlass.cute as cute + + +def make_fake_tensor(dtype, shape, divisibility=1, leading_dim=-1) -> Optional[cute.Tensor]: + if leading_dim < 0: + leading_dim = len(shape) + leading_dim + if dtype is None: + return None + stride = tuple( + cute.sym_int64(divisibility=divisibility) if i != leading_dim else 1 + for i in range(len(shape)) + ) + return cute.runtime.make_fake_tensor( + dtype, shape, stride=stride, assumed_align=divisibility * dtype.width // 8 + ) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/copy_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/copy_utils.py new file mode 100644 index 00000000..52549e4d --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/copy_utils.py @@ -0,0 +1,614 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import re +from typing import Optional, Type, Tuple, Callable + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, Boolean, const_expr +from cutlass.cute.nvgpu import cpasync, warpgroup +from cutlass.cutlass_dsl import dsl_user_op +import cutlass.pipeline + + +@dsl_user_op +def cvt_copy( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + retile: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + assert isinstance(src.iterator, cute.Pointer) and src.memspace == cute.AddressSpace.rmem + if const_expr(src.element_type != dst.element_type): + src_cvt = cute.make_fragment_like(src, dst.element_type) + src_cvt.store(src.load().to(dst.element_type)) + src = src_cvt + if const_expr(retile): + src = tiled_copy.retile(src) + cute.copy(tiled_copy, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +@dsl_user_op +def load_s2r(src: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + dst = cute.make_fragment_like(src, src.element_type, loc=loc, ip=ip) + cute.autovec_copy(src, dst, loc=loc, ip=ip) + return dst + + +@dsl_user_op +def load_s2r_retile( + tiled_copy: cute.TiledCopy, + src: cute.Tensor, + dst_shape: cute.Tensor | cute.Shape, + *, + loc=None, + ip=None, +) -> cute.Tensor: + # Will also accept dst_shape being a tensor, in which case we write into that tensor + if const_expr(not isinstance(dst_shape, cute.Tensor)): + dst = cute.make_fragment(dst_shape, src.element_type, loc=loc, ip=ip) + else: + dst = dst_shape + cute.copy(tiled_copy, src, tiled_copy.retile(dst), loc=loc, ip=ip) + return dst + + +@dsl_user_op +def get_copy_atom( + dtype: Type[cutlass.Numeric], num_copy_elems: int, is_async: bool = False, *, loc=None, ip=None +) -> cute.CopyAtom: + num_copy_bits = const_expr(min(128, num_copy_elems * dtype.width)) + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + return cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + + +@dsl_user_op +def copy( + src: cute.Tensor, + dst: cute.Tensor, + *, + pred: Optional[cute.Tensor] = None, + is_async: bool = False, + loc=None, + ip=None, + **kwargs, +) -> None: + num_copy_elems = src.shape[0][0] + copy_atom = get_copy_atom(src.element_type, num_copy_elems, is_async) + cute.copy(copy_atom, src, dst, pred=pred, loc=loc, ip=ip, **kwargs) + + +def tiled_copy_1d( + dtype: Type[cutlass.Numeric], num_threads: int, num_copy_elems: int = 1, is_async: bool = False +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + thr_layout = cute.make_layout(num_threads) + val_layout = cute.make_layout(num_copy_elems) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def tiled_copy_2d( + dtype: Type[cutlass.Numeric], + threads_per_row: int, + num_threads: int, + num_copy_elems: int = 1, + is_async: bool = False, +) -> cute.TiledCopy: + num_copy_bits = num_copy_elems * dtype.width + copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() + copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) + assert num_threads % threads_per_row == 0 + thr_layout = cute.make_ordered_layout( + (num_threads // threads_per_row, threads_per_row), + order=(1, 0), + ) + val_layout = cute.make_layout((1, num_copy_elems)) + return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +@cute.jit +def predicate_k(tAcA: cute.Tensor, limit: Int32) -> cute.Tensor: + # Only compute predicates for the "k" dimension. For the mn dimension, we will use "if" + tApA = cute.make_fragment( + cute.make_layout( + (cute.size(tAcA, mode=[0, 1]), cute.size(tAcA, mode=[1]), cute.size(tAcA, mode=[2])), + stride=(cute.size(tAcA, mode=[2]), 0, 1), + ), + Boolean, + ) + for rest_v in cutlass.range_constexpr(tApA.shape[0]): + for rest_k in cutlass.range_constexpr(tApA.shape[2]): + tApA[rest_v, 0, rest_k] = cute.elem_less(tAcA[(0, rest_v), 0, rest_k][1], limit) + return tApA + + +# def tiled_copy_2d( +# dtype: Type[cutlass.Numeric], major_mode_size: int, num_threads: int, is_async: bool = False +# ) -> cute.TiledCopy: +# num_copy_bits = math.gcd(major_mode_size, 128 // dtype.width) * dtype.width +# copy_elems = num_copy_bits // dtype.width +# copy_op = cpasync.CopyG2SOp() if is_async else cute.nvgpu.CopyUniversalOp() +# copy_atom = cute.make_copy_atom(copy_op, dtype, num_bits_per_copy=num_copy_bits) +# gmem_threads_per_row = major_mode_size // copy_elems +# assert num_threads % gmem_threads_per_row == 0 +# thr_layout = cute.make_ordered_layout( +# (num_threads // gmem_threads_per_row, gmem_threads_per_row), +# order=(1, 0), +# ) +# val_layout = cute.make_layout((1, copy_elems)) +# return cute.make_tiled_copy_tv(copy_atom, thr_layout, val_layout) + + +def parse_swizzle_from_pointer(ptr: cute.Pointer) -> Tuple[int, int, int]: + """Extract swizzle parameters from a pointer's swizzle_type. + + The swizzle_type string has the form '!cute.swizzle<"S">' where + b, m, s are the swizzle parameters (bits, base, shift). + + Returns: + A cute.Swizzle object constructed from the extracted parameters + + Raises: + ValueError: If the swizzle_type string cannot be parsed + """ + # Ideally there should be a better API to get swizzle parameters, but we'll just parse + # the string here. + swizzle_str = str(ptr.type.swizzle_type) + # Extract the inner part "S" + match = re.search(r"S<(\d+),(\d+),(\d+)>", swizzle_str) + if match: + b, m, s = int(match.group(1)), int(match.group(2)), int(match.group(3)) + return b, m, s + else: + raise ValueError(f"Could not parse swizzle_type: {swizzle_str}") + + +def swizzle_int(ptr_int: Int32, b: int, m: int, s: int) -> Int32: + bit_msk = (1 << b) - 1 + yyy_msk = bit_msk << (m + s) + return ptr_int ^ ((ptr_int & yyy_msk) >> s) + + +def swizzle_ptr(ptr: cute.Pointer): + b, m, s = parse_swizzle_from_pointer(ptr) + ptr_int = swizzle_int(ptr.toint(), b, m, s) + return cute.make_ptr(ptr.dtype, ptr_int, ptr.memspace, assumed_align=ptr.alignment) + + +def as_position_independent_swizzle_tensor(tensor: cute.Tensor) -> cute.Tensor: + outer = tensor.layout + width = tensor.element_type.width + inner = cute.make_swizzle(*parse_swizzle_from_pointer(tensor.iterator)) + # Need to recast the swizzle from byte (e.g. <3, 4, 3> to element units (e.g. <3, 3, 3> for + # for 16 bits and <3, 2, 3> for 32 bits) + new_layout = cute.recast_layout( + width, 8, cute.make_composed_layout(inner, 0, cute.recast_layout(8, width, outer)) + ) + # recast_ptr to remove the pointer swizzle + return cute.make_tensor(cute.recast_ptr(tensor.iterator, dtype=tensor.element_type), new_layout) + + +def partition_D_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_D(tensor).iterator), + thr_copy.partition_D(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +def partition_S_position_independent( + thr_copy: cute.core.ThrCopy, tensor: cute.Tensor +) -> cute.Tensor: + return cute.make_tensor( + swizzle_ptr(thr_copy.partition_S(tensor).iterator), + thr_copy.partition_S(as_position_independent_swizzle_tensor(tensor)).layout, + ) + + +@dsl_user_op +def sm90_get_smem_load_op( + layout_c: cutlass.utils.LayoutEnum, + elem_ty_c: Type[cutlass.Numeric], + *, + loc=None, + ip=None, +) -> cute.CopyAtom: + """ + Selects the largest vectorized smem load atom available subject to constraint of gmem layout. + + Parameters: + ----------- + layout_c : LayoutEnum + The layout enum of the output tensor D. + + elem_ty_c : Type[Numeric] + The element type for output tensor D. + + Returns: + -------- + Either SmemLoadMatrix or SimtSyncCopy, based on the input parameters. + """ + + if not isinstance(elem_ty_c, cutlass.cutlass_dsl.NumericMeta): + raise TypeError(f"elem_ty_c must be a Numeric, but got {elem_ty_c}") + is_m_major = layout_c.is_m_major_c() + if elem_ty_c.width == 16: + return cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(is_m_major, 4), elem_ty_c, loc=loc, ip=ip + ) + else: + return cute.make_copy_atom(cute.nvgpu.CopyUniversalOp(), elem_ty_c, loc=loc, ip=ip) + + +def get_smem_store_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.StMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_load_atom( + arch: cutlass.Constexpr[int], element_type: Type[cute.Numeric], transpose: bool = False +) -> cute.CopyAtom: + if const_expr(arch < 90 or element_type.width != 16): + return cute.make_copy_atom( + cute.nvgpu.CopyUniversalOp(), + element_type, + num_bits_per_copy=(2 if not transpose else 1) * element_type.width, + ) + else: + return cute.make_copy_atom( + cute.nvgpu.warp.LdMatrix8x8x16bOp(transpose=transpose, num_matrices=4), + element_type, + ) + + +def get_smem_store_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sC = thr_copy.partition_D(sC) + else: + tRS_sC = partition_D_position_independent(thr_copy, sC) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sC[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sC + + +def get_smem_load_C( + tiled_mma: cute.TiledMma, + sC: cute.Tensor, + tidx: Int32, + arch: int, + transpose: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sC.element_type + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_C(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sC = thr_copy.partition_S(sC) + else: + tSR_sC = partition_S_position_independent(thr_copy, sC) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = thr_copy_RS.partition_S(cute.make_identity_tensor(sC.shape[:2])).shape + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sC[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + return copy_fn, thr_copy, tSR_sC + + +def get_smem_store_A( + tiled_mma: cute.TiledMma, sA: cute.Tensor, tidx: Int32, arch: int, position_independent=False +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_store_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tRS_sA = thr_copy.partition_D(sA) + else: + tRS_sA = partition_D_position_independent(thr_copy, sA) + + def copy_fn(src: cute.Tensor, dst_idx: Int32, **new_kwargs): + cvt_copy(tiled_copy, src, tRS_sA[None, None, None, dst_idx], retile=True, **new_kwargs) + + return copy_fn, thr_copy, tRS_sA + + +def get_smem_load_A( + tiled_mma: cute.TiledMma, + sA: cute.Tensor, + tidx: Int32, + arch: int, + with_dst_tensor: bool = False, + position_independent=False, +) -> Tuple[Callable, cute.TiledCopy, cute.Tensor]: + dtype = sA.element_type + transpose = tiled_mma.op.a_major_mode == warpgroup.OperandMajorMode.MN + copy_atom = get_smem_load_atom(arch, dtype, transpose) + tiled_copy = cute.make_tiled_copy_A(copy_atom, tiled_mma) + thr_copy = tiled_copy.get_slice(tidx) + if const_expr(not position_independent): + tSR_sA = thr_copy.partition_S(sA) + else: + tSR_sA = partition_S_position_independent(thr_copy, sA) + copy_atom_RS = get_smem_store_atom(arch, dtype, transpose) + thr_copy_RS = cute.make_tiled_copy_C(copy_atom_RS, tiled_mma).get_slice(tidx) + tRS_shape = tiled_mma.partition_shape_A(sA.shape[:2]) + + def copy_fn(src_idx: Int32, **new_kwargs): + return load_s2r_retile( + tiled_copy, tSR_sA[None, None, None, src_idx], dst_shape=tRS_shape, **new_kwargs + ) + + def copy_fn_w_dst_tensor(src_idx: Int32, dst: cute.Tensor, **new_kwargs): + return load_s2r_retile(tiled_copy, tSR_sA[None, None, None, src_idx], dst, **new_kwargs) + + return copy_fn if not with_dst_tensor else copy_fn_w_dst_tensor, thr_copy, tSR_sA + + +def tma_get_copy_fn( + atom: cute.CopyAtom, + cta_coord: cute.Coord, + cta_layout: cute.Layout, + src_tensor: cute.Tensor, + dst_tensor: cute.Tensor, + filter_zeros: bool = False, + single_stage: bool = False, + **kwargs, +) -> Callable: + src_is_smem = const_expr( + isinstance(src_tensor.iterator, cute.Pointer) + and src_tensor.memspace == cute.AddressSpace.smem + ) + smem_tensor, gmem_tensor = (src_tensor, dst_tensor) if src_is_smem else (dst_tensor, src_tensor) + group_rank_smem = const_expr(cute.rank(smem_tensor) - (1 if not single_stage else 0)) + group_rank_gmem = const_expr(cute.rank(gmem_tensor) - (1 if not single_stage else 0)) + # ((atom_v, rest_v), STAGE), ((atom_v, rest_v), RestK) + s, g = cpasync.tma_partition( + atom, + cta_coord, + cta_layout, + cute.group_modes(smem_tensor, 0, group_rank_smem), + cute.group_modes(gmem_tensor, 0, group_rank_gmem), + ) + if const_expr(filter_zeros): + s = cute.filter_zeros(s) + g = cute.filter_zeros(g) + src, dst = (s, g) if src_is_smem else (g, s) + + def copy_tma(src_idx, dst_idx, **new_kwargs): + cute.copy(atom, src[None, src_idx], dst[None, dst_idx], **new_kwargs, **kwargs) + + def copy_tma_single_stage(**new_kwargs): + cute.copy(atom, src, dst, **new_kwargs, **kwargs) + + return (copy_tma if const_expr(not single_stage) else copy_tma_single_stage), s, g + + +def tma_producer_copy_fn(copy: Callable, pipeline: cutlass.pipeline.PipelineAsync): + def copy_fn(src_idx, producer_state: cutlass.pipeline.PipelineState, **new_kwargs): + copy( + src_idx=src_idx, + dst_idx=producer_state.index, + tma_bar_ptr=pipeline.producer_get_barrier(producer_state), + **new_kwargs, + ) + + return copy_fn + + +@cute.jit +def gather_m_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (whatever, K) + sA: cute.Tensor, # (tile_M, tile_N, STAGE) + gsAIdx: cute.Tensor, # (tile_M), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + tAsA = thr_copy_A.partition_D(sA) + # k-major + assert tAsA.shape[2] == 1 + tAsA = cute.group_modes(cute.slice_(tAsA, (None, None, 0, None)), 0, 2) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_fragment(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + m_idx = cute.make_fragment(rows_per_thread, Int32) + for m in cutlass.range(rows_per_thread, unroll_full=True): + row_idx = tAcA[0, m, 0][0] + if tApA_m[m]: + m_idx[m] = gsAIdx[row_idx] + else: + m_idx[m] = 0 # It's ok to load row 0 in the case of OOB + + mA_k = cute.logical_divide(mA, (None, tile_shape_mk[1])) + + def copy_fn(src_idx, dst_idx, pred: bool = False): + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_fragment(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + mA_cur = mA_k[None, (None, src_idx)] + for m in cutlass.range_constexpr(tAcA.shape[1]): + # cute.tiled_divide(mA_cur[m_idx[m], None], (elems_per_load,)) would give shape + # ((elems_per_load), thread_per_row) + # But we actually want shape ((elems_per_load, 1), thread_per_row) to match tAsA + # So we append 1s to the last dimension and then do tiled_divide, then slice. + mA_row = cute.tiled_divide( + cute.append_ones(mA_cur[m_idx[m], None], up_to_rank=2), (elems_per_load, 1) + )[None, None, 0] + if const_expr(is_even_m_smem) or tApA_m[m]: + # There's only 1 load per row + assert cute.size(tAcA.shape, mode=[2]) == 1 + ki = tAcA[0, 0, 0][1] // elems_per_load + cute.copy(thr_copy_A, mA_row[None, ki], tAsA[(None, m), dst_idx], pred=tApA_k) + + return copy_fn + + +@cute.jit +def gather_k_get_copy_fn( + thr_copy_A: cute.ThrCopy, + mA: cute.Tensor, # (tile_M, whatever) + sA: cute.Tensor, # (tile_M, tile_N, STAGE) + gsAIdx: cute.Tensor, # (tile_K, RestK), either gmem or smem + limit_m: Int32, + limit_k: Int32, +) -> Callable: + gAIdx, sAIdx = None, None + if const_expr(gsAIdx.memspace == cute.AddressSpace.gmem): + gAIdx = gsAIdx + else: + assert gsAIdx.memspace == cute.AddressSpace.smem + sAIdx = gsAIdx + tile_shape_mk = (cute.size(sA, mode=[0]), cute.size(sA, mode=[1])) + # (atom_v, CPY_M, 1, STAGE) + tAsA = thr_copy_A.partition_D(sA) + # m-major + tAsA = cute.group_modes(tAsA, 0, 3) + + is_even_m_smem = tile_shape_mk[0] % thr_copy_A.tiler_mn[0].shape == 0 + if const_expr(not is_even_m_smem): + limit_m = min(limit_m, tile_shape_mk[0]) + elems_per_load = cute.size(tAsA.shape[0][0]) + cA = cute.make_identity_tensor(tile_shape_mk) + tAcA = thr_copy_A.partition_S(cA) + t0AcA = thr_copy_A.get_slice(0).partition_S(cA) + # Instead of comparing tAcA to limit_m, we instead compare t0AcA to limit_m - tAcA[0][0] + # since we know that tAcA[m][0] = t0AcA[m][0] + tAcA[0][0]. + # This is so that when we do the comparison, t0AcA is known at compile time. + limit_m = limit_m - tAcA[0][0] + limit_k = limit_k - tAcA[0][1] + # Read and cache indices for A + rows_per_thread = const_expr(cute.size(tAcA.shape, mode=[1])) + cols_per_thread = const_expr(cute.size(tAcA.shape, mode=[2])) + tApA_m = cute.make_fragment(rows_per_thread, Boolean) + for m in cutlass.range(rows_per_thread, unroll_full=True): + tApA_m[m] = t0AcA[0, m, 0][0] < limit_m + threads_per_col = const_expr(thr_copy_A.tiler_mn[0].shape // elems_per_load) + # This is very convoluted but idk a better way + # for tile_M=128, flat_divide gives (8, 16, K), + # then logical_divide gives ((8, 1), (8, 2), K). + tidx = thr_copy_A.thr_idx + tAmA = cute.logical_divide( + cute.flat_divide(mA, (elems_per_load,)), (elems_per_load, threads_per_col) + )[None, (tidx % threads_per_col, None), None] # ((8, 1), 2, K) + + def prefetch_from_gmem_fn(src_idx, pred: bool = False) -> Tuple[cute.Tensor, cute.Tensor]: + # Prefetch mAIdx early, even before smem is free + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_fragment(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + gAIdx_cur = gAIdx[None, src_idx] + k_idx = cute.make_fragment(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + if const_expr(not pred): + k_idx[k] = gAIdx_cur[col_idx] + else: + if tApA_k[k]: + k_idx[k] = gAIdx_cur[col_idx] + else: + k_idx[k] = -1 + return k_idx, tApA_k + + def prefetch_from_smem_fn( + a_prefetch_pipeline, src_idx, dst_idx, a_prefetch_consumer_state, pred: bool = False + ) -> Tuple[cute.Tensor, cute.Tensor]: + tApA_k = None + if const_expr(pred): + tApA_k = cute.make_fragment(cols_per_thread, Boolean) + limit_k_cur = limit_k - src_idx * tile_shape_mk[1] + for k in cutlass.range(cols_per_thread, unroll_full=True): + tApA_k[k] = t0AcA[0, 0, k][1] < limit_k_cur + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + sAIdx_cur = sAIdx[None, dst_idx] + k_idx = cute.make_fragment(cols_per_thread, Int32) + for k in cutlass.range(cols_per_thread): + col_idx = tAcA[0, 0, k][1] + k_idx[k] = sAIdx_cur[col_idx] + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + return k_idx, tApA_k + + def copy_fn( + src_idx, dst_idx, k_idx_tApA_k: Tuple[cute.Tensor, cute.Tensor], pred: bool = False + ): + k_idx, tApA_k = k_idx_tApA_k + tApA_k_pred = None + if const_expr(pred): + tApA_k_pred = cute.prepend_ones(tApA_k, up_to_rank=2) # (1, cols_per_thread) + for k in cutlass.range_constexpr(tAcA.shape[2]): + # copy_A(tAmA[None, None, k_idx[k]], tAsA[(None, None, k), smem_idx], pred=cute.prepend_ones(tApA_m, up_to_rank=2)) + for m in cutlass.range_constexpr(tAcA.shape[1]): + if tApA_m[m]: + cute.copy( + thr_copy_A, + tAmA[None, m, k_idx[k]], + tAsA[(None, m, k), dst_idx], + pred=None if const_expr(tApA_k_pred is None) else tApA_k_pred[None, k], + ) + + return copy_fn, prefetch_from_gmem_fn if const_expr( + gAIdx is not None + ) else prefetch_from_smem_fn diff --git a/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_ptxas.py b/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_ptxas.py new file mode 100644 index 00000000..4e00f3f0 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_ptxas.py @@ -0,0 +1,151 @@ +""" +System ptxas replacement for CUTLASS DSL. +Environment variables: + CUTE_DSL_PTXAS_PATH - Path to ptxas (e.g., /usr/local/cuda/bin/ptxas) + CUTE_DSL_PTXAS_VERBOSE - Set to 1 for verbose output +""" + +import os +import sys +import re +import ctypes +import subprocess +from pathlib import Path + +import cutlass + + +CUTE_DSL_PTXAS_PATH = os.environ.get("CUTE_DSL_PTXAS_PATH", None) +VERBOSE = os.environ.get("CUTE_DSL_PTXAS_VERBOSE", "0") == "1" + +_original_load_cuda_library = None +_user_wanted_ptx = False # True if user originally set CUTE_DSL_KEEP_PTX=1 + + +def _log(msg): + if VERBOSE: + print(f"[ptxas] {msg}", file=sys.stderr) + + +def _get_ptx(compiled_func) -> tuple[str, Path] | None: + """Find and read PTX file, stripping null bytes.""" + func_name = getattr(compiled_func, "function_name", None) + if not func_name: + return None + + dump_dir = os.environ.get("CUTE_DSL_DUMP_DIR", Path.cwd()) + for ptx_path in Path(dump_dir).glob(f"*{func_name}*.ptx"): + content = ptx_path.read_text().rstrip("\x00") + if ".entry " in content and content.rstrip().endswith("}"): + _log(f"Found PTX: {ptx_path}") + return content, ptx_path + return None + + +def _compile_ptx(ptx_path: Path, ptx_content: str) -> bytes: + """Compile PTX to cubin using system ptxas.""" + # Extract arch from PTX + match = re.search(r"\.target\s+(sm_\d+[a-z]?)", ptx_content) + arch = match.group(1) if match else "sm_90a" + + # Write stripped content back if needed + if ptx_path.read_text() != ptx_content: + ptx_path.write_text(ptx_content) + + # Compile + cubin_tmp = ptx_path.with_suffix(".cubin.tmp") + try: + assert CUTE_DSL_PTXAS_PATH is not None + result = subprocess.run( + [CUTE_DSL_PTXAS_PATH, f"-arch={arch}", "-O3", "-o", str(cubin_tmp), str(ptx_path)], + capture_output=True, + text=True, + ) + if result.returncode != 0: + raise RuntimeError(f"ptxas failed: {result.stderr}") + + cubin_data = cubin_tmp.read_bytes() + _log(f"Compiled {ptx_path.name} -> {len(cubin_data)} bytes ({arch})") + + # Save cubin if CUTE_DSL_KEEP_CUBIN is set + if os.environ.get("CUTE_DSL_KEEP_CUBIN", "0") == "1": + cubin_out = ptx_path.with_suffix(".cubin") + cubin_out.write_bytes(cubin_data) + _log(f"Saved: {cubin_out}") + + return cubin_data + finally: + cubin_tmp.unlink(missing_ok=True) + + +def _patched_load_cuda_library(self): + """Replacement for _load_cuda_library that uses system ptxas.""" + + result = _get_ptx(self) + if not result: + _log("PTX not found, falling back to embedded ptxas") + return _original_load_cuda_library(self) + + ptx_content, ptx_path = result + + try: + cubin = _compile_ptx(ptx_path, ptx_content) + except Exception as e: + _log(f"Compilation failed ({e}), falling back to embedded ptxas") + return _original_load_cuda_library(self) + + # Load cubin + import cuda.bindings.runtime as cuda_runtime + + err, library = cuda_runtime.cudaLibraryLoadData(cubin, None, None, 0, None, None, 0) + if err != cuda_runtime.cudaError_t.cudaSuccess: + _log(f"cudaLibraryLoadData failed ({err}), falling back to embedded ptxas") + return _original_load_cuda_library(self) + + # Register kernels on all devices + _, cuda_load_to_device = self._get_cuda_init_and_load() + lib_ptr = ctypes.c_void_p(int(library)) + dev_id = ctypes.c_int32(0) + err_val = ctypes.c_int32(0) + args = (ctypes.c_void_p * 3)( + ctypes.cast(ctypes.pointer(lib_ptr), ctypes.c_void_p), + ctypes.cast(ctypes.pointer(dev_id), ctypes.c_void_p), + ctypes.cast(ctypes.pointer(err_val), ctypes.c_void_p), + ) + + for dev in range(self.num_devices): + dev_id.value = dev + cuda_load_to_device(args) + if err_val.value != 0: + _log("cuda_load_to_device failed, falling back to embedded ptxas") + return _original_load_cuda_library(self) + + _log(f"Loaded kernel from {ptx_path.name}") + + # Delete PTX if user didn't originally want it kept + if not _user_wanted_ptx: + ptx_path.unlink(missing_ok=True) + + return [cuda_runtime.cudaLibrary_t(lib_ptr.value)] + + +def patch(): + """Install system ptxas hook. Call before importing cutlass.""" + global _original_load_cuda_library, _user_wanted_ptx + + assert CUTE_DSL_PTXAS_PATH is not None + if not os.path.isfile(CUTE_DSL_PTXAS_PATH) or not os.access(CUTE_DSL_PTXAS_PATH, os.X_OK): + raise RuntimeError(f"ptxas not found: {CUTE_DSL_PTXAS_PATH}") + + # Track if user originally wanted PTX kept + _user_wanted_ptx = os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1" + # os.environ['CUTE_DSL_KEEP_PTX'] = '1' + assert os.environ.get("CUTE_DSL_KEEP_PTX", "0") == "1", ( + "Require CUTE_DSL_KEEP_PTX=1 to use system's ptxas" + ) + + cls = cutlass.cutlass_dsl.cuda_jit_executor.CudaDialectJitCompiledFunction + _original_load_cuda_library = cls._load_cuda_library + cls._load_cuda_library = _patched_load_cuda_library + _log("Patch applied") + return diff --git a/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_utils.py new file mode 100644 index 00000000..9c92cf39 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/cute_dsl_utils.py @@ -0,0 +1,104 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from functools import lru_cache +from dataclasses import dataclass, fields + +import torch + +try: + from triton.tools.disasm import extract +except ImportError: + extract = None + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float16, BFloat16, Float32 +from cutlass.base_dsl.typing import JitArgument +from cutlass.cutlass_dsl import NumericMeta + + +StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None)) + + +load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data +cute_compile_og = cute.compile + + +torch2cute_dtype_map = { + torch.float16: Float16, + torch.bfloat16: BFloat16, + torch.float32: Float32, + torch.int32: Int32, + torch.int64: Int64, +} + + +@lru_cache +def get_max_active_clusters(cluster_size): + return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size) + + +@lru_cache +def get_device_capacity(device: torch.device = None) -> Tuple[int, int]: + return torch.cuda.get_device_capability(device) + + +@dataclass +class ParamsBase: + def __extract_mlir_values__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + values, self._values_pos = [], [] + for obj in non_constexpr_fields: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) + + +@dataclass +class ArgumentsBase(JitArgument): + def __c_pointers__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + c_ptrs = [] + for obj in non_constexpr_fields: + if hasattr(obj, "__c_pointers__"): + c_ptrs.extend(obj.__c_pointers__()) + return c_ptrs + + def __get_mlir_types__(self): + all_fields = [getattr(self, field.name) for field in fields(self)] + non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)] + types, self._values_pos = [], [] + for obj in non_constexpr_fields: + if hasattr(obj, "__get_mlir_types__"): + obj_types = obj.__get_mlir_types__() + types.extend(obj_types) + self._values_pos.append(len(obj_types)) + else: + self._values_pos.append(0) + return types + + def __new_from_mlir_values__(self, values): + all_fields = {field.name: getattr(self, field.name) for field in fields(self)} + constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)} + non_constexpr_fields = { + n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes) + } + for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos): + non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items]) + values = values[n_items:] + return self.__class__(**non_constexpr_fields, **constexpr_fields) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/fast_math.py b/sonic-moe/torch-ext/sonicmoe/quack/fast_math.py new file mode 100644 index 00000000..e581084c --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/fast_math.py @@ -0,0 +1,80 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Uint32 +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm + +from .cute_dsl_utils import ParamsBase + + +@cute.jit +def clz(x: Int32) -> Int32: + # for i in cutlass.range_constexpr(32): + # if (1 << (31 - i)) & x: + # return Int32(i) + # return Int32(32) + # Early exit is not supported yet + res = Int32(32) + done = False + for i in cutlass.range(32): + if ((1 << (31 - i)) & x) and not done: + res = Int32(i) + done = True + return res + + +def find_log2(x: Int32) -> Int32: + a: Int32 = Int32(31 - clz(x)) + return a + ((x & (x - 1)) != 0) # Round up, add 1 if not a power of 2. + + +@dsl_user_op +def umulhi(a: Int32, b: Int32, *, loc=None, ip=None) -> Uint32: + return Uint32( + llvm.inline_asm( + T.i32(), + [Int32(a).ir_value(loc=loc, ip=ip), Int32(b).ir_value(loc=loc, ip=ip)], + "mul.hi.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dataclass +class FastDivmod(ParamsBase): + divisor: Int32 + multiplier: Uint32 + shift_right: Uint32 + + # called by host + @staticmethod + def create(divisor: Int32) -> "FastDivmod": + """Construct the FastDivmod object, in host code. + This precomputes some values based on the divisor and is computationally expensive. + """ + p = Uint32(31 + find_log2(divisor)) + divisor_u32 = Uint32(divisor) + multiplier = Uint32(((cutlass.Uint64(1) << p) + divisor_u32 - 1) // divisor_u32) + shift_right = Uint32(p - 32) + return FastDivmod(divisor, multiplier, shift_right) + + @cute.jit + def div(self, dividend: Int32) -> Int32: + return ( + Int32(umulhi(dividend, self.multiplier) >> self.shift_right) + if self.divisor != 1 + else dividend + ) + + def divmod(self, dividend: Int32) -> Tuple[Int32, Int32]: + quotient = self.div(dividend) + remainder = dividend - quotient * self.divisor + return quotient, remainder diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm.py new file mode 100644 index 00000000..d3d3f1af --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm.py @@ -0,0 +1,194 @@ +from typing import Optional +from functools import partial + +from torch import Tensor + +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass import Float32 +from cutlass.cute.runtime import from_dlpack, make_ptr + +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters +from .gemm_wrapper_utils import GemmWrapperBase +from .gemm_default_epi import GemmDefaultSm90, GemmDefaultSm100 + + +def gemm( + # (l, m, k) or (total_m, k) if varlen_m or (m, total_k) if varlen_k or (whatever, k) if gather_A_varlen_m or (m, whatever) if gather_A_varlen_k + A: Tensor, + B: Tensor, # (l, n, k) or (n, total_k) if varlen_k + D: Tensor, # (l, m, n) or (total_m, n) if varlen_m + C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + tile_count_semaphore: Optional[Tensor], # (1,) + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + max_swizzle_size: int = 8, + rowvec_bias: Optional[Tensor] = None, # (l, n) + colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, + cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length + cu_seqlens_k: Optional[Tensor] = None, # (l+1,) cumulative sum of k values for variable length + A_idx: Optional[Tensor] = None, # (total_m,) or (total_k,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (l,) permutation of batch indices for scheduler + add_to_output: bool = False, +) -> None: + varlen = cu_seqlens_m is not None or cu_seqlens_k is not None + assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), ( + "Only one of cu_seqlens_m and cu_seqlens_k can be specified" + ) + gather_A = A_idx is not None + if gather_A: + assert varlen, "gather_A requires varlen (cu_seqlens_m or cu_seqlens_k must be specified)" + assert cluster_N == 1, "gather_A requires cluster_N=1" + if varlen: + assert persistent, "varlen requires persistent=True" + if add_to_output: + assert cu_seqlens_m is None, "Add to output not supported with varlen_m" + if cu_seqlens_m is not None: + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + assert D.stride(-1) == 1, "varlen_m requires D to be n-major" + if cu_seqlens_k is not None: + assert A.stride(-2) == 1, "varlen_k requires A to be m-major" + assert B.stride(-2) == 1, "varlen_k requires B to be n-major" + + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, B, D, C, cu_seqlens_m=cu_seqlens_m, cu_seqlens_k=cu_seqlens_k, A_idx=A_idx + ) + GemmWrapperBase.permute_tensors( + tensor_infos, varlen_m=cu_seqlens_m is not None, varlen_k=cu_seqlens_k is not None + ) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmDefaultSm100 if device_capacity[0] > 9 else GemmDefaultSm90 + + acc_dtype = Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + + def scalar_arg(scalar: float | Tensor): + if isinstance(scalar, float): + return Float32(scalar) if scalar != 1.0 else None + else: + assert isinstance(scalar, Tensor) + return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + + epi_args = GemmCls.EpilogueArguments( + scalar_arg(alpha), + scalar_arg(beta), + mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 + ) + if rowvec_bias is not None + else None, + mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 if cu_seqlens_m is None else 0 + ) + if colvec_bias is not None + else None, + add_to_output=add_to_output, + ) + scheduler_args = GemmWrapperBase.create_scheduler_args( + max_active_clusters, + tile_count_semaphore, + batch_idx_permute, + max_swizzle_size, + ) + + # Create varlen arguments if needed (assumes persistent=True when varlen) + varlen_args = GemmWrapperBase.create_varlen_args( + cu_seqlens_m, + cu_seqlens_k, + A_idx, + max_active_clusters, + cluster_shape_mnk, + tensor_infos, + GemmCls.num_epi_tensormaps, + pingpong, + ) + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + None, # activation + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + # Technically we don't need to recompile for different max_swizzle_size, but currently + # not recompiling will skew the autotuning results due to power throttling. + # Effectively we're recompiling as a way to pause between benchmarks during autotuning. + max_swizzle_size, + rowvec_bias.dtype if rowvec_bias is not None else None, + colvec_bias.dtype if colvec_bias is not None else None, + 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0), + 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0), + add_to_output, + cu_seqlens_m is not None, + cu_seqlens_k is not None, + gather_A, + batch_idx_permute is not None, + key_tensor_names=("A", "B", "D", "C"), + ) + cache = gemm.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm_obj = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + ) + cache[compile_key] = cute.compile( + gemm_obj, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_act.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_act.py new file mode 100644 index 00000000..efc2d8c0 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_act.py @@ -0,0 +1,510 @@ +# Copyright (c) 2025, Wentao Guo, Tri Dao. +from typing import Tuple, Optional, Callable +from functools import partial +from dataclasses import dataclass + +from torch import Tensor + +import cutlass +import cutlass.cute as cute +import cutlass.utils.hopper_helpers as sm90_utils_og +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass import Int32, Float32, Boolean, const_expr +from cutlass.cutlass_dsl import if_generate +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import from_dlpack + +from .cute_dsl_utils import ArgumentsBase, ParamsBase +from .varlen_utils import VarlenManager +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .gemm_default_epi import GemmDefaultEpiMixin +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters +from .gemm_wrapper_utils import GemmWrapperBase +from . import sm90_utils as sm90_utils +from . import copy_utils as copy_utils +from . import activation + + +class GemmActMixin(GemmDefaultEpiMixin): + num_epi_tensormaps: int = 1 + + @dataclass + class EpilogueArguments(ArgumentsBase): + mPostAct: cute.Tensor + act_fn: cutlass.Constexpr[Optional[Callable]] = None + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + + @dataclass + class EpilogueParams(ParamsBase): + tma_atom_postact: cute.CopyAtom + mPostAct_mnl: cute.Tensor + epi_postact_smem_layout_staged: cute.ComposedLayout + epi_tile_postact: cute.Tile + act_fn: cutlass.Constexpr[Optional[Callable]] = None + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + + def epi_to_underlying_arguments( + self, args: EpilogueArguments, *, loc=None, ip=None + ) -> EpilogueParams: + self.postact_dtype = args.mPostAct.element_type + self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) + + self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2] + epi_tile_postact = self.epi_tile + utils_cls = sm100_utils if self.arch == 100 else sm90_utils + epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( + self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage + ) + tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( + args.mPostAct, + epi_postact_smem_layout_staged, + epi_tile_postact, + op_type="store", + ) + # Assume all strides are divisible by 32 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s + for s in t.stride + ) + mRowVecBroadcast, mColVecBroadcast = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (args.mRowVecBroadcast, args.mColVecBroadcast) + ] + return self.EpilogueParams( + tma_atom_postact, + tma_tensor_postact, + epi_postact_smem_layout_staged, + epi_tile_postact, + args.act_fn, + alpha=args.alpha, + beta=args.beta, + mRowVecBroadcast=mRowVecBroadcast, + mColVecBroadcast=mColVecBroadcast, + ) + + def epi_get_tma_atoms( + self, params: EpilogueParams, *, loc=None, ip=None + ) -> list[cute.CopyAtom]: + return [params.tma_atom_postact] + + def epi_get_tensormap_update_shapes_orders( + self, + params: EpilogueParams, + cu_seqlens_m: Optional[cute.Tensor], + batch_idx: Int32, + *, + loc=None, + ip=None, + ) -> tuple[list[Int32], list[int]]: + shapes = [cu_seqlens_m[batch_idx + 1] if cu_seqlens_m is not None else None] + orders = [0 if const_expr(self.postact_layout.is_m_major_c()) else 1] + return shapes, orders + + @staticmethod + def epi_smem_bytes_per_stage( + args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile + ) -> int: + postact_dtype = args.mPostAct.element_type + postact_bytes_per_stage = cute.size(cute.shape(epi_tile)) * (postact_dtype.width // 8) + rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage( + args, cta_tile_shape_mnk, epi_tile + ) + return postact_bytes_per_stage + rowvec_colvec_bytes + + def epi_get_smem_struct(self, params: EpilogueParams): + row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1] + col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0] + row_vec_dtype = ( + params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32 + ) + col_vec_dtype = ( + params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32 + ) + + @cute.struct + class EpiSharedStorage: + sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16] + sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16] + sPostAct: cute.struct.Align[ + cute.struct.MemRange[ + self.postact_dtype, cute.cosize(params.epi_postact_smem_layout_staged) + ], + self.buffer_align_bytes, + ] + + return EpiSharedStorage + + def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: + sRowVec, sColVec = super().epi_get_smem_tensors(params, storage) + sPostAct = storage.epi.sPostAct.get_tensor( + params.epi_postact_smem_layout_staged.outer, + swizzle=params.epi_postact_smem_layout_staged.inner, + ) + return (sRowVec, sColVec, sPostAct) + + @cute.jit + def epilogue( + self, + params: EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + tma_desc_epi_ptrs: list[Optional[cute.Pointer]], + epi_pipeline: cutlass.pipeline.PipelineAsync, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + epi_read_state: cutlass.pipeline.PipelineState, + epi_producer_state: cutlass.pipeline.PipelineState, + epi_tile: cute.Tile, + load_acc_subtile: Callable, + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor], + tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100 + tiled_copy_r2s: cute.TiledCopy, + tRS_sD: cute.Tensor, + tiled_copy_s2r: Optional[cute.TiledCopy], + tSR_rC: Optional[cute.Tensor], + tSR_sC: Optional[cute.Tensor], + copy_D: Optional[Callable], + copy_C: Optional[Callable], + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tile_scheduler, + tidx: Int32, + is_tma_warp: Boolean, + ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: + has_C = const_expr(tRS_rC is not None) + has_D = const_expr(copy_D is not None) + + tma_atom_postact = params.tma_atom_postact + mPostAct_mnl = params.mPostAct_mnl + sRowVec, sColVec, sPostAct = epi_smem_tensors + get_smem_store_op = ( + partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r) + if self.arch == 100 + else sm90_utils_og.sm90_get_smem_store_op + ) + copy_atom_postact_r2s = get_smem_store_op( + self.postact_layout, self.postact_dtype, self.acc_dtype + ) + # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) + # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom) + tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s) + tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct) + (tma_desc_postact_ptr,) = tma_desc_epi_ptrs + batch_idx = tile_coord_mnkl[3] + copy_postact, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_postact, + varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx), + self.cta_tile_shape_postact_mn, + params.epi_tile_postact, + sPostAct, + tile_coord_mnkl, + tma_desc_ptr=tma_desc_postact_ptr, + ) + + # We iterate over epi tiles in the N dimension first before the M dimension + epi_tile_shape = cute.zipped_divide( + cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile + ).shape[1] + epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) + epi_tile_num = cute.size(epi_tile_shape) + num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num + + epi_tensors = self.epi_begin( + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + + if const_expr(copy_C is not None): + for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + def tma_store_fn(src_idx, dst_idx): + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + epilogue_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + if const_expr(has_D): + copy_D(src_idx=src_idx, dst_idx=dst_idx) + copy_postact(src_idx=src_idx, dst_idx=dst_idx) + # Can't use if statement here, epi_store_pipeline object isn't captured somehow + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) + epilogue_barrier.arrive_and_wait() + + delay_tma_store = True + + src_idx_prev, dst_idx_prev = None, None + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # The global memory coordinate for the current epi tile + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from acc to D registers + load_acc_subtile(tRS_rD, epi_idx) + epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) + if const_expr(has_C): + epi_pipeline.consumer_wait(epi_read_state) + cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) + # Fence to make sure shared memory read is visible to TMA load + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + with cute.arch.elect_one(): + epi_pipeline.consumer_release(epi_read_state) + epi_read_state.advance() + if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage + if const_expr(delay_tma_store): + if const_expr(epi_idx > 0): + tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) + src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord + # Copy from D registers to shared memory + if const_expr(has_D): + copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) + cute.copy( + tiled_copy_postact_r2s, + tiled_copy_postact_r2s.retile(tRS_rPostAct), + tRS_sPostAct[None, None, None, epi_buffer], + ) + if const_expr(not delay_tma_store): + tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord) + + if const_expr(delay_tma_store): + tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) + + self.epi_end( + params, + epi_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + + return epi_read_state, epi_producer_state + + @cute.jit + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC) + # Apply activation function if provided + # If we don't have .shape here, the compiler generates local stores and loads + if const_expr(params.act_fn is not None): + tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rPostAct[i] = params.act_fn(tRS_rD[i]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( + (tRS_rD[2 * i], tRS_rD[2 * i + 1]) + ) + else: + tRS_rPostAct = tRS_rD + # Type conversion + tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype) + tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) + return tRS_rPostAct_out + + +class GemmActSm90(GemmActMixin, GemmSm90): + pass + + +class GemmActSm100(GemmActMixin, GemmSm100): + pass + + +act_fn_map = { + None: None, + "relu": activation.relu, + "relu_sq": activation.relu_sq, + "gelu_tanh_approx": activation.gelu_tanh_approx, +} + + +def gemm_act( + A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m + B: Tensor, # (l, n, k) + D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + tile_count_semaphore: Optional[Tensor], # (1,) + activation: Optional[str], + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + max_swizzle_size: int = 8, + rowvec_bias: Optional[Tensor] = None, # (l, n) + colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m + cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length + A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m +) -> None: + if cu_seqlens_m is not None: + assert persistent, "varlen_m requires persistent=True" + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + if D is not None: + assert D.stride(-1) == 1, "varlen_m requires D to be n-major" + assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" + gather_A = A_idx is not None + if gather_A: + assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cluster_N == 1, "gather_A requires cluster_N=1" + assert activation in act_fn_map, f"Unsupported activation {activation}" + + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, B, D, C, additional_tensors={"PostAct": PostAct}, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx + ) + GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + "PostAct": ("m", "n", "l"), + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmActSm100 if device_capacity[0] > 9 else GemmActSm90 + + acc_dtype = Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + act_fn = act_fn_map[activation] + epi_args = GemmCls.EpilogueArguments( + tensor_infos["PostAct"].cute_tensor, + act_fn, + mRowVecBroadcast=from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 + ) + if rowvec_bias is not None + else None, + mColVecBroadcast=from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 if cu_seqlens_m is None else 0 + ) + if colvec_bias is not None + else None, + ) + scheduler_args = GemmWrapperBase.create_scheduler_args( + max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size + ) + + # Create varlen arguments if needed (assumes persistent=True when varlen_m) + varlen_args = GemmWrapperBase.create_varlen_args( + cu_seqlens_m, + None, # cu_seqlens_k + A_idx, + max_active_clusters, + cluster_shape_mnk, + tensor_infos, + GemmCls.num_epi_tensormaps, + pingpong, + ) + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + activation, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + max_swizzle_size, + rowvec_bias.dtype if rowvec_bias is not None else None, + colvec_bias.dtype if colvec_bias is not None else None, + cu_seqlens_m is not None, + A_idx is not None, + key_tensor_names=("A", "B", "D", "PostAct", "C"), + ) + cache = gemm_act.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm_obj = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + ) + cache[compile_key] = cute.compile( + gemm_obj, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm_act.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_config.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_config.py new file mode 100644 index 00000000..fa19a28b --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_config.py @@ -0,0 +1,95 @@ +# Copyright (C) 2025, Fri Dao. +import itertools +from typing import Optional, List, Literal +from functools import partial +from dataclasses import dataclass + + +@dataclass(frozen=True) +class GemmConfig: + tile_m: int = 128 + tile_n: int = 192 + pingpong: bool = True + cluster_m: int = 2 + cluster_n: int = 1 + swap_ab: bool = False + # raster_order: int = 1 + max_swizzle_size: int = 8 + + +def get_all_configs( + device_capacity: Literal[9, 10] = 9, + epilogue: Optional[str] = None, + tune_coop: bool = True, + # tune_raster_order=True, +) -> List[GemmConfig]: + assert device_capacity in [9, 10] + if device_capacity == 9: + tile_n_vals = [128, 144, 160, 176, 192, 208] + tile_mn_coop_vals = [(256, tile_n) for tile_n in tile_n_vals] + [ + (128, 224), + (128, 256), + # (192, 256), # Getting IOT instruction (core dumped) in the bwd + ] + tile_mn_pingpong_vals = [(128, tile_n) for tile_n in tile_n_vals] + [(192, 128)] + if epilogue in ["gated"]: + tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if n % 32 == 0 and m != 192] + tile_mn_pingpong_vals = [(m, n) for m, n in tile_mn_pingpong_vals if n % 32 == 0] + elif epilogue in ["lse"]: + tile_mn_coop_vals = [(m, n) for m, n in tile_mn_coop_vals if m != 192] + tile_mn_vals = [] + if tune_coop: + tile_mn_vals += [(m, n, False) for m, n in tile_mn_coop_vals] + tile_mn_vals += [(m, n, True) for m, n in tile_mn_pingpong_vals] + cluster = [(1, 2), (2, 1)] + # cluster = [(1, 1), (1, 2), (2, 1)] + if epilogue in ["lse"]: + cluster = [(1, 2), (2, 1)] + swap_ab_vals = [False, True] + if epilogue in ["lse", "gated"]: + swap_ab_vals = [False] + # raster_swizzle = ( + # [(0, 1)] + # if not tune_raster_order + # else [(1, 1), (1, 2), (1, 4), (1, 8), (2, 1), (2, 2), (2, 4), (2, 8)] + # ) + return [ + GemmConfig( + tile_m=tile_m, + tile_n=tile_n, + pingpong=pingpong, + cluster_m=cluster_m, + cluster_n=cluster_n, + swap_ab=swap_ab, + # raster_order=raster_order, + # max_swizzle_size=max_swizzle_size, + ) + for (tile_m, tile_n, pingpong), (cluster_m, cluster_n), swap_ab in itertools.product( + tile_mn_vals, + cluster, + swap_ab_vals, + # raster_swizzle, + ) + ] + elif device_capacity == 10: + tile_n_vals = [128, 160, 192, 224, 256] + tile_n_64_vals = [128, 192, 256] + tile_mn_cluster_vals = ( + [(128, tile_n, (1, 2)) for tile_n in tile_n_vals] + # + [(128, tile_n, (2, 1)) for tile_n in tile_n_64_vals] + + [(128, tile_n, (2, 1)) for tile_n in tile_n_vals] + + [(256, tile_n, (2, 1)) for tile_n in tile_n_vals] + ) + swap_ab_vals = [False, True] + if epilogue in ["lse", "gated"]: + swap_ab_vals = [False] + max_swizzle_size_vals = [4, 8, 16] + GemmConfigCls = partial(GemmConfig, pingpong=False) # There's no pingpong on Sm100 + return [ + GemmConfigCls( + tile_m=m, tile_n=n, cluster_m=cm, cluster_n=cn, swap_ab=sab, max_swizzle_size=ms + ) + for (m, n, (cm, cn)), sab, ms in itertools.product( + tile_mn_cluster_vals, swap_ab_vals, max_swizzle_size_vals + ) + ] diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_dact.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_dact.py new file mode 100644 index 00000000..a194933a --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_dact.py @@ -0,0 +1,215 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple +from functools import partial + +from torch import Tensor + +import cutlass +import cutlass.cute as cute +from cutlass import Float32, const_expr +import cutlass.torch as cutlass_torch + +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .gemm_default_epi import GemmDefaultEpiMixin +from .gemm_act import GemmActMixin +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters +from .gemm_wrapper_utils import GemmWrapperBase +from . import activation + + +class GemmDActMixin(GemmActMixin): + # Different from GemmActSm90, here act_bwd_fn must take in 2 arguments (x, dout) + # and return 2 arguments (dx, out) + EpilogueArguments = GemmActMixin.EpilogueArguments + EpilogueParams = GemmActMixin.EpilogueParams + + @cute.jit + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + assert tRS_rC is not None + # We don't add C to the accumulator + GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC=None) + tRS_rC_acc = cute.make_fragment_like(tRS_rC, self.acc_dtype) + tRS_rC_acc.store(tRS_rC.load().to(self.acc_dtype)) + # If we don't have .shape here, the compiler generates local stores and loads + if const_expr(params.act_fn is not None): + tRS_rPostAct = cute.make_fragment(tRS_rD.layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rD[i], tRS_rPostAct[i] = params.act_fn(tRS_rC_acc[i], tRS_rD[i]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + ( + (tRS_rD[2 * i], tRS_rD[2 * i + 1]), + (tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1]), + ) = params.act_fn( + (tRS_rC_acc[2 * i], tRS_rC_acc[2 * i + 1]), + (tRS_rD[2 * i], tRS_rD[2 * i + 1]), + ) + else: + tRS_rPostAct = tRS_rC_acc + # Type conversion + tRS_rPostAct_out = cute.make_fragment_like(tRS_rPostAct, self.postact_dtype) + tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) + return tRS_rPostAct_out + + +class GemmDActSm90(GemmDActMixin, GemmSm90): + pass + + +class GemmDActSm100(GemmDActMixin, GemmSm100): + pass + + +dact_fn_map = { + None: None, + "relu": activation.drelu, + "relu_sq": activation.drelu_sq, + "gelu_tanh_approx": activation.dgelu_tanh_approx, +} + + +def gemm_dact( + A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m + B: Tensor, # (l, n, k) + Out: Tensor, # (l, m, n) or (total_m, n) if varlen_m + PreAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + tile_count_semaphore: Optional[Tensor], # (1,) + activation: Optional[str], + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = True, + persistent: bool = True, + max_swizzle_size: int = 8, + cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length + A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m +) -> None: + if cu_seqlens_m is not None: + assert persistent, "varlen_m requires persistent=True" + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major" + assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major" + assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" + gather_A = A_idx is not None + if gather_A: + assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cluster_N == 1, "gather_A requires cluster_N=1" + assert activation in dact_fn_map, f"Unsupported activation {activation}" + + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, + B, + Out, + PreAct, + additional_tensors={"PostAct": PostAct}, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + "PostAct": ("m", "n", "l"), + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmDActSm100 if device_capacity[0] > 9 else GemmDActSm90 + + acc_dtype = Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + act_fn = dact_fn_map[activation] + epi_args = GemmCls.EpilogueArguments(tensor_infos["PostAct"].cute_tensor, act_fn) + scheduler_args = GemmWrapperBase.create_scheduler_args( + max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size + ) + + # Create varlen arguments if needed (assumes persistent=True when varlen_m) + varlen_args = GemmWrapperBase.create_varlen_args( + cu_seqlens_m, + None, # cu_seqlens_k + A_idx, + max_active_clusters, + cluster_shape_mnk, + tensor_infos, + GemmCls.num_epi_tensormaps, + pingpong, + ) + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + activation, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + max_swizzle_size, + cu_seqlens_m is not None, + A_idx is not None, + key_tensor_names=("A", "B", "D", "PostAct", "C"), + ) + cache = gemm_dact.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + ) + cache[compile_key] = cute.compile( + gemm, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm_dact.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_default_epi.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_default_epi.py new file mode 100644 index 00000000..9d22e4e8 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_default_epi.py @@ -0,0 +1,259 @@ +# Copyright (c) 2025, Wentao Guo, Tri Dao. +from typing import Optional, Tuple +from functools import partial +from dataclasses import dataclass + + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, Boolean, const_expr + +from .cute_dsl_utils import ArgumentsBase, ParamsBase +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .sm90_utils import partition_for_epilogue +from . import utils as utils +from . import copy_utils as copy_utils +from .varlen_utils import VarlenManager + + +class GemmDefaultEpiMixin: + num_epi_tensormaps: int = 0 + + @dataclass + class EpilogueArguments(ArgumentsBase): + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + add_to_output: bool = False + + @dataclass + class EpilogueParams(ParamsBase): + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + + def epi_to_underlying_arguments( + self, args: EpilogueArguments, *, loc=None, ip=None + ) -> EpilogueParams: + # Assume all strides are divisible by 32 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s + for s in t.stride + ) + mRowVecBroadcast, mColVecBroadcast = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (args.mRowVecBroadcast, args.mColVecBroadcast) + ] + return self.EpilogueParams( + alpha=args.alpha, + beta=args.beta, + mRowVecBroadcast=mRowVecBroadcast, + mColVecBroadcast=mColVecBroadcast, + ) + + @cute.jit + def epi_begin( + self, + params: EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + epi_tile: cute.Tile, + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tidx: Int32, + ): + alpha, beta = None, None + if const_expr(hasattr(params, "alpha") and params.alpha is not None): + alpha = utils.load_scalar_or_pointer(params.alpha) + if const_expr(hasattr(params, "beta") and params.beta is not None): + beta = utils.load_scalar_or_pointer(params.beta) + sRowVec, sColVec, *rest = epi_smem_tensors + tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1] + batch_idx = tile_coord_mnkl[3] + num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE + # Don't need sync as we assume the previous epilogue has finished + + partition_for_epilogue_fn = partial( + partition_for_epilogue, + epi_tile=epi_tile, + tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, + tidx=tidx, + reference_src=tiled_copy_t2r is None, + ) + + tDsRowVec = None + if const_expr(params.mRowVecBroadcast is not None): + rowvec_dtype = params.mRowVecBroadcast.element_type + num_copy_elems = const_expr(max(32, rowvec_dtype.width)) // rowvec_dtype.width + thr_copy_RV = copy_utils.tiled_copy_1d( + params.mRowVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True + ).get_slice(tidx) + mRowVec = params.mRowVecBroadcast[batch_idx, None] + gRowVec = cute.local_tile(mRowVec, (tile_N,), (tile_coord_mnkl[1],)) + tRVgRV = thr_copy_RV.partition_S(gRowVec) + tRVsRV = thr_copy_RV.partition_D(sRowVec) + tRVcRV = thr_copy_RV.partition_S(cute.make_identity_tensor(tile_N)) + limit_n = min(mRowVec.shape[0] - tile_coord_mnkl[1] * tile_N, tile_N) + tRVpRV = cute.make_fragment((1, cute.size(tRVsRV.shape[1])), Boolean) + for m in cutlass.range(cute.size(tRVsRV.shape[1]), unroll_full=True): + tRVpRV[0, m] = tRVcRV[0, m] < limit_n + cute.copy(thr_copy_RV, tRVgRV, tRVsRV, pred=tRVpRV) + # (CPY, CPY_M, CPY_N, EPI_M, EPI_N) + tDsRowVec = partition_for_epilogue_fn( + cute.make_tensor( + sRowVec.iterator, cute.make_layout((tile_M, tile_N), stride=(0, 1)) + ) + ) + if const_expr(tiled_copy_t2r is not None): + tDsRowVec = tiled_copy_r2s.retile(tDsRowVec) + + tDsColVec = None + if const_expr(params.mColVecBroadcast is not None): + colvec_dtype = params.mColVecBroadcast.element_type + num_copy_elems = const_expr(max(32, colvec_dtype.width)) // colvec_dtype.width + thr_copy_CV = copy_utils.tiled_copy_1d( + params.mColVecBroadcast.element_type, num_epi_threads, num_copy_elems, is_async=True + ).get_slice(tidx) + if const_expr(not varlen_manager.varlen_m): + mColVec = params.mColVecBroadcast[batch_idx, None] + else: + mColVec = cute.domain_offset( + (varlen_manager.params.cu_seqlens_m[batch_idx],), params.mColVecBroadcast + ) + gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],)) + tCVgCV = thr_copy_CV.partition_S(gColVec) + tCVsCV = thr_copy_CV.partition_D(sColVec) + tCVcCV = thr_copy_CV.partition_S(cute.make_identity_tensor(tile_M)) + limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M) + tCVpCV = cute.make_fragment((1, cute.size(tCVsCV.shape[1])), Boolean) + for m in cutlass.range(cute.size(tCVsCV.shape[1]), unroll_full=True): + tCVpCV[0, m] = tCVcCV[0, m] < limit_m + cute.copy(thr_copy_CV, tCVgCV, tCVsCV, pred=tCVpCV) + tDsColVec = partition_for_epilogue_fn( + cute.make_tensor( + sColVec.iterator, cute.make_layout((tile_M, tile_N), stride=(1, 0)) + ) + ) + if const_expr(tiled_copy_t2r is not None): + tDsColVec = tiled_copy_r2s.retile(tDsColVec) + + if const_expr(params.mRowVecBroadcast is not None or params.mColVecBroadcast is not None): + cute.arch.cp_async_commit_group() + cute.arch.cp_async_wait_group(0) + epilogue_barrier.arrive_and_wait() + return alpha, beta, tDsRowVec, tDsColVec + + def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord): + alpha, beta, tDsRowVec, tDsColVec = epi_tensors + tDrRowVec_cvt = None + if const_expr(tDsRowVec is not None): + tDsRowVec_cur = cute.group_modes(tDsRowVec, 3, cute.rank(tDsRowVec))[ + None, None, None, epi_coord + ] + # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur) + tDrRowVec = cute.make_fragment(tDsRowVec_cur.layout, tDsRowVec_cur.element_type) + cute.autovec_copy(cute.filter_zeros(tDsRowVec_cur), cute.filter_zeros(tDrRowVec)) + tDrRowVec_cvt = cute.make_fragment_like(tDrRowVec, self.acc_dtype) + tDrRowVec_cvt.store(tDrRowVec.load().to(self.acc_dtype)) + tDrColVec_cvt = None + if const_expr(tDsColVec is not None): + tDsColVec_cur = cute.group_modes(tDsColVec, 3, cute.rank(tDsColVec))[ + None, None, None, epi_coord + ] + # This somehow doesn't work, some dim with stride 0 turns to non-zero stride + # tDrRowVec = cute.make_fragment_like(tDsRowVec_cur) + tDrColVec = cute.make_fragment(tDsColVec_cur.layout, tDsColVec_cur.element_type) + cute.autovec_copy(cute.filter_zeros(tDsColVec_cur), cute.filter_zeros(tDrColVec)) + tDrColVec_cvt = cute.make_fragment_like(tDrColVec, self.acc_dtype) + tDrColVec_cvt.store(tDrColVec.load().to(self.acc_dtype)) + return alpha, beta, tDrRowVec_cvt, tDrColVec_cvt + + @cute.jit + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + alpha, beta, tDrRowVec, tDrColVec = epi_loop_tensors + rD = tRS_rD.load() + # Apply alpha scaling to accumulator if alpha is provided (not None) + if const_expr(hasattr(params, "alpha") and params.alpha is not None): + alpha = utils.load_scalar_or_pointer(params.alpha) + rD *= alpha + # Apply C with beta scaling + if const_expr(tRS_rC is not None): + if const_expr(not hasattr(params, "beta") or params.beta is None): + # beta is None, default behavior: add C (beta=1.0) + rD += tRS_rC.load().to(tRS_rD.element_type) + else: + beta = utils.load_scalar_or_pointer(params.beta) + rD += beta * tRS_rC.load().to(tRS_rD.element_type) + tRS_rD.store(rD) + if const_expr(tDrRowVec is not None): + for i in cutlass.range(cute.size(tDrRowVec), unroll_full=True): + tRS_rD[i] += tDrRowVec[i] + if const_expr(tDrColVec is not None): + for i in cutlass.range(cute.size(tDrColVec), unroll_full=True): + tRS_rD[i] += tDrColVec[i] + return None + + @staticmethod + def epi_smem_bytes_per_stage( + args: Optional[EpilogueArguments], + cta_tile_shape_mnk: Tuple[int, int, int], + epi_tile: cute.Tile, + ) -> int: + row_vec_smem_size = 0 if args.mRowVecBroadcast is None else cta_tile_shape_mnk[1] + col_vec_smem_size = 0 if args.mColVecBroadcast is None else cta_tile_shape_mnk[0] + row_vec_dtype = ( + args.mRowVecBroadcast.element_type if args.mRowVecBroadcast is not None else Float32 + ) + col_vec_dtype = ( + args.mColVecBroadcast.element_type if args.mColVecBroadcast is not None else Float32 + ) + return ( + row_vec_smem_size * row_vec_dtype.width + col_vec_smem_size * col_vec_dtype.width + ) // 8 + + def epi_get_smem_struct(self, params: EpilogueParams): + row_vec_smem_size = 0 if params.mRowVecBroadcast is None else self.cta_tile_shape_mnk[1] + col_vec_smem_size = 0 if params.mColVecBroadcast is None else self.cta_tile_shape_mnk[0] + row_vec_dtype = ( + params.mRowVecBroadcast.element_type if params.mRowVecBroadcast is not None else Float32 + ) + col_vec_dtype = ( + params.mColVecBroadcast.element_type if params.mColVecBroadcast is not None else Float32 + ) + + @cute.struct + class EpiSharedStorage: + sRowVec: cute.struct.Align[cute.struct.MemRange[row_vec_dtype, row_vec_smem_size], 16] + sColVec: cute.struct.Align[cute.struct.MemRange[col_vec_dtype, col_vec_smem_size], 16] + + return EpiSharedStorage + + def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: + sRowVec = None + if const_expr(params.mRowVecBroadcast is not None): + sRowVec = storage.epi.sRowVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[1])) + sColVec = None + if const_expr(params.mColVecBroadcast is not None): + sColVec = storage.epi.sColVec.get_tensor(cute.make_layout(self.cta_tile_shape_mnk[0])) + return (sRowVec, sColVec) + + +class GemmDefaultSm90(GemmDefaultEpiMixin, GemmSm90): + pass + + +class GemmDefaultSm100(GemmDefaultEpiMixin, GemmSm100): + pass diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_interface.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_interface.py new file mode 100644 index 00000000..8ea5b786 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_interface.py @@ -0,0 +1,1058 @@ +# Copyright (c) 2025, Tri Dao +from typing import Optional, Tuple, Literal +from functools import partial + +import torch +import torch.nn.functional as F +from torch import Tensor +from ._ops_compat import add_quack_op_namespace_prefix + +from .gemm_config import GemmConfig, get_all_configs + +from .autotuner import autotune, AutotuneConfig +from .cute_dsl_utils import get_device_capacity +from .gemm import gemm as gemm_sm90_sm100 +from .gemm_act import gemm_act as gemm_act_sm90_sm100 +from .gemm_dact import gemm_dact as gemm_dact_sm90_sm100 +from .gemm_symmetric import gemm_symmetric as gemm_symmetric_sm90_sm100 + + +# Dictionary mapping activation names to PyTorch functions +act_to_pytorch_fn_map = { + None: lambda x: x, + "relu": F.relu, + "relu_sq": lambda x: F.relu(x).square(), + "gelu_tanh_approx": partial(F.gelu, approximate="tanh"), +} + + +# Dictionary mapping gated activation names to their forward functions +# Each function takes (gate, up) and returns postact +gated_to_pytorch_fn_map = { + "swiglu": lambda gate, up: F.silu(gate) * up, + "swiglu_oai": lambda gate, up: gate * torch.sigmoid(1.702 * gate) * (up + 1), + "reglu": lambda gate, up: F.relu(gate) * up, + "geglu": lambda gate, up: F.gelu(gate, approximate="tanh") * up, + "glu": lambda gate, up: torch.sigmoid(gate) * up, +} + + +def _get_default_device_capacity(): + if not torch.cuda.is_available(): + return (9, 0) + cap = get_device_capacity(torch.device("cuda")) + if cap[0] not in (9, 10): + return (9, 0) + return cap + + +class _LazyDeviceCapacity: + """Defer torch.cuda.get_device_capability until first access so the + module can be imported in environments without a GPU (e.g. nix build).""" + _value = None + def __getitem__(self, idx): + if self._value is None: + self._value = _get_default_device_capacity() + return self._value[idx] + + +default_device_capacity = _LazyDeviceCapacity() + + +def default_config(device): + if get_device_capacity(device)[0] != 10: + return GemmConfig(tile_m=128, tile_n=192, cluster_m=2, cluster_n=1, pingpong=True) + else: + return GemmConfig(tile_m=256, tile_n=256, cluster_m=2, cluster_n=1, pingpong=False) + + +def prune_invalid_gemm_configs(configs, named_args: dict, **kwargs): + kwargs = named_args | kwargs + gather_A = kwargs.get("A_idx", None) is not None + varlen_m = kwargs.get("cu_seqlens_m", None) is not None + if varlen_m or gather_A: # Doesn't support swap_ab + configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] + if gather_A: + if get_device_capacity(kwargs["A"].device)[0] == 9: + # tile_n == 208 causes register spills, as gather_A requires more registers for the producer + configs = [ + conf + for conf in configs + if conf.kwargs["config"].cluster_n == 1 and conf.kwargs["config"].tile_n != 208 + ] + return configs + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + key=["dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_tuned( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + alpha: float | Tensor = 1.0, # (1,) + beta: float | Tensor = 1.0, # (1,) + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + cu_seqlens_k: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + add_to_output: bool = False, + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + varlen_k = cu_seqlens_k is not None + varlen = varlen_m or varlen_k + gather_A = A_idx is not None + if gather_A: + assert varlen, "gather_A requires either varlen_m or varlen_k" + assert config.cluster_n == 1, "gather_A requires cluster_n=1" + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) or (N, total_K) + if B.ndim == 2 and not varlen_k: + B = B.unsqueeze(0) # (1, N, K) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if out.ndim == 2 and not varlen_m: + out = out.unsqueeze(0) + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + batch_size = B.shape[0] if not varlen_k else cu_seqlens_k.shape[0] - 1 + if varlen_m: + # If gather_A (A_idx provided), use its length; otherwise use A.shape[0] + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-2]) + else: + out_shape = (batch_size, A.shape[-2], B.shape[-2]) + assert out.shape == out_shape, f"out shape mismatch: {out.shape} vs {out_shape}" + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + ) + gemm_sm90_sm100( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + out if not config.swap_ab else out.mT, + (C if not config.swap_ab else C.mT) if C is not None else None, + tile_count_semaphore, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + alpha=alpha, + beta=beta, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + ) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_act_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact + preact_out: Optional[Tensor], + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if preact_out is not None and preact_out.ndim == 2 and not varlen_m: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + ) + gemm_act_sm90_sm100( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0])], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_dact_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N) or (L, N, N) or (total_M, N) if varlen_m + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if PreAct.ndim == 2 and not varlen_m: + PreAct = PreAct.unsqueeze(0) # (1, M, N) + if dx_out.ndim == 2 and not varlen_m: + D = dx_out.unsqueeze(0) + else: + D = dx_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + ) + gemm_dact_sm90_sm100( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + D if not config.swap_ab else D.mT, + PreAct if not config.swap_ab else PreAct.mT, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + max_swizzle_size=config.max_swizzle_size, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + + +def gemm( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + alpha: float | Tensor = 1.0, + out_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tensor: + """GEMM with optional output tensor and tuning control.""" + if out is None: + out_dtype = A.dtype if out_dtype is None else out_dtype + varlen_m = cu_seqlens_m is not None + varlen_k = cu_seqlens_k is not None + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif varlen_k: + L = cu_seqlens_k.shape[0] - 1 + # For varlen_k, the first dimension is always A.shape[0] (M dimension) + out_shape = (L, A.shape[0], B.shape[-1]) + else: + out_shape = ( + (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1]) + ) + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + alpha_tensor = alpha if not isinstance(alpha, float) else None + alpha = alpha if isinstance(alpha, float) else 1.0 + gemm_out( + A, + B, + out, + bias=bias, + alpha=alpha, + alpha_tensor=alpha_tensor, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + ) + return out + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_out"), + mutates_args=("out",), + device_types="cuda", + # We have to split out alpha and alpha_tensor since torch.library requires + # each argument to have a fixed type + # schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? bias, float alpha=1.0, Tensor? alpha_tensor=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_out( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + alpha: float = 1.0, + alpha_tensor: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """GEMM with pre-allocated output tensor.""" + fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) + alpha = alpha_tensor if alpha_tensor is not None else alpha + fn( + A, + B, + out, + C=None, + bias=bias, + alpha=alpha, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + dynamic_scheduler=dynamic_scheduler, + ) + + +def gemm_ref( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + alpha: float | Tensor = 1.0, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + """Reference implementation for GEMM with pre-allocated output.""" + # The out_dtype argument requires torch >= 2.8 + out_dtype = A.dtype if out_dtype is None else out_dtype + if cu_seqlens_m is None and cu_seqlens_k is None: + fn = torch.bmm if A.ndim == 3 else torch.mm + out = fn(A, B, out_dtype=out_dtype, out=out) + if not isinstance(alpha, float) or alpha != 1.0: + out *= alpha + if bias is not None: + bias = bias if A.ndim == 2 else bias.unsqueeze(1) + out += bias + elif cu_seqlens_m is not None: + # Handle varlen_m case + if out is None: + # When gather_A (A_idx provided), output size is determined by A_idx length + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device) + for i in range(cu_seqlens_m.shape[0] - 1): + A_slice = ( + A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]] + if A_idx is not None + else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] + ) + torch.mm(A_slice, B[i], out=out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]) + if not isinstance(alpha, float) or alpha != 1.0: + out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] *= alpha + if bias is not None: + out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] += bias[i] + else: # cu_seqlens_k is not None + L = cu_seqlens_k.shape[0] - 1 + if out is None: + out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device) + for i in range(L): + A_slice = ( + A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]] + if A_idx is not None + else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]] + ) + torch.mm(A_slice, B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :], out=out[i]) + if not isinstance(alpha, float) or alpha != 1.0: + out *= alpha + if bias is not None: + out += bias + return out + + +def gemm_add( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k + out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, + out_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tensor: + """GEMM with addition and optional output tensor.""" + if out is None: + out_dtype = A.dtype if out_dtype is None else out_dtype + varlen_m = cu_seqlens_m is not None + varlen_k = cu_seqlens_k is not None + if varlen_m: + # If A_idx is provided (gather_A), use its length; otherwise use A.shape[0] + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif varlen_k: + L = cu_seqlens_k.shape[0] - 1 + # For varlen_k, the first dimension is always A.shape[0] (M dimension) + out_shape = (L, A.shape[0], B.shape[-1]) + else: + out_shape = ( + (A.shape[0], B.shape[-1]) if A.ndim == 2 else (A.shape[0], A.shape[-2], B.shape[-1]) + ) + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + add_to_output = C is out and isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None + alpha_tensor = alpha if not isinstance(alpha, float) else None + alpha = alpha if isinstance(alpha, float) else 1.0 + beta_tensor = beta if not isinstance(beta, float) else None + beta = beta if isinstance(beta, float) else 1.0 + gemm_add_out( + A, + B, + C if not add_to_output else None, + out, + alpha, + beta, + alpha_tensor, + beta_tensor, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + ) + return out + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_add_out"), + mutates_args=("out",), + device_types="cuda", + # We have to split out alpha and alpha_tensor since torch.library requires + # each argument to have a fixed type + # schema="(Tensor A, Tensor B, Tensor C, Tensor(a3!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_add_out( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k + out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + alpha: float = 1.0, + beta: float = 1.0, + alpha_tensor: Optional[Tensor] = None, + beta_tensor: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + add_to_output: bool = False, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """GEMM with addition and pre-allocated output tensor.""" + fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) + alpha = alpha_tensor if alpha_tensor is not None else alpha + beta = beta_tensor if beta_tensor is not None else beta + fn( + A, + B, + out, + C, + alpha=alpha, + beta=beta, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + dynamic_scheduler=dynamic_scheduler, + ) + + +def gemm_add_ref( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + C: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + out_dtype: Optional[torch.dtype] = None, +) -> Tensor: + """Reference implementation for GEMM with addition and pre-allocated output.""" + if cu_seqlens_m is None and cu_seqlens_k is None: + if isinstance(alpha, float) and isinstance(beta, float): + out = torch.addmm(C, A, B, out_dtype=out_dtype, alpha=alpha, beta=beta, out=out) + else: + out_dtype = ( + out.dtype if out is not None else (out_dtype if out_dtype is not None else A.dtype) + ) + result = (alpha * (A @ B) + beta * C).to(out_dtype) + if out is not None: + out.copy_(result) + if bias is not None: + bias = bias if A.ndim == 2 else bias.unsqueeze(1) + out += bias + elif cu_seqlens_m is not None: + # Handle varlen_m case + if out is None: + # When gather_A (A_idx provided), output size is determined by A_idx length + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_dtype = out_dtype if out_dtype is not None else A.dtype + out = torch.empty((total_m, B.shape[-1]), dtype=out_dtype, device=A.device) + for i in range(cu_seqlens_m.shape[0] - 1): + A_slice = ( + A[A_idx[cu_seqlens_m[i] : cu_seqlens_m[i + 1]]] + if A_idx is not None + else A[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] + ) + C_slice = C[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] + out_slice = out[cu_seqlens_m[i] : cu_seqlens_m[i + 1]] + result = alpha * torch.mm(A_slice, B[i]) + beta * C_slice + if bias is not None: + result += bias[i] + out_slice.copy_(result) + else: # cu_seqlens_k is not None + # Handle varlen_k case + L = cu_seqlens_k.shape[0] - 1 + out_dtype = out_dtype if out_dtype is not None else A.dtype + if out is None: + out = torch.empty((L, A.shape[0], B.shape[1]), dtype=out_dtype, device=A.device) + for i in range(L): + A_slice = ( + A[:, A_idx[cu_seqlens_k[i] : cu_seqlens_k[i + 1]]] + if A_idx is not None + else A[:, cu_seqlens_k[i] : cu_seqlens_k[i + 1]] + ) + B_slice = B[cu_seqlens_k[i] : cu_seqlens_k[i + 1], :] + result = alpha * torch.mm(A_slice, B_slice) + beta * C[i] + out[i].copy_(result) + if bias is not None: + out += bias + return out + + +def gemm_add_inplace( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """In-place GEMM with addition: out = alpha * A @ B + beta * out. + Args: + A: (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k - input tensor + B: (K, N) or (L, K, N) or (total_K, N) if varlen_k - input tensor + out: (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k - tensor to accumulate into (modified in-place) + alpha: Scalar multiplier for A @ B + beta: Scalar multiplier for out + cu_seqlens_m: Optional cumulative sequence lengths for variable M + cu_seqlens_k: Optional cumulative sequence lengths for variable K + dynamic_scheduler: Whether to use dynamic scheduler + tuned: Whether to use autotuned configuration + """ + alpha_tensor = alpha if not isinstance(alpha, float) else None + alpha = alpha if isinstance(alpha, float) else 1.0 + beta_tensor = beta if not isinstance(beta, float) else None + beta = beta if isinstance(beta, float) else 1.0 + gemm_add_inplace_op( + A, + B, + out, + alpha, + beta, + alpha_tensor, + beta_tensor, + cu_seqlens_m, + cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + dynamic_scheduler=dynamic_scheduler, + tuned=tuned, + ) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_add_inplace"), + mutates_args=("out",), + device_types="cuda", + # We have to split out alpha and alpha_tensor since torch.library requires + # each argument to have a fixed type + # schema="(Tensor A, Tensor B, Tensor(a2!) out, float alpha=1.0, float beta=1.0, Tensor? alpha_tensor=None, Tensor? beta_tensor=None, Tensor? cu_seqlens_m=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_add_inplace_op( + # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A with varlen_m or (M, whatever) if gather_A with varlen_k + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m or (L, M, N) if varlen_k + alpha: float = 1.0, + beta: float = 1.0, + alpha_tensor: Optional[Tensor] = None, + beta_tensor: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) or (total_K,) indices for gather_A when varlen + batch_idx_permute: Optional[Tensor] = None, # (L,) permutation of batch indices for scheduler + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + fn = gemm_tuned if tuned else partial(gemm_tuned.fn, config=None) + alpha = alpha_tensor if alpha_tensor is not None else alpha + beta = beta_tensor if beta_tensor is not None else beta + add_to_output = isinstance(beta, float) and beta == 1.0 and cu_seqlens_m is None + # Use out as both input bias and output + fn( + A, + B, + out, + out if not add_to_output else None, + alpha=alpha, + beta=beta, + cu_seqlens_m=cu_seqlens_m, + cu_seqlens_k=cu_seqlens_k, + A_idx=A_idx, + batch_idx_permute=batch_idx_permute, + add_to_output=add_to_output, + dynamic_scheduler=dynamic_scheduler, + ) + + +def gemm_act( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + store_preact: bool = True, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tuple[Optional[Tensor], Tensor]: + """GEMM with activation and optional output tensors.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + varlen_m = cu_seqlens_m is not None + # Determine output shape based on gather_A + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + if preact_out is None and store_preact: + preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device) + gemm_act_out( + A, + B, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + return preact_out, postact_out + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_act_out"), + mutates_args=("preact_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_act_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """GEMM with activation and pre-allocated output tensors.""" + fn = gemm_act_tuned if tuned else partial(gemm_act_tuned.fn, config=None) + fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler) + + +def gemm_act_ref( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + store_preact: bool = True, +) -> Tuple[Optional[Tensor], Tensor]: + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + if C is None: + out = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + else: + out = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + postact = act_to_pytorch_fn_map[activation](out).to(postact_dtype) + return out.to(out_dtype) if store_preact else None, postact + + +def gemm_dact( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + dx_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Tuple[Tensor, Tensor]: + """GEMM with activation gradient and optional output tensors.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype + varlen_m = cu_seqlens_m is not None + # Determine output shape based on gather_A + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + if dx_out is None: + dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(out_shape, dtype=postact_dtype, device=A.device) + gemm_dact_out( + A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler, tuned + ) + return dx_out, postact_out + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_dact_out"), + mutates_args=("dx_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, str? activation=None, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> ()", +) +def gemm_dact_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + dx_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> None: + """GEMM with activation gradient and pre-allocated output tensors.""" + fn = gemm_dact_tuned if tuned else partial(gemm_dact_tuned.fn, config=None) + fn(A, B, PreAct, dx_out, postact_out, activation, cu_seqlens_m, A_idx, dynamic_scheduler) + + +def gemm_dact_ref( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + PreAct: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + activation: Literal[None, "relu", "relu_sq", "gelu_tanh_approx"] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, +) -> Tuple[Tensor, Tensor]: + """Reference implementation for GEMM with activation gradient.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype + dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype) + postact = act_to_pytorch_fn_map[activation](PreAct) + # Compute gradient using autograd + if activation is None: + dx = dout + else: + PreAct_requires_grad = PreAct.requires_grad + PreAct.requires_grad_(True) + postact_for_grad = act_to_pytorch_fn_map[activation](PreAct) + dx = torch.autograd.grad(postact_for_grad, PreAct, dout, create_graph=False)[0] + PreAct.requires_grad_(PreAct_requires_grad) + return dx.to(out_dtype), postact.to(postact_dtype) + + +def gemm_gated_ref( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"] = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + store_preact: bool = True, +) -> Tuple[Optional[Tensor], Tensor]: + """Reference implementation for GEMM with gated activation forward. + + Args: + A: (M, K) - input tensor + B: (K, N) - weight tensor with gate and up projections + C: (M, N) - optional bias tensor + activation: Type of gated activation + out_dtype: Output dtype for preact + postact_dtype: Output dtype for postact + store_preact: Whether to return the pre-activation + + Returns: + (preact, postact) where: + - preact: (M, N) pre-activation (if store_preact=True, else None) + - postact: (M, N // 2) post-activation output + """ + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + if C is None: + preact = gemm_ref(A, B, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + else: + preact = gemm_add_ref(A, B, C, bias=bias, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx) + # Split preact into gate and up projections + gate = preact[..., ::2] # (M, N//2) + up = preact[..., 1::2] # (M, N//2) + postact = gated_to_pytorch_fn_map[activation](gate, up) + return preact.to(out_dtype) if store_preact else None, postact.to(postact_dtype) + + +def gemm_dgated_ref( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (M, total_K) if varlen_k or (whatever, K) if gather_A + B: Tensor, # (K, N) or (L, K, N) or (total_K, N) if varlen_k + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + activation: Literal["glu", "swiglu", "swiglu_oai", "reglu", "geglu"], + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, +) -> Tuple[Tensor, Tensor]: + """Reference implementation for GEMM with gated activation gradient. + + Args: + A: (M, K) - dout input tensor + B: (K, N) - weight tensor + PreAct: (M, 2*N) - pre-activation tensor with gate and up projections interleaved + activation: Type of gated activation + out_dtype: Output dtype for dx + postact_dtype: Output dtype for postact + + Returns: + (dx, postact) where: + - dx: (M, 2*N) gradient w.r.t. PreAct + - postact: (M, N) post-activation output + """ + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype + dout = gemm_ref(A, B, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx).to(out_dtype) + # Split PreAct into gate and up projections + gate = PreAct[..., ::2] # (M, N) + up = PreAct[..., 1::2] # (M, N) + # Use autograd to compute gradients w.r.t. gate and up + gate_requires_grad, up_requires_grad = gate.requires_grad, up.requires_grad + gate.requires_grad_(True) + up.requires_grad_(True) + postact = gated_to_pytorch_fn_map[activation](gate, up) + dgate, dup = torch.autograd.grad(postact, [gate, up], dout, create_graph=False) + gate.requires_grad_(gate_requires_grad) + up.requires_grad_(up_requires_grad) + # Interleave gradients back + dx = torch.stack([dgate, dup], dim=-1).reshape(PreAct.shape) + return dx.to(out_dtype), postact.to(postact_dtype) + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_symmetric_out"), + mutates_args=("out",), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!) out, Tensor? C=None, bool dynamic_scheduler=False, float alpha=1.0, float beta=1.0) -> ()", +) +def gemm_symmetric_out( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, M) or (L, K, M) + out: Tensor, # (M, M) or (L, M, M) + C: Optional[Tensor] = None, # (M, M) or (L, M, M) + dynamic_scheduler: bool = False, + alpha: float = 1.0, + beta: float = 1.0, +) -> None: + """GEMM with guaranteed symmetric output.""" + if A.ndim == 2: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (M, K) or (L, M, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, M, K) + if C is not None and C.ndim == 2: + C = C.unsqueeze(0) # (1, M, M) + if out.ndim == 2: + out = out.unsqueeze(0) + else: + out = out + tile_count_semaphore = ( + torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + ) + gemm_symmetric_sm90_sm100( + A, + B, + out if out is not None else None, + C if C is not None else None, + tile_count_semaphore, + tile_M=128, + tile_N=256, + cluster_M=2, + cluster_N=1, + pingpong=False, + persistent=True, + max_swizzle_size=8, + alpha=alpha, + beta=beta, + ) + + +def gemm_symmetric( + A: Tensor, # (M, K) or (L, M, K) + B: Tensor, # (K, M) or (L, K, M) + C: Optional[Tensor] = None, # (M, M) or (L, M, M) + out: Optional[Tensor] = None, # (M, M) or (L, M, M) + out_dtype: Optional[torch.dtype] = None, + dynamic_scheduler: bool = False, + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, +) -> Tuple[Optional[Tensor], Tensor]: + """GEMM with symmetric output.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + # Determine output shape based on gather_A + if A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + if out is None: + out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + + alpha_val = alpha if isinstance(alpha, float) else 1.0 + beta_val = beta if isinstance(beta, float) else 1.0 + + gemm_symmetric_out( + A, B, out, C, dynamic_scheduler=dynamic_scheduler, alpha=alpha_val, beta=beta_val + ) + return out + + +# TODO: this is not quite right, do we need to register gemm_add not gemm_add_out? +# try: +# from torch._inductor.fx_passes.reinplace import InplaceableOp +# torch._inductor.fx_passes.reinplace.inplaceable_ops.update({ +# torch.ops.quack.gemm_add_out.default: +# InplaceableOp(torch.ops.quack.gemm_add_inplace.default, mutated_arg=2) +# }) +# except ImportError: +# pass diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm100.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm100.py new file mode 100644 index 00000000..647e0f53 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm100.py @@ -0,0 +1,2809 @@ +# Based on the cute-dsl example: +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/blackwell/dense_gemm_persistent.py + +import argparse +from typing import Optional, Type, Tuple, Union, Callable, Literal +from functools import partial + +import cuda.bindings.driver as cuda +import torch + +import cutlass +import cutlass.cute as cute +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.torch as cutlass_torch +import cutlass.pipeline as pipeline +import cutlass.utils.blackwell_helpers as sm100_utils +import cutlass.utils.blockscaled_layout as blockscaled_utils +from cutlass.cute.nvgpu.warp import ( + LdMatrix8x8x16bOp, + LdMatrix16x16x8bOp, + StMatrix8x8x16bOp, + StMatrix16x8x8bOp, +) +from cutlass import Int32, Float32, Boolean, const_expr +from cutlass.utils import LayoutEnum +from cutlass.cute.runtime import from_dlpack, make_ptr + +from .pipeline import PipelineTmaCpAsyncUmma +from .cute_dsl_utils import ParamsBase, ArgumentsBase +from .tile_scheduler import TileSchedulerOptions +from .varlen_utils import VarlenArguments, VarlenManager +from .gemm_sm90 import GemmSm90, NamedBarrierGemm +from . import copy_utils as copy_utils +from . import sm100_utils as quack_sm100_utils + +# return PipelineStateWAdvance instead of PipelineState + +""" +A high-performance persistent batched dense GEMM example for the NVIDIA Blackwell SM100 architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Blackwell's tcgen05.mma for matrix multiply-accumulate (MMA) operations (including 2cta mma instructions) + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Support persistent tile scheduling to better overlap memory load/store with mma between tiles + - Support warp specialization to avoid explicit pipelining between mainloop load and mma + +This GEMM works as follows: +1. DMA warp: Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. MMA warp: Perform matrix multiply-accumulate (MMA) operations using tcgen05.mma instruction. +3. EPILOGUE warp: + - Load completed accumulator from tensor memory (TMEM) to registers (RMEM) using tcgen05.ld. + - Type convert C matrix to output type. + - Optionally store C matrix from registers (RMEM) to shared memory (SMEM) to global memory (GMEM) with TMA operations, + or directly store C matrix from registers (RMEM) to global memory (GMEM) without TMA operations. + +SM100 tcgen05.mma instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Write accumulator to TMEM +The accumulator in TMEM must then be loaded to registers before writing back to GMEM. + +Input arguments to this example is same as dense_gemm.py. + +.. code-block:: bash + + python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_2cta_instrs + +To collect performance with NCU profiler: + +.. code-block:: bash + + ncu python examples/blackwell/dense_gemm_persistent.py \ + --ab_dtype Float16 --d_dtype Float16 --acc_dtype Float32 \ + --mma_tiler_mn 256,128 --cluster_shape_mn 2,1 \ + --mnkl 8192,8192,8192,1 \ + --use_2cta_instrs \ + --warmup_iterations 1 --iterations 10 --skip_ref_check + + +Constraints are same as dense_gemm.py: +* Supported input data types: fp16, bf16, tf32, int8, uint8, fp8 (e4m3fn, e5m2), + see detailed valid dtype combinations in below GemmSm100 class documentation +* A/B tensor must have the same data type +* Mma tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) +* Mma tiler N must be 32-256, step 32 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 16 +* Cluster shape M must be multiple of 2 if use_2cta_instrs=True +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 4, 8, and 16 for TFloat32, + Float16/BFloat16, and Int8/Uint8/Float8, respectively. +* OOB tiles are not allowed when TMA store is disabled +""" + + +class GemmSm100(GemmSm90): + """This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Blackwell GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Shape of the Matrix Multiply-Accumulate (MMA) tile (M,N) + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: Cluster dimensions (M,N) for parallel processing + :type cluster_shape_mn: Tuple[int, int] + + :note: In current version, A and B tensor must have the same data type + - i.e., Float8E4M3FN for A and Float8E5M2 for B is not supported + + :note: Supported A/B data types: + - TFloat32 + - Float16/BFloat16 + - Int8/Uint8 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulator data types: + - Float32 (for all floating point A/B data types) + - Float16 (only for fp16 and fp8 A/B data types) + - Int32 (only for uint8/int8 A/B data types) + + :note: Supported C data types: + - Float32 (for float32 and int32 accumulator data types) + - Int32 (for float32 and int32 accumulator data types) + - Float16/BFloat16 (for fp16 and fp8 accumulator data types) + - Int8/Uint8 (for uint8/int8 accumulator data types) + - Float8E4M3FN/Float8E5M2 (for float32 accumulator data types) + + :note: Constraints: + - MMA tiler M must be 64/128 (use_2cta_instrs=False) or 128/256 (use_2cta_instrs=True) + - MMA tiler N must be 32-256, step 32 + - Cluster shape M must be multiple of 2 if use_2cta_instrs=True + - Cluster shape M/N must be positive and power of 2, total cluster size <= 16 + + Example: + >>> gemm = GemmSm100( + ... acc_dtype=Float32, + ... mma_tiler_mn=(128, 128), + ... cluster_shape_mn=(2, 2) + ... ) + >>> gemm(mA, mB, mD, max_active_clusters, stream) + """ + + arch = 100 + num_epi_tensormaps = GemmSm90.num_epi_tensormaps + + EpilogueArguments = GemmSm90.EpilogueArguments + EpilogueParams = GemmSm90.EpilogueParams + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + a_dtype: Type[cutlass.Numeric], # ignored for now + mma_tiler_mn: Tuple[int, int], + cluster_shape_mnk: Tuple[int, int, int], + sf_vec_size: Optional[int] = None, + gather_A: bool = False, + ): + """Initializes the configuration for a Blackwell dense GEMM kernel. + + This configuration includes several key aspects: + + 1. MMA Instruction Settings (tcgen05): + - acc_dtype: Data types for MMA accumulator. + - mma_tiler_mn: The (M, N) shape of the MMA instruction tiler. + - use_2cta_instrs: Boolean indicating if the tcgen05 MMA variant + with cta_group=2 should be used. + + 2. Cluster Shape: + - cluster_shape_mnk: The (ClusterM, ClusterN) shape of the CTA cluster. + + :param acc_dtype: Data type of the accumulator. + :type acc_dtype: type[cutlass.Numeric] + :param mma_tiler_mn: Tuple (M, N) shape of the MMA instruction. + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mnk: Tuple (ClusterM, ClusterN) shape of the cluster. + :type cluster_shape_mnk: Tuple[int, int] + """ + + self.acc_dtype: Type[cutlass.Numeric] = acc_dtype + self.use_2cta_instrs = cluster_shape_mnk[0] == 2 and mma_tiler_mn[0] in (256,) + self.cluster_shape_mnk = cluster_shape_mnk + assert cluster_shape_mnk[2] == 1, "Cluster shape K must be 1" + # K dimension is deferred in _setup_attributes + self.mma_tiler = (*mma_tiler_mn, 1) + self.sf_vec_size = sf_vec_size + self.blockscaled = sf_vec_size is not None + self.is_persistent = True + self.pingpong = False # for compatibility with GemmSm90 + self.gather_A = gather_A + if gather_A: + assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A " + + self.cta_group = tcgen05.CtaGroup.TWO if self.use_2cta_instrs else tcgen05.CtaGroup.ONE + + self.num_ab_load_warps = 1 if not self.gather_A else 5 + self.occupancy = 1 + # Set specialized warp ids + self.epilog_warp_id = (0, 1, 2, 3) + self.mma_warp_id = 4 + self.ab_load_warp_id = 5 + self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps + self.scheduler_warp_id = self.epi_load_warp_id + 1 + self.num_epi_warps = len(self.epilog_warp_id) + self.threads_per_cta = cute.arch.WARP_SIZE * ( + self.num_ab_load_warps + + len( + ( + self.mma_warp_id, + self.epi_load_warp_id, + self.scheduler_warp_id, + *self.epilog_warp_id, + ) + ) + ) + + def _setup_attributes(self, epilogue_args: EpilogueArguments, varlen_args: VarlenArguments): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + - Computing tensor memory allocation columns + """ + # Compute mma instruction shapes + mma_inst_bits_k = 256 + # (MMA_Tile_Shape_M, MMA_Tile_Shape_N, MMA_Inst_Shape_K) + self.mma_inst_shape_mnk = ( + self.mma_tiler[0], + self.mma_tiler[1], + mma_inst_bits_k // self.a_dtype.width, + ) + # (CTA_Tile_Shape_M, Round_Up(MMA_Tile_Shape_N, 128), MMA_Inst_Shape_K) + self.mma_inst_shape_mnk_sfb = ( + self.mma_inst_shape_mnk[0] // (2 if self.use_2cta_instrs else 1), + cute.round_up(self.mma_inst_shape_mnk[1], 128), + self.mma_inst_shape_mnk[2], + ) + + # Configure tiled mma + if const_expr(not self.blockscaled): + self.tiled_mma = sm100_utils.make_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.acc_dtype, + self.cta_group, + self.mma_tiler[:2], + ) + self.tiled_mma_sfb = None + else: + self.tiled_mma = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + self.cta_group, + self.mma_inst_shape_mnk[:2], + ) + self.tiled_mma_sfb = sm100_utils.make_blockscaled_trivial_tiled_mma( + self.a_dtype, + self.a_major_mode, + self.b_major_mode, + self.sf_dtype, + self.sf_vec_size, + tcgen05.CtaGroup.ONE, + self.mma_inst_shape_mnk_sfb[:2], + ) + + # Compute mma/cluster/tile shapes + mma_inst_tile_k = 4 + self.mma_tiler = ( + self.mma_inst_shape_mnk[0], + self.mma_inst_shape_mnk[1], + self.mma_inst_shape_mnk[2] * mma_inst_tile_k, + ) + if const_expr(self.blockscaled): + self.mma_tiler_sfb = ( + self.mma_inst_shape_mnk_sfb[0], + self.mma_inst_shape_mnk_sfb[1], + self.mma_inst_shape_mnk_sfb[2] * mma_inst_tile_k, + ) + else: + self.mma_tiler_sfb = None + self.cta_tile_shape_mnk = ( + self.mma_tiler[0] // cute.size(self.tiled_mma.thr_id.shape), + self.mma_tiler[1], + self.mma_tiler[2], + ) + + # Compute cluster layout + self.cluster_layout_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (self.tiled_mma.thr_id.shape,), + ) + if const_expr(self.blockscaled): + self.cluster_layout_sfb_vmnk = cute.tiled_divide( + cute.make_layout(self.cluster_shape_mnk), + (self.tiled_mma_sfb.thr_id.shape,), + ) + else: + self.cluster_layout_sfb_vmnk = None + + # Compute number of multicast CTAs for A/B + self.num_mcast_ctas_a = cute.size(self.cluster_layout_vmnk.shape[2]) + if self.gather_A: + assert self.num_mcast_ctas_a == 1 + self.num_mcast_ctas_b = cute.size(self.cluster_layout_vmnk.shape[1]) + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + if const_expr(self.blockscaled): + self.num_mcast_ctas_sfb = cute.size(self.cluster_layout_sfb_vmnk.shape[1]) + self.is_sfb_mcast = self.num_mcast_ctas_sfb > 1 + + # Compute epilogue subtile + self.epi_tile = sm100_utils.compute_epilogue_tile_shape( + self.cta_tile_shape_mnk, + self.use_2cta_instrs, + self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR, + self.d_dtype if self.d_dtype is not None else cutlass.BFloat16, + layout_c=self.c_layout, + elem_ty_c=self.c_dtype, + ) + + # Setup A/B/C stage count in shared memory and ACC stage count in tensor memory + prefetch_A_idx = ( + None + if not self.gather_A + else ("varlen_m" if varlen_args.mCuSeqlensM is not None else "varlen_k") + ) + ( + self.num_acc_stage, + self.ab_stage, + self.epi_stage, + self.epi_c_stage, + ) = self._compute_stages( + self.tiled_mma, + self.mma_tiler, + self.cta_tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.b_dtype, + self.sf_dtype, + self.sf_vec_size, + self.d_dtype, + self.c_dtype, + self.d_layout, + self.c_layout, + epilogue_args, + prefetch_A_idx, + cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity + self.occupancy, + ) + self.sched_stage = 1 + self.a_prefetch_stage = ( + 0 + if not self.gather_A + else (2 if varlen_args.mCuSeqlensM is not None else self.ab_stage) + ) + + # Compute A/B/SFA/SFB/C shared memory layout + self.a_smem_layout_staged = sm100_utils.make_smem_layout_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) + self.a_smem_load_layout_staged = self.a_smem_layout_staged + if const_expr(self.gather_A): + self.a_smem_load_layout_staged = quack_sm100_utils.make_smem_layout_cpasync_a( + self.tiled_mma, self.mma_tiler, self.a_dtype, self.ab_stage + ) + self.b_smem_layout_staged = sm100_utils.make_smem_layout_b( + self.tiled_mma, self.mma_tiler, self.b_dtype, self.ab_stage + ) + self.epi_smem_layout_staged = None + if const_expr(self.d_dtype is not None): + self.epi_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.d_dtype, self.d_layout, self.epi_tile, self.epi_stage + ) + self.epi_c_smem_layout_staged = None + if const_expr(self.c_dtype is not None): + self.epi_c_smem_layout_staged = sm100_utils.make_smem_layout_epi( + self.c_dtype, self.c_layout, self.epi_tile, self.epi_c_stage + ) + if const_expr(self.blockscaled): + self.sfa_smem_layout_staged = blockscaled_utils.make_smem_layout_sfa( + self.tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.ab_stage, + ) + self.sfb_smem_layout_staged = blockscaled_utils.make_smem_layout_sfb( + self.tiled_mma, + self.mma_tiler, + self.sf_vec_size, + self.ab_stage, + ) + else: + self.sfa_smem_layout_staged, self.sfb_smem_layout_staged = None, None + + # Compute the number of tensor memory allocation columns + if const_expr(not self.blockscaled): + self.num_tmem_alloc_cols = self._compute_num_tmem_alloc_cols( + self.tiled_mma, self.mma_tiler, self.num_acc_stage + ) + else: + SM100_TMEM_CAPACITY_COLUMNS = 512 + self.num_tmem_alloc_cols = SM100_TMEM_CAPACITY_COLUMNS + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: Optional[cute.Tensor], + mC: Optional[cute.Tensor], + epilogue_args: ArgumentsBase, + scheduler_args: TileSchedulerOptions, + varlen_args: Optional[VarlenArguments], + stream: cuda.CUstream, + mSFA: Optional[cute.Tensor] = None, + mSFB: Optional[cute.Tensor] = None, + ): + """Execute the GEMM operation in steps: + - Setup static attributes before smem/grid/tma computation + - Setup TMA load/store atoms and tensors + - Compute grid size with regard to hardware constraints + - Define shared storage for kernel + - Launch the kernel synchronously + + :param mA: Input tensor A + :type mA: cute.Tensor + :param mB: Input tensor B + :type mB: cute.Tensor + :param mD: Output tensor D + :type mD: cute.Tensor + :param max_active_clusters: Maximum number of active clusters + :type max_active_clusters: cutlass.Constexpr + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + :raises TypeError: If input data types are incompatible with the MMA instruction. + :raises AssertionError: If OOB (Out-Of-Bounds) tiles are present when TMA store is disabled. + """ + if const_expr(self.blockscaled): + assert mSFA is not None and mSFB is not None + # Setup static attributes before smem/grid/tma computation + self.a_dtype = mA.element_type + self.b_dtype = mB.element_type + self.d_dtype = mD.element_type if mD is not None else None + self.c_dtype = mC.element_type if mC is not None else None + self.sf_dtype: Optional[Type[cutlass.Numeric]] = ( + mSFA.element_type if mSFA is not None else None + ) + self.a_layout = LayoutEnum.from_tensor(mA) + self.b_layout = LayoutEnum.from_tensor(mB) + self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None + self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None + self.a_major_mode = LayoutEnum.from_tensor(mA).mma_major_mode() + self.b_major_mode = LayoutEnum.from_tensor(mB).mma_major_mode() + + # Check if input data types are compatible with MMA instruction + if const_expr(self.a_dtype != self.b_dtype): + raise TypeError(f"Type must match: {self.a_dtype} != {self.b_dtype}") + + if const_expr(varlen_args is None): + varlen_args = VarlenArguments() + assert (varlen_args.mAIdx is not None) == self.gather_A + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s + for s in t.stride + ) + mA, mD = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mA, mD) + ] + + # Setup attributes that dependent on gemm inputs + self._setup_attributes(epilogue_args, varlen_args) + + if const_expr(self.blockscaled): + # Setup sfa/sfb tensor by filling A/B tensor to scale factor atom layout + # ((Atom_M, Rest_M),(Atom_K, Rest_K),RestL) + sfa_layout = blockscaled_utils.tile_atom_to_shape_SF(mA.shape, self.sf_vec_size) + mSFA = cute.make_tensor(mSFA.iterator, sfa_layout) + # ((Atom_N, Rest_N),(Atom_K, Rest_K),RestL) + sfb_layout = blockscaled_utils.tile_atom_to_shape_SF(mB.shape, self.sf_vec_size) + mSFB = cute.make_tensor(mSFB.iterator, sfb_layout) + + atom_thr_size = cute.size(self.tiled_mma.thr_id.shape) + + # Setup TMA load for A & B + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, None, 0)) + tma_atom_a, tma_tensor_a = None, None + if const_expr(not self.gather_A): + a_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, self.tiled_mma.thr_id + ) + tma_atom_a, tma_tensor_a = cute.nvgpu.make_tiled_tma_atom_A( + a_op, + mA, + a_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if mA.element_type is Float32 else None), + ) + b_op = sm100_utils.cluster_shape_to_tma_atom_B( + self.cluster_shape_mnk, self.tiled_mma.thr_id + ) + tma_atom_b, tma_tensor_b = cute.nvgpu.make_tiled_tma_atom_B( + b_op, + mB, + b_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=(cutlass.TFloat32 if mB.element_type is Float32 else None), + ) + + tma_atom_sfa, tma_tensor_sfa = None, None + tma_atom_sfb, tma_tensor_sfb = None, None + if const_expr(self.blockscaled): + # Setup TMA load for SFA + sfa_op = sm100_utils.cluster_shape_to_tma_atom_A( + self.cluster_shape_mnk, self.tiled_mma.thr_id + ) + sfa_smem_layout = cute.slice_(self.sfa_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfa, tma_tensor_sfa = cute.nvgpu.make_tiled_tma_atom_A( + sfa_op, + mSFA, + sfa_smem_layout, + self.mma_tiler, + self.tiled_mma, + self.cluster_layout_vmnk.shape, + internal_type=cutlass.Int16, + ) + # Setup TMA load for SFB + sfb_op = sm100_utils.cluster_shape_to_tma_atom_SFB( + self.cluster_shape_mnk, self.tiled_mma.thr_id + ) + sfb_smem_layout = cute.slice_(self.sfb_smem_layout_staged, (None, None, None, 0)) + tma_atom_sfb, tma_tensor_sfb = cute.nvgpu.make_tiled_tma_atom_B( + sfb_op, + mSFB, + sfb_smem_layout, + self.mma_tiler_sfb, + self.tiled_mma_sfb, + self.cluster_layout_sfb_vmnk.shape, + internal_type=cutlass.Int16, + ) + + self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout) + if const_expr(not self.gather_A): + self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout) + if const_expr(self.blockscaled): + sfa_copy_size = cute.size_in_bytes(self.sf_dtype, sfa_smem_layout) + sfb_copy_size = cute.size_in_bytes(self.sf_dtype, sfb_smem_layout) + self.num_tma_load_bytes += sfa_copy_size + sfb_copy_size + self.num_tma_load_bytes *= atom_thr_size + + # Setup TMA store for D + tma_atom_d, tma_tensor_d = None, None + if const_expr(mD is not None): + tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( + mD, + self.epi_smem_layout_staged, + self.epi_tile, + op_type="store" + if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) + else "add", + ) + tma_atom_c, tma_tensor_c = None, None + if const_expr(mC is not None): + tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors( + mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load" + ) + + epilogue_params = self.epi_to_underlying_arguments(epilogue_args) + varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + + TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None) + tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args) + tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args) + grid = TileSchedulerCls.get_grid_shape( + tile_sched_params, scheduler_args.max_active_clusters + ) + + self.buffer_align_bytes = 1024 + + epi_smem_size = cute.cosize(self.epi_smem_layout_staged) if mD is not None else 0 + epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + sf_dtype = self.sf_dtype if const_expr(self.blockscaled) else cutlass.Float8E8M0FNU + sfa_smem_size = ( + cute.cosize(self.sfa_smem_layout_staged) if const_expr(self.blockscaled) else 0 + ) + sfb_smem_size = ( + cute.cosize(self.sfb_smem_layout_staged) if const_expr(self.blockscaled) else 0 + ) + a_idx_smem_size = 0 + if const_expr(self.gather_A): + a_idx_smem_size = self.a_prefetch_stage * ( + self.cta_tile_shape_mnk[0] + if varlen_args.mCuSeqlensM is not None + else self.cta_tile_shape_mnk[2] + ) + + # Define shared storage for kernel + @cute.struct + class SharedStorage: + ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] + epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2] + acc_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.num_acc_stage * 2] + sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] + a_prefetch_pipeline_array_ptr: cute.struct.MemRange[ + cutlass.Int64, self.a_prefetch_stage * 2 + ] + tile_count: cute.struct.MemRange[Int32, self.sched_stage] + tmem_dealloc_mbar_ptr: cutlass.Int64 + tmem_holding_buf: Int32 + sAIdx: cute.struct.Align[cute.struct.MemRange[Int32, a_idx_smem_size], 16] + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size + ], + self.buffer_align_bytes, + ] + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size + ], + self.buffer_align_bytes, + ] + epi: self.epi_get_smem_struct(epilogue_params) + # (MMA, MMA_M, MMA_K, STAGE) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged.outer)], + self.buffer_align_bytes, + ] + # (MMA, MMA_M, MMA_K, STAGE) + sSFA: cute.struct.Align[ + cute.struct.MemRange[sf_dtype, sfa_smem_size], + self.buffer_align_bytes, + ] + # (MMA, MMA_N, MMA_K, STAGE) + sSFB: cute.struct.Align[ + cute.struct.MemRange[sf_dtype, sfb_smem_size], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + self.tiled_mma, + self.tiled_mma_sfb, + tma_atom_a, + tma_tensor_a if const_expr(not self.gather_A) else mA, + tma_atom_b, + tma_tensor_b, + tma_atom_sfa, + tma_tensor_sfa, + tma_atom_sfb, + tma_tensor_sfb, + tma_atom_d, + tma_tensor_d, + tma_atom_c, + tma_tensor_c, + epilogue_params, + varlen_params, + self.cluster_layout_vmnk, + self.cluster_layout_sfb_vmnk, + self.a_smem_layout_staged, + self.a_smem_load_layout_staged, + self.b_smem_layout_staged, + self.sfa_smem_layout_staged, + self.sfb_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_c_smem_layout_staged, + self.epi_tile, + tile_sched_params, + TileSchedulerCls, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tiled_mma_sfb: Optional[cute.TiledMma], + tma_atom_a: Optional[cute.CopyAtom], + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_sfa: Optional[cute.CopyAtom], + mSFA_mkl: Optional[cute.Tensor], + tma_atom_sfb: Optional[cute.CopyAtom], + mSFB_nkl: Optional[cute.Tensor], + tma_atom_d: Optional[cute.CopyAtom], + mD_mnl: Optional[cute.Tensor], + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: Optional[cute.Tensor], + epilogue_params: ParamsBase, + varlen_params: VarlenManager.Params, + cluster_layout_vmnk: cute.Layout, + cluster_layout_sfb_vmnk: Optional[cute.Layout], + a_smem_layout: cute.ComposedLayout, + a_smem_load_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + sfa_smem_layout: Optional[cute.Layout], + sfb_smem_layout: Optional[cute.Layout], + epi_smem_layout: Union[cute.Layout, cute.ComposedLayout, None], + epi_c_smem_layout: Union[cute.Layout, cute.ComposedLayout, None], + epi_tile: cute.Tile, + tile_sched_params: ParamsBase, + TileSchedulerCls: cutlass.Constexpr[Callable], + ): + """ + GPU device kernel performing the Persistent batched GEMM computation. + """ + + varlen_m = const_expr(varlen_params.cu_seqlens_m is not None) + varlen_k = const_expr(varlen_params.cu_seqlens_k is not None) + assert not (varlen_m and varlen_k) + if const_expr(self.gather_A): + assert varlen_m or varlen_k + has_D = const_expr(mD_mnl is not None) + has_C = const_expr(mC_mnl is not None) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == self.ab_load_warp_id: + for tma_atom in ( + tma_atom_a, + tma_atom_b, + tma_atom_sfa, + tma_atom_sfb, + tma_atom_d, + tma_atom_c, + ): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + use_2cta_instrs = cute.size(tiled_mma.thr_id.shape) == 2 + + # + # Setup cta/thread coordinates + # + # Coords inside cluster + bidx, _, _ = cute.arch.block_idx() + mma_tile_coord_v = bidx % cute.size(tiled_mma.thr_id.shape) + is_leader_cta = mma_tile_coord_v == 0 + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + # Coord inside cta + tidx, _, _ = cute.arch.thread_idx() + + # + # Alloc and init: a+b full/empty, accumulator full/empty, tensor memory dealloc barrier + # + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + tmem_dealloc_mbar_ptr = storage.tmem_dealloc_mbar_ptr + tmem_holding_buf = storage.tmem_holding_buf + + # Tensor memory dealloc barrier init + if use_2cta_instrs: + if warp_idx == self.ab_load_warp_id: + num_tmem_dealloc_threads = 32 + cute.arch.mbarrier_init(tmem_dealloc_mbar_ptr, num_tmem_dealloc_threads) + + # Initialize pipelines and states + ab_pipeline = self.make_ab_pipeline( + tiled_mma=tiled_mma, + cluster_layout_vmnk=cluster_layout_vmnk, + ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(), + is_leader_cta=is_leader_cta, + ) + epi_pipeline = None + if const_expr(has_C): + epi_pipeline = self.make_epi_pipeline( + c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)), + epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(), + ) + acc_pipeline = self.make_acc_pipeline( + cluster_layout_vmnk=cluster_layout_vmnk, + acc_pipeline_mbar_ptr=storage.acc_pipeline_array_ptr.data_ptr(), + ) + sched_pipeline = None + tile_count = None + if const_expr(tile_sched_params.tile_count_semaphore is not None): + # Dynamic persistent scheduler + sched_pipeline = self.make_sched_pipeline( + self.cluster_shape_mnk, + sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(), + has_C=has_C, + ) + tile_count = storage.tile_count.get_tensor((self.sched_stage,)) + a_prefetch_pipeline = None + if const_expr(self.gather_A): + a_prefetch_pipeline = self.make_a_prefetch_pipeline( + storage.a_prefetch_pipeline_array_ptr.data_ptr(), + ) + + # Setup smem tensor A/B/D + # (MMA, MMA_M, MMA_K, STAGE) + sA_mma = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) + sA = storage.sA.get_tensor(a_smem_load_layout.outer, swizzle=a_smem_load_layout.inner) + # (MMA, MMA_N, MMA_K, STAGE) + sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) + sAIdx = None + if const_expr(self.gather_A): + a_idx_smem_dim = self.cta_tile_shape_mnk[0] if varlen_m else self.cta_tile_shape_mnk[2] + a_idx_smem_layout = cute.make_layout((a_idx_smem_dim, self.a_prefetch_stage)) + sAIdx = storage.sAIdx.get_tensor(a_idx_smem_layout) + sSFA, sSFB = None, None + if const_expr(self.blockscaled): + # (MMA, MMA_M, MMA_K, STAGE) + sSFA = storage.sSFA.get_tensor(sfa_smem_layout) + # (MMA, MMA_N, MMA_K, STAGE) + sSFB = storage.sSFB.get_tensor(sfb_smem_layout) + sD = None + if const_expr(has_D): + # (EPI_TILE_M, EPI_TILE_N, STAGE) + sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sC = None + if const_expr(has_C): + sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) + epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage) + + thr_mma = tiled_mma.get_slice(mma_tile_coord_v) + thr_mma_sfb = ( + tiled_mma_sfb.get_slice(mma_tile_coord_v) if const_expr(self.blockscaled) else None + ) + + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(self.mma_tiler[:2]) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, self.num_acc_stage)) + + varlen_manager = VarlenManager.create( + varlen_params, + has_D, + self.num_epi_tensormaps, + # Only used if not varlen_m + len_m_static=Int32( + mA_mkl.shape[0] + if varlen_k or varlen_params.mAIdx is None + else varlen_params.mAIdx.shape[0] + ), + len_k_static=Int32(mA_mkl.shape[1]), + ) + + TileSchedulerCls = partial( + TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline + ) + + tmem_alloc_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.TmemPtr), + num_threads=cute.arch.WARP_SIZE * len((self.mma_warp_id, *self.epilog_warp_id)), + ) + epi_load_barrier = None + if const_expr(has_C): + epi_load_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.EpilogueLoad), num_threads=2 * cute.arch.WARP_SIZE + ) + + # + # Specialized AB load warps + # + if warp_idx == self.ab_load_warp_id: + is_tma_warp = True + # initialize tensormap for A & B + varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp) + tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr() + tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr() + # Compute multicast mask for A/B buffer full + block_in_cluster_coord_vmnk = cluster_layout_vmnk.get_flat_coord(cta_rank_in_cluster) + block_in_cluster_coord_sfb_vmnk = None + if const_expr(self.blockscaled): + block_in_cluster_coord_sfb_vmnk = cluster_layout_sfb_vmnk.get_flat_coord( + cta_rank_in_cluster + ) + a_mcast_mask, b_mcast_mask = None, None + sfa_mcast_mask, sfb_mcast_mask = None, None + if const_expr(self.is_a_mcast or self.is_b_mcast or use_2cta_instrs): + a_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + b_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=1 + ) + if const_expr(self.blockscaled): + sfa_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_vmnk, block_in_cluster_coord_vmnk, mcast_mode=2 + ) + sfb_mcast_mask = cpasync.create_tma_multicast_mask( + cluster_layout_sfb_vmnk, block_in_cluster_coord_sfb_vmnk, mcast_mode=1 + ) + + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if const_expr(varlen_k): + # wait tensormap initialization complete before update + varlen_manager.fence_tensormap_init() + do_epi_load_barrier_arrive = Boolean(True) + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + varlen_manager.update_tensormap_AB( + batch_idx, + self.a_layout, + self.b_layout, + is_tma_warp, + ) + # /////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////// + mma_tile_coord_mnl = ( + tile_coord_mnkl[0] // cute.size(tiled_mma.thr_id.shape), + tile_coord_mnkl[1], + tile_coord_mnkl[3], + ) + gA_mk = None + if const_expr(not self.gather_A): + mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + # (bM, bK, RestK) + gA_mk = cute.local_tile( + mA_mk, + cute.select(self.mma_tiler, [0, 2]), + (mma_tile_coord_mnl[0], None), + ) + # (bN, bK, RestK) + gB_nk = cute.local_tile( + varlen_manager.offset_batch_B(mB_nkl, batch_idx), + cute.select(self.mma_tiler, [1, 2]), + (mma_tile_coord_mnl[1], None), + ) + if const_expr(self.blockscaled): + # (bM, bK) + gSFA_mkl = cute.local_tile( + varlen_manager.offset_batch_A(mSFA_mkl, batch_idx), + cute.select(self.mma_tiler, [0, 2]), + (mma_tile_coord_mnl[0], None), + ) + # (bN, bK) + gSFB_nkl = cute.local_tile( + varlen_manager.offset_batch_B(mSFB_nkl, batch_idx), + cute.select(self.mma_tiler, [1, 2]), + (mma_tile_coord_mnl[1], None), + ) + + # Partition global tensor for TiledMMA_A/B/D + # Then partition global/shared tensor for TMA load A/B + varlen_manager.fence_tensormap_update_AB(is_tma_warp) + len_k = varlen_manager.len_k(batch_idx) + # TMA load A partition_S/D + a_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, 0, None, 0)).shape + ) + copy_A = None + if const_expr(not self.gather_A): + # (MMA, MMA_M, MMA_K, RestK) + tCgA = thr_mma.partition_A(gA_mk) + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=block_in_cluster_coord_vmnk[2], + cta_layout=a_cta_layout, + src_tensor=tCgA, + dst_tensor=sA, + mcast_mask=a_mcast_mask, + tma_desc_ptr=tma_desc_a_ptr, + ) + # (MMA, MMA_N, MMA_K, RestK) + tCgB = thr_mma.partition_B(gB_nk) + if const_expr(self.blockscaled): + # (MMA, MMA_M, MMA_K) + tCgSFA = thr_mma.partition_A(gSFA_mkl) + # (MMA, MMA_N, MMA_K) + tCgSFB = thr_mma_sfb.partition_B(gSFB_nkl) + # TMA load B partition_S/D + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=block_in_cluster_coord_vmnk[1], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_vmnk, (0, None, 0, 0)).shape + ), + src_tensor=tCgB, + dst_tensor=sB, + mcast_mask=b_mcast_mask, + tma_desc_ptr=tma_desc_b_ptr, + ) + copy_SFA, copy_SFB = None, None + if const_expr(self.blockscaled): + # TMA load SFA partition_S/D + copy_SFA, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfa, + cta_coord=block_in_cluster_coord_vmnk[2], + cta_layout=a_cta_layout, + src_tensor=tCgSFA, + dst_tensor=sSFA, + filter_zeros=True, + mcast_mask=sfa_mcast_mask, + # tma_desc_ptr=tma_desc_sfa_ptr, + ) + # TMA load SFB partition_S/D + sfb_cta_layout = cute.make_layout( + cute.slice_(cluster_layout_sfb_vmnk, (0, None, 0, 0)).shape + ) + copy_SFB, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_sfb, + cta_coord=block_in_cluster_coord_sfb_vmnk[1], + cta_layout=sfb_cta_layout, + src_tensor=tCgSFB, + dst_tensor=sSFB, + filter_zeros=True, + mcast_mask=sfb_mcast_mask, + # tma_desc_ptr=tma_desc_sfa_ptr, + ) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_producer_state = self.load_AB( + ab_pipeline, + ab_producer_state, + copy_A, + copy_B, + k_tile_cnt, + copy_SFA, + copy_SFB, + ) + if const_expr(epi_load_barrier is not None): + # In the first work tile, the epi load warp will wait for the signal + # from the mainloop load warp to start loading C, to avoid interfering + # with loading A and B. + if do_epi_load_barrier_arrive: + epi_load_barrier.arrive() + do_epi_load_barrier_arrive = Boolean(False) + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # Wait A/B buffer empty + ab_pipeline.producer_tail(ab_producer_state) + + if const_expr(self.gather_A): + if ( + warp_idx >= self.ab_load_warp_id + 1 + and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps + ): + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + ab_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + a_prefetch_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.a_prefetch_stage + ) + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + # /////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////// + mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) + if const_expr(varlen_m): + # (M, K) + mA_mk = mA_mkl + else: + assert varlen_k + # (tile_M, K) + mA_mk = cute.local_tile( + mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) + ) + # Partition global tensor for TiledMMA_A/B/D + len_m = varlen_manager.len_m(batch_idx) + len_k = varlen_manager.len_k(batch_idx) + # TMA load A partition_S/D + tiled_copy_A = self._make_gmem_tiled_copy_A( + mA_mkl.element_type, self.a_layout, (self.num_ab_load_warps - 1) * 32 + ) + tidx = cute.arch.thread_idx()[0] - (self.ab_load_warp_id + 1) * 32 + thr_copy_A = tiled_copy_A.get_slice(tidx) + copy_A, prefetch_A = None, None + if const_expr(varlen_m): + a_prefetch_pipeline.consumer_wait(a_prefetch_consumer_state) + copy_A = copy_utils.gather_m_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + sAIdx[None, a_prefetch_consumer_state.index], + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + cute.arch.sync_warp() + with cute.arch.elect_one(): + a_prefetch_pipeline.consumer_release(a_prefetch_consumer_state) + a_prefetch_consumer_state.advance() + else: + copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + sAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + prefetch_A = partial(prefetch_A, a_prefetch_pipeline) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_producer_state, a_prefetch_consumer_state = self.load_A_gather_A( + ab_pipeline, + ab_producer_state, + a_prefetch_consumer_state, + copy_A, + prefetch_A, + k_tile_cnt, + ) + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # + # Specialized scheduler warp. Will also prefetch A indices if gatherA + # + if const_expr(tile_sched_params.tile_count_semaphore is not None or self.gather_A): + if warp_idx == self.scheduler_warp_id: + is_scheduler_warp = True + if const_expr(cute.size(cluster_layout_vmnk) > 1): + is_scheduler_warp = cute.arch.block_idx_in_cluster() == 0 + tile_M = self.cta_tile_shape_mnk[0] + tile_K = self.cta_tile_shape_mnk[2] + thr_copy_AIdx, tAsAIdx, tAcAIdx = None, None, None + if const_expr(self.gather_A): + tiled_copy_AIdx = copy_utils.tiled_copy_1d(Int32, num_threads=32, is_async=True) + thr_copy_AIdx = tiled_copy_AIdx.get_slice(cute.arch.lane_idx()) + tAsAIdx = thr_copy_AIdx.partition_D(sAIdx) + tAcAIdx = thr_copy_AIdx.partition_S( + cute.make_identity_tensor(tile_M if varlen_m else tile_K) + ) + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.initial_work_tile_info() + a_prefetch_producer_state = None + if const_expr(self.gather_A): + a_prefetch_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.a_prefetch_stage + ) + while work_tile.is_valid_tile: + if const_expr(self.gather_A): + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) + if const_expr(varlen_m): + # (tile_M,) + gAIdx = cute.local_tile(mAIdx_mk, (tile_M,), (tile_coord_mnkl[0],)) + tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) + len_m = varlen_manager.len_m(batch_idx) + m_limit = len_m - tile_coord_mnkl[0] * tile_M + tApAIdx_m = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean) + for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): + tApAIdx_m[0, m] = tAcAIdx[0, m] < m_limit + a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + cute.copy( + thr_copy_AIdx, + tAgAIdx, + tAsAIdx[None, None, a_prefetch_producer_state.index], + pred=tApAIdx_m, + ) + a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_producer_state.advance() + else: + # (tile_K, RestK) + gAIdx = cute.flat_divide(mAIdx_mk, (tile_K,)) + tAgAIdx = thr_copy_AIdx.partition_S(gAIdx) + len_k = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(len_k, tile_K) + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): + a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + cute.copy( + thr_copy_AIdx, + tAgAIdx[None, None, k_tile], + tAsAIdx[None, None, a_prefetch_producer_state.index], + ) + a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_producer_state.advance() + if 0 < k_tile_cnt: + k_tile = k_tile_cnt - 1 + k_limit = len_k - k_tile * tile_K + tApAIdx_k = cute.make_fragment((1, tAsAIdx.shape[1]), Boolean) + for m in cutlass.range(tAsAIdx.shape[1], unroll_full=True): + tApAIdx_k[0, m] = tAcAIdx[0, m] < k_limit + a_prefetch_pipeline.producer_acquire(a_prefetch_producer_state) + cute.copy( + tiled_copy_AIdx, + tAgAIdx[None, None, k_tile], + tAsAIdx[None, None, a_prefetch_producer_state.index], + pred=tApAIdx_k, + ) + a_prefetch_pipeline.producer_commit(a_prefetch_producer_state) + a_prefetch_producer_state.advance() + # Advance to next tile + tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + if is_scheduler_warp: + tile_scheduler.producer_tail() + + # + # Specialized TMA epi load warp + # + if const_expr(mC_mnl is not None): + if warp_idx == self.epi_load_warp_id: + epi_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.epi_c_stage + ) + do_epi_load_barrier_wait = Boolean(True) + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + copy_C_fn, _, bGS_gC = self.epilog_gmem_copy_and_partition( + tma_atom_c, + varlen_manager.offset_batch_epi(mC_mnl, batch_idx), + self.cta_tile_shape_mnk[:2], + epi_tile, + sC, + tile_coord_mnkl, + ) + copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline) + if do_epi_load_barrier_wait: + epi_load_barrier.arrive_and_wait() + do_epi_load_barrier_wait = Boolean(False) + epi_tile_num = const_expr(cute.size(bGS_gC, mode=[1])) + for epi_idx in cutlass.range(epi_tile_num, unroll=1): + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=epi_idx, producer_state=epi_producer_state) + # Epi pipeline's producer commit is a NOP + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + epi_pipeline.producer_tail(epi_producer_state) + + # + # Specialized MMA warp + # + if warp_idx == self.mma_warp_id: + tmem_alloc_barrier.arrive_and_wait() + # Retrieving tensor memory ptr and make accumulator tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # Partition shared/tensor memory tensor for TiledMMA_A/B/D + # (MMA, MMA_M, MMA_K, STAGE) + tCrA = tiled_mma.make_fragment_A(sA_mma) + # (MMA, MMA_N, MMA_K, STAGE) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + if const_expr(self.blockscaled): + # Make SFA tmem tensor + sfa_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base), + dtype=self.sf_dtype, + ) + # (MMA, MMA_M, MMA_K) + tCtSFA_layout = blockscaled_utils.make_tmem_layout_sfa( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfa_smem_layout, (None, None, None, 0)), + ) + tCtSFA = cute.make_tensor(sfa_tmem_ptr, tCtSFA_layout) + # Make SFB tmem tensor + sfb_tmem_ptr = cute.recast_ptr( + acc_tmem_ptr + + tcgen05.find_tmem_tensor_col_offset(tCtAcc_base) + + tcgen05.find_tmem_tensor_col_offset(tCtSFA), + dtype=self.sf_dtype, + ) + # (MMA, MMA_N, MMA_K) + tCtSFB_layout = blockscaled_utils.make_tmem_layout_sfb( + tiled_mma, + self.mma_tiler, + self.sf_vec_size, + cute.slice_(sfb_smem_layout, (None, None, None, 0)), + ) + tCtSFB = cute.make_tensor(sfb_tmem_ptr, tCtSFB_layout) + # Partition for S2T copy of SFA/SFB + ( + tiled_copy_s2t_sfa, + tCsSFA_compact_s2t, + tCtSFA_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFA, tCtSFA) + ( + tiled_copy_s2t_sfb, + tCsSFB_compact_s2t, + tCtSFB_compact_s2t, + ) = self.mainloop_s2t_copy_and_partition(sSFB, tCtSFB) + else: + tCtSFA, tCtSFB = None, None + tiled_copy_s2t_sfa, tCsSFA_compact_s2t, tCtSFA_compact_s2t = None, None, None + tiled_copy_s2t_sfb, tCsSFB_compact_s2t, tCtSFB_compact_s2t = None, None, None + + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + ab_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.ab_stage + ) + acc_producer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Producer, self.num_acc_stage + ) + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + k_len = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(k_len, self.mma_tiler[2]) + # Set tensor memory buffer for current tile + # (MMA, MMA_M, MMA_N) + tCtAcc = tCtAcc_base[None, None, None, acc_producer_state.index] + ab_consumer_state, acc_producer_state, tiled_mma = self.mma( + ab_pipeline, + acc_pipeline, + ab_consumer_state, + acc_producer_state, + tiled_mma, + tCrA, + tCrB, + tCtAcc, + k_tile_cnt, + is_leader_cta, + cta_rank_in_cluster, + tCtSFA, + tCtSFB, + tiled_copy_s2t_sfa, + tiled_copy_s2t_sfb, + tCsSFA_compact_s2t, + tCsSFB_compact_s2t, + tCtSFA_compact_s2t, + tCtSFB_compact_s2t, + ) + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # Wait for accumulator buffer empty + acc_pipeline.producer_tail(acc_producer_state) + + # + # Specialized epilogue warps + # + if warp_idx < self.mma_warp_id: + # Alloc tensor memory buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.alloc_tmem( + self.num_tmem_alloc_cols, tmem_holding_buf, is_two_cta=use_2cta_instrs + ) + # Bar sync for retrieve tensor memory ptr from shared memory + tmem_alloc_barrier.arrive_and_wait() + + is_tma_warp = Boolean(warp_idx == self.epilog_warp_id[0]) + varlen_manager.init_tensormap_epi( + tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp + ) + tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr() + tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs() + + # Retrieving tensor memory ptr and make accumulator tensor + acc_tmem_ptr = cute.arch.retrieve_tmem_ptr( + self.acc_dtype, alignment=16, ptr_to_buffer_holding_addr=tmem_holding_buf + ) + # (MMA, MMA_M, MMA_N, STAGE) + tCtAcc_base = cute.make_tensor(acc_tmem_ptr, tCtAcc_fake.layout) + + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), + num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, + ) + + # Partition for epilogue + epi_tidx = tidx + tiled_copy_t2r, tTR_tAcc_base, tTR_rAcc = self.epilog_tmem_copy_and_partition( + epi_tidx, tCtAcc_base, epi_tile, use_2cta_instrs + ) + + tTR_rD = cute.make_fragment(tTR_rAcc.shape, self.acc_dtype) + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( + tiled_copy_t2r, self.d_layout, self.d_dtype, tTR_rD, sD, epi_tidx + ) + tRS_rC, tSR_rC, tSR_sC = None, None, None + tiled_copy_s2r = None + if const_expr(mC_mnl is not None): + tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( + tiled_copy_t2r, self.c_layout, self.c_dtype, sC, tRS_rD.layout, epi_tidx + ) + + # Persistent tile scheduling loop + tile_scheduler = TileSchedulerCls() + work_tile = tile_scheduler.initial_work_tile_info() + acc_consumer_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.num_acc_stage + ) + epi_store_pipeline = self.make_epi_store_pipeline() + epi_read_state = pipeline.make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_c_stage + ) + if const_expr(varlen_m): + # wait tensormap initialization complete before update + varlen_manager.fence_tensormap_init() + while work_tile.is_valid_tile: + # Get tile coord from tile scheduler + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders( + epilogue_params, varlen_params.cu_seqlens_m, batch_idx + ) + varlen_manager.update_tensormap_epi( + batch_idx, + self.d_layout, + epi_shapes, + epi_orders, + is_tma_warp, + ) + + # Set tensor memory buffer for current tile + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M) + tTR_tAcc = tTR_tAcc_base[None, None, None, None, None, acc_consumer_state.index] + + # Wait for accumulator buffer full + acc_pipeline.consumer_wait(acc_consumer_state) + + varlen_manager.fence_tensormap_update_epi(is_tma_warp) + + copy_D = None + if const_expr(has_D): + copy_D, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_d, + varlen_manager.offset_batch_epi(mD_mnl, batch_idx), + self.cta_tile_shape_mnk[:2], + epi_tile, + sD, + tile_coord_mnkl, + tma_desc_ptr=tma_desc_d_ptr, + ) + copy_C = None # We're using a separate warp to load C + + tTR_tAcc = cute.group_modes(tTR_tAcc, 3, cute.rank(tTR_tAcc)) + k_len = varlen_manager.len_k(batch_idx) + load_acc_subtile = partial( + self.epi_load_acc_subtile, + tiled_copy_t2r, + tiled_copy_r2s, + tTR_tAcc, + tTR_rAcc, + clear_acc=varlen_k and k_len == 0, + ) + + epi_read_state, _ = self.epilogue( + epilogue_params, + epi_smem_tensors, + tma_desc_epi_ptrs, + epi_pipeline, + epi_store_pipeline, + epi_read_state, + None, # epi_producer_state + epi_tile, + load_acc_subtile, + tRS_rD, + tRS_rC, + tiled_copy_t2r, + tiled_copy_r2s, + tRS_sD, + tiled_copy_s2r, + tSR_rC, + tSR_sC, + copy_D, + copy_C, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tile_scheduler, + epi_tidx, + is_tma_warp, + ) + + # Async arrive accumulator buffer empty + with cute.arch.elect_one(): + acc_pipeline.consumer_release(acc_consumer_state) + acc_consumer_state.advance() + + # Advance to next tile + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + + # Dealloc the tensor memory buffer + if warp_idx == self.epilog_warp_id[0]: + cute.arch.relinquish_tmem_alloc_permit(is_two_cta=use_2cta_instrs) + epilogue_barrier.arrive_and_wait() + if warp_idx == self.epilog_warp_id[0]: + if const_expr(use_2cta_instrs): + cute.arch.mbarrier_arrive(tmem_dealloc_mbar_ptr, cta_rank_in_cluster ^ 1) + cute.arch.mbarrier_wait(tmem_dealloc_mbar_ptr, 0) + cute.arch.dealloc_tmem( + acc_tmem_ptr, self.num_tmem_alloc_cols, is_two_cta=use_2cta_instrs + ) + + # Wait for D store complete + if is_tma_warp: + epi_store_pipeline.producer_tail() + + @cute.jit + def load_A_gather_A( + self, + a_pipeline: cutlass.pipeline.PipelineAsync, + a_producer_state: cutlass.pipeline.PipelineState, + a_prefetch_consumer_state: Optional[cutlass.pipeline.PipelineState], + copy_A: Callable, + prefetch_A: Optional[Callable], + k_tile_cnt: Int32, + ) -> Tuple[cutlass.pipeline.PipelineState, Optional[cutlass.pipeline.PipelineState]]: + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt + peek_a_empty_status = Boolean(True) + if 0 < k_tile_cnt: + peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) + # ///////////////////////////////////////////////////////////////////////// + # cp.async on A + # ///////////////////////////////////////////////////////////////////////// + is_tma_warp = False + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): + smem_idx = a_producer_state.index + prefetch_out = () + if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state),) + a_prefetch_consumer_state.advance() + a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp) + copy_A(k_tile, smem_idx, *prefetch_out) + # This tells mbarrier to track the completion of cp.async + a_pipeline.producer_cpasync_commit(a_producer_state) + a_producer_state.advance() + peek_a_empty_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_a_empty_status = a_pipeline.producer_try_acquire(a_producer_state) + # bound checking in the K dimension on the last k_tile + if 0 < k_tile_cnt: + k_tile = k_tile_cnt - 1 + smem_idx = a_producer_state.index + prefetch_out = () + if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + prefetch_out = (prefetch_A(k_tile, smem_idx, a_prefetch_consumer_state, pred=True),) + a_prefetch_consumer_state.advance() + a_pipeline.producer_acquire(a_producer_state, peek_a_empty_status, is_tma_warp) + copy_A(k_tile, smem_idx, *prefetch_out, pred=True) + a_pipeline.producer_cpasync_commit(a_producer_state) + a_producer_state.advance() + return a_producer_state, a_prefetch_consumer_state + + @cute.jit + def mma( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + acc_pipeline: cutlass.pipeline.PipelineAsync, + ab_consumer_state: cutlass.pipeline.PipelineState, + acc_producer_state: cutlass.pipeline.PipelineState, + tiled_mma: cute.TiledMma, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + acc: cute.Tensor, + k_tile_cnt: Int32, + is_leader_cta: Boolean, + cta_rank_in_cluster: Int32, + tCtSFA: Optional[cute.Tensor] = None, + tCtSFB: Optional[cute.Tensor] = None, + tiled_copy_s2t_sfa: Optional[cute.TiledCopy] = None, + tiled_copy_s2t_sfb: Optional[cute.TiledCopy] = None, + tCsSFA_compact_s2t: Optional[cute.Tensor] = None, + tCsSFB_compact_s2t: Optional[cute.Tensor] = None, + tCtSFA_compact_s2t: Optional[cute.Tensor] = None, + tCtSFB_compact_s2t: Optional[cute.Tensor] = None, + ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState, cute.TiledMma]: + blockscaled = const_expr(tiled_copy_s2t_sfa is not None) + if const_expr(blockscaled): + assert all(x is not None for x in (tCtSFA, tCtSFB)) + assert all(x is not None for x in (tiled_copy_s2t_sfa, tiled_copy_s2t_sfb)) + assert all(x is not None for x in (tCsSFA_compact_s2t, tCsSFB_compact_s2t)) + assert all(x is not None for x in (tCtSFA_compact_s2t, tCtSFB_compact_s2t)) + # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will + # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader + # CTA will wait for that then arrive at the mbarrier on the leader CTA. + need_nonleader_cta = const_expr(self.gather_A and self.use_2cta_instrs) + # Peek (try_wait) AB buffer full for k_tile = 0 + peek_ab_full_status = Boolean(True) + if 0 < k_tile_cnt and (is_leader_cta or need_nonleader_cta): + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + # Wait for accumulator buffer empty + if is_leader_cta: + acc_pipeline.producer_acquire(acc_producer_state) + # Reset the ACCUMULATE field for each tile + tiled_mma.set(tcgen05.Field.ACCUMULATE, False) + # Mma mainloop + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_tile in cutlass.range(k_tile_cnt, unroll=1): + if const_expr(need_nonleader_cta): + if not is_leader_cta: + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + with cute.arch.elect_one(): + # The odd CTA signals the even CTA + ab_pipeline.sync_object_full.arrive_mbarrier( + ab_consumer_state.index, dst_rank=cta_rank_in_cluster & 0xFE + ) + if is_leader_cta: + # Conditionally wait for AB buffer full + ab_pipeline.consumer_wait(ab_consumer_state, peek_ab_full_status) + # Copy SFA/SFB from smem to tmem + if const_expr(blockscaled): + s2t_stage_coord = (None, None, None, None, ab_consumer_state.index) + tCsSFA_compact_s2t_staged = tCsSFA_compact_s2t[s2t_stage_coord] + tCsSFB_compact_s2t_staged = tCsSFB_compact_s2t[s2t_stage_coord] + cute.copy(tiled_copy_s2t_sfa, tCsSFA_compact_s2t_staged, tCtSFA_compact_s2t) + cute.copy(tiled_copy_s2t_sfb, tCsSFB_compact_s2t_staged, tCtSFB_compact_s2t) + for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_blk_coord = (None, None, k_blk_idx, ab_consumer_state.index) + if const_expr(blockscaled): + # Set SFA/SFB tensor to tiled_mma + sf_kblock_coord = (None, None, k_blk_idx) + tiled_mma.set(tcgen05.Field.SFA, tCtSFA[sf_kblock_coord].iterator) + tiled_mma.set(tcgen05.Field.SFB, tCtSFB[sf_kblock_coord].iterator) + cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + # Async arrive AB buffer empty + ab_pipeline.consumer_release(ab_consumer_state) + ab_consumer_state.advance() + # Peek (try_wait) AB buffer full for k_tile = k_tile + 1 + peek_ab_full_status = Boolean(True) + if k_tile + 1 < k_tile_cnt and (is_leader_cta or need_nonleader_cta): + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_consumer_state) + # Async arrive accumulator buffer full + if is_leader_cta: + acc_pipeline.producer_commit(acc_producer_state) + acc_producer_state.advance() + # If we don't return the tiled_mma, we get compiler error + # "operand #0 does not dominate this use" + return ab_consumer_state, acc_producer_state, tiled_mma + + @cute.jit + def epi_load_acc_subtile( + self, + tiled_copy_t2r: cute.TiledCopy, + tiled_copy_r2s: cute.TiledCopy, + tTR_tAcc: cute.Tensor, + tTR_rAcc: cute.Tensor, + tRS_rD: cute.Tensor, + epi_idx: int, + clear_acc: Boolean = False, + ): + if not clear_acc: + # Load accumulator from tensor memory buffer to register + cute.copy(tiled_copy_t2r, tTR_tAcc[None, None, None, epi_idx], tTR_rAcc) + tRS_rAcc = tiled_copy_r2s.retile(tTR_rAcc) + tRS_rD.store(tRS_rAcc.load()) + else: + tRS_rD.fill(0.0) + + def mainloop_s2t_copy_and_partition( + self, + sSF: cute.Tensor, + tSF: cute.Tensor, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for smem to tmem load for scale factor tensor, then use it to partition smem memory (source) and tensor memory (destination). + + :param sSF: The scale factor tensor in smem + :type sSF: cute.Tensor + :param tSF: The scale factor tensor in tmem + :type tSF: cute.Tensor + + :return: A tuple containing (tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t) where: + - tiled_copy_s2t: The tiled copy operation for smem to tmem load for scale factor tensor(s2t) + - tCsSF_compact_s2t: The partitioned scale factor tensor in smem + - tSF_compact_s2t: The partitioned scale factor tensor in tmem + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # (MMA, MMA_MN, MMA_K, STAGE) + tCsSF_compact = cute.filter_zeros(sSF) + # (MMA, MMA_MN, MMA_K) + tCtSF_compact = cute.filter_zeros(tSF) + # Make S2T CopyAtom and tiledCopy + copy_atom_s2t = cute.make_copy_atom(tcgen05.Cp4x32x128bOp(self.cta_group), self.sf_dtype) + tiled_copy_s2t = tcgen05.make_s2t_copy(copy_atom_s2t, tCtSF_compact) + thr_copy_s2t = tiled_copy_s2t.get_slice(0) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t_ = thr_copy_s2t.partition_S(tCsSF_compact) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K, STAGE) + tCsSF_compact_s2t = tcgen05.get_s2t_smem_desc_tensor(tiled_copy_s2t, tCsSF_compact_s2t_) + # ((ATOM_V, REST_V), Rest_Tiler, MMA_MN, MMA_K) + tCtSF_compact_s2t = thr_copy_s2t.partition_D(tCtSF_compact) + return tiled_copy_s2t, tCsSF_compact_s2t, tCtSF_compact_s2t + + def epilog_tmem_copy_and_partition( + self, + tidx: Int32, + tAcc: cute.Tensor, + epi_tile: cute.Tile, + use_2cta_instrs: Union[Boolean, bool], + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for tensor memory load, then use it to partition tensor memory (source) and register array (destination). + + :param tidx: The thread index in epilogue warp groups + :type tidx: Int32 + :param tAcc: The accumulator tensor to be copied and partitioned + :type tAcc: cute.Tensor + :param epi_tile: The epilogue tiler + :type epi_tile: cute.Tile + :param use_2cta_instrs: Whether use_2cta_instrs is enabled + :type use_2cta_instrs: bool + + :return: A tuple containing (tiled_copy_t2r, tTR_tAcc, tTR_rAcc) where: + - tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + - tTR_tAcc: The partitioned accumulator tensor + - tTR_rAcc: The accumulated tensor in register used to hold t2r results + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + # Make tiledCopy for tensor memory load + copy_atom_t2r = sm100_utils.get_tmem_load_op( + self.cta_tile_shape_mnk, + self.d_layout if self.d_layout is not None else LayoutEnum.ROW_MAJOR, + self.d_dtype if self.d_dtype is not None else cutlass.BFloat16, + self.acc_dtype, + epi_tile, + use_2cta_instrs, + ) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N, STAGE) + tAcc_epi = cute.flat_divide(tAcc[((None, None), 0, 0, None)], epi_tile) + # (EPI_TILE_M, EPI_TILE_N) + tiled_copy_t2r = tcgen05.make_tmem_copy(copy_atom_t2r, tAcc_epi[(None, None, 0, 0, 0)]) + + thr_copy_t2r = tiled_copy_t2r.get_slice(tidx) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_M, STAGE) + tTR_tAcc = thr_copy_t2r.partition_S(tAcc_epi) + + cAcc = cute.make_identity_tensor((self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1])) + # (EPI_TILE_M, EPI_TILE_N, EPI_M, EPI_N) + cAcc_epi = cute.flat_divide(cAcc, epi_tile) + # (T2R, T2R_M, T2R_N, EPI_M, EPI_N) + tTR_cAcc = thr_copy_t2r.partition_D(cAcc_epi) + # (T2R, T2R_M, T2R_N) + tTR_rAcc = cute.make_fragment(tTR_cAcc[None, None, None, 0, 0].shape, self.acc_dtype) + return tiled_copy_t2r, tTR_tAcc, tTR_rAcc + + def epilog_smem_store_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + d_layout: Optional[LayoutEnum], + dtype: Optional[Type[cutlass.Numeric]], + tTR_rD: cute.Tensor, + sD: cute.Tensor, + tidx: Int32, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + """ + Make tiledCopy for shared memory store, then use it to partition register array (source) and shared memory (destination). + + :param tiled_copy_t2r: The tiled copy operation for tmem to register copy(t2r) + :type tiled_copy_t2r: cute.TiledCopy + :param tTR_rD: The partitioned accumulator tensor + :type tTR_rD: cute.Tensor + :param tidx: The thread index in epilogue warp groups + :type tidx: Int32 + :param sD: The shared memory tensor to be copied and partitioned + :type sD: cute.Tensor + :type sepi: cute.Tensor + + :return: A tuple containing (tiled_copy_r2s, tRS_rD, tRS_sD) where: + - tiled_copy_r2s: The tiled copy operation for register to smem copy(r2s) + - tRS_rD: The partitioned tensor C (register source) + - tRS_sD: The partitioned tensor C (smem destination) + :rtype: Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor] + """ + copy_atom_r2s = sm100_utils.get_smem_store_op( + d_layout if d_layout is not None else LayoutEnum.ROW_MAJOR, + dtype if dtype is not None else cutlass.BFloat16, + self.acc_dtype, + tiled_copy_t2r, + ) + tiled_copy_r2s = cute.make_tiled_copy_D(copy_atom_r2s, tiled_copy_t2r) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None + # (R2S, R2S_M, R2S_N) + tRS_rD = tiled_copy_r2s.retile(tTR_rD) + return tiled_copy_r2s, tRS_rD, tRS_sD + + def epilog_smem_load_and_partition( + self, + tiled_copy_t2r: cute.TiledCopy, + c_layout: LayoutEnum, + dtype: Type[cutlass.Numeric], + # tTR_rC: cute.Tensor, + sC: cute.Tensor, + tRS_rD_layout: cutlass.Layout, + tidx: Int32, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + copy_atom_r2s = sm100_utils.get_smem_store_op( + c_layout, dtype, self.acc_dtype, tiled_copy_t2r + ) + store_op = copy_atom_r2s.op + # m8n8 16-bit path + if isinstance(store_op, StMatrix8x8x16bOp): + op = LdMatrix8x8x16bOp(num_matrices=store_op.num_matrices, transpose=store_op.transpose) + # m16n8 8-bit store -> m16n16 8-bit load + elif isinstance(store_op, StMatrix16x8x8bOp) and store_op.num_matrices in [2, 4]: + # transpose=True is enforced by the class + op = LdMatrix16x16x8bOp(num_matrices=store_op.num_matrices // 2) + else: + op = cute.nvgpu.CopyUniversalOp() + copy_atom_s2r = cute.make_copy_atom(op, dtype) + tiled_copy_s2r = cute.make_tiled_copy_D(copy_atom_s2r, tiled_copy_t2r) + thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) + # (R2S, R2S_M, R2S_N, PIPE_D) + tSR_sC = thr_copy_s2r.partition_S(sC) + tRS_rC = cute.make_fragment(tRS_rD_layout, dtype) + # (R2S, R2S_M, R2S_N) + tSR_rC = tiled_copy_s2r.retile(tRS_rC) + return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC + + @cute.jit + def make_ab_pipeline( + self, + tiled_mma: cute.TiledMma, + cluster_layout_vmnk: cute.Layout, + ab_pipeline_mbar_ptr: cute.Pointer, + is_leader_cta: Boolean, + ) -> pipeline.PipelineAsync: + # If gather_A and use_2cta_instrs, the cp.async for the non-leader CTA will + # arrive at an mbarrier on the non-leader CTA side, then the mma warp of the non-leader + # CTA will wait for that then arrive at the mbarrier on the leader CTA. + # The producer count for the leader CTA is 1 (TMA) + num_cpasync_threads + # + 1 (from non-leader CTA). + # The producer count for the non-leader CTA is num_cpasync_threads + # (TMA doesn't arrive there). + if const_expr(not self.gather_A): + producer_cnt = 1 + else: + producer_cnt = (self.num_ab_load_warps - 1) * 32 + ( + 1 if const_expr(not self.use_2cta_instrs) else 2 + ) + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + # Each warp will contribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = mcast_size + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + if const_expr(not self.gather_A): + pipeline_ab = pipeline.PipelineTmaUmma.create( + barrier_storage=ab_pipeline_mbar_ptr, + num_stages=self.ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + else: + pipeline_ab = PipelineTmaCpAsyncUmma.create( + barrier_storage=ab_pipeline_mbar_ptr, + num_stages=self.ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + producer_drop_count=None + if not self.use_2cta_instrs + else (2 if not is_leader_cta else 0), + ) + return pipeline_ab + + def make_acc_pipeline( + self, cluster_layout_vmnk: cute.Layout, acc_pipeline_mbar_ptr: cute.Pointer + ) -> pipeline.PipelineAsync: + acc_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + num_acc_consumer_threads = self.num_epi_warps * (2 if self.use_2cta_instrs else 1) + acc_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, num_acc_consumer_threads + ) + return pipeline.PipelineUmmaAsync.create( + barrier_storage=acc_pipeline_mbar_ptr, + num_stages=self.num_acc_stage, + producer_group=acc_pipeline_producer_group, + consumer_group=acc_pipeline_consumer_group, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_sched_pipeline( + self, + cluster_layout_mnk: cute.Layout, + sched_pipeline_mbar_ptr: cute.Pointer, + has_C: bool = False, + ) -> pipeline.PipelineAsync: + # Threads/warps participating in this pipeline + sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(cluster_layout_mnk) + # Each warp that are not the scheduler warp will contribute 1 to the arrive count + warps_per_cta = self.num_ab_load_warps + len( + (self.mma_warp_id, *self.epilog_warp_id, self.scheduler_warp_id) + ) + if has_C: + warps_per_cta += 1 + consumer_arrive_cnt = warps_per_cta * cluster_size - 1 + sched_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + return pipeline.PipelineAsync.create( + barrier_storage=sched_pipeline_mbar_ptr, + num_stages=self.sched_stage, + producer_group=sched_pipeline_producer_group, + consumer_group=sched_pipeline_consumer_group, + # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. + consumer_mask=None if const_expr(cluster_size == 1) else 0, + ) + + @cute.jit + def make_a_prefetch_pipeline( + self, a_prefetch_pipeline_mbar_ptr: cute.Pointer + ) -> pipeline.PipelineAsync: + producer_cnt = 32 + a_prefetch_producer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, producer_cnt, alignment=producer_cnt + ) + consumer_arrive_cnt = self.num_ab_load_warps - 1 + a_prefetch_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + return pipeline.PipelineCpAsync.create( + barrier_storage=a_prefetch_pipeline_mbar_ptr, + num_stages=self.a_prefetch_stage, + producer_group=a_prefetch_producer_group, + consumer_group=a_prefetch_consumer_group, + ) + + @classmethod + def _compute_stages( + cls, + tiled_mma: cute.TiledMma, + mma_tiler_mnk: Tuple[int, int, int], + cta_tile_shape_mnk: Tuple[int, int, int], + epi_tile: cute.Tile, + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + sf_dtype: Optional[Type[cutlass.Numeric]], + sf_vec_size: Optional[int], + d_dtype: Optional[Type[cutlass.Numeric]], + c_dtype: Optional[Type[cutlass.Numeric]], + d_layout: Optional[LayoutEnum], + c_layout: Optional[LayoutEnum], + epilogue_args: EpilogueArguments, + prefetch_A_idx: Literal[None, "varlen_m", "varlen_k"], + smem_capacity: int, + occupancy: int, + ) -> Tuple[int, int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The shape (M, N, K) of the MMA tiler. + :type mma_tiler_mnk: tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param epi_tile: The epilogue tile shape. + :type epi_tile: cute.Tile + :param d_dtype: Data type of operand C (output). + :type d_dtype: type[cutlass.Numeric] + :param d_layout: Layout enum of operand D. + :type d_layout: LayoutEnum + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (ACC stages, A/B operand stages, C stages) + :rtype: tuple[int, int, int] + """ + blockscaled = sf_dtype is not None + # Default ACC stages + if const_expr(not blockscaled): + num_acc_stage = 2 + else: + num_acc_stage = 1 if mma_tiler_mnk[1] == 256 else 2 + + # Default D stages + epi_stage = 4 if cute.size(epi_tile[1]) <= 16 else 2 + epi_c_stage = 0 if c_dtype is None else (4 if cute.size(epi_tile[1]) <= 16 else 2) + + # Calculate smem layout and size for one stage of A, B, and C + a_smem_layout_staged_one = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a_dtype, + 1, # a tmp 1 stage is provided + ) + b_smem_layout_staged_one = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b_dtype, + 1, # a tmp 1 stage is provided + ) + d_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi(d_dtype, d_layout, epi_tile, 1) + if d_dtype is not None + else None + ) + c_smem_layout_staged_one = ( + sm100_utils.make_smem_layout_epi(c_dtype, c_layout, epi_tile, 1) + if c_dtype is not None + else None + ) + if const_expr(blockscaled): + sfa_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfa( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + sfb_smem_layout_staged_one = blockscaled_utils.make_smem_layout_sfb( + tiled_mma, + mma_tiler_mnk, + sf_vec_size, + 1, # a tmp 1 stage is provided + ) + + ab_bytes_per_stage = cute.size_in_bytes( + a_dtype, a_smem_layout_staged_one + ) + cute.size_in_bytes(b_dtype, b_smem_layout_staged_one) + if const_expr(prefetch_A_idx == "varlen_k"): # Need smem to prefetch A indices + ab_bytes_per_stage += Int32.width // 8 * cta_tile_shape_mnk[2] + if const_expr(blockscaled): + ab_bytes_per_stage += cute.size_in_bytes( + sf_dtype, sfa_smem_layout_staged_one + ) + cute.size_in_bytes(sf_dtype, sfb_smem_layout_staged_one) + mbar_helpers_bytes = 1024 + if const_expr(prefetch_A_idx == "varlen_m"): + mbar_helpers_bytes += Int32.width // 8 * cta_tile_shape_mnk[0] * 2 + d_bytes_per_stage = ( + cute.size_in_bytes(d_dtype, d_smem_layout_staged_one) if d_dtype is not None else 0 + ) + epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( + epilogue_args, cta_tile_shape_mnk, epi_tile + ) + epi_bytes = epi_bytes_per_stage * epi_stage + if const_expr(c_dtype is not None): + c_bytes_per_stage = cute.size_in_bytes(c_dtype, c_smem_layout_staged_one) + epi_bytes += c_bytes_per_stage * epi_c_stage + + # Calculate A/B/SFA/SFB stages: + # Start with total smem per CTA (capacity / occupancy) + # Subtract reserved bytes and initial C stages bytes + # Divide remaining by bytes needed per A/B/SFA/SFB stage + remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes + ab_stage = remaining_bytes // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // (epi_bytes_per_stage) + return num_acc_stage, ab_stage, epi_stage, epi_c_stage + + @staticmethod + def _compute_num_tmem_alloc_cols( + tiled_mma: cute.TiledMma, + mma_tiler: Tuple[int, int, int], + num_acc_stage: int, + ) -> int: + """ + Compute the number of tensor memory allocation columns. + + :param tiled_mma: The tiled MMA object defining the core computation. + :type tiled_mma: cute.TiledMma + :param mma_tiler: The shape (M, N, K) of the MMA tile. + :type mma_tiler: tuple[int, int, int] + :param num_acc_stage: The stage of the accumulator tensor. + :type num_acc_stage: int + + :return: The number of tensor memory allocation columns. + :rtype: int + """ + acc_shape = tiled_mma.partition_shape_C(mma_tiler[:2]) + tCtAcc_fake = tiled_mma.make_fragment_C(cute.append(acc_shape, num_acc_stage)) + num_tmem_alloc_cols = cutlass.utils.get_num_tmem_alloc_cols(tCtAcc_fake) + return num_tmem_alloc_cols + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + d_dtype: Optional[Type[cutlass.Numeric]], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if b_dtype != a_dtype: + is_valid = False + ab_dtype = a_dtype + if ab_dtype not in { + cutlass.Float16, + cutlass.BFloat16, + cutlass.TFloat32, + cutlass.Uint8, + cutlass.Int8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if ( + acc_dtype not in {Float32, cutlass.Float16, Int32} + or acc_dtype == cutlass.Float16 + and ab_dtype not in {cutlass.Float16, cutlass.Float8E4M3FN, cutlass.Float8E5M2} + or acc_dtype == Int32 + and ab_dtype not in {cutlass.Uint8, cutlass.Int8} + ): + is_valid = False + if d_dtype is not None and ( + acc_dtype == Float32 + and d_dtype + not in { + Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + Int32, + cutlass.Int8, + cutlass.Uint8, + } + or acc_dtype == cutlass.Float16 + and d_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + } + or acc_dtype == Int32 + and d_dtype + not in { + cutlass.BFloat16, + cutlass.Float16, + Float32, + Int32, + cutlass.Int8, + cutlass.Uint8, + } + ): + is_valid = False + if ab_dtype is cutlass.Float4E2M1FN and not (a_major == "k" and b_major == "k"): + is_valid = False + return is_valid + + @staticmethod + def is_valid_dtypes_and_scale_factor_vec_size( + ab_dtype: Type[cutlass.Numeric], + sf_dtype: Type[cutlass.Numeric], + sf_vec_size: int, + d_dtype: Type[cutlass.Numeric], + ) -> bool: + """ + Check if the dtypes and sf_vec_size are valid combinations + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param sf_dtype: The data type of the scale factor + :type sf_dtype: Type[cutlass.Numeric] + :param sf_vec_size: The vector size of the scale factor + :type sf_vec_size: int + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + + :return: True if the dtypes and sf_vec_size are valid, False otherwise + :rtype: bool + """ + is_valid = True + + # Check valid ab_dtype + if ab_dtype not in {cutlass.Float4E2M1FN, cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_vec_size + if sf_vec_size not in {16, 32}: + is_valid = False + + # Check valid sf_dtype + if sf_dtype not in {cutlass.Float8E8M0FNU, cutlass.Float8E4M3FN}: + is_valid = False + + # Check valid sf_dtype and sf_vec_size combinations + if sf_dtype == cutlass.Float8E4M3FN and sf_vec_size == 32: + is_valid = False + if ab_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} and sf_vec_size == 16: + is_valid = False + + # Check valid d_dtype + if d_dtype not in { + Float32, + cutlass.Float16, + cutlass.BFloat16, + cutlass.Float8E5M2, + cutlass.Float8E4M3FN, + }: + is_valid = False + + return is_valid + + @staticmethod + def is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + blockscaled: bool, + ) -> bool: + """ + Check if the mma tiler and cluster shape are valid + + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + + :return: True if the mma tiler and cluster shape are valid, False otherwise + :rtype: bool + """ + is_valid = True + # Skip invalid mma tile shape + if mma_tiler_mn[0] not in [64, 128, 256]: + is_valid = False + if not blockscaled: + if mma_tiler_mn[1] not in range(32, 257, 32): + is_valid = False + else: + if mma_tiler_mn[1] not in [128, 256]: + is_valid = False + # Skip invalid cluster shape + is_power_of_2 = lambda x: x > 0 and (x & (x - 1)) == 0 + if ( + cluster_shape_mn[0] * cluster_shape_mn[1] > 16 + or cluster_shape_mn[0] <= 0 + or cluster_shape_mn[1] <= 0 + or not is_power_of_2(cluster_shape_mn[0]) + or not is_power_of_2(cluster_shape_mn[1]) + ): + is_valid = False + if blockscaled: + # Special cluster shape check for scale factor multicasts. + # Due to limited size of scale factors, we can't multicast among more than 4 CTAs. + if cluster_shape_mn[0] > 4 or cluster_shape_mn[1] > 4: + is_valid = False + return is_valid + + @staticmethod + def is_valid_tensor_alignment( + m: int, + n: int, + k: int, + l: int, + ab_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + """ + Check if the tensor alignment is valid + + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param d_major: The major axis of the C tensor + :type d_major: str + + :return: True if the problem shape is valid, False otherwise + :rtype: bool + """ + is_valid = True + + def check_contigous_16B_alignment(dtype, is_mode0_major, tensor_shape): + major_mode_idx = 0 if is_mode0_major else 1 + num_major_elements = tensor_shape[major_mode_idx] + num_contiguous_elements = 16 * 8 // dtype.width + return num_major_elements % num_contiguous_elements == 0 + + if ( + not check_contigous_16B_alignment(ab_dtype, a_major == "m", (m, k, l)) + or not check_contigous_16B_alignment(ab_dtype, b_major == "n", (n, k, l)) + or not check_contigous_16B_alignment(d_dtype, d_major == "m", (m, n, l)) + ): + is_valid = False + return is_valid + + @staticmethod + def can_implement( + ab_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + mma_tiler_mn: Tuple[int, int], + cluster_shape_mn: Tuple[int, int], + m: int, + n: int, + k: int, + l: int, + a_major: str, + b_major: str, + d_major: str, + ) -> bool: + """ + Check if the gemm can be implemented + + :param ab_dtype: The data type of the A and B operands + :type ab_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param mma_tiler_mn: The (M, N) shape of the MMA instruction tiler + :type mma_tiler_mn: Tuple[int, int] + :param cluster_shape_mn: The (ClusterM, ClusterN) shape of the CTA cluster + :type cluster_shape_mn: Tuple[int, int] + :param m: The number of rows in the A tensor + :type m: int + :param n: The number of columns in the B tensor + :type n: int + :param k: The number of columns in the A tensor + :type k: int + :param l: The number of columns in the C tensor + :type l: int + :param a_major: The major axis of the A tensor + :type a_major: str + :param b_major: The major axis of the B tensor + :type b_major: str + :param d_major: The major axis of the C tensor + :type d_major: str + + :return: True if the gemm can be implemented, False otherwise + :rtype: bool + """ + can_implement = True + # Skip unsupported types + if not GemmSm100.is_valid_dtypes(ab_dtype, ab_dtype, acc_dtype, d_dtype, a_major, b_major): + can_implement = False + # Skip invalid mma tile shape and cluster shape + if not GemmSm100.is_valid_mma_tiler_and_cluster_shape( + mma_tiler_mn, cluster_shape_mn, blockscaled=False + ): + can_implement = False + # Skip illegal problem shape for load/store alignment + if not GemmSm100.is_valid_tensor_alignment( + m, n, k, l, ab_dtype, d_dtype, a_major, b_major, d_major + ): + can_implement = False + return can_implement + + +def run( + mnkl: Tuple[int, int, int, int], + ab_dtype: Type[cutlass.Numeric], + d_dtype: Type[cutlass.Numeric], + c_dtype: Optional[Type[cutlass.Numeric]], + acc_dtype: Type[cutlass.Numeric], + a_major: str, + b_major: str, + d_major: str, + c_major: str, + mma_tiler_mn: Tuple[int, int] = (256, 256), + cluster_shape_mn: Tuple[int, int] = (2, 1), + tolerance: float = 1e-01, + warmup_iterations: int = 0, + iterations: int = 1, + skip_ref_check: bool = False, + dynamic_persistent: bool = False, + **kwargs, +): + """Execute a persistent batched dense GEMM operation on Blackwell architecture with performance benchmarking. + + This function prepares input tensors, configures and launches the persistent GEMM kernel, + optionally performs reference validation, and benchmarks the execution performance. + + :param mnkl: Problem size (M, N, K, L) + :type mnkl: Tuple[int, int, int, int] + :param ab_dtype: Data type for input tensors A and B + :type ab_dtype: Type[cutlass.Numeric] + :param d_dtype: Data type for output tensor C + :type d_dtype: Type[cutlass.Numeric] + :param acc_dtype: Data type for accumulation during matrix multiplication + :type acc_dtype: Type[cutlass.Numeric] + :param a_major/b_major/d_major: Memory layout of tensor A/B/C + :type a_major/b_major/d_major: str + :param mma_tiler_mn: MMA tiling size. If not specified in the decorator parameters, the autotuner will use the + default value of (256, 256). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type mma_tiler_mn: Tuple[int, int], optional + :param cluster_shape_mn: Cluster shape. If not specified in the decorator parameters, the autotuner will use the + default value of (2, 1). Otherwise, the autotuner will use the value specified in the decorator parameters. + :type cluster_shape_mn: Tuple[int, int], optional + :param tolerance: Tolerance value for reference validation comparison, defaults to 1e-01 + :type tolerance: float, optional + :param warmup_iterations: Number of warmup iterations before benchmarking, defaults to 0 + :type warmup_iterations: int, optional + :param iterations: Number of benchmark iterations to run, defaults to 1 + :type iterations: int, optional + :param skip_ref_check: Whether to skip reference result validation, defaults to False + :type skip_ref_check: bool, optional + :raises RuntimeError: If CUDA GPU is not available + :raises ValueError: If the configuration is invalid or unsupported by the kernel + :return: Execution time of the GEMM kernel + :rtype: float + """ + print("Running Blackwell Persistent Dense GEMM test with:") + print(f"mnkl: {mnkl}") + print(f"AB dtype: {ab_dtype}, C dtype: {d_dtype}, Acc dtype: {acc_dtype}") + print(f"Matrix majors - A: {a_major}, B: {b_major}, C: {d_major}") + print(f"Mma Tiler (M, N): {mma_tiler_mn}, Cluster Shape (M, N): {cluster_shape_mn}") + print(f"Tolerance: {tolerance}") + print(f"Warmup iterations: {warmup_iterations}") + print(f"Iterations: {iterations}") + print(f"Skip reference checking: {skip_ref_check}") + + assert not dynamic_persistent, "Dynamic persistent mode is not supported yet." + + # Unpack parameters + m, n, k, l = mnkl + + # Skip unsupported testcase + if not GemmSm100.can_implement( + ab_dtype, + acc_dtype, + d_dtype, + mma_tiler_mn, + cluster_shape_mn, + m, + n, + k, + l, + a_major, + b_major, + d_major, + ): + raise TypeError( + f"Unsupported testcase {ab_dtype}, {acc_dtype}, {d_dtype}, {mma_tiler_mn}, {cluster_shape_mn}, {m}, {n}, {k}, {l}, {a_major}, {b_major}, {d_major}" + ) + + if not torch.cuda.is_available(): + raise RuntimeError("GPU is required to run this example!") + + torch.manual_seed(1111) + + # Create and permute tensor A/B/C + def create_and_permute_tensor(l, mode0, mode1, is_mode0_major, dtype, is_dynamic_layout=True): + # is_mode0_major: (l, mode1, mode0) -> (mode0, mode1, l) + # else: (l, mode0, mode1) -> (mode0, mode1, l) + shape = (l, mode1, mode0) if is_mode0_major else (l, mode0, mode1) + permute_order = (2, 1, 0) if is_mode0_major else (1, 2, 0) + is_unsigned = dtype in {cutlass.Uint8} + # Temporarily use uint8 as torch does not support fp8 type + torch_dtype = cutlass_torch.dtype(dtype) + gen_dtype = ( + torch_dtype + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch.bfloat16 + ) + + # Create dtype torch tensor (cpu) + torch_tensor_cpu = cutlass_torch.create_and_permute_torch_tensor( + shape, + gen_dtype, + permute_order=permute_order, + # init_type=cutlass.torch.TensorInitType.RANDOM, + # init_config=cutlass.torch.RandomInitConfig( + # min_val=0 if is_unsigned else -2, max_val=4 if is_unsigned else 2 + # ), + init_type=cutlass.torch.TensorInitType.GAUSSIAN, + init_config=cutlass.torch.GaussianInitConfig(std=k ** (-0.5), scale=1), + ).to(torch_dtype) + # Create dtype torch tensor (gpu) + torch_tensor = torch_tensor_cpu.cuda() + + # Create f32 torch tensor (cpu) + f32_torch_tensor = torch_tensor_cpu.to(dtype=torch.float32) + + # Create dtype cute tensor (gpu) + torch_tensor_view = ( + torch_tensor + if dtype not in {cutlass.Float8E5M2, cutlass.Float8E4M3FN} + else torch_tensor.view(torch.uint8) + ) + cute_tensor = from_dlpack(torch_tensor_view, assumed_align=16) + cute_tensor.element_type = dtype + if is_dynamic_layout: + cute_tensor = cute_tensor.mark_layout_dynamic(leading_dim=(0 if is_mode0_major else 1)) + cute_tensor = cutlass_torch.convert_cute_tensor( + f32_torch_tensor, + cute_tensor, + dtype, + is_dynamic_layout=is_dynamic_layout, + ) + + return f32_torch_tensor, cute_tensor, torch_tensor, torch_tensor_cpu + + a_ref, mA, a_torch, a_torch_cpu = create_and_permute_tensor( + l, m, k, a_major == "m", ab_dtype, is_dynamic_layout=True + ) + b_ref, mB, b_torch, b_torch_cpu = create_and_permute_tensor( + l, n, k, b_major == "n", ab_dtype, is_dynamic_layout=True + ) + _, mD, d_torch, d_torch_cpu = create_and_permute_tensor( + l, m, n, d_major == "m", d_dtype, is_dynamic_layout=True + ) + if c_dtype is not None: + c, mC, c_torch, d_torch_cpu = create_and_permute_tensor(l, m, n, c_major == "m", c_dtype) + else: + c, mC, c_torch = None, None, None + + # Configure gemm kernel + cluster_shape_mnk = (*cluster_shape_mn, 1) + gemm = GemmSm100(acc_dtype, ab_dtype, mma_tiler_mn, cluster_shape_mnk) + + # Compute max active clusters on current device + hardware_info = cutlass.utils.HardwareInfo() + max_active_clusters = hardware_info.get_max_active_clusters( + cluster_shape_mn[0] * cluster_shape_mn[1] + ) + if dynamic_persistent: + tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device="cuda") + else: + tile_count_semaphore = None + + scheduler_args = TileSchedulerOptions( + Int32(max_active_clusters), + tile_count_semaphore=make_ptr( + Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4 + ) + if tile_count_semaphore is not None + else None, + ) + epi_args = gemm.EpilogueArguments() + varlen_args = VarlenArguments() + + # Get current CUDA stream from PyTorch + torch_stream = torch.cuda.current_stream() + # Get the raw stream pointer as a CUstream + current_stream = cuda.CUstream(torch_stream.cuda_stream) + # Compile gemm kernel + compiled_gemm = cute.compile( + gemm, + mA, + mB, + mD, + mC, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + if not skip_ref_check: + compiled_gemm(mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream) + if ab_dtype in { + cutlass.Int8, + cutlass.Uint8, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + ref = torch.einsum("mkl,nkl->mnl", a_ref.cpu(), b_ref.cpu()) + else: + ref = torch.einsum("mkl,nkl->mnl", a_ref, b_ref) + if c is not None: + ref = ref + c + ref = ref.cpu() + + # Copy gpu result back + gpu_d = d_torch.cpu() + + # Convert ref to c_type + if d_dtype == Float32: + ref_d = ref + elif d_dtype in {cutlass.Float8E5M2, cutlass.Float8E4M3FN}: + # m major: (l, n, m) -> (m, n, l) + # n major: (l, m, n) -> (m, n, l) + permute_order = (1, 2, 0) if d_major == "n" else (2, 1, 0) + shape = (l, m, n) if d_major == "n" else (l, n, m) + f8_torch_tensor = cutlass_torch.create_and_permute_torch_tensor( + shape, + torch.uint8, + permute_order=permute_order, + init_type=cutlass_torch.TensorInitType.SKIP, + ).cuda() + # Create dtype cute tensor (gpu) + ref_d_tensor = from_dlpack(f8_torch_tensor, assumed_align=16).mark_layout_dynamic( + leading_dim=(1 if d_major == "n" else 0) + ) + ref_d_tensor.element_type = d_dtype + ref_d_tensor = cutlass_torch.convert_cute_tensor( + ref, + ref_d_tensor, + d_dtype, + is_dynamic_layout=True, + ) + + ref_d = f8_torch_tensor.cpu() + else: + ref_d = ref.to(cutlass_torch.dtype(d_dtype)) + + # Reference checking ref_d and gpu_d + torch.testing.assert_close(gpu_d, ref_d, atol=tolerance, rtol=1e-05) + + from triton.testing import do_bench + + current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) + + flops = 2 * m * n * k * l + + repeats = iterations + warmup = warmup_iterations + + import time + + time.sleep(0.5) + if ab_dtype.width == 8: + assert l == 1 + scale_ab = torch.ones((1,), dtype=torch.float32, device="cuda") + fn_cublas = lambda: torch._scaled_mm( + a_torch[:, :, 0], + b_torch[:, :, 0].mT, + scale_a=scale_ab, + scale_b=scale_ab, + out_dtype=torch.bfloat16, + # use_fast_accum=fp8_fast_accum, + ) + else: + if c_torch is None: + fn_cublas = lambda: torch.matmul(a_torch.permute(2, 0, 1), b_torch.permute(2, 0, 1).mT) + else: + c_torch_convert = c_torch.to(a_torch.dtype) # In case C is in FP32 + fn_cublas = lambda: torch.baddbmm( + c_torch_convert.permute(2, 0, 1), + a_torch.permute(2, 0, 1), + b_torch.permute(2, 0, 1).mT, + ) + timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats) + tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops + print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}") + + time.sleep(0.5) + fn = lambda: compiled_gemm( + mA, mB, mD, mC, epi_args, scheduler_args, varlen_args, current_stream + ) + timing = do_bench(fn, warmup=warmup, rep=repeats) + tflops = flops / (timing * 1e9) # Convert to TFlops + print(f"Cute-DSL Average time: {timing:.3f} ms, TFLOPS: {tflops:.1f}") + + # time.sleep(0.5) + # timing_cublas = do_bench(fn_cublas, warmup=warmup, rep=repeats) + # tflops_cublas = flops / (timing_cublas * 1e9) # Convert to TFlops + # print(f"CuBLAS Average time: {timing_cublas:.3f} ms, TFLOPS: {tflops_cublas:.1f}") + + +if __name__ == "__main__": + + def parse_comma_separated_ints(s: str) -> Tuple[int, ...]: + try: + return tuple(int(x.strip()) for x in s.split(",")) + except ValueError: + raise argparse.ArgumentTypeError("Invalid format. Expected comma-separated integers.") + + parser = argparse.ArgumentParser(description="Example of Dense Persistent GEMM on Blackwell.") + + parser.add_argument( + "--mnkl", + type=parse_comma_separated_ints, + default=(256, 256, 512, 1), + help="mnkl dimensions (comma-separated)", + ) + parser.add_argument( + "--mma_tiler_mn", + type=parse_comma_separated_ints, + default=(128, 128), + help="Mma tile shape (comma-separated)", + ) + parser.add_argument( + "--cluster_shape_mn", + type=parse_comma_separated_ints, + default=(1, 1), + help="Cluster shape (comma-separated)", + ) + parser.add_argument("--ab_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--d_dtype", type=cutlass.dtype, default=cutlass.BFloat16) + parser.add_argument("--c_dtype", type=cutlass.dtype, default=None) + parser.add_argument("--acc_dtype", type=cutlass.dtype, default=Float32) + parser.add_argument("--a_major", choices=["k", "m"], type=str, default="k") + parser.add_argument("--b_major", choices=["k", "n"], type=str, default="k") + parser.add_argument("--d_major", choices=["n", "m"], type=str, default="n") + parser.add_argument("--c_major", choices=["n", "m"], type=str, default="n") + + parser.add_argument("--tolerance", type=float, default=3e-02, help="Tolerance for validation") + parser.add_argument("--warmup_iterations", type=int, default=5, help="Warmup iterations") + parser.add_argument( + "--iterations", + type=int, + default=30, + help="Number of iterations to run the kernel", + ) + parser.add_argument("--skip_ref_check", action="store_true", help="Skip reference checking") + parser.add_argument( + "--dynamic_persistent", action="store_true", help="Dynamic persistent kernel" + ) + + args = parser.parse_args() + + if len(args.mnkl) != 4: + parser.error("--mnkl must contain exactly 4 values") + + if len(args.mma_tiler_mn) != 2: + parser.error("--mma_tiler_mn must contain exactly 2 values") + + if len(args.cluster_shape_mn) != 2: + parser.error("--cluster_shape_mn must contain exactly 2 values") + + run( + args.mnkl, + args.ab_dtype, + args.d_dtype, + args.c_dtype, + args.acc_dtype, + args.a_major, + args.b_major, + args.d_major, + args.c_major, + args.mma_tiler_mn, + args.cluster_shape_mn, + args.tolerance, + args.warmup_iterations, + args.iterations, + args.skip_ref_check, + args.dynamic_persistent, + ) + print("PASS") diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm90.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm90.py new file mode 100644 index 00000000..e5e132af --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_sm90.py @@ -0,0 +1,2070 @@ +# Based on the cute-dsl example: +# https://github.com/NVIDIA/cutlass/blob/main/examples/python/CuTeDSL/hopper/dense_gemm.py + +import enum +from typing import Tuple, Type, Callable, Optional, Union, Literal +from functools import partial +import math + + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, warp, warpgroup +import cutlass.utils.hopper_helpers as sm90_utils +from cutlass import Int32, Float32, Float16, Boolean, const_expr +from cutlass.cutlass_dsl import if_generate +from cutlass.utils import LayoutEnum + + +from .cute_dsl_utils import ParamsBase, ArgumentsBase +from .tile_scheduler import ( + TileSchedulerOptions, + TileSchedulerArguments, + TileScheduler, + VarlenMTileSchedulerArguments, + VarlenMTileScheduler, +) +from .varlen_utils import VarlenArguments, VarlenManager + +# return PipelineStateWAdvance instead of PipelineState +from .pipeline import make_pipeline_state, PipelineTmaCpAsync +from . import copy_utils as copy_utils +from . import sm90_utils as quack_sm90_utils + +""" +A high-performance batched dense GEMM (C = A * B) example for the NVIDIA Hopper architecture +using CUTE DSL. +- Matrix A is MxKxL, L is batch dimension, A can be row-major("K") or column-major("M") +- Matrix B is NxKxL, L is batch dimension, B can be row-major("N") or column-major("K") +- Matrix C is MxNxL, L is batch dimension, C can be row-major("N") or column-major("M") + +This GEMM kernel supports the following features: + - Utilizes Tensor Memory Access (TMA) for efficient memory operations + - Utilizes Hopper's WGMMA for matrix multiply-accumulate (MMA) operations + - Implements TMA multicast with cluster to reduce L2 memory traffic + - Supports multi-stage pipeline to overlap computation and memory access + +This GEMM works as follows: +1. Load A and B matrices from global memory (GMEM) to shared memory (SMEM) using TMA operations. +2. Perform matrix multiply-accumulate (MMA) operations using WGMMA instruction. +3. Store results from registers (RMEM) to shared memory (SMEM), then to global memory (GMEM) with TMA operations. + +Hopper WGMMA instructions operate as follows: +- Read matrix A from SMEM +- Read matrix B from SMEM +- Perform MMA operation and store the result in Accumulator(register) + +Constraints: +* Supported input data types: fp16, fp8 (e4m3fn, e5m2) +* For fp16 types, A and B must have the same data type +* For fp8 types, A and B can have different types (e4m3fn or e5m2) but both must be 8-bit +* Fp8 types only support k-major layout +* Only fp32 accumulation is supported in this example +* CTA tile shape M must be 64/128 +* CTA tile shape N must be 64/128/256 +* CTA tile shape K must be 64 +* Cluster shape M/N must be positive and power of 2, total cluster size <= 4 +* The contiguous dimension of A/B/C tensors must be at least 16 bytes aligned, + i.e, number of elements is a multiple of 8, 16 for Float16, and Float8, respectively. +""" + + +class NamedBarrierGemm(enum.IntEnum): + Epilogue = enum.auto() # starts from 1 as barrier 0 is reserved for sync_threads() + # For mainloop load warps to signal that the epilogue load warp can start. + # This is to avoid loading C too early, interfering with loading A and B. + EpilogueLoad = enum.auto() + MmaWG0 = enum.auto() + MmaWG1 = enum.auto() + EpiWG0 = enum.auto() + EpiWG1 = enum.auto() + TmemPtr = enum.auto() + + +class GemmSm90: + """ + This class implements batched matrix multiplication (C = A x B) with support for various data types + and architectural features specific to Hopper GPUs with persistent tile scheduling and warp specialization. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int, int] + :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing + :type cluster_shape_mnk: Tuple[int, int, int] + + :note: Data type requirements: + - For 16-bit types: A and B must have the same data type + - For 8-bit types: A and B can have different types (Float8E4M3FN/Float8E5M2) as long as both are 8-bit + - Float8 types only support k-major layout + + :note: Supported data types: + - Float16 + - BFloat16 + - Float8E4M3FN/Float8E5M2 + + :note: Supported accumulation types: + - Float32 (for all floating point inputs) + + :note: Constraints: + - Cluster shape M/N must be positive and power of 2, total cluster size <= 4 + + Example: + >>> gemm = GemmSm90( + ... acc_dtype=Float32, + ... tile_shape_mn=(128, 256), + ... cluster_shape_mnk=(1, 1, 1) + ... ) + >>> gemm(a_tensor, b_tensor, c_tensor, stream) + """ + + arch = 90 + num_epi_tensormaps: int = 0 + + EpilogueArguments = ArgumentsBase + EpilogueParams = ParamsBase + + def __init__( + self, + acc_dtype: Type[cutlass.Numeric], + a_dtype: Type[cutlass.Numeric], + tile_shape_mn: Tuple[int, int], + cluster_shape_mnk: Tuple[int, int, int], + pingpong: bool = False, + is_persistent: bool = True, + fp8_fast_accum: bool = False, + gather_A: bool = False, + ): + """ + Initializes the configuration for a Hopper dense GEMM kernel. + + This configuration includes data types for operands, tile shape, cluster configuration, + and thread layout. + + :param acc_dtype: Data type for accumulation during computation + :type acc_dtype: type[cutlass.Numeric] + :param tile_shape_mn: Shape of the CTA tile (M,N) + :type tile_shape_mn: Tuple[int, int] + :param cluster_shape_mnk: Cluster dimensions (M,N,K) for parallel processing + :type cluster_shape_mnk: Tuple[int, int, int] + """ + + self.acc_dtype = acc_dtype + self.pingpong = pingpong + self.is_persistent = is_persistent + if self.pingpong: + assert self.is_persistent, "Pingpong gemm requires persistent scheduler" + self.fp8_slow_accum = not fp8_fast_accum and a_dtype.width == 8 + self.gather_A = gather_A + if gather_A: + assert cluster_shape_mnk[1] == 1, "Cluster shape N must be 1 for gather A " + + self.cluster_shape_mnk = cluster_shape_mnk + # K dimension is deferred in _setup_attributes + self.cta_tile_shape_mnk = (*tile_shape_mn, 1) + tile_M, tile_N = self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[1] + # check the cta tile shape + if not self.pingpong: + if tile_M not in [64, 128, 192, 256, 320]: + raise ValueError("CTA tile shape M must be 64/128/192/256/320") + if tile_M in [192, 320]: # special case + tile_N_max = 256 if tile_M == 192 else 160 + if not (tile_N % 32 == 0 and tile_N <= tile_N_max): + raise ValueError( + f"If tile_m == {tile_M}, CTA tile shape N must be divisible by 32 and <= {tile_N_max}" + ) + else: + if not ( + (tile_N % 16 == 0 and tile_N <= 256) or (tile_N % 32 == 0 and tile_N <= 512) + ): + raise ValueError( + "CTA tile shape N must be divisible by 16 and <= 256, or divisible by 32 and <= 512" + ) + else: + if tile_M not in [64, 128, 192]: + raise ValueError("CTA tile shape M must be 64/128/192 if pingpong") + tile_N_max = 256 if tile_M == 64 else (208 if tile_M == 128 else 128) + if not (tile_N % 16 == 0 and tile_N <= tile_N_max): + raise ValueError(f"CTA tile shape N must be divisible by 16 and <= {tile_N_max}") + + if not self.pingpong: + if tile_M == 320: # tile_M / 64 is not even so we have to split along N + atom_layout_m, atom_layout_n = 1, 2 + elif tile_M == 192: + if tile_N <= 128: + atom_layout_m, atom_layout_n = 3, 1 + else: + atom_layout_m, atom_layout_n = 1, 2 + else: + atom_layout_m = ( + self.cta_tile_shape_mnk[0] // 64 if self.cta_tile_shape_mnk[0] < 256 else 2 + ) + atom_layout_n = 1 + assert atom_layout_m in [1, 2, 3] and atom_layout_n in [1, 2] + else: + atom_layout_m, atom_layout_n = 1, 1 + self.atom_layout_mnk = (atom_layout_m, atom_layout_n, 1) + + self.num_mcast_ctas_a = self.cluster_shape_mnk[1] + if self.gather_A: + assert self.num_mcast_ctas_a == 1 + self.num_mcast_ctas_b = self.cluster_shape_mnk[0] + self.is_a_mcast = self.num_mcast_ctas_a > 1 + self.is_b_mcast = self.num_mcast_ctas_b > 1 + + self.occupancy = 1 + self.mma_warp_groups = math.prod(self.atom_layout_mnk) * (1 if not self.pingpong else 2) + if self.pingpong: + assert self.mma_warp_groups == 2 + assert self.mma_warp_groups in [1, 2, 3] + self.num_threads_per_warp_group = 128 + self.threads_per_cta = (self.mma_warp_groups + 1) * self.num_threads_per_warp_group + self.smem_capacity = cutlass.utils.get_smem_capacity_in_bytes("sm_90") + self.num_epi_warps = (self.mma_warp_groups if not self.pingpong else 1) * 4 + self.num_ab_load_warps = 1 if not self.gather_A else 4 + self.ab_load_warp_id = self.mma_warp_groups * 4 + # self.num_epi_load_threads = cute.arch.WARP_SIZE * 1 + # self.epi_load_warp_id = self.ab_load_warp_id + self.num_ab_load_warps + + regs_per_thread = math.prod(self.cta_tile_shape_mnk[:2]) // ( + math.prod(self.atom_layout_mnk) * self.num_threads_per_warp_group + ) + if self.fp8_slow_accum: + regs_per_thread *= 2 + if not self.gather_A: + if self.mma_warp_groups == 3: + self.num_regs_load, self.num_regs_mma = 32, 160 + else: + heavy_register_pressure = regs_per_thread >= 208 + self.num_regs_load, self.num_regs_mma = ( + (40, 232) if not heavy_register_pressure else (24, 240) + ) + else: + if self.mma_warp_groups == 3: + self.num_regs_load, self.num_regs_mma = 56, 152 + else: + self.num_regs_load, self.num_regs_mma = (56, 224) + + self.ab_stage = None + self.epi_stage = None + + self.a_smem_layout_staged = None + self.b_smem_layout_staged = None + self.epi_smem_layout_staged = None + self.epi_tile = None + + self.shared_storage = None + self.buffer_align_bytes = 1024 + + def _setup_attributes(self, epilogue_args: EpilogueArguments): + """Set up configurations that are dependent on GEMM inputs + + This method configures various attributes based on the input tensor properties + (data types, leading dimensions) and kernel settings: + - Configuring tiled MMA + - Computing MMA/cluster/tile shapes + - Computing cluster layout + - Computing multicast CTAs for A/B + - Computing epilogue subtile + - Setting up A/B/C stage counts in shared memory + - Computing A/B/C shared memory layout + """ + + self.tiled_mma = sm90_utils.make_trivial_tiled_mma( + self.a_dtype, + self.b_dtype, + self.a_layout.sm90_mma_major_mode(), + self.b_layout.sm90_mma_major_mode(), + self.acc_dtype, + self.atom_layout_mnk, + tiler_mn=(64, self.cta_tile_shape_mnk[1] // self.atom_layout_mnk[1]), + ) + if const_expr(self.atom_layout_mnk[1] > 1): + # If N dimension is split among 2 WGs, we need to permute the N dimension so + # that in the epilogue, WG0 and WG1 can write to epi smem of size e.g. (64, 32) + # containing accumulators that are next to each other in the N dimension. + # Without permutation WG0 would write to epi smem of size (64, 16) and + # WG1 would write to a separate epi smem of size (64, 16) that's far away. + atom_n = self.atom_layout_mnk[1] + permutation_n = cute.make_ordered_layout( + (8, self.cta_tile_shape_mnk[1] // atom_n // 8, atom_n), order=(0, 2, 1) + ) + self.tiled_mma = cute.make_tiled_mma( + cute.make_mma_atom(self.tiled_mma.op), + self.atom_layout_mnk, + permutation_mnk=(None, permutation_n, None), + ) + mma_inst_shape_k = cute.size(self.tiled_mma.shape_mnk, mode=[2]) + mma_inst_tile_k = 4 + self.cta_tile_shape_mnk = ( + self.cta_tile_shape_mnk[0], + self.cta_tile_shape_mnk[1], + mma_inst_shape_k * mma_inst_tile_k, + ) + + self.cluster_layout_mnk = cute.make_layout(self.cluster_shape_mnk) + + self.epi_tile = self._sm90_compute_tile_shape_or_override( + self.cta_tile_shape_mnk, + self.atom_layout_mnk, + self.d_dtype, + ) + + # Compute stage before compute smem layout + self.ab_stage, self.epi_stage, self.epi_c_stage = self._compute_stages( + self.cta_tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.b_dtype, + self.d_dtype, + self.c_dtype, + epilogue_args, + cutlass.utils.get_smem_capacity_in_bytes(f"sm_{self.arch}"), # smem_capacity + self.occupancy, + # epi_smem will reuse smem ab if not persistent. + overlap_sD_sA=not self.is_persistent, + ) + self.sched_stage = 2 if self.pingpong else 1 + + ( + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_c_smem_layout_staged, + ) = self._make_smem_layouts( + self.cta_tile_shape_mnk, + self.epi_tile, + self.a_dtype, + self.a_layout, + self.b_dtype, + self.b_layout, + self.ab_stage, + self.d_dtype, + self.d_layout, + self.epi_stage, + self.c_dtype, + self.c_layout, + self.epi_c_stage, + ) + + @cute.jit + def __call__( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: Optional[cute.Tensor], + mC: Optional[cute.Tensor], + epilogue_args: ArgumentsBase, + scheduler_args: TileSchedulerOptions, + varlen_args: Optional[VarlenArguments], + stream: cuda.CUstream, + ): + """Execute the GEMM operation in steps: + - Setup static attributes + - Setup TMA load/store atoms and tensors + - Compute grid size + - Define shared storage for kernel + - Launch the kernel synchronously + + :param mA: Input tensor A + :type mA: cute.Tensor + :param mB: Input tensor B + :type mB: cute.Tensor + :param mD: Output tensor D + :type mD: cute.Tensor + :param stream: CUDA stream for asynchronous execution + :type stream: cuda.CUstream + """ + + # setup static attributes before smem/grid/tma computation + self.a_dtype = mA.element_type + self.b_dtype = mB.element_type + self.d_dtype = mD.element_type if mD is not None else None + self.c_dtype = mC.element_type if mC is not None else None + self.a_layout = LayoutEnum.from_tensor(mA) + self.b_layout = LayoutEnum.from_tensor(mB) + self.d_layout = LayoutEnum.from_tensor(mD) if mD is not None else None + self.c_layout = LayoutEnum.from_tensor(mC) if mC is not None else None + + if const_expr(self.a_dtype.width == 16 and self.a_dtype != self.b_dtype): + raise TypeError(f"Type mismatch: {self.a_dtype} != {self.b_dtype}") + if const_expr(self.a_dtype.width != self.b_dtype.width): + raise TypeError(f"Type width mismatch: {self.a_dtype.width} != {self.b_dtype.width}") + if const_expr(self.a_dtype.width != 16 and self.a_dtype.width != 8): + raise TypeError("a_dtype should be float16 or float8") + + if const_expr(varlen_args is None): + varlen_args = VarlenArguments() + assert (varlen_args.mAIdx is not None) == self.gather_A + + # Assume all strides are divisible by 128 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=128 // t.element_type.width) if not cute.is_static(s) else s + for s in t.stride + ) + mA, mD = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) + if t is not None + else None + for t in (mA, mD) + ] + + self._setup_attributes(epilogue_args) + + a_smem_layout = cute.slice_(self.a_smem_layout_staged, (None, None, 0)) + b_smem_layout = cute.slice_(self.b_smem_layout_staged, (None, None, 0)) + tma_atom_a, tma_tensor_a = None, None + if const_expr(not self.gather_A): + tma_atom_a, tma_tensor_a = self._make_tma_atoms_and_tensors( + mA, + a_smem_layout, + (self.cta_tile_shape_mnk[0], self.cta_tile_shape_mnk[2]), + self.cluster_shape_mnk[1], + ) + tma_atom_b, tma_tensor_b = self._make_tma_atoms_and_tensors( + mB, + b_smem_layout, + (self.cta_tile_shape_mnk[1], self.cta_tile_shape_mnk[2]), + self.cluster_shape_mnk[0], + ) + + self.num_tma_load_bytes = cute.size_in_bytes(self.b_dtype, b_smem_layout) + if const_expr(not self.gather_A): + self.num_tma_load_bytes += cute.size_in_bytes(self.a_dtype, a_smem_layout) + + tma_atom_d, tma_tensor_d = None, None + if const_expr(mD is not None): + tma_atom_d, tma_tensor_d = self._make_tma_epi_atoms_and_tensors( + mD, + self.epi_smem_layout_staged, + self.epi_tile, + op_type="store" + if not (hasattr(epilogue_args, "add_to_output") and epilogue_args.add_to_output) + else "add", + ) + tma_atom_c, tma_tensor_c = None, None + if const_expr(mC is not None): + tma_atom_c, tma_tensor_c = self._make_tma_epi_atoms_and_tensors( + mC, self.epi_c_smem_layout_staged, self.epi_tile, op_type="load" + ) + + epilogue_params = self.epi_to_underlying_arguments(epilogue_args) + varlen_params = VarlenManager.to_underlying_arguments(varlen_args) + + TileSchedulerCls = self.get_scheduler_class(varlen_m=varlen_args.mCuSeqlensM is not None) + tile_sched_args = self.get_scheduler_arguments(mA, mB, mD, scheduler_args, varlen_args) + tile_sched_params = TileSchedulerCls.to_underlying_arguments(tile_sched_args) + grid = TileSchedulerCls.get_grid_shape( + tile_sched_params, scheduler_args.max_active_clusters + ) + + epi_smem_size = ( + cute.cosize(self.epi_smem_layout_staged) if self.is_persistent and mD is not None else 0 + ) + epi_c_smem_size = cute.cosize(self.epi_c_smem_layout_staged) if mC is not None else 0 + + @cute.struct + class SharedStorage: + ab_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.ab_stage * 2] + epi_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.epi_c_stage * 2] + sched_pipeline_array_ptr: cute.struct.MemRange[cutlass.Int64, self.sched_stage * 2] + tile_count: cute.struct.MemRange[Int32, self.sched_stage] + sD: cute.struct.Align[ + cute.struct.MemRange[ + self.d_dtype if self.d_dtype is not None else Int32, epi_smem_size + ], + self.buffer_align_bytes, + ] + sC: cute.struct.Align[ + cute.struct.MemRange[ + self.c_dtype if self.c_dtype is not None else Int32, epi_c_smem_size + ], + self.buffer_align_bytes, + ] + epi: self.epi_get_smem_struct(epilogue_params) + sA: cute.struct.Align[ + cute.struct.MemRange[self.a_dtype, cute.cosize(self.a_smem_layout_staged)], + self.buffer_align_bytes, + ] + sB: cute.struct.Align[ + cute.struct.MemRange[self.b_dtype, cute.cosize(self.b_smem_layout_staged)], + self.buffer_align_bytes, + ] + + self.shared_storage = SharedStorage + + # Launch the kernel synchronously + self.kernel( + self.tiled_mma, + tma_atom_a, + tma_tensor_a if const_expr(not self.gather_A) else mA, + tma_atom_b, + tma_tensor_b, + tma_atom_d, + tma_tensor_d, + tma_atom_c, + tma_tensor_c, + epilogue_params, + varlen_params, + self.cluster_layout_mnk, + self.a_smem_layout_staged, + self.b_smem_layout_staged, + self.epi_smem_layout_staged, + self.epi_c_smem_layout_staged, + tile_sched_params, + TileSchedulerCls, + ).launch( + grid=grid, + block=[self.threads_per_cta, 1, 1], + cluster=self.cluster_shape_mnk, + stream=stream, + min_blocks_per_mp=1, + ) + return + + # GPU device kernel + @cute.kernel + def kernel( + self, + tiled_mma: cute.TiledMma, + tma_atom_a: Optional[cute.CopyAtom], + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + tma_atom_d: Optional[cute.CopyAtom], + mD_mnl: Optional[cute.Tensor], + tma_atom_c: Optional[cute.CopyAtom], + mC_mnl: Optional[cute.Tensor], + epilogue_params: ParamsBase, + varlen_params: VarlenManager.Params, + cluster_layout_mnk: cute.Layout, + a_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, + epi_smem_layout: cute.ComposedLayout, + epi_c_smem_layout: cute.ComposedLayout, + tile_sched_params: ParamsBase, + TileSchedulerCls: cutlass.Constexpr[Callable], + ): + """ + GPU device kernel performing the batched GEMM computation. + + :param tma_atom_a: TMA copy atom for A tensor + :type tma_atom_a: cute.CopyAtom + :param mA_mkl: Input tensor A + :type mA_mkl: cute.Tensor + :param tma_atom_b: TMA copy atom for B tensor + :type tma_atom_b: cute.CopyAtom + :param mB_nkl: Input tensor B + :type mB_nkl: cute.Tensor + :param tma_atom_d: TMA copy atom for D tensor + :type tma_atom_d: cute.CopyAtom + :param mD_mnl: Output tensor D + :type mD_mnl: cute.Tensor + :param tiled_mma: Tiled MMA object + :type tiled_mma: cute.TiledMma + :param cluster_layout_mnk: CTA layout + :type cluster_layout_mnk: cute.Layout + :param a_smem_layout: Shared memory layout for A + :type a_smem_layout: cute.ComposedLayout + :param b_smem_layout: Shared memory layout for B + :type b_smem_layout: cute.ComposedLayout + :param epi_smem_layout: Shared memory layout for epilogue + :type epi_smem_layout: cute.ComposedLayout + """ + + varlen_m = const_expr(varlen_params.cu_seqlens_m is not None) + varlen_k = const_expr(varlen_params.cu_seqlens_k is not None) + assert not (varlen_m and varlen_k) + if const_expr(self.gather_A): + assert varlen_m or varlen_k + has_D = const_expr(mD_mnl is not None) + has_C = const_expr(mC_mnl is not None) + + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + + # ///////////////////////////////////////////////////////////////////////////// + # Prefetch Tma desc + # ///////////////////////////////////////////////////////////////////////////// + if warp_idx == self.ab_load_warp_id: + for tma_atom in (tma_atom_a, tma_atom_b, tma_atom_d, tma_atom_c): + if const_expr(tma_atom is not None): + cpasync.prefetch_descriptor(tma_atom) + + # ///////////////////////////////////////////////////////////////////////////// + # Alloc and init AB full/empty + ACC full mbar (pipeline) + # ///////////////////////////////////////////////////////////////////////////// + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(self.shared_storage) + + ab_pipeline = self.make_ab_pipeline( + tiled_mma=tiled_mma, + cluster_layout_vmnk=cute.make_layout((1, *cluster_layout_mnk.shape)), + ab_pipeline_mbar_ptr=storage.ab_pipeline_array_ptr.data_ptr(), + ) + epi_pipeline = None + if const_expr(has_C): + epi_pipeline = self.make_epi_pipeline( + c_smem_layout=cute.slice_(epi_c_smem_layout, (None, None, 0)), + epi_pipeline_mbar_ptr=storage.epi_pipeline_array_ptr.data_ptr(), + ) + sched_pipeline = None + tile_count = None + if const_expr(tile_sched_params.tile_count_semaphore is not None): + # Dynamic persistent scheduler + sched_pipeline = self.make_sched_pipeline( + cluster_layout_mnk, + sched_pipeline_mbar_ptr=storage.sched_pipeline_array_ptr.data_ptr(), + varlen_k=varlen_k, + ) + tile_count = storage.tile_count.get_tensor((self.sched_stage,)) + + # /////////////////////////////////////////////////////////////////////////////// + # Generate smem tensor A/B + # /////////////////////////////////////////////////////////////////////////////// + sA = storage.sA.get_tensor(a_smem_layout.outer, swizzle=a_smem_layout.inner) + sB = storage.sB.get_tensor(b_smem_layout.outer, swizzle=b_smem_layout.inner) + sD = None + if const_expr(has_D): + if const_expr(not self.is_persistent): + sD_ptr = cute.recast_ptr(sA.iterator, epi_smem_layout.inner, dtype=self.d_dtype) + sD = cute.make_tensor(sD_ptr, epi_smem_layout.outer) + else: + sD = storage.sD.get_tensor(epi_smem_layout.outer, swizzle=epi_smem_layout.inner) + sC = None + if const_expr(has_C): + sC = storage.sC.get_tensor(epi_c_smem_layout.outer, swizzle=epi_c_smem_layout.inner) + epi_smem_tensors = self.epi_get_smem_tensors(epilogue_params, storage) + + varlen_manager = VarlenManager.create( + varlen_params, + has_D, + self.num_epi_tensormaps, + # Only used if not varlen_m + len_m_static=Int32( + mA_mkl.shape[0] + if varlen_k or varlen_params.mAIdx is None + else varlen_params.mAIdx.shape[0] + ), + len_k_static=Int32(mA_mkl.shape[1]), + pingpong=self.pingpong, + warp_idx=warp_idx, + ) + + TileSchedulerCls = partial( + TileSchedulerCls.create, tile_sched_params, tile_count, sched_pipeline + ) + + if warp_idx >= self.ab_load_warp_id: + cute.arch.warpgroup_reg_dealloc(self.num_regs_load) + if ( + warp_idx >= self.ab_load_warp_id + and warp_idx < self.ab_load_warp_id + self.num_ab_load_warps + ): + is_tma_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id + # initialize tensormap for A & B + varlen_manager.init_tensormap_AB(tma_atom_a, tma_atom_b, is_tma_warp) + tma_desc_a_ptr = varlen_manager.get_tma_desc_a_ptr() + tma_desc_b_ptr = varlen_manager.get_tma_desc_b_ptr() + # /////////////////////////////////////////////////////////////////////////////// + # Get mcast mask + # /////////////////////////////////////////////////////////////////////////////// + cta_rank_in_cluster = cute.arch.make_warp_uniform(cute.arch.block_idx_in_cluster()) + block_in_cluster_coord_mnk = cluster_layout_mnk.get_flat_coord(cta_rank_in_cluster) + a_mcast_mask = cute.make_layout_image_mask( + cluster_layout_mnk, block_in_cluster_coord_mnk, mode=1 + ) + b_mcast_mask = cute.make_layout_image_mask( + cluster_layout_mnk, block_in_cluster_coord_mnk, mode=0 + ) + a_mcast_mask = a_mcast_mask if self.is_a_mcast else 0 + b_mcast_mask = b_mcast_mask if self.is_b_mcast else 0 + + # Persistent tile scheduling loop + is_scheduler_warp = self.num_ab_load_warps == 1 or warp_idx == self.ab_load_warp_id + if const_expr(cute.size(cluster_layout_mnk) > 1): + is_scheduler_warp = is_scheduler_warp and cute.arch.block_idx_in_cluster() == 0 + tile_scheduler = TileSchedulerCls(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.initial_work_tile_info() + ab_producer_state = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.ab_stage + ) + if const_expr(varlen_k): + # wait tensormap initialization complete before update + varlen_manager.fence_tensormap_init() + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + varlen_manager.update_tensormap_AB( + batch_idx, + self.a_layout, + self.b_layout, + is_tma_warp, + ) + # /////////////////////////////////////////////////////////////////////////// + # Local_tile partition global tensors + # /////////////////////////////////////////////////////////////////////////// + if const_expr(not self.gather_A): + mA_mk = varlen_manager.offset_batch_A(mA_mkl, batch_idx) + # (bM, bK, RestK) + gA_mk = cute.local_tile( + mA_mk, + cute.select(self.cta_tile_shape_mnk, [0, 2]), + (tile_coord_mnkl[0], None), + ) + else: + mAIdx_mk = varlen_manager.offset_batch_AIdx(batch_idx) + if const_expr(varlen_m): + gAIdx = cute.local_tile( + mAIdx_mk, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0],) + ) + # (M, K) + mA_mk = mA_mkl + else: + assert varlen_k + # (tile_K, RestK) + gAIdx = cute.flat_divide(mAIdx_mk, (self.cta_tile_shape_mnk[2],)) + # (tile_M, K) + mA_mk = cute.local_tile( + mA_mkl, (self.cta_tile_shape_mnk[0],), (tile_coord_mnkl[0], None) + ) + # (bN, bK, RestK) + gB_nk = cute.local_tile( + varlen_manager.offset_batch_B(mB_nkl, batch_idx), + cute.select(self.cta_tile_shape_mnk, [1, 2]), + (tile_coord_mnkl[1], None), + ) + # ////////////////////////////////////////////////////////////////////////// + # Partition shared tensor for TMA load A/B + # ////////////////////////////////////////////////////////////////////////// + varlen_manager.fence_tensormap_update_AB(is_tma_warp) + len_m = varlen_manager.len_m(batch_idx) + len_k = varlen_manager.len_k(batch_idx) + # TMA load A partition_S/D + copy_A = None + if const_expr(not self.gather_A): + copy_A, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_a, + cta_coord=block_in_cluster_coord_mnk[1], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_mnk, (0, None, 0)).shape + ), + src_tensor=gA_mk, + dst_tensor=sA, + mcast_mask=a_mcast_mask, + tma_desc_ptr=tma_desc_a_ptr, + ) + else: + tiled_copy_A = self._make_gmem_tiled_copy_A( + mA_mkl.element_type, self.a_layout, self.num_ab_load_warps * 32 + ) + tidx = ( + cute.arch.thread_idx()[0] - cute.arch.WARP_SIZE * self.ab_load_warp_id + ) + thr_copy_A = tiled_copy_A.get_slice(tidx) + copy_A, prefetch_A = None, None + if const_expr(varlen_m): + copy_A = copy_utils.gather_m_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + gAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + else: + copy_A, prefetch_A = copy_utils.gather_k_get_copy_fn( + thr_copy_A, + mA_mk, + sA, + gAIdx, + limit_m=len_m - tile_coord_mnkl[0] * self.cta_tile_shape_mnk[0], + limit_k=len_k, + ) + # TMA load B partition_S/D + copy_B, _, _ = copy_utils.tma_get_copy_fn( + tma_atom_b, + cta_coord=block_in_cluster_coord_mnk[0], + cta_layout=cute.make_layout( + cute.slice_(cluster_layout_mnk, (None, 0, 0)).shape + ), + src_tensor=gB_nk, + dst_tensor=sB, + mcast_mask=b_mcast_mask, + tma_desc_ptr=tma_desc_b_ptr, + ) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + if const_expr(not self.gather_A): + ab_producer_state = self.load_AB( + ab_pipeline, ab_producer_state, copy_A, copy_B, k_tile_cnt + ) + else: + ab_producer_state = self.load_AB_gather_A( + ab_pipeline, + ab_producer_state, + copy_A, + prefetch_A, + copy_B, + k_tile_cnt, + varlen_m=varlen_m, + ) + tile_scheduler.fetch_next_work(is_scheduler_warp=is_scheduler_warp) + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + if const_expr(self.pingpong and not varlen_k): + # Need to write the tile_idx to smem for the next WG in the pingpong mode + tile_scheduler.advance_to_next_work(is_scheduler_warp=is_scheduler_warp) + ab_pipeline.producer_tail(ab_producer_state) + if is_scheduler_warp: + tile_scheduler.producer_tail() + + if warp_idx < self.ab_load_warp_id: + cute.arch.warpgroup_reg_alloc(self.num_regs_mma) + is_tma_warp = Boolean( + (not self.pingpong and warp_idx == 0) + or (self.pingpong and (warp_idx == 0 or warp_idx == 4)) + ) + varlen_manager.init_tensormap_epi( + tma_atom_d, self.epi_get_tma_atoms(epilogue_params), is_tma_warp + ) + tma_desc_d_ptr = varlen_manager.get_tma_desc_d_ptr() + tma_desc_epi_ptrs = varlen_manager.get_tma_desc_epi_ptrs() + # ////////////////////////////////////////////////////////////////////////////// + # Partition global tensor for TiledMMA_A/B/C + # ////////////////////////////////////////////////////////////////////////////// + tidx, _, _ = cute.arch.thread_idx() + warp_group_idx = cute.arch.make_warp_uniform(tidx // self.num_threads_per_warp_group) + if const_expr(self.pingpong): + tidx = tidx % self.num_threads_per_warp_group + warp_group_thread_layout = cute.make_layout( + self.mma_warp_groups if not self.pingpong else 1, + stride=self.num_threads_per_warp_group, + ) + thr_mma = tiled_mma.get_slice( + warp_group_thread_layout(warp_group_idx if not self.pingpong else 0) + ) + + # ////////////////////////////////////////////////////////////////////////////// + # Make fragments + # ////////////////////////////////////////////////////////////////////////////// + tCrA = tiled_mma.make_fragment_A(thr_mma.partition_A(sA)) + tCrB = tiled_mma.make_fragment_B(thr_mma.partition_B(sB)) + + acc_shape = tiled_mma.partition_shape_C( + cute.select(self.cta_tile_shape_mnk, mode=[0, 1]) + ) + acc = cute.make_fragment(acc_shape, self.acc_dtype) + acc_slow = None + if const_expr(self.fp8_slow_accum): + acc_slow = cute.make_fragment(acc_shape, self.acc_dtype) + + if const_expr(self.pingpong): + if warp_group_idx == 0: + # WG0 needs a start signal at the very beginning + self.pingpong_barrier_arrive(warp_group_idx=0, stage="mma") + self.pingpong_barrier_arrive(warp_group_idx=0, stage="epi") + + k_tile_cnt_static = cute.ceil_div(mA_mkl.shape[1], self.cta_tile_shape_mnk[2]) + c_tile_cnt = cute.size(cute.ceil_div(self.cta_tile_shape_mnk[:2], self.epi_tile)) + + ab_read_state = make_pipeline_state(pipeline.PipelineUserType.Consumer, self.ab_stage) + epi_store_pipeline = self.make_epi_store_pipeline() + epi_read_state = make_pipeline_state( + pipeline.PipelineUserType.Consumer, self.epi_c_stage + ) + epi_producer_state = make_pipeline_state( + pipeline.PipelineUserType.Producer, self.epi_c_stage + ) + tile_scheduler = TileSchedulerCls() + work_tile = None + if const_expr(self.pingpong): + if const_expr(varlen_k): + work_tile = tile_scheduler.initial_work_tile_info() + if warp_idx >= 4: + # Advance 2nd Math WG pipeline states to the end of 1st Math WG + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + if const_expr(not varlen_k): + ab_read_state.advance_iters(k_tile_cnt_static) + else: + len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_read_state.advance_iters(k_tile_cnt) + tile_scheduler.advance_to_next_work() + if const_expr(varlen_k): + work_tile = tile_scheduler.get_current_work() + if const_expr(not varlen_k): + work_tile = tile_scheduler.initial_work_tile_info() + else: + work_tile = tile_scheduler.initial_work_tile_info() + if const_expr(varlen_m): + # wait tensormap initialization complete before update + varlen_manager.fence_tensormap_init() + while work_tile.is_valid_tile: + tile_coord_mnkl = work_tile.tile_idx + batch_idx = tile_coord_mnkl[3] + epi_shapes, epi_orders = self.epi_get_tensormap_update_shapes_orders( + epilogue_params, varlen_params.cu_seqlens_m, batch_idx + ) + varlen_manager.update_tensormap_epi( + batch_idx, + self.d_layout, + epi_shapes, + epi_orders, + is_tma_warp, + ) + len_k = varlen_manager.len_k(batch_idx) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_read_state, tiled_mma = self.mma( + ab_pipeline, + ab_read_state, + tiled_mma, + tCrA, + tCrB, + acc, + acc_slow, + k_tile_cnt, + warp_group_idx, + ) + if const_expr(varlen_k): + if k_tile_cnt == 0: + acc.fill(0.0) + + # ///////////////////////////////////////////////////////////////////////////// + # EPILOGUE + # ///////////////////////////////////////////////////////////////////////////// + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, "epi") + + epilogue_barrier = pipeline.NamedBarrier( + barrier_id=int(NamedBarrierGemm.Epilogue), + num_threads=self.num_epi_warps * cute.arch.WARP_SIZE, + ) + + varlen_manager.fence_tensormap_update_epi(is_tma_warp) + + copy_D = None + if const_expr(has_D): + copy_D, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_d, + varlen_manager.offset_batch_epi(mD_mnl, batch_idx), + self.cta_tile_shape_mnk[:2], + self.epi_tile, + sD, + tile_coord_mnkl, + tma_desc_ptr=tma_desc_d_ptr, + ) + copy_C = None + if const_expr(has_C): + copy_C_fn, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_c, + varlen_manager.offset_batch_epi(mC_mnl, batch_idx), + self.cta_tile_shape_mnk[:2], + self.epi_tile, + sC, + tile_coord_mnkl, + ) + copy_C = copy_utils.tma_producer_copy_fn(copy_C_fn, epi_pipeline) + + d_dtype_for_layout = self.d_dtype if self.d_dtype is not None else cutlass.BFloat16 + tiled_copy_r2s, tRS_rD, tRS_sD = self.epilog_smem_store_and_partition( + tiled_mma, self.d_layout, d_dtype_for_layout, sD, tidx + ) + # (R2S, R2S_M, R2S_N) + tRS_rAcc = tiled_copy_r2s.retile(acc) + load_acc_subtile = partial(self.epi_load_acc_subtile, tRS_rAcc) + if const_expr(has_C): + tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC = self.epilog_smem_load_and_partition( + tiled_mma, self.c_layout, self.c_dtype, sC, tRS_rD.layout, tidx + ) + else: + tiled_copy_s2r, tSR_sC, tRS_rC, tSR_rC = None, None, None, None + + # Wait for all warp groups in the thread block to finish, because smem for tensor + # A in the mainloop is reused in the epilogue if not persistent. + if const_expr(not self.is_persistent): + epilogue_barrier.arrive_and_wait() + + self.epi_visit_acc(epilogue_params, acc, tiled_mma, tile_coord_mnkl, tidx) + + epi_read_state, epi_producer_state = self.epilogue( + epilogue_params, + epi_smem_tensors, + tma_desc_epi_ptrs, + epi_pipeline, + epi_store_pipeline, + epi_read_state, + epi_producer_state, + self.epi_tile, + load_acc_subtile, + tRS_rD, + tRS_rC, + None, # tiled_copy_t2r, for Sm100 only + tiled_copy_r2s, + tRS_sD, + tiled_copy_s2r, + tSR_rC, + tSR_sC, + copy_D, + copy_C, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tile_scheduler, + tidx, + is_tma_warp, + ) + + if const_expr(self.pingpong): + # With pingpong, 2 WGs write two different output tiles to the same smem, + # so we have to make sure the smem content is done reading before signaling + # the next WG's epilogue. + if is_tma_warp: + epi_store_pipeline.producer_tail() + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="epi") + + if const_expr(not self.pingpong): + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + else: # Skip a tile for pingpong + # Update starting load/store pipeline states for the next tile + epi_read_state.advance_iters(c_tile_cnt) + epi_producer_state.advance_iters(c_tile_cnt) + # Update starting mainloop pipeline state for the next tile + if const_expr(not varlen_k): + ab_read_state.advance_iters(k_tile_cnt_static) + tile_scheduler.advance_to_next_work(advance_count=self.mma_warp_groups) + work_tile = tile_scheduler.get_current_work() + else: + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + if work_tile.is_valid_tile: + len_k = varlen_manager.len_k(batch_idx=work_tile.tile_idx[3]) + k_tile_cnt = cute.ceil_div(len_k, self.cta_tile_shape_mnk[2]) + ab_read_state.advance_iters(k_tile_cnt) + tile_scheduler.advance_to_next_work() + work_tile = tile_scheduler.get_current_work() + # End of persistent scheduler loop + + # Wait for D store complete + if const_expr(not self.pingpong): + if is_tma_warp: + epi_store_pipeline.producer_tail() + + @cute.jit + def load_AB( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_producer_state: cutlass.pipeline.PipelineState, + copy_A: Optional[Callable], + copy_B: Callable, + k_tile_cnt: Int32, + # These are for Sm100 blockscaled gemm + copy_SFA: Optional[Callable] = None, + copy_SFB: Optional[Callable] = None, + ) -> cutlass.pipeline.PipelineState: + blockscaled = const_expr(copy_SFA is not None) + if const_expr(blockscaled): + assert copy_SFB is not None + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt + peek_ab_empty_status = Boolean(True) + if 0 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # ///////////////////////////////////////////////////////////////////////// + # TMA load + # ///////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range(k_tile_cnt, unroll=1): + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status) + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + smem_idx = ab_producer_state.index + if const_expr(copy_A is not None): + copy_A(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + if const_expr(blockscaled): + copy_SFA(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_SFB(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + # Mainloop pipeline's producer commit is a NOP + ab_pipeline.producer_commit(ab_producer_state) + ab_producer_state.advance() + peek_ab_empty_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + return ab_producer_state + + @cute.jit + def load_AB_gather_A( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_producer_state: cutlass.pipeline.PipelineState, + copy_A: Callable, + prefetch_A: Optional[Callable], + copy_B: Callable, + k_tile_cnt: Int32, + varlen_m: bool = True, + ) -> cutlass.pipeline.PipelineState: + warp_idx = cute.arch.make_warp_uniform(cute.arch.warp_idx()) + # Peek (try_wait) AB buffer empty for k_block = prefetch_k_tile_cnt + peek_ab_empty_status = Boolean(True) + if 0 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # ///////////////////////////////////////////////////////////////////////// + # TMA load on B and cp.async on A + # ///////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range(k_tile_cnt - 1, unroll=1): + prefetch_out = () + if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + prefetch_out = (prefetch_A(k_tile),) + # Wait for A/B buffers to be empty before loading into them + # Also sets the transaction barrier for the A/B buffers + # A tiny bit faster to rotate the warp that does TMA + # However, for varlen_k, we must use the warp_idx == self.ab_load_warp_id + # since that's the warp that does the tensormap update. + is_tma_warp = warp_idx == self.ab_load_warp_id + ( + (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0 + ) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + smem_idx = ab_producer_state.index + # A bit faster to load B first while we calculate the indices for A + if is_tma_warp: + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_A(k_tile, smem_idx, *prefetch_out) + # This tells mbarrier to track the completion of cp.async + ab_pipeline.producer_cpasync_commit(ab_producer_state) + ab_producer_state.advance() + peek_ab_empty_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_empty_status = ab_pipeline.producer_try_acquire(ab_producer_state) + # bound checking in the K dimension on the last k_tile + if 0 < k_tile_cnt: + k_tile = k_tile_cnt - 1 + prefetch_out = () + if const_expr(prefetch_A is not None): # Prefetch early, even before smem is free + prefetch_out = (prefetch_A(k_tile, pred=True),) + is_tma_warp = warp_idx == self.ab_load_warp_id + ( + (k_tile % self.num_ab_load_warps) if const_expr(varlen_m) else 0 + ) + ab_pipeline.producer_acquire(ab_producer_state, peek_ab_empty_status, is_tma_warp) + smem_idx = ab_producer_state.index + if is_tma_warp: + tma_bar_ptr = ab_pipeline.producer_get_barrier(ab_producer_state) + copy_B(k_tile, smem_idx, tma_bar_ptr=tma_bar_ptr) + copy_A(k_tile, smem_idx, *prefetch_out, pred=True) + ab_pipeline.producer_cpasync_commit(ab_producer_state) + ab_producer_state.advance() + return ab_producer_state + + @cute.jit + def mma( + self, + ab_pipeline: cutlass.pipeline.PipelineAsync, + ab_read_state: cutlass.pipeline.PipelineState, + tiled_mma: cute.TiledMma, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + acc: cute.Tensor, + acc_slow: Optional[cute.Tensor], + k_tile_cnt: Int32, + warp_group_idx: Int32, + ) -> Tuple[cutlass.pipeline.PipelineState, cute.TiledMma]: + # ///////////////////////////////////////////////////////////////////////////// + # Prologue MMAs + # ///////////////////////////////////////////////////////////////////////////// + k_pipe_mmas = 1 + ab_release_state = ab_read_state.clone() + num_prologue_mma = min(k_pipe_mmas, k_tile_cnt) + if const_expr(self.pingpong): + self.pingpong_barrier_sync(warp_group_idx, stage="mma") + peek_ab_full_status = Boolean(True) + if 0 < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) + tiled_mma.set(warpgroup.Field.ACCUMULATE, False) + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_tile in cutlass.range(num_prologue_mma): + # Wait for A/B buffer to be ready + ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) + warpgroup.fence() + for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_blk_coord = (None, None, k_blk_idx, ab_read_state.index) + cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) + tiled_mma.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + ab_read_state.advance() + peek_ab_full_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) + # If k_tile_cnt == 0, this is not correct. But we will set acc to 0 in the mainloop + # in that case. + if const_expr(self.fp8_slow_accum): + warpgroup.wait_group(0) + acc_slow.store(acc.load()) + + # ///////////////////////////////////////////////////////////////////////////// + # MAINLOOP + # ///////////////////////////////////////////////////////////////////////////// + for k_tile in cutlass.range(num_prologue_mma, k_tile_cnt, unroll=1): + # Wait for TMA copies to complete + ab_pipeline.consumer_wait(ab_read_state, peek_ab_full_status) + # WGMMA + warpgroup.fence() + if const_expr(self.fp8_slow_accum): + tiled_mma.set(warpgroup.Field.ACCUMULATE, False) + for k_blk_idx in cutlass.range(num_k_blocks, unroll_full=True): + k_blk_coord = (None, None, k_blk_idx, ab_read_state.index) + cute.gemm(tiled_mma, acc, tCrA[k_blk_coord], tCrB[k_blk_coord], acc) + tiled_mma.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + # Wait on the wgmma barrier for previous k_pipe_mmas wgmmas to complete + if const_expr(not self.fp8_slow_accum): + warpgroup.wait_group(k_pipe_mmas) + else: + warpgroup.wait_group(0) + acc_slow.store(acc_slow.load() + acc.load()) + ab_pipeline.consumer_release(ab_release_state) + ab_read_state.advance() + ab_release_state.advance() + peek_ab_full_status = Boolean(True) + if k_tile + 1 < k_tile_cnt: + peek_ab_full_status = ab_pipeline.consumer_try_wait(ab_read_state) + if const_expr(self.pingpong): + # Cue for next WG's MMA to start + self.pingpong_barrier_arrive(1 - warp_group_idx, stage="mma") + if const_expr(not self.fp8_slow_accum): + # fp8_slow_accum would already called wait_group(0) inside the loop + warpgroup.wait_group(0) + for k_tile in cutlass.range(num_prologue_mma, unroll=1): + ab_pipeline.consumer_release(ab_release_state) + ab_release_state.advance() + if const_expr(self.fp8_slow_accum): + acc.store(acc_slow.load()) + # If we don't return the tiled_mma, we get compiler error + # "operand #0 does not dominate this use" + return ab_read_state, tiled_mma + + @cute.jit + def epilogue( + self, + params: EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + tma_desc_epi_ptrs: list[Optional[cute.Pointer]], + epi_pipeline: cutlass.pipeline.PipelineAsync, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + epi_read_state: cutlass.pipeline.PipelineState, + epi_producer_state: Optional[cutlass.pipeline.PipelineState], + epi_tile: cute.Tile, + load_acc_subtile: Callable, + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor], + tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100 + tiled_copy_r2s: cute.TiledCopy, + tRS_sD: cute.Tensor, + tiled_copy_s2r: Optional[cute.ThrCopy], + tSR_rC: Optional[cute.Tensor], + tSR_sC: Optional[cute.Tensor], + copy_D: Optional[Callable], + copy_C: Optional[Callable], + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tile_scheduler, + tidx: Int32, + is_tma_warp: Boolean, + ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: + has_C = const_expr(tRS_rC is not None) + has_D = const_expr(copy_D is not None) + epi_tile_shape = cute.zipped_divide( + cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile + ).shape[1] + # We iterate over epi tiles in the N dimension first before the M dimension + epi_tile_layout = cute.make_ordered_layout(epi_tile_shape, order=(1, 0)) + epi_tile_num = cute.size(epi_tile_shape) + num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num + + epi_tensors = self.epi_begin( + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + + if const_expr(copy_C is not None): + for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + def tma_store_fn(src_idx, dst_idx): + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + epilogue_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + if const_expr(has_D): + copy_D(src_idx=src_idx, dst_idx=dst_idx) + # Can't use if statement here, epi_store_pipeline object isn't captured somehow + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) + epilogue_barrier.arrive_and_wait() + + # We could delay the TMA store by 1 epi tile to better overlap the non-TMA ops + # with the TMA store. However, currently this doesn't seem to improve perf. + delay_tma_store = False + + src_idx_prev, dst_idx_prev = None, None + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # The global memory coordinate for the current epi tile + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from acc to D registers + load_acc_subtile(tRS_rD, epi_idx) + epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) + if const_expr(has_C): + epi_pipeline.consumer_wait(epi_read_state) + cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) + # Fence to make sure shared memory read is visible to TMA load + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + with cute.arch.elect_one(): + epi_pipeline.consumer_release(epi_read_state) + epi_read_state.advance() + if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + tRS_rEpi = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage + if const_expr(delay_tma_store): + if const_expr(epi_idx > 0): + tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) + src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord + # Copy from D registers to shared memory + if const_expr(has_D): + copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) + if const_expr(not delay_tma_store): + tma_store_fn(src_idx=epi_buffer, dst_idx=gmem_coord) + + if const_expr(delay_tma_store): + tma_store_fn(src_idx=src_idx_prev, dst_idx=dst_idx_prev) + + self.epi_end( + params, + epi_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + + return epi_read_state, epi_producer_state + + def get_scheduler_class(self, varlen_m: bool = False): + """Return the scheduler class to use. Override in subclasses for custom schedulers.""" + return TileScheduler if not varlen_m else VarlenMTileScheduler + + def get_scheduler_arguments( + self, + mA: cute.Tensor, + mB: cute.Tensor, + mD: Optional[cute.Tensor], + scheduler_args, + varlen_args, + ): + """Create scheduler arguments. Override in subclasses for custom schedulers.""" + if const_expr(varlen_args.mCuSeqlensM is None): + num_problems = ( + mD.shape[2] + if mD is not None + else ( + mB.shape[2] + if varlen_args.mCuSeqlensK is None + else varlen_args.mCuSeqlensK.shape[0] - 1 + ) + ) + problem_shape_ntile_mnl = ( + cute.ceil_div(mA.shape[0], self.cta_tile_shape_mnk[0]), + cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]), + num_problems, + ) + tile_sched_args = TileSchedulerArguments( + problem_shape_ntile_mnl=problem_shape_ntile_mnl, + raster_order=scheduler_args.raster_order, + group_size=scheduler_args.max_swizzle_size, + cluster_shape_mnk=self.cluster_shape_mnk, + tile_count_semaphore=scheduler_args.tile_count_semaphore, + batch_idx_permute=scheduler_args.batch_idx_permute, + is_persistent=self.is_persistent, + ) + else: + assert mD is not None or not self.gather_A + problem_shape_ntile_mnl = ( + None, + cute.ceil_div(mB.shape[0], self.cta_tile_shape_mnk[1]), + varlen_args.mCuSeqlensM.shape[0] - 1, + ) + tile_sched_args = VarlenMTileSchedulerArguments( + problem_shape_ntile_mnl=problem_shape_ntile_mnl, + total_m=mD.shape[0] if mD is not None else varlen_args.mAIdx.shape[0], + cu_seqlens_m=varlen_args.mCuSeqlensM, + raster_order=scheduler_args.raster_order, + group_size=scheduler_args.max_swizzle_size, + tile_shape_mn=self.cta_tile_shape_mnk[:2], + cluster_shape_mnk=self.cluster_shape_mnk, + tile_count_semaphore=scheduler_args.tile_count_semaphore, + is_persistent=self.is_persistent, + ) + return tile_sched_args + + @cute.jit + def epi_load_acc_subtile(self, tRS_rAcc: cute.Tensor, tRS_rD: cute.Tensor, epi_idx: int): + for epi_v in cutlass.range_constexpr(cute.size(tRS_rD)): + tRS_rD[epi_v] = tRS_rAcc[epi_idx * cute.size(tRS_rD) + epi_v] + + @cute.jit + def epi_begin( + self, + params: EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + epi_tile: cute.Tile, + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tidx: Int32, + ) -> Tuple[cute.Tensor, ...]: + return () + + def epi_begin_loop( + self, params: EpilogueParams, epi_tensors: Tuple[cute.Tensor, ...], epi_coord: cute.Coord + ) -> Tuple[cute.Tensor, ...]: + return () + + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + return None + + def epi_visit_acc( + self, + params: EpilogueParams, + acc: cute.Tensor, + tiled_mma: cute.TiledMma, + tile_coord_mnkl: cute.Coord, + tidx: Int32, + ) -> None: + pass + + @cute.jit + def epi_end( + self, + params: EpilogueParams, + epi_tensors: Tuple[cute.Tensor, ...], + epi_tile: cute.Tile, + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tile_coord_mnkl: cute.Coord, + varlen_manager, + tidx, + ) -> None: + pass + + def epi_to_underlying_arguments( + self, args: EpilogueArguments, *, loc=None, ip=None + ) -> EpilogueParams: + return self.EpilogueParams() + + def epi_get_tma_atoms( + self, params: EpilogueParams, *, loc=None, ip=None + ) -> list[cute.CopyAtom]: + """Subclasses can override this""" + return [] + + def epi_get_tensormap_update_shapes_orders( + self, + params: EpilogueParams, + cu_seqlens_m: cute.Tensor, + batch_idx: Int32, + *, + loc=None, + ip=None, + ) -> tuple[list[Int32], list[int]]: + """Subclasses can override this""" + return [], [] + + @staticmethod + def epi_smem_bytes_per_stage( + args: Optional[EpilogueArguments], + cta_tile_shape_mnk: Tuple[int, int, int], + epi_tile: cute.Tile, + ) -> int: + return 0 + + def epi_get_smem_struct(self, params: EpilogueParams): + return cute.struct.MemRange[Int32, 0] # Dummy struct + + def epi_get_smem_tensors(self, params: EpilogueParams, storage) -> Tuple[cute.Tensor, ...]: + return tuple() + + def pingpong_barrier_sync(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]): + assert stage in ["mma", "epi"] + barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 + cute.arch.barrier( + barrier_id=int(barrier) + warp_group_idx, + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def pingpong_barrier_arrive(self, warp_group_idx: Int32, stage: Literal["mma", "epi"]): + assert stage in ["mma", "epi"] + barrier = NamedBarrierGemm.MmaWG0 if stage == "mma" else NamedBarrierGemm.EpiWG0 + cute.arch.barrier_arrive( + barrier_id=int(barrier) + warp_group_idx, + number_of_threads=2 * self.num_threads_per_warp_group, + ) + + def epilog_smem_copy_atom(self, tiled_mma: cute.TiledMma) -> cute.TiledCopy: + copy_atom_C = cute.make_copy_atom( + warp.StMatrix8x8x16bOp( + self.d_layout.is_m_major_c() if self.d_layout is not None else False, + num_matrices=4 if self.epi_tile[1] % 16 == 0 else 2, + ), + Float16, # this is just to get the right source layout + ) + tiled_copy_C_atom = cute.make_tiled_copy_C_atom(copy_atom_C, tiled_mma) + return tiled_copy_C_atom + + def epilog_smem_store_and_partition( + self, + tiled_mma: cute.TiledMma, + d_layout: Optional[LayoutEnum], + dtype: Type[cutlass.Numeric], + sD: Optional[cute.Tensor], + tidx: Int32, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + if d_layout is None: + d_layout = LayoutEnum.ROW_MAJOR + tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) + # Doesn't work with tile_N % 8 == 0 but tile_n % 16 != since this always + # get st.matrix with num_matrices=4 + copy_atom_r2s = sm90_utils.sm90_get_smem_store_op( + d_layout, elem_ty_d=dtype, elem_ty_acc=self.acc_dtype + ) + tiled_copy_r2s = cute.make_tiled_copy_S(copy_atom_r2s, tiled_copy_C_atom) + # (R2S, R2S_M, R2S_N, PIPE_D) + thr_copy_r2s = tiled_copy_r2s.get_slice(tidx) + tRS_sD = thr_copy_r2s.partition_D(sD) if sD is not None else None + sD_shape = sD.shape[:2] if sD is not None else self.epi_tile + tRS_rD_shape = thr_copy_r2s.partition_S(cute.make_identity_tensor(sD_shape)).shape + tRS_rD = cute.make_fragment(tRS_rD_shape, self.acc_dtype) + return tiled_copy_r2s, tRS_rD, tRS_sD + + def epilog_smem_load_and_partition( + self, + tiled_mma: cute.TiledMma, + c_layout: LayoutEnum, + dtype: Type[cutlass.Numeric], + sC: cute.Tensor, + tRS_rD_layout: cutlass.Layout, + tidx: Int32, + ) -> Tuple[cute.TiledCopy, cute.Tensor, cute.Tensor]: + tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) + copy_atom_s2r = copy_utils.sm90_get_smem_load_op(c_layout, dtype) + tiled_copy_s2r = cute.make_tiled_copy_S(copy_atom_s2r, tiled_copy_C_atom) + thr_copy_s2r = tiled_copy_s2r.get_slice(tidx) + tSR_sC = thr_copy_s2r.partition_S(sC) + tRS_rC = cute.make_fragment(tRS_rD_layout, dtype) + tSR_rC = thr_copy_s2r.retile(tRS_rC) + return tiled_copy_s2r, tRS_rC, tSR_rC, tSR_sC + + def epilog_gmem_copy_and_partition( + self, + atom: Union[cute.CopyAtom, cute.TiledCopy], + mD_mn: cute.Tensor, + tile_shape_mn: cute.Tile, + epi_tile: cute.Tile, + sD: cute.Tensor, + tile_coord_mnkl: cute.Coord, + tma_desc_ptr: Optional[cute.Pointer] = None, + ) -> Tuple[cute.Tensor, cute.Tensor]: + # (bM, bN) + gD = cute.local_tile(mD_mn, tile_shape_mn, tile_coord_mnkl[:2]) + tDgD_for_tma_partition = cute.zipped_divide(gD, epi_tile) + is_s2g = isinstance( + atom.op, (cpasync.CopyBulkTensorTileS2GOp, cpasync.CopyReduceBulkTensorTileS2GOp) + ) + src_tensor, dst_tensor = ( + (sD, tDgD_for_tma_partition) if is_s2g else (tDgD_for_tma_partition, sD) + ) + return copy_utils.tma_get_copy_fn( + atom, + cta_coord=0, + cta_layout=cute.make_layout(1), + src_tensor=src_tensor, + dst_tensor=dst_tensor, + tma_desc_ptr=tma_desc_ptr, + ) + + def make_ab_pipeline( + self, + tiled_mma: cute.TiledMma, + cluster_layout_vmnk: cute.Layout, + ab_pipeline_mbar_ptr: cute.Pointer, + ): + # Threads/warps participating in this pipeline + producer_cnt = 1 if const_expr(not self.gather_A) else 1 + self.num_ab_load_warps * 32 + ab_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, producer_cnt) + # Each warp will contribute to the arrive count with the number of mcast size + mcast_size = self.num_mcast_ctas_a + self.num_mcast_ctas_b - 1 + consumer_arrive_cnt = mcast_size * tiled_mma.size // cute.arch.WARP_SIZE + ab_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + pipeline_cls = pipeline.PipelineTmaAsync if not self.gather_A else PipelineTmaCpAsync + return pipeline_cls.create( + barrier_storage=ab_pipeline_mbar_ptr, + num_stages=self.ab_stage, + producer_group=ab_pipeline_producer_group, + consumer_group=ab_pipeline_consumer_group, + tx_count=self.num_tma_load_bytes, + cta_layout_vmnk=cluster_layout_vmnk, + ) + + def make_epi_pipeline( + self, c_smem_layout: cute.Layout | cute.ComposedLayout, epi_pipeline_mbar_ptr: cute.Pointer + ): + # Threads/warps participating in this pipeline + epi_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + # Each warp will contribute 1 to the arrive count + consumer_arrive_cnt = self.num_epi_warps + epi_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + tma_copy_c_bytes = cute.size_in_bytes(self.c_dtype, c_smem_layout) + return pipeline.PipelineTmaAsync.create( + barrier_storage=epi_pipeline_mbar_ptr, + num_stages=self.epi_c_stage, + producer_group=epi_pipeline_producer_group, + consumer_group=epi_pipeline_consumer_group, + tx_count=tma_copy_c_bytes, + ) + + def make_epi_store_pipeline(self): + # Threads/warps participating in tma store pipeline + num_epi_threads = self.num_epi_warps * cute.arch.WARP_SIZE + epi_store_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread, num_epi_threads) + return pipeline.PipelineTmaStore.create( + num_stages=self.epi_stage, producer_group=epi_store_producer_group + ) + + def make_sched_pipeline( + self, cluster_layout_mnk: cute.Layout, sched_pipeline_mbar_ptr: cute.Pointer, varlen_k: bool + ): + # Threads/warps participating in this pipeline + sched_pipeline_producer_group = pipeline.CooperativeGroup(pipeline.Agent.Thread) + cluster_size = cute.size(cluster_layout_mnk) + # Each warp that are not the scheduler warp will contribute 1 to the arrive count + # If pingpong and varlen_k, then all 8 mma warps will participate in the scheduler barrier + # at each round. If pingpong and not varlen_k, then only 4 mma warp will participate. + consumer_arrive_cnt = ( + (self.mma_warp_groups if not (self.pingpong and not varlen_k) else 1) * 4 + + self.num_ab_load_warps + ) * cluster_size - 1 + sched_pipeline_consumer_group = pipeline.CooperativeGroup( + pipeline.Agent.Thread, consumer_arrive_cnt + ) + return pipeline.PipelineAsync.create( + barrier_storage=sched_pipeline_mbar_ptr, + num_stages=self.sched_stage, + producer_group=sched_pipeline_producer_group, + consumer_group=sched_pipeline_consumer_group, + # If there's cluster, the consumers must arrive at the mbar of CTA 0 in the cluster. + consumer_mask=None if const_expr(cluster_size == 1) else 0, + ) + + @classmethod + def _compute_stages( + cls, + cta_tile_shape_mnk: Tuple[int, int, int], + epi_tile: Tuple[int, int], + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + d_dtype: Optional[Type[cutlass.Numeric]], + c_dtype: Optional[Type[cutlass.Numeric]], + epilogue_args: EpilogueArguments, + smem_capacity: int, + occupancy: int, + overlap_sD_sA: bool = False, + ) -> Tuple[int, int]: + """Computes the number of stages for A/B/C operands based on heuristics. + + :param cta_tile_shape_mnk: The shape (M, N, K) of the CTA tile. + :type cta_tile_shape_mnk: Tuple[int, int, int] + :param a_dtype: Data type of operand A. + :type a_dtype: type[cutlass.Numeric] + :param b_dtype: Data type of operand B. + :type b_dtype: type[cutlass.Numeric] + :param smem_capacity: Total available shared memory capacity in bytes. + :type smem_capacity: int + :param occupancy: Target number of CTAs per SM (occupancy). + :type occupancy: int + + :return: A tuple containing the computed number of stages for: + (A/B operand stages, epilogue stages) + :rtype: Tuple[int, int] + """ + + epi_stage = 4 if epi_tile[1] <= 16 else 2 + if overlap_sD_sA: + epi_bytes = 0 + else: + d_bytes_per_stage = ( + cute.size(epi_tile) * d_dtype.width // 8 if d_dtype is not None else 0 + ) + epi_bytes_per_stage = d_bytes_per_stage + cls.epi_smem_bytes_per_stage( + epilogue_args, cta_tile_shape_mnk, epi_tile + ) + epi_bytes = epi_bytes_per_stage * epi_stage + epi_c_stage = 0 if c_dtype is None else (4 if epi_tile[1] <= 16 else 2) + if c_dtype is not None: + epi_bytes += cute.size(epi_tile) * c_dtype.width // 8 * epi_c_stage + + a_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None)) + b_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None)) + ab_bytes_per_stage = ( + cute.size(a_shape) * a_dtype.width // 8 + cute.size(b_shape) * b_dtype.width // 8 + ) + mbar_helpers_bytes = 1024 + + remaining_bytes = smem_capacity // occupancy - mbar_helpers_bytes - epi_bytes + ab_stage = remaining_bytes // ab_bytes_per_stage + + # Refine epilogue stages: + # Calculate remaining smem after allocating for A/B stages and reserved bytes + # Add remaining unused smem to epilogue + if not overlap_sD_sA and epi_bytes_per_stage > 0: + epi_stage += (remaining_bytes - ab_bytes_per_stage * ab_stage) // epi_bytes_per_stage + return ab_stage, epi_stage, epi_c_stage + + @staticmethod + def _sm90_compute_tile_shape_or_override( + cta_tile_shape_mnk: Tuple[int, int, int], + atom_layout_mnk: Tuple[int, int, int], + element_type: Optional[Type[cutlass.Numeric]] = None, + epi_tile_override: Tuple[int, int] | None = None, + ) -> Tuple[int, int]: + """Compute the epilogue tile shape or use override if provided. + + :param cta_tile_shape_mnk: CTA tile shape (M,N,K) + :type cta_tile_shape_mnk: Tuple[int, int, int] + :param element_type: Data type of elements + :type element_type: type[cutlass.Numeric] + :param is_cooperative: Whether to use cooperative approach + :type is_cooperative: bool + :param epi_tile_override: Optional override for epilogue tile shape + :type epi_tile_override: Tuple[int, int] or None + + :return: Computed epilogue tile shape + :rtype: Tuple[int, int] + """ + if epi_tile_override is not None: + return epi_tile_override + if cta_tile_shape_mnk[0] % 128 == 0 and atom_layout_mnk[0] > 1: + tile_m = math.gcd(128, cute.size(cta_tile_shape_mnk, mode=[0])) + tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1])) + elif cta_tile_shape_mnk[0] % 192 == 0 and atom_layout_mnk[0] > 1: + tile_m = math.gcd(192, cute.size(cta_tile_shape_mnk, mode=[0])) + tile_n = math.gcd(32, cute.size(cta_tile_shape_mnk, mode=[1])) + else: + # In the case of tile shape 128 x N but atom_layout 1 x 2, we need to set + # epi_tile_m = 64. If epi_tile_m = 128, the epilogue would iterate along the + # M dimension first, then move to the N dimension. But the accumulator in registers + # iterate along the N dimension first, then move to the M dimension. + # We could change the epilogue to accommodate this, + # but it's easier to just set epi_tile_m = 64. + n_perf = 64 if element_type is not None and element_type.width == 8 else 32 + tile_m = math.gcd(64, cute.size(cta_tile_shape_mnk, mode=[0])) + tile_n = math.gcd(n_perf, cute.size(cta_tile_shape_mnk, mode=[1])) + return (tile_m, tile_n) + + @staticmethod + def _make_smem_layouts( + cta_tile_shape_mnk: Tuple[int, int, int], + epi_tile: Tuple[int, int], + a_dtype: Type[cutlass.Numeric], + a_layout: LayoutEnum, + b_dtype: Type[cutlass.Numeric], + b_layout: LayoutEnum, + ab_stage: int, + d_dtype: Optional[Type[cutlass.Numeric]], + d_layout: LayoutEnum, + epi_stage: int, + c_dtype: Optional[Type[cutlass.Numeric]], + c_layout: Optional[LayoutEnum], + epi_c_stage: int, + ) -> Tuple[ + cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout, Optional[cute.ComposedLayout] + ]: + """Create shared memory layouts for A, B, and C tensors. + + :param cta_tile_shape_mnk: CTA tile shape (M,N,K) + :type cta_tile_shape_mnk: Tuple[int, int, int] + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + :param a_dtype: Data type for matrix A + :type a_dtype: type[cutlass.Numeric] + :param a_layout: Layout enum for matrix A + :type a_layout: LayoutEnum + :param b_dtype: Data type for matrix B + :type b_dtype: type[cutlass.Numeric] + :param b_layout: Layout enum for matrix B + :type b_layout: LayoutEnum + :param ab_stage: Number of stages for A/B tensors + :type ab_stage: int + :param d_dtype: Data type for output matrix D + :type d_dtype: type[cutlass.Numeric] + :param d_layout: Layout enum for the output matrix C + :type d_layout: LayoutEnum + :param epi_stage: Number of epilogue stages + :type epi_stage: int + + :return: Tuple of shared memory layouts for A, B, and C + :rtype: Tuple[cute.ComposedLayout, cute.ComposedLayout, cute.ComposedLayout] + """ + a_smem_shape = cute.slice_(cta_tile_shape_mnk, (None, 0, None)) + + a_is_k_major = a_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K + b_is_k_major = b_layout.sm90_mma_major_mode() == warpgroup.OperandMajorMode.K + a_major_mode_size = cta_tile_shape_mnk[2 if a_is_k_major else 0] + a_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size), + a_dtype, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape, ab_stage), + order=(0, 1, 2) if a_is_k_major else (1, 0, 2), + ) + + b_smem_shape = cute.slice_(cta_tile_shape_mnk, (0, None, None)) + + b_major_mode_size = cta_tile_shape_mnk[2 if b_is_k_major else 1] + b_smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils.get_smem_layout_atom(b_layout, b_dtype, b_major_mode_size), + b_dtype, + ) + b_smem_layout_staged = cute.tile_to_shape( + b_smem_layout_atom, + cute.append(b_smem_shape, ab_stage), + order=(0, 1, 2) if b_is_k_major else (1, 0, 2), + ) + + epi_smem_layout_staged = None + if d_dtype is not None: + epi_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi( + d_dtype, d_layout, epi_tile, epi_stage + ) + + epi_c_smem_layout_staged = None + if c_dtype is not None: + assert c_layout is not None + epi_c_smem_layout_staged = quack_sm90_utils.make_smem_layout_epi( + c_dtype, c_layout, epi_tile, epi_c_stage + ) + + return ( + a_smem_layout_staged, + b_smem_layout_staged, + epi_smem_layout_staged, + epi_c_smem_layout_staged, + ) + + @staticmethod + def _make_tma_epi_atoms_and_tensors( + tensor_d: cute.Tensor, + epi_smem_layout_staged: cute.ComposedLayout, + epi_tile: Tuple[int, int], + op_type: Literal["store", "load", "add"], + ) -> Tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for storing D or loading C. + + :param tensor_d: Output tensor D + :type tensor_d: cute.Tensor + :param epi_smem_layout_staged: Shared memory layout for epilogue + :type epi_smem_layout_staged: cute.ComposedLayout + :param epi_tile: Epilogue tile shape + :type epi_tile: Tuple[int, int] + + :return: TMA atom and tensor for C + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + assert op_type in ["load", "store", "add"] + epi_smem_layout = cute.slice_(epi_smem_layout_staged, (None, None, 0)) + d_cta_v_layout = cute.composition(cute.make_identity_layout(tensor_d.shape), epi_tile) + op = ( + cpasync.CopyBulkTensorTileG2SOp() + if op_type == "load" + else cpasync.CopyBulkTensorTileS2GOp() + if op_type == "store" + else cpasync.CopyReduceBulkTensorTileS2GOp(cute.ReductionOp.ADD) + ) + tma_atom_d, tma_tensor_d = cpasync.make_tiled_tma_atom( + op, tensor_d, epi_smem_layout, d_cta_v_layout + ) + return tma_atom_d, tma_tensor_d + + @staticmethod + def _make_tma_atoms_and_tensors( + tensor: cute.Tensor, + smem_layout: cute.ComposedLayout, + smem_tile: Tuple[int, int], + mcast_dim: int, + ) -> Tuple[cute.CopyAtom, cute.Tensor]: + """Create TMA atoms and tensors for input tensors. + + :param tensor: Input tensor (A or B) + :type tensor: cute.Tensor + :param smem_layout: Shared memory layout for the tensor + :type smem_layout: cute.ComposedLayout + :param smem_tile: Shared memory tile shape + :type smem_tile: Tuple[int, int] + :param mcast_dim: Multicast dimension + :type mcast_dim: int + + :return: TMA atom and tensor + :rtype: Tuple[cute.CopyAtom, cute.Tensor] + """ + op = ( + cpasync.CopyBulkTensorTileG2SOp() + if mcast_dim == 1 + else cpasync.CopyBulkTensorTileG2SMulticastOp() + ) + tma_atom, tma_tensor = cpasync.make_tiled_tma_atom( + op, + tensor, + smem_layout, + smem_tile, + num_multicast=mcast_dim, + ) + return tma_atom, tma_tensor + + def _make_gmem_tiled_copy_A(self, dtype, major_mode, num_threads, copy_bits=128): + atom_async_copy = cute.make_copy_atom( + cpasync.CopyG2SOp(cache_mode=cpasync.LoadCacheMode.GLOBAL), + dtype, + num_bits_per_copy=copy_bits, + ) + copy_elems = copy_bits // dtype.width + loads_per_cache_line = 128 * 8 // copy_bits # 128 bytes per cache line + shape_dim_1 = cute.size(self.cta_tile_shape_mnk[2]) // copy_elems + if shape_dim_1 > loads_per_cache_line: + shape_dim_1 = math.gcd(shape_dim_1, loads_per_cache_line) + # thread layout for copy + thread_layout = cute.make_layout( + (num_threads // shape_dim_1, shape_dim_1), stride=(shape_dim_1, 1) + ) + if major_mode != LayoutEnum.ROW_MAJOR: + shape_dim_0 = cute.size(self.cta_tile_shape_mnk[0]) // copy_elems + if shape_dim_0 > loads_per_cache_line: + shape_dim_0 = math.gcd(shape_dim_0, loads_per_cache_line) + thread_layout = cute.make_layout( + (shape_dim_0, num_threads // shape_dim_0), stride=(1, shape_dim_0) + ) + # Value layout for copy + value_layout = ( + cute.make_layout((1, copy_elems)) + if major_mode == LayoutEnum.ROW_MAJOR + else cute.make_layout((copy_elems, 1)) + ) + return cute.make_tiled_copy_tv(atom_async_copy, thread_layout, value_layout) + + @staticmethod + def is_valid_dtypes( + a_dtype: Type[cutlass.Numeric], + b_dtype: Type[cutlass.Numeric], + acc_dtype: Type[cutlass.Numeric], + d_dtype: Optional[Type[cutlass.Numeric]], + a_major: str, + b_major: str, + ) -> bool: + """ + Check if the dtypes are valid + + :param a_dtype: The data type of tensor A + :type a_dtype: Type[cutlass.Numeric] + :param b_dtype: The data type of tensor B + :type b_dtype: Type[cutlass.Numeric] + :param acc_dtype: The data type of the accumulator + :type acc_dtype: Type[cutlass.Numeric] + :param d_dtype: The data type of the output tensor + :type d_dtype: Type[cutlass.Numeric] + :param a_major: major mode of tensor A + :type a_major: str + :param b_major: major mode of tensor B + :type b_major: str + + :return: True if the dtypes are valid, False otherwise + :rtype: bool + """ + is_valid = True + if a_dtype not in { + Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # tested b_dtype + if b_dtype not in { + Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + if acc_dtype not in {Float32, Float16}: + is_valid = False + # tested d_dtype + if d_dtype not in { + None, + Float32, + Float16, + cutlass.BFloat16, + cutlass.Float8E4M3FN, + cutlass.Float8E5M2, + }: + is_valid = False + # make sure a_dtype == b_dtype for Float16 + if a_dtype.width == 16 and a_dtype != b_dtype: + is_valid = False + # make sure a_dtype.width == b_dtype.width (i.e, Float8E4M3FN or Float8E5M2) + if a_dtype.width != b_dtype.width: + is_valid = False + + # for Float8 types, this implementation only supports k-major layout + if (a_dtype.width == 8 and a_major != "k") or (b_dtype.width == 8 and b_major != "k"): + is_valid = False + return is_valid diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_symmetric.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_symmetric.py new file mode 100644 index 00000000..99348d0b --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_symmetric.py @@ -0,0 +1,330 @@ +from typing import Tuple, Optional, Callable +from functools import partial +from torch import Tensor +from .gemm_act import GemmActMixin, act_fn_map, gemm_act +from .gemm_sm90 import GemmSm90 +from .gemm_sm100 import GemmSm100 +from .tile_scheduler import TriangularTileScheduler +from .gemm_wrapper_utils import GemmWrapperBase +from .cute_dsl_utils import get_device_capacity, get_max_active_clusters +from .varlen_utils import VarlenManager +from . import copy_utils as copy_utils +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +from cutlass.cute.runtime import make_ptr +from cutlass import Int32, Float32, Boolean, const_expr +import cutlass.utils.hopper_helpers as sm90_utils_og +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cutlass_dsl import if_generate + + +class GemmSymmetricMixin(GemmActMixin, GemmSm90): + def get_scheduler_class(self, varlen_m: bool = False): + return TriangularTileScheduler + + @cute.jit + def epilogue( + self, + params: GemmActMixin.EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + tma_desc_epi_ptrs: list[Optional[cute.Pointer]], + epi_pipeline: cutlass.pipeline.PipelineAsync, + epi_store_pipeline: cutlass.pipeline.PipelineAsync, + epi_read_state: cutlass.pipeline.PipelineState, + epi_producer_state: cutlass.pipeline.PipelineState, + epi_tile: cute.Tile, + load_acc_subtile: Callable, + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor], + tiled_copy_t2r: Optional[cute.TiledCopy], # Only for Sm100 + tiled_copy_r2s: cute.TiledCopy, + tRS_sD: cute.Tensor, + tiled_copy_s2r: Optional[cute.TiledCopy], + tSR_rC: Optional[cute.Tensor], + tSR_sC: Optional[cute.Tensor], + copy_D: Optional[Callable], + copy_C: Optional[Callable], + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tile_scheduler, + tidx: Int32, + is_tma_warp: Boolean, + ) -> Tuple[cutlass.pipeline.PipelineState, cutlass.pipeline.PipelineState]: + has_C = const_expr(tRS_rC is not None) + has_D = const_expr(copy_D is not None) + + tma_atom_postact = params.tma_atom_postact + mPostAct_mnl = params.mPostAct_mnl + sRowVec, sColVec, sPostAct = epi_smem_tensors + get_smem_store_op = ( + partial(sm100_utils.get_smem_store_op, tiled_tmem_load=tiled_copy_t2r) + if self.arch == 100 + else sm90_utils_og.sm90_get_smem_store_op + ) + copy_atom_postact_r2s = get_smem_store_op( + self.postact_layout, self.postact_dtype, self.acc_dtype + ) + # tiled_copy_C_atom = self.epilog_smem_copy_atom(tiled_mma) + # tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_C_atom) + tiled_copy_postact_r2s = cute.make_tiled_copy_S(copy_atom_postact_r2s, tiled_copy_r2s) + tRS_sPostAct = tiled_copy_postact_r2s.get_slice(tidx).partition_D(sPostAct) + (tma_desc_postact_ptr,) = tma_desc_epi_ptrs + batch_idx = tile_coord_mnkl[3] + copy_postact, _, _ = self.epilog_gmem_copy_and_partition( + tma_atom_postact, + varlen_manager.offset_batch_epi(mPostAct_mnl, batch_idx), + self.cta_tile_shape_postact_mn, + params.epi_tile_postact, + sPostAct, + tile_coord_mnkl, + tma_desc_ptr=tma_desc_postact_ptr, + ) + + # We iterate over epi tiles in the N dimension first before the M dimension + epi_tile_shape = cute.zipped_divide( + cute.make_layout(self.cta_tile_shape_mnk[:2]), epi_tile + ).shape[1] + epi_tile_layout = cute.make_layout(epi_tile_shape, stride=(epi_tile_shape[1], 1)) + epi_tile_num = cute.size(epi_tile_shape) + num_prev_subtiles = tile_scheduler.num_tiles_executed * epi_tile_num + + epi_tensors = self.epi_begin( + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + + if const_expr(copy_C is not None): + for epi_idx in cutlass.range(min(epi_tile_num, self.epi_c_stage), unroll=1): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + + def tma_store_fn(src_idx, dst_idx, tile_coord_mnkl): + pid_m = tile_coord_mnkl[0] + pid_n = tile_coord_mnkl[1] + # Fence and barrier to make sure shared memory store is visible to TMA store + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + epilogue_barrier.arrive_and_wait() + # Copy from shared memory to global memory + if is_tma_warp: + square_tile_m = pid_m // self.cluster_shape_mnk[0] + square_tile_n = pid_n // self.cluster_shape_mnk[1] + if const_expr(has_D): + copy_D(src_idx=src_idx, dst_idx=dst_idx) + if square_tile_m != square_tile_n: # don't write twice to the same tile + copy_postact(src_idx=src_idx, dst_idx=dst_idx) + # Can't use if statement here, epi_store_pipeline object isn't captured somehow + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_commit()) + if_generate(is_tma_warp, lambda: epi_store_pipeline.producer_acquire()) + epilogue_barrier.arrive_and_wait() + + delay_tma_store = True + + src_idx_prev, dst_idx_prev = None, None + for epi_idx in cutlass.range_constexpr(epi_tile_num): + # The global memory coordinate for the current epi tile + gmem_coord = epi_tile_layout.get_hier_coord(epi_idx) + # Copy from acc to D registers + load_acc_subtile(tRS_rD, epi_idx) + epi_loop_tensors = self.epi_begin_loop(params, epi_tensors, gmem_coord) + if const_expr(has_C): + epi_pipeline.consumer_wait(epi_read_state) + cute.copy(tiled_copy_s2r, tSR_sC[None, None, None, epi_read_state.index], tSR_rC) + # Fence to make sure shared memory read is visible to TMA load + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + with cute.arch.elect_one(): + epi_pipeline.consumer_release(epi_read_state) + epi_read_state.advance() + if const_expr(copy_C is not None and epi_idx + self.epi_c_stage < epi_tile_num): + gmem_coord_C = epi_tile_layout.get_hier_coord(epi_idx + self.epi_c_stage) + if is_tma_warp: + epi_pipeline.producer_acquire(epi_producer_state) + copy_C(src_idx=gmem_coord_C, producer_state=epi_producer_state) + epi_pipeline.producer_commit(epi_producer_state) + epi_producer_state.advance() + tRS_rPostAct = self.epi_visit_subtile(params, epi_loop_tensors, tRS_rD, tRS_rC) + epi_buffer = (num_prev_subtiles + epi_idx) % self.epi_stage + if const_expr(delay_tma_store): + if const_expr(epi_idx > 0): + tma_store_fn( + src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl + ) + src_idx_prev, dst_idx_prev = epi_buffer, gmem_coord + # Copy from D registers to shared memory + if const_expr(has_D): + copy_utils.cvt_copy(tiled_copy_r2s, tRS_rD, tRS_sD[None, None, None, epi_buffer]) + cute.copy( + tiled_copy_postact_r2s, + tiled_copy_postact_r2s.retile(tRS_rPostAct), + tRS_sPostAct[None, None, None, epi_buffer], + ) + if const_expr(not delay_tma_store): + tma_store_fn( + src_idx=epi_buffer, dst_idx=gmem_coord, tile_coord_mnkl=tile_coord_mnkl + ) + + if const_expr(delay_tma_store): + tma_store_fn( + src_idx=src_idx_prev, dst_idx=dst_idx_prev, tile_coord_mnkl=tile_coord_mnkl + ) + + self.epi_end( + params, + epi_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + tidx, + ) + + return epi_read_state, epi_producer_state + + +class GemmSymmetricSm90(GemmSymmetricMixin, GemmSm90): + pass + + +class GemmSymmetricSm100(GemmSymmetricMixin, GemmSm100): + pass + + +def gemm_symmetric( + A: Tensor, # (l, m, k) + B: Tensor, # (l, m, k) + D: Optional[Tensor], # (l, m, m) + C: Optional[Tensor], # (l, m, m) + tile_count_semaphore: Optional[Tensor], # (1,) + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + max_swizzle_size: int = 8, + alpha: float | Tensor = 1.0, + beta: float | Tensor = 1.0, +) -> None: + # Tranpose D so the "activation" is a write to the mirrored tile + PostAct = D.mT + + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, B, D, C, additional_tensors={"PostAct": PostAct} + ) + assert M == N, "M and N must be the same; symmetric gemm only supports square matrices" + GemmWrapperBase.permute_tensors(tensor_infos) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + "PostAct": ("m", "n", "l"), + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmSymmetricSm90 if device_capacity[0] == 9 else GemmSymmetricSm100 + + acc_dtype = Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors({k: v for k, v in tensor_infos.items()}, major_configs) + + def scalar_arg(scalar: float | Tensor): + if isinstance(scalar, float): + return Float32(scalar) if scalar != 1.0 else None + else: + assert isinstance(scalar, Tensor) + return make_ptr(Float32, scalar.data_ptr(), cute.AddressSpace.gmem, assumed_align=4) + + activation = None # Equivalent to identity + act_fn = act_fn_map[activation] + epi_args = GemmCls.EpilogueArguments( + tensor_infos["PostAct"].cute_tensor, act_fn, scalar_arg(alpha), scalar_arg(beta) + ) + scheduler_args = GemmWrapperBase.create_scheduler_args( + max_active_clusters, tile_count_semaphore, max_swizzle_size=max_swizzle_size + ) + varlen_args = None + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + activation, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + max_swizzle_size, + 2 if isinstance(alpha, Tensor) else (1 if alpha == 1.0 else 0), + 2 if isinstance(beta, Tensor) else (1 if beta == 1.0 else 0), + key_tensor_names=("A", "B", "D", "PostAct", "C"), + ) + cache = gemm_act.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm_obj = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=False, + ) + cache[compile_key] = cute.compile( + gemm_obj, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm_act.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack/gemm_wrapper_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/gemm_wrapper_utils.py new file mode 100644 index 00000000..b3ad9411 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/gemm_wrapper_utils.py @@ -0,0 +1,317 @@ +# Copyright (c) 2025, Tri Dao. +from typing import Optional, Tuple, Dict, Any +from dataclasses import dataclass + +import torch +from torch import Tensor + +import cutlass.cute as cute +from cutlass import Int32 +from cutlass.cute.runtime import from_dlpack, make_ptr + +from .cute_dsl_utils import torch2cute_dtype_map +from .varlen_utils import VarlenArguments +from .tile_scheduler import TileSchedulerOptions + + +@dataclass +class GemmTensorInfo: + tensor: Optional[Tensor] + dtype: Optional[Any] = None + major: Optional[str] = None + cute_tensor: Optional[cute.Tensor] = None + + +class GemmWrapperBase: + @staticmethod + def validate_tensor(tensor: Tensor, name: str, ndim: int) -> None: + assert tensor.dim() == ndim and tensor.is_cuda, f"{name} must be a {ndim}D CUDA tensor" + assert tensor.dtype in torch2cute_dtype_map, f"Unsupported dtype for {name}" + + @staticmethod + def validate_shape(tensor: Tensor, expected_shape: Tuple[int, ...], name: str) -> None: + assert tensor.shape == expected_shape, ( + f"{name} must have shape {expected_shape}, got {tensor.shape}" + ) + + @staticmethod + def get_major_order(tensor: Tensor, dims: Tuple[str, str, str]) -> str: + # Tensor is already permuted to (dims[0], dims[1], dims[2]) + # stride(1) == 1 means dims[1] is contiguous (innermost) + return dims[1] if tensor.stride(1) == 1 else dims[0] + + @staticmethod + def create_cute_tensor( + tensor: Optional[Tensor], + major: Optional[str], + dims: Tuple[str, str, str], + assumed_align: int = 16, + ) -> Optional[cute.Tensor]: + if tensor is None: + return None + # Tensor is already permuted to (dims[0], dims[1], dims[2]) or (dim[0], dim[1]) + # If major is dims[1], leading_dim is 1; if major is dims[0], leading_dim is 0 + leading_dim = 1 if major == dims[1] else 0 + return from_dlpack(tensor.detach(), assumed_align=assumed_align).mark_layout_dynamic( + leading_dim=leading_dim + ) + + @staticmethod + def validate_and_prepare_tensors( + A: Tensor, + B: Tensor, + D: Optional[Tensor] = None, + C: Optional[Tensor] = None, + additional_tensors: Optional[Dict[str, Tensor]] = None, + cu_seqlens_m: Optional[Tensor] = None, + cu_seqlens_k: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, + ) -> Tuple[int, int, int, int, Dict[str, GemmTensorInfo]]: + assert not (cu_seqlens_m is not None and cu_seqlens_k is not None), ( + "Only one of cu_seqlens_m and cu_seqlens_k can be specified" + ) + assert B.dtype == A.dtype, "A and B must have the same dtype" + + # Validate A_idx if provided (for gather_A case) + gather_A = A_idx is not None + if gather_A: + assert cu_seqlens_m is not None or cu_seqlens_k is not None, ( + "gather_A requires either varlen_m or varlen_k" + ) + assert A_idx.dtype == torch.int32, f"A_idx must be int32, got {A_idx.dtype}" + assert A_idx.dim() == 1, f"A_idx must be 1D, got {A_idx.dim()}D" + + # Determine mode and extract dimensions + if cu_seqlens_m is not None: + # varlen_m: A is (total_m, k) or (whatever, k) if gather_A, B is (l, n, k), D/C are (total_m, n) + assert A.dim() == 2, f"A must be 2D when using varlen_m, got {A.dim()}D" + assert B.dim() == 3, f"B must be 3D with varlen_m, got {B.dim()}D" + + if gather_A: + # When gather_A, A can have any number of rows, we use A_idx.shape[0] as total_M + total_M = A_idx.shape[0] + _, K = A.shape + else: + total_M, K = A.shape + + L, N, K_B = B.shape + assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}" + assert cu_seqlens_m.shape == (L + 1,), ( + f"cu_seqlens_m must have shape ({L + 1},), got {cu_seqlens_m.shape}" + ) + M = total_M + dc_shape = (total_M, N) + dc_ndim = 2 + elif cu_seqlens_k is not None: + # varlen_k: A is (m, total_k) or (m, whatever) if gather_A, B is (n, total_k), D/C are (l, m, n) + assert A.dim() == 2, f"A must be 2D when using varlen_k, got {A.dim()}D" + assert B.dim() == 2, f"B must be 2D with varlen_k, got {B.dim()}D" + + if gather_A: + # When gather_A with varlen_k, A can have any number of columns, we use A_idx.shape[0] as total_K + M, _ = A.shape + total_K = A_idx.shape[0] + else: + M, total_K = A.shape + + N, K_B = B.shape + assert total_K == K_B, f"K dimension mismatch: expected {total_K}, B has {K_B}" + L = cu_seqlens_k.shape[0] - 1 + assert cu_seqlens_k.shape == (L + 1,), ( + f"cu_seqlens_k must have shape ({L + 1},), got {cu_seqlens_k.shape}" + ) + K = total_K + dc_shape = (L, M, N) + dc_ndim = 3 + else: + # Normal case - all tensors must be 3D + GemmWrapperBase.validate_tensor(A, "A", 3) + GemmWrapperBase.validate_tensor(B, "B", 3) + L, M, K = A.shape + _, N, K_B = B.shape + assert K == K_B, f"K dimension mismatch: A has {K}, B has {K_B}" + GemmWrapperBase.validate_shape(B, (L, N, K), "B") + dc_shape = (L, M, N) + dc_ndim = 3 + + # Validate D and C shapes uniformly + for tensor, name in [(D, "D"), (C, "C")]: + if tensor is not None: + assert tensor.dim() == dc_ndim, ( + f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D" + ) + assert tensor.shape == dc_shape, ( + f"{name} shape {tensor.shape} doesn't match expected {dc_shape}" + ) + + tensors = { + "A": GemmTensorInfo(A), + "B": GemmTensorInfo(B), + "D": GemmTensorInfo(D), + "C": GemmTensorInfo(C), + } + + if additional_tensors: + for name, tensor in additional_tensors.items(): + if tensor is not None: + assert tensor.dim() == dc_ndim, ( + f"{name} must be {dc_ndim}D for this mode, got {tensor.dim()}D" + ) + assert tensor.shape == dc_shape, ( + f"{name} shape {tensor.shape} doesn't match expected {dc_shape}" + ) + tensors[name] = GemmTensorInfo(tensor) + + return L, M, K, N, tensors + + @staticmethod + def permute_tensors( + tensors: Dict[str, GemmTensorInfo], varlen_m: bool = False, varlen_k: bool = False + ) -> None: + # Determine which tensors need permutation + if varlen_m: + # Only B needs permutation (3D tensor) + tensors_to_permute = ["B"] + elif varlen_k: + # Only D and C need permutation (3D tensors) + tensors_to_permute = ["D", "C"] + else: + # All tensors need permutation + tensors_to_permute = None + + # Apply permutation from (L, *, *) -> (*, *, L) for selected tensors + for name, info in tensors.items(): + if info.tensor is not None and info.tensor.ndim == 3: + if tensors_to_permute is None or name in tensors_to_permute: + info.tensor = info.tensor.permute(1, 2, 0) + + @staticmethod + def extract_dtypes(tensors: Dict[str, GemmTensorInfo]) -> None: + for name, info in tensors.items(): + if info.tensor is not None: + info.dtype = torch2cute_dtype_map[info.tensor.dtype] + + @staticmethod + def determine_major_orders( + tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]] + ) -> None: + for name, dims in major_configs.items(): + if name in tensors and tensors[name].tensor is not None: + tensors[name].major = GemmWrapperBase.get_major_order(tensors[name].tensor, dims) + + @staticmethod + def create_cute_tensors( + tensors: Dict[str, GemmTensorInfo], major_configs: Dict[str, Tuple[str, str, str]] + ) -> None: + for name, info in tensors.items(): + if info.tensor is not None and name in major_configs: + info.cute_tensor = GemmWrapperBase.create_cute_tensor( + info.tensor, info.major, major_configs[name] + ) + + @staticmethod + def create_scheduler_args( + max_active_clusters: int, + tile_count_semaphore: Optional[Tensor] = None, + batch_idx_permute: Optional[Tensor] = None, + max_swizzle_size: int = 8, + ) -> TileSchedulerOptions: + return TileSchedulerOptions( + Int32(max_active_clusters), + tile_count_semaphore=make_ptr( + Int32, tile_count_semaphore.data_ptr(), cute.AddressSpace.gmem, assumed_align=4 + ) + if tile_count_semaphore is not None + else None, + batch_idx_permute=( + from_dlpack(batch_idx_permute, assumed_align=4).mark_layout_dynamic(leading_dim=0) + ) + if batch_idx_permute is not None + else None, + max_swizzle_size=Int32(max_swizzle_size), + ) + + @staticmethod + def create_varlen_args( + cu_seqlens_m: Optional[Tensor], + cu_seqlens_k: Optional[Tensor], + A_idx: Optional[Tensor], + max_active_clusters: int, + cluster_shape_mnk: Tuple[int, int, int], + tensors: Dict[str, GemmTensorInfo], + num_epi_tensormaps: int = 0, + pingpong: bool = False, + ) -> Optional[Any]: + if cu_seqlens_m is None and cu_seqlens_k is None: + return None + # When varlen_m, we assume persistent=True + # Grid size depends on num_active_clusters and cluster size + cluster_size = cluster_shape_mnk[0] * cluster_shape_mnk[1] + num_blocks = max_active_clusters * cluster_size + # Calculate number of tensormaps needed + if cu_seqlens_m is not None: + # For varlen_m: need tensormaps for D and epilogue tensors + num_tensormaps = num_epi_tensormaps * (1 if not pingpong else 2) + if tensors["D"].tensor is not None: + num_tensormaps += 1 if not pingpong else 2 # D tensormap + else: + # For varlen_k: need tensormaps for A & B + num_tensormaps = 2 if A_idx is None else 1 + # Create tensormap buffer (each tensormap is 128 bytes = 16 int64s) + tensormap_size = 128 // 8 # 16 int64s + if num_tensormaps > 0: + device = cu_seqlens_m.device if cu_seqlens_m is not None else cu_seqlens_k.device + tensormaps = torch.empty( + (num_blocks, num_tensormaps, tensormap_size), + dtype=torch.int64, + device=device, + ) + tensormaps_cute = from_dlpack(tensormaps, assumed_align=128).mark_compact_shape_dynamic( + mode=0, stride_order=(0, 1, 2) + ) + else: + tensormaps_cute = None + + return VarlenArguments( + mCuSeqlensM=( + from_dlpack(cu_seqlens_m, assumed_align=4).mark_layout_dynamic(leading_dim=0) + if cu_seqlens_m is not None + else None + ), + mCuSeqlensK=( + from_dlpack(cu_seqlens_k, assumed_align=4).mark_layout_dynamic(leading_dim=0) + if cu_seqlens_k is not None + else None + ), + mTensormaps=tensormaps_cute, + mAIdx=( + from_dlpack(A_idx, assumed_align=4).mark_layout_dynamic(leading_dim=0) + if A_idx is not None + else None + ), + ) + + @staticmethod + def get_compile_key( + tensors: Dict[str, GemmTensorInfo], + activation: Optional[str], + tile_shape_mn: Tuple[int, int], + cluster_shape_mnk: Tuple[int, int, int], + pingpong: bool, + persistent: bool, + has_semaphore: bool, + *args, + key_tensor_names: Tuple[str, ...] = ("A", "B", "D", "C"), + ) -> Tuple: + key_parts = [] + for name in key_tensor_names: + if name in tensors: + key_parts.append(tensors[name].dtype) + key_parts.append(activation) + key_parts.extend([tile_shape_mn, cluster_shape_mnk]) + for name in key_tensor_names: + if name in tensors: + key_parts.append(tensors[name].major) + key_parts.extend([pingpong, persistent, has_semaphore]) + key_parts.extend(args) + return tuple(key_parts) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/layout_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/layout_utils.py new file mode 100644 index 00000000..522ed68c --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/layout_utils.py @@ -0,0 +1,295 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + + +import cutlass +import cutlass.cute as cute + +from cutlass import Int32, const_expr + +from .utils import prmt + + +def transpose_view(a: cute.Tensor) -> cute.Tensor: + """Transpose the first two dimensions of a tensor on smem.""" + shape = (a.shape[1], a.shape[0], *a.shape[2:]) + order = (1, 0, *range(2, cute.rank(a))) + return cute.composition(a, cute.make_ordered_layout(shape, order=order)) + + +def select(a: cute.Tensor, mode: list[int]) -> cute.Tensor: + return cute.make_tensor(a.iterator, cute.select(a.layout, mode)) + + +def expand(a: cute.Tensor, dim: int, size: Int32 | int) -> cute.Tensor: + shape = (*a.shape[:dim], size, *a.shape[dim:]) + stride = (*a.layout.stride[:dim], 0, *a.layout.stride[dim:]) + return cute.make_tensor(a.iterator, cute.make_layout(shape, stride=stride)) + + +@cute.jit +def permute_gated_Cregs_b16(t: cute.Tensor) -> None: + assert t.element_type.width == 16 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b16 permutation" + t_u32 = cute.recast_tensor(t, Int32) + + quad_idx = cute.arch.lane_idx() % 4 + lane_03 = quad_idx == 0 or quad_idx == 3 + selector_upper = Int32(0x5410) if lane_03 else Int32(0x1054) + selector_lower = Int32(0x7632) if lane_03 else Int32(0x3276) + # upper_map = [0, 3, 1, 2] + # lower_map = [1, 2, 0, 3] + # upper_idx = upper_map[quad_idx] + # indexing isn't supported so we have to do arithmetic + upper_idx = quad_idx // 2 if quad_idx % 2 == 0 else 3 - quad_idx // 2 + lower_idx = upper_idx ^ 1 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t_u32.shape) // 2, unroll_full=True): + upper, lower = t_u32[i * 2 + 0], t_u32[i * 2 + 1] + upper0 = upper if lane_03 else lower + lower0 = lower if lane_03 else upper + upper0 = cute.arch.shuffle_sync(upper0, offset=upper_idx, mask_and_clamp=mask_and_clamp) + lower0 = cute.arch.shuffle_sync(lower0, offset=lower_idx, mask_and_clamp=mask_and_clamp) + t_u32[i * 2 + 0] = prmt(upper0, lower0, selector_upper) + t_u32[i * 2 + 1] = prmt(upper0, lower0, selector_lower) + + +@cute.jit +def permute_Cregs_b32_for_stsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 + a b | c d | e f | g h + to + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + This is so that we can use STSM (instead of STS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [2, 0, 3, 1] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b10 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a b | c d | e f | g h -> a b | c d | f e | h g + left0 = left if quad_idx < 2 else right + right0 = right if quad_idx < 2 else left + # a b | c d | f e | h g -> a b | f d | c e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a e | f b | c g | h d + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a e | f b | c g | h d -> a e | b f | c g | d h + t[i * 4 + r * 2 + 0] = left0 if quad_idx % 2 == 0 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx % 2 == 0 else left0 + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + + +@cute.jit +def permute_Cregs_b32_for_ldsm(t: cute.Tensor) -> None: + """Permute and shuffle within 4 threads to change the layout from + T0 | T1 | T2 | T3 | T0 | T1 | T2 | T3 + a | b | c | d | e | f | g | h + to + T0 | T1 | T2 | T3 + a b | c d | e f | g h + This is so that we can use LDSM (instead of LDS.64) to store C registers without bank conflict. + """ + + assert t.element_type.width == 32 + assert cute.size(t.shape) % 4 == 0, "Tensor size must be a multiple of 4 for b32 permutation" + + quad_idx = cute.arch.lane_idx() % 4 + # left_map = [0, 2, 1, 3] + # right_map = [1, 3, 0, 2] + # indexing isn't supported so we have to do arithmetic + left_idx = quad_idx // 2 if quad_idx % 2 == 0 else 2 + quad_idx // 2 + right_idx = left_idx ^ 0b01 + + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + width = 4 + mask = cute.arch.WARP_SIZE - width + clamp = cute.arch.WARP_SIZE - 1 + mask_and_clamp = mask << 8 | clamp + + # This is just the inverse of permute_Cregs_b32_for_stsm + for i in cutlass.range(cute.size(t.shape) // 4, unroll_full=True): + t[i * 4 + 1], t[i * 4 + 2] = t[i * 4 + 2], t[i * 4 + 1] + for r in cutlass.range(2, unroll_full=True): + left, right = t[i * 4 + r * 2 + 0], t[i * 4 + r * 2 + 1] + # a e | b f | c g | d h -> a e | f b | c g | h d + left0 = left if quad_idx % 2 == 0 else right + right0 = right if quad_idx % 2 == 0 else left + # a e | f b | c g | h d -> a b | f d | c e | h g + right0 = cute.arch.shuffle_sync(right0, offset=right_idx, mask_and_clamp=mask_and_clamp) + # a b | f d | c e | h g -> a b | c d | f e | h g + left0 = cute.arch.shuffle_sync(left0, offset=left_idx, mask_and_clamp=mask_and_clamp) + # a b | c d | f e | h g -> a b | c d | e f | g h + t[i * 4 + r * 2 + 0] = left0 if quad_idx < 2 else right0 + t[i * 4 + r * 2 + 1] = right0 if quad_idx < 2 else left0 + + +@cute.jit +def concat_layout(*layouts: cute.Layout) -> cute.Layout: + return cute.make_layout( + tuple(l.shape for l in layouts), + stride=tuple(l.stride for l in layouts), + ) + + +def convert_layout_acc_mn(acc_layout: cute.Layout) -> cute.Layout: + """ + For Sm80, convert ((2, 2), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, MMA_N), ...). + For Sm90, convert ((2, 2, V), MMA_M, MMA_N, ...) to ((2, MMA_M), (2, V, MMA_N), ...). + """ + acc_layout_col_major = cute.make_layout(acc_layout.shape) + acc_layout_mn = cute.make_layout( + ( + (acc_layout_col_major.shape[0][1], acc_layout_col_major.shape[1]), # MMA_M + ( + acc_layout_col_major.shape[0][0], + *acc_layout_col_major.shape[0][2:], + acc_layout_col_major.shape[2], + ), # MMA_N + *acc_layout_col_major.shape[3:], + ), + stride=( + (acc_layout_col_major.stride[0][1], acc_layout_col_major.stride[1]), # MMA_M + ( + acc_layout_col_major.stride[0][0], + *acc_layout_col_major.stride[0][2:], + acc_layout_col_major.stride[2], + ), # MMA_N + *acc_layout_col_major.stride[3:], + ), + ) + return cute.composition(acc_layout, acc_layout_mn) + + +def make_acc_tensor_mn_view(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) + + +def reshape_acc_to_mn(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_mn(acc.layout)) + + +@cute.jit +def convert_layout_acc_frgA(acc_layout: cute.Layout) -> cute.Layout: + # For back to back gemm, convert layout of acc0 to gemm 1 accept layout. + # For Sm80, as the mma instruction shape is 16x8x16, we need to convert from (4, MMA_M, MMA_N) to ((4, 2), MMA_M, MMA_N / 2) + # For Sm90, FP16/BF16, convert acc_layout from ((2, 2, N / 8), MMA_M, MMA_N) to ((2, 2, 2), MMA_M, (N / 16, MMA_N)) + # TODO: Sm90 FP8 + if const_expr(cute.rank(acc_layout.shape[0]) == 3): # Sm90 + l = cute.logical_divide( + acc_layout, ((None, None, 2), None, None) + ) # ((2, 2, (2, N / 16)), MMA_M, MMA_N) + rA_mma_view = cute.make_layout( + ( + (l.shape[0][0], l.shape[0][1], l.shape[0][2][0]), + l.shape[1], + (l.shape[0][2][1], l.shape[2]), + ), + stride=( + (l.stride[0][0], l.stride[0][1], l.stride[0][2][0]), + l.stride[1], + (l.stride[0][2][1], l.stride[2]), + ), + ) + else: # Sm80 + # (4, MMA_M, MMA_N) -> (4, MMA_M, (2, MMA_N / 2)) + l = cute.logical_divide(acc_layout, (None, None, 2)) + rA_mma_view = cute.make_layout( + ( + (l.shape[0], l.shape[2][0]), + l.shape[1], + l.shape[2][1], + ), + stride=( + (l.stride[0], l.stride[2][0]), + l.stride[1], + l.stride[2][1], + ), + ) + return rA_mma_view + + +def reshape_acc_to_frgA(acc: cute.Tensor) -> cute.Tensor: + return cute.make_tensor(acc.iterator, convert_layout_acc_frgA(acc.layout)) + + +def convert_layout_zero_stride( + input: cute.Tensor | cute.Layout, ref_layout: cute.Layout +) -> cute.Layout: + layout = input.layout if const_expr(isinstance(input, cute.Tensor)) else input + # Group the modes with non-zero stride in the ref_layout together, + # and the modes with zero stride together + layout_flat = cute.flatten(layout) + ref_layout_flat = cute.flatten(ref_layout) + nonzero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride != 0] + zero_modes = [i for i in range(cute.rank(layout_flat)) if ref_layout_flat[i].stride == 0] + # There's an edge case when all modes are zero stride + new_shape = ( + tuple(layout_flat[i].shape for i in nonzero_modes) if len(nonzero_modes) > 0 else (1,), + tuple(layout_flat[i].shape for i in zero_modes), + ) + new_stride = ( + tuple(layout_flat[i].stride for i in nonzero_modes) if len(nonzero_modes) > 0 else (0,), + tuple(layout_flat[i].stride for i in zero_modes), + ) + out_layout = cute.make_layout(new_shape, stride=new_stride) + if const_expr(isinstance(input, cute.Tensor)): + return cute.make_tensor(input.iterator, out_layout) + else: + return out_layout + + +def mma_partition_C_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_C(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] + + +def mma_partition_A_vec( + sVec: cute.Tensor, thr_mma: cute.core.ThrMma, expand_shape: int, is_colvec: bool +) -> cute.Tensor: + assert cute.rank(sVec) == 2 + assert sVec.stride[0] == 1 + stage = sVec.shape[1] + shape = ( + (sVec.shape[0], expand_shape, stage) + if const_expr(is_colvec) + else (expand_shape, sVec.shape[0], stage) + ) + stride = (1, 0, sVec.stride[1]) if const_expr(is_colvec) else (0, 1, sVec.stride[1]) + sVec_mma = cute.make_tensor(sVec.iterator, cute.make_layout(shape, stride=stride)) + tC_sVec = make_acc_tensor_mn_view(thr_mma.partition_A(sVec_mma)) + return tC_sVec[None, 0, None] if const_expr(is_colvec) else tC_sVec[0, None, None] diff --git a/sonic-moe/torch-ext/sonicmoe/quack/pipeline.py b/sonic-moe/torch-ext/sonicmoe/quack/pipeline.py new file mode 100644 index 00000000..af915232 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/pipeline.py @@ -0,0 +1,324 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional +from dataclasses import dataclass + +import cutlass.cute as cute +from cutlass import Boolean, Int32, const_expr +from cutlass.cutlass_dsl import if_generate, and_, dsl_user_op +from cutlass.pipeline import MbarrierArray, CooperativeGroup, PipelineOp, pipeline_init_wait +from cutlass.pipeline import PipelineAsync, PipelineTmaAsync, PipelineState, PipelineUserType +from cutlass.pipeline import PipelineTmaUmma + + +class PipelineStateWAdvance(PipelineState): + @dsl_user_op + def advance_iters(self, num_iterations: Int32, *, loc=None, ip=None): + self._count += Int32(num_iterations) + new_index = self._index + Int32(num_iterations) + # How many times did we cross the stages boundary + num_crossings = new_index // self.stages + self._phase ^= num_crossings + self._index = new_index % self.stages + + # This can be overridden by derived classes + def __new_from_mlir_values__(self, values): + return PipelineStateWAdvance( + self.stages, Int32(values[0]), Int32(values[1]), Int32(values[2]) + ) + + +def make_pipeline_state(type: PipelineUserType, stages: int): + """ + Creates a pipeline state. Producers are assumed to start with an empty buffer and have a flipped phase bit of 1. + """ + if type is PipelineUserType.Producer: + return PipelineStateWAdvance( + stages, + Int32(0), + Int32(0), + Int32(1), + ) + elif type is PipelineUserType.Consumer: + return PipelineStateWAdvance( + stages, + Int32(0), + Int32(0), + Int32(0), + ) + else: + assert False, "Error: invalid PipelineUserType specified for make_pipeline_state." + + +@dataclass(frozen=True) +class PipelineTmaCpAsync(PipelineTmaAsync): + """ + PipelineTmaCpAsync is used for CpAsync + TMA producers and AsyncThread consumers + """ + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + tidx: Optional[Int32] = None, + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaAsync. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: CooperativeGroup for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: CooperativeGroup for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param tidx: thread index to consumer async threads + :type tidx: Int32 | None + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.AsyncThread + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8), num_stages, producer, tx_count + ) + sync_object_empty = PipelineAsync._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + if tidx is None: + tidx, _, _ = cute.arch.thread_idx() + if cta_layout_vmnk is None: + cta_layout_vmnk = cute.make_layout((1, 1, 1, 1)) + ( + dst_rank, + is_signalling_thread, + ) = PipelineTmaAsync.init_empty_barrier_arrive_signal(cta_layout_vmnk, tidx) + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + dst_rank = None + else: + dst_rank = dst_rank + + producer_mask = None + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaCpAsync( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + dst_rank, + is_signalling_thread, + ) + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + is_tma_warp: Optional[Boolean] = True, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the transaction barrier. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + ) + # This is the difference between this and PipelineTmaAsync: we could have multiple + # warps calling this, but only 1 warp should do the arrive on the full barrier + if_generate( + is_tma_warp, + lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip), + ) + + @dsl_user_op + def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None): + """ + We need the mbarrier to track the completion of cp.async + """ + cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip) + + +class MbarrierArrayWDropCount(MbarrierArray): + @dsl_user_op + def __init__( + self, + barrier_storage: cute.Pointer, + num_stages: int, + agent: tuple[PipelineOp, CooperativeGroup], + tx_count: int = 0, + drop_count: Optional[Int32] = None, + *, + loc=None, + ip=None, + ) -> None: + self.barrier_storage = barrier_storage + self.tx_count = tx_count + self.num_stages = num_stages + self.op_type, self.cg = agent + self.arrive_count = self.cg.size + self.drop_count = drop_count + + if self.num_stages <= 0: + raise ValueError("Error: Mbarrier stage count must be greater than 0.") + if self.arrive_count <= 0: + raise ValueError("Error: Mbarrier arrive count must be greater than 0.") + if self.op_type is PipelineOp.TmaLoad and self.tx_count < 0: + raise ValueError("Error: Mbarrier tx count must not be less than 0 for TMA ops.") + + if const_expr(drop_count is not None): + self.arrive_count = self.arrive_count - drop_count + + # Store mbarrier base pointer + self.mbarrier_base = self.barrier_storage + + # Mbarrier initialization in constructor + self.mbarrier_init(loc=loc, ip=ip) + + def __extract_mlir_values__(self): + return [self.barrier_storage, self.drop_count] + + def __new_from_mlir_values__(self, values): + return MbarrierArrayWDropCount( + values[0], self.num_stages, (self.op_type, self.cg), self.tx_count, values[1] + ) + + +@dataclass(frozen=True) +class PipelineTmaCpAsyncUmma(PipelineTmaUmma): + """ + PipelineTmaCpAsync is used for CpAsync + TMA producers and UMMA consumers + (e.g. Blackwell mainloops) + """ + + @staticmethod + def create( + *, + num_stages: int, + producer_group: CooperativeGroup, + consumer_group: CooperativeGroup, + tx_count: int, + barrier_storage: cute.Pointer = None, + cta_layout_vmnk: Optional[cute.Layout] = None, + producer_drop_count: Optional[Int32] = None, + mcast_mode_mn: tuple[int, int] = (1, 1), + ): + """ + This helper function computes any necessary attributes and returns an instance of PipelineTmaUmma. + :param barrier_storage: Pointer to the smem address for this pipeline's mbarriers + :type barrier_storage: cute.Pointer + :param num_stages: Number of buffer stages for this pipeline + :type num_stages: Int32 + :param producer_group: `CooperativeGroup` for the producer agent + :type producer_group: CooperativeGroup + :param consumer_group: `CooperativeGroup` for the consumer agent + :type consumer_group: CooperativeGroup + :param tx_count: Number of bytes expected to be written to the transaction barrier for one stage + :type tx_count: int + :param cta_layout_vmnk: Layout of the cluster shape + :type cta_layout_vmnk: cute.Layout | None + :param mcast_mode_mn: Tuple specifying multicast modes for m and n dimensions (each 0 or 1) + :type mcast_mode_mn: tuple[int, int], optional + """ + if not isinstance(barrier_storage, cute.Pointer): + raise ValueError( + f"Expected barrier_storage to be a cute.Pointer, but got {type(barrier_storage)}" + ) + + producer_type = PipelineOp.TmaLoad + consumer_type = PipelineOp.TCGen05Mma + + producer = (producer_type, producer_group) + consumer = (consumer_type, consumer_group) + + sync_object_full = MbarrierArrayWDropCount( + barrier_storage.align(min_align=8), + num_stages, + producer, + tx_count, + drop_count=producer_drop_count, + ) + sync_object_empty = PipelineTmaUmma._make_sync_object( + barrier_storage.align(min_align=8) + num_stages, num_stages, consumer + ) + + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk) == 1: + # No mcast mask if not using clusters + producer_mask = None + # All threadblocks are leaders if not using clusters + is_leader_cta = True + else: + producer_mask = PipelineTmaUmma._compute_mcast_arrival_mask(cta_layout_vmnk, mcast_mode_mn) + is_leader_cta = PipelineTmaUmma._compute_is_leader_cta(cta_layout_vmnk) + + cta_group = ( + cute.nvgpu.tcgen05.CtaGroup.ONE + if cta_layout_vmnk is None or cute.size(cta_layout_vmnk, mode=[0]) == 1 + else cute.nvgpu.tcgen05.CtaGroup.TWO + ) + + consumer_mask = producer_mask + + pipeline_init_wait(cta_layout_vmnk) + + return PipelineTmaCpAsyncUmma( + sync_object_full, + sync_object_empty, + num_stages, + producer_mask, + consumer_mask, + is_leader_cta, + cta_group, + ) + + @dsl_user_op + def producer_acquire( + self, + state: PipelineState, + try_acquire_token: Optional[Boolean] = None, + is_tma_warp: Optional[Boolean] = True, + *, + loc=None, + ip=None, + ): + """ + TMA producer commit conditionally waits on buffer empty and sets the + transaction barrier for leader threadblocks. + """ + if_generate( + try_acquire_token is None or try_acquire_token == 0, + lambda: self.sync_object_empty.wait(state.index, state.phase, loc=loc, ip=ip), + ) + # This is the difference between this and PipelineTmaAsync: we could have multiple + # warps calling this, but only 1 warp should do the arrive on the full barrier + if_generate( + and_(self.is_leader_cta, is_tma_warp), + lambda: self.sync_object_full.arrive(state.index, self.producer_mask, loc=loc, ip=ip), + ) + + @dsl_user_op + def producer_cpasync_commit(self, state: PipelineState, *, loc=None, ip=None): + """ + We need the mbarrier to track the completion of cp.async + """ + cute.arch.cp_async_mbarrier_arrive_noinc(self.producer_get_barrier(state, loc=loc, ip=ip), loc=loc, ip=ip) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/reduce.py b/sonic-moe/torch-ext/sonicmoe/quack/reduce.py new file mode 100644 index 00000000..08125d40 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/reduce.py @@ -0,0 +1,279 @@ +# Copyright (c) 2025, Tri Dao. + +import math +import operator +from typing import Callable, Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float32, Boolean, const_expr + +from . import utils as utils + + +@cute.jit +def block_reduce( + val: cute.Numeric, op: Callable, reduction_buffer: cute.Tensor, init_val: cute.Numeric = 0.0 +) -> cute.Numeric: + """reduction_buffer has shape (num_warps / warp_per_row, warps_per_row)""" + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + warps_per_row = cute.size(reduction_buffer.shape[1]) + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = val + cute.arch.barrier() + block_reduce_val = init_val + if lane_idx < warps_per_row: + block_reduce_val = reduction_buffer[row_idx, lane_idx] + return cute.arch.warp_reduction(block_reduce_val, op) + + +@cute.jit +def cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: cute.Pointer, + init_val: cute.Numeric = 0.0, + phase: Optional[Int32] = None, +) -> cute.Numeric: + """reduction_buffer has shape (num_warps / warps_per_row, (warps_per_row, cluster_n))""" + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * reduction_buffer.element_type.width // 8, + ) + if lane_idx < cluster_n: + utils.store_shared_remote( + val, + utils.elem_pointer(reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster))), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + block_reduce_val = init_val + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + block_reduce_val = op(block_reduce_val, reduction_buffer[row_idx, idx]) + return cute.arch.warp_reduction(block_reduce_val, op) + + +@cute.jit +def block_or_cluster_reduce( + val: cute.Numeric, + op: Callable, + reduction_buffer: cute.Tensor, + mbar_ptr: Optional[cute.Pointer], + phase: Optional[Int32] = None, + init_val: cute.Numeric = 0.0, +) -> cute.Numeric: + """Perform either block or cluster reduction based on whether mbar_ptr is provided.""" + if const_expr(mbar_ptr is None): + return block_reduce(val, op, reduction_buffer, init_val=init_val) + else: + return cluster_reduce(val, op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val) + + +@cute.jit +def row_reduce( + x: cute.TensorSSA | cute.Numeric, + op: cute.ReductionOp, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + phase: Optional[Int32] = None, + init_val: cute.Numeric = 0.0, + hook_fn: Optional[Callable] = None, +) -> cute.Numeric: + """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n))""" + if const_expr(isinstance(x, cute.TensorSSA)): + val = x.reduce(op, init_val=init_val, reduction_profile=0) + else: + val = x + warp_op = { + cute.ReductionOp.ADD: operator.add, + cute.ReductionOp.MAX: cute.arch.fmax if const_expr(x.dtype == Float32) else max, + cute.ReductionOp.MIN: min, + cute.ReductionOp.MUL: operator.mul, + }[op] + val = cute.arch.warp_reduction( + val, + warp_op, + threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if const_expr(hook_fn is not None): + hook_fn() + if const_expr(reduction_buffer is not None): + warps_per_row, cluster_n = reduction_buffer.shape[1] + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if const_expr(warps_per_row > 1 or cluster_n > 1): + val = block_or_cluster_reduce( + val, warp_op, reduction_buffer, mbar_ptr, phase=phase, init_val=init_val + ) + return val + + +@cute.jit +def online_softmax_reduce( + x: cute.TensorSSA, + threads_per_row: cutlass.Constexpr[int], + reduction_buffer: Optional[cute.Tensor] = None, + mbar_ptr: Optional[cute.Pointer] = None, + hook_fn: Optional[Callable] = None, + phase: Optional[Int32] = None, + return_exp_x: bool = False, +) -> [Float32, Float32, Optional[cute.TensorSSA]]: + assert x.dtype == Float32, "x must be of type Float32" + """reduction_buffer must have shape (num_warps / warps_per_row, (warps_per_row, cluster_n), 2)""" + max_x = cute.arch.warp_reduction( + x.reduce(cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0), + cute.arch.fmax, + threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE), + ) + log2_e = math.log2(math.e) + exp_x = cute.math.exp2(x * log2_e - (max_x * log2_e), fastmath=True) + sum_exp_x = cute.arch.warp_reduction( + exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0), + operator.add, + threads_in_group=min(threads_per_row, cute.arch.WARP_SIZE), + ) + if const_expr(hook_fn is not None): + hook_fn() + if const_expr(reduction_buffer is not None): + rows_per_block, (warps_per_row, cluster_n) = reduction_buffer.shape + assert cluster_n == 1 or mbar_ptr is not None, ( + "mbar_ptr must be provided for cluster reduction" + ) + if const_expr(warps_per_row > 1 or cluster_n > 1): + assert reduction_buffer.element_type == Int64, ( + "reduction_buffer must be of type cute.Int64" + ) + lane_idx, warp_idx = cute.arch.lane_idx(), cute.arch.warp_idx() + row_idx, col_idx = warp_idx // warps_per_row, warp_idx % warps_per_row + if const_expr(mbar_ptr is None): + if lane_idx == 0: + reduction_buffer[row_idx, col_idx] = utils.f32x2_to_i64(max_x, sum_exp_x) + cute.arch.barrier() + max_x_single_warp = -Float32.inf + sum_exp_x = 0.0 + if lane_idx < warps_per_row: + max_x_single_warp, sum_exp_x = utils.i64_to_f32x2( + reduction_buffer[row_idx, lane_idx] + ) + max_x_final = cute.arch.warp_reduction(max_x_single_warp, cute.arch.fmax) + sum_exp_x *= cute.math.exp(max_x_single_warp - max_x_final, fastmath=True) + sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add) + if const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + else: + cta_rank_in_cluster = cute.arch.block_idx_in_cluster() + if warp_idx == 0: + with cute.arch.elect_one(): + num_warps = rows_per_block * warps_per_row + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, + num_warps * cluster_n * reduction_buffer.element_type.width // 8, + ) + if lane_idx < cluster_n: + utils.store_shared_remote( + utils.f32x2_to_i64(max_x, sum_exp_x), + utils.elem_pointer( + reduction_buffer, (row_idx, (col_idx, cta_rank_in_cluster)) + ), + mbar_ptr, + peer_cta_rank_in_cluster=lane_idx, + ) + cute.arch.mbarrier_wait(mbar_ptr, phase=phase if phase is not None else 0) + num_iter = cute.ceil_div(warps_per_row * cluster_n, cute.arch.WARP_SIZE) + max_x_single_warp = cute.make_fragment(num_iter, Float32) + max_x_single_warp.fill(-Float32.inf) + sum_exp_x_single_warp = cute.make_fragment(num_iter, Float32) + sum_exp_x_single_warp.fill(0.0) + for i in cutlass.range_constexpr(num_iter): + idx = lane_idx + i * cute.arch.WARP_SIZE + if idx < cute.size(reduction_buffer, mode=[1]): + max_x_single_warp[i], sum_exp_x_single_warp[i] = utils.i64_to_f32x2( + reduction_buffer[row_idx, idx] + ) + max_x_final = max_x_single_warp.load().reduce( + cute.ReductionOp.MAX, init_val=-Float32.inf, reduction_profile=0 + ) + max_x_final = cute.arch.warp_reduction(max_x_final, cute.arch.fmax) + sum_exp_x = 0.0 + for i in cutlass.range_constexpr(num_iter): + sum_exp_x += sum_exp_x_single_warp[i] * cute.math.exp( + max_x_single_warp[i] - max_x_final, fastmath=True + ) + sum_exp_x = cute.arch.warp_reduction(sum_exp_x, operator.add) + if const_expr(return_exp_x): + exp_x *= cute.math.exp(max_x - max_x_final, fastmath=True) + max_x = max_x_final + return max_x, sum_exp_x, (exp_x if const_expr(return_exp_x) else None) + + +@cute.jit +def sum_swap_shuffle( + X: cute.Tensor, elem_per_lane: int = 1, subwarp_size: int = 1, warp_size: int = 32 +) -> cute.Tensor: + """ + For warp reduction, we use Swap Shuffle + The normal way to reduction among threads: + use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. + After each step of reduction, a half of threads won't work in the following steps. + That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). + To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, + we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. + After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. + We can recursively do this until the problem size is 1. + """ + assert ( + subwarp_size >= 1 + and subwarp_size <= 32 + and subwarp_size == 1 << int(math.log2(subwarp_size)) + ) + assert ( + warp_size <= 32 + and warp_size % subwarp_size == 0 + and warp_size == 1 << int(math.log2(warp_size)) + ) + lane_idx = cute.arch.lane_idx() // subwarp_size + X = cute.logical_divide(X, cute.make_layout(elem_per_lane)) # (elem_per_lane, M) + numvec = cute.size(X, mode=[1]) + assert numvec <= 32 // subwarp_size + # If X has more values than warp_size // subwarp_size, we first do a normal warp reduction + # to sum up values held by lanes further than size(X) away + for i in cutlass.range( + int(math.log2(numvec)), int(math.log2(warp_size // subwarp_size)), unroll_full=True + ): + for v in cutlass.range(cute.size(X), unroll_full=True): + shfl_val = cute.arch.shuffle_sync_bfly(X[v], offset=(1 << i) * subwarp_size) + X[v] = X[v] + shfl_val + for logm in cutlass.range_constexpr(int(math.log2(cute.size(X, mode=[1]))) - 1, -1, -1): + m = 1 << logm + for r in cutlass.range(m, unroll_full=True): + frg_A = X[None, r] + frg_B = X[None, r + m] + # First half of threads swap fragments from the first half of data to the second + should_swap = not Boolean(lane_idx & m) + for v in cutlass.range(cute.size(frg_A), unroll_full=True): + # Step 1: swap + lower, upper = frg_A[v], frg_B[v] + frg_A[v] = upper if should_swap else lower + frg_B[v] = lower if should_swap else upper + # Step 2: shuffle + # each half of threads get a half of data from the other half of threads + shfl_val = cute.arch.shuffle_sync_bfly(frg_A[v], offset=m * subwarp_size) + # Step 3: reduction + frg_A[v] = frg_B[v] + shfl_val + return X[None, 0] diff --git a/sonic-moe/torch-ext/sonicmoe/quack/reduction_base.py b/sonic-moe/torch-ext/sonicmoe/quack/reduction_base.py new file mode 100644 index 00000000..9139f512 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/reduction_base.py @@ -0,0 +1,83 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +from typing import Type, Tuple, Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Int64, Float32, const_expr + +from . import copy_utils as copy_utils + + +class ReductionBase: + def __init__(self, dtype: Type[cutlass.Numeric], N: int, stage: int, reduction_dtype=Float32): + self.dtype = dtype + self.N = N + self.stage = stage + self.reduction_dtype = reduction_dtype + + def _threads_per_row(self): + raise NotImplementedError() + + def _num_threads(self): + return 128 if self.N <= 16384 else 256 + + def _set_cluster_n(self): + self.cluster_n = 1 + + def _get_tiled_copy(self, vecsize: int = 1): + assert self.N % vecsize == 0, f"Input N {self.N} is not divisible by vector size {vecsize}" + threads_per_row = self._threads_per_row() + num_threads = self._num_threads() + assert num_threads % cute.arch.WARP_SIZE == 0 + num_blocks_N = cute.ceil_div(self.N // vecsize, threads_per_row * self.cluster_n) + tiler_mn = (num_threads // threads_per_row, vecsize * num_blocks_N * threads_per_row) + tiled_copy = copy_utils.tiled_copy_2d(self.dtype, threads_per_row, num_threads, vecsize) + return tiled_copy, tiler_mn, threads_per_row + + def _get_reduction_buffer_layout(self, tv_layout: cute.Layout, cluster_n: int): + num_warps = cute.size(tv_layout, mode=[0]) // cute.arch.WARP_SIZE + warps_per_row = ( + num_warps + if cute.rank(tv_layout.shape[0]) == 1 + else max(tv_layout.shape[0][0] // cute.arch.WARP_SIZE, 1) + ) + return cute.make_ordered_layout( + (num_warps // warps_per_row, (warps_per_row, cluster_n), self.stage), + order=(1, 0, 2), + ) + + def _allocate_reduction_buffer_and_mbar( + self, smem: cutlass.utils.SmemAllocator, tv_layout: cute.Layout, is_persistent: bool = False + ) -> Tuple[cute.Tensor, Optional[cute.Pointer]]: + reduction_buffer = smem.allocate_tensor( + self.reduction_dtype, + self._get_reduction_buffer_layout(tv_layout, self.cluster_n), + byte_alignment=8, + ) + if const_expr(self.cluster_n > 1): + mbar_ptr = smem.allocate_array( + Int64, num_elems=self.stage if not is_persistent else self.stage * 2 + ) + else: + mbar_ptr = None + return reduction_buffer, mbar_ptr + + @cute.jit + def _initialize_cluster( + self, + tidx: Int32, + mbar_ptr: cute.Pointer, + num_warps: int, + is_persistent: bool = False, + ): + if const_expr(self.cluster_n > 1): + if tidx < self.stage: # Initialize full barrier + cute.arch.mbarrier_init(mbar_ptr + tidx, 1) + if const_expr(is_persistent): # Initialize empty barrier + cute.arch.mbarrier_init( + mbar_ptr + self.stage + tidx, num_warps * self.cluster_n + ) + cute.arch.mbarrier_init_fence() + # Cluster arrive after barrier init + cute.arch.cluster_arrive_relaxed() diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sm100_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/sm100_utils.py new file mode 100644 index 00000000..2c12a38b --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sm100_utils.py @@ -0,0 +1,62 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Type, Union + +import cutlass.cute as cute +import cutlass.utils.blackwell_helpers as sm100_utils_og +from cutlass.cute.nvgpu.tcgen05 import OperandMajorMode +from cutlass.cutlass_dsl import Numeric, dsl_user_op + + +@dsl_user_op +def make_smem_layout_cpasync_a( + tiled_mma: cute.TiledMma, + mma_tiler_mnk: cute.Tile, + a_dtype: Type[Numeric], + num_stages: int, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + """ + :param tiled_mma: The tiled MMA used to partition tensor A + :type tiled_mma: cute.TiledMma + :param mma_tiler_mnk: The MMA tile shape + :type mma_tiler_mnk: cute.cute.Tile + :param a_dtype: The element type for tensor A + :type a_dtype: Type[Numeric] + :param num_stages: The number of pipeline stages for tensor A + :type num_stages: int + + :return: SMEM layout for tensor A + :rtype: Union[cute.Layout, cute.ComposedLayout] + """ + + is_k_major = tiled_mma.op.a_major_mode == OperandMajorMode.K + a_smem_shape = tiled_mma.partition_shape_A( + cute.dice(mma_tiler_mnk, (1, None, 1), loc=loc, ip=ip) + ) + a_smem_shape_mn_k = ( + cute.size(a_smem_shape[0][0], loc=loc, ip=ip) * a_smem_shape[1], + cute.size(a_smem_shape[0][1], loc=loc, ip=ip) * a_smem_shape[2], + ) + a_smem_layout_atom = sm100_utils_og.make_smem_layout_atom( + sm100_utils_og.get_smem_layout_atom_ab( + tiled_mma.op.a_major_mode, + a_dtype, + a_smem_shape_mn_k, + loc=loc, + ip=ip, + ), + a_dtype, + loc=loc, + ip=ip, + ) + a_smem_layout_staged = cute.tile_to_shape( + a_smem_layout_atom, + cute.append(a_smem_shape_mn_k, num_stages, loc=loc, ip=ip), + order=((1, 0, 2) if not is_k_major else (0, 1, 2)), + loc=loc, + ip=ip, + ) + return a_smem_layout_staged diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sm90_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/sm90_utils.py new file mode 100644 index 00000000..659ae2a9 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sm90_utils.py @@ -0,0 +1,157 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Type, Union, Optional + +import cutlass +import cutlass.cute as cute +import cutlass.utils.hopper_helpers as sm90_utils_og +from cutlass.cute.nvgpu import warpgroup +from cutlass.cutlass_dsl import Numeric, dsl_user_op +from cutlass import Float32, Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum + + +@dsl_user_op +def make_smem_layout( + dtype: Type[Numeric], + layout: LayoutEnum, + tile: cute.Tile, + stage: Optional[int] = None, + *, + loc=None, + ip=None, +) -> Union[cute.Layout, cute.ComposedLayout]: + shape = cute.product_each(cute.shape(tile, loc=loc, ip=ip), loc=loc, ip=ip) + major_mode_size = shape[1] if layout.is_n_major_c() else shape[0] + smem_layout_atom = warpgroup.make_smem_layout_atom( + sm90_utils_og.get_smem_layout_atom(layout, dtype, major_mode_size), + dtype, + ) + order = (1, 0, 2) if const_expr(layout.is_m_major_c()) else (0, 1, 2) + smem_layout_staged = cute.tile_to_shape( + smem_layout_atom, + cute.append(shape, stage) if const_expr(stage is not None) else shape, + order=order if const_expr(stage is not None) else order[:2], + ) + return smem_layout_staged + + +# For compatibility with blackwell_helpers.py +make_smem_layout_epi = make_smem_layout + + +@dsl_user_op +def partition_for_epilogue( + cT: cute.Tensor, + epi_tile: cute.Tile, + tiled_copy: cute.TiledCopy, + tidx: Int32, + reference_src: bool, # do register tensors reference the src or dst layout of the tiled copy + *, + loc=None, + ip=None, +) -> cute.Tensor: + thr_copy = tiled_copy.get_slice(tidx) + cT_epi = cute.flat_divide(cT, epi_tile) + # (CPY, CPY_M, CPY_N, EPI_M, EPI_N) + if const_expr(reference_src): + return thr_copy.partition_S(cT_epi, loc=loc, ip=ip) + else: + return thr_copy.partition_D(cT_epi, loc=loc, ip=ip) + + +@cute.jit +def gemm( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: cutlass.Constexpr[bool] = False, + wg_wait: cutlass.Constexpr[int] = 0, + # A_in_regs: cutlass.Constexpr[bool] = False, + swap_AB: cutlass.Constexpr[bool] = False, +) -> None: + if const_expr(swap_AB): + gemm(tiled_mma, acc, tCrB, tCrA, zero_init=zero_init, wg_wait=wg_wait, swap_AB=False) + else: + warpgroup.fence() + # We make a new mma_atom since we'll be modifying its attribute (accumulate). + # Otherwise the compiler complains "operand #0 does not dominate this use" + mma_atom = cute.make_mma_atom(tiled_mma.op) + mma_atom.set(warpgroup.Field.ACCUMULATE, not zero_init) + for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])): + cute.gemm(mma_atom, acc, tCrA[None, None, k], tCrB[None, None, k], acc) + mma_atom.set(warpgroup.Field.ACCUMULATE, True) + warpgroup.commit_group() + if const_expr(wg_wait >= 0): + warpgroup.wait_group(wg_wait) + + +def gemm_zero_init( + tiled_mma: cute.TiledMma, + shape: cute.Shape, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> cute.Tensor: + if const_expr(swap_AB): + return gemm_zero_init( + tiled_mma, shape[::-1], tCrB, tCrA, B_idx, A_idx, wg_wait, swap_AB=False + ) + else: + acc = cute.make_fragment(tiled_mma.partition_shape_C(shape), Float32) + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=True, wg_wait=wg_wait) + return acc + + +def gemm_w_idx( + tiled_mma: cute.TiledMma, + acc: cute.Tensor, + tCrA: cute.Tensor, + tCrB: cute.Tensor, + zero_init: Boolean, + A_idx: Optional[Int32] = None, + B_idx: Optional[Int32] = None, + wg_wait: int = -1, + swap_AB: bool = False, +) -> None: + if const_expr(swap_AB): + gemm_w_idx(tiled_mma, acc, tCrB, tCrA, zero_init, B_idx, A_idx, wg_wait, swap_AB=False) + else: + rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx] + rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx] + gemm(tiled_mma, acc, rA, rB, zero_init=zero_init, wg_wait=wg_wait) + + +def partition_fragment_ABC( + thr_mma: cute.ThrMma, + shape_mnk: cute.Shape, + sA: Optional[cute.Tensor], + sB: Optional[cute.Tensor], + swap_AB: bool = False, +): + is_rs = thr_mma.op.a_src == warpgroup.OperandSource.RMEM + if const_expr(not swap_AB): + acc = cute.make_fragment(thr_mma.partition_shape_C(shape_mnk[:2]), Float32) + if const_expr(not is_rs): + assert sA is not None + tCrA = thr_mma.make_fragment_A(thr_mma.partition_A(sA)) + else: + tCrA = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[0], shape_mnk[2]))) + assert sB is not None + tCrB = thr_mma.make_fragment_B(thr_mma.partition_B(sB)) + else: + acc = cute.make_fragment(thr_mma.partition_shape_C((shape_mnk[1], shape_mnk[0])), Float32) + if const_expr(not is_rs): + assert sB is not None + tCrB = thr_mma.make_fragment_A(thr_mma.partition_A(sB)) + else: # B in rmem + tCrB = thr_mma.make_fragment_A(thr_mma.partition_shape_A((shape_mnk[1], shape_mnk[2]))) + assert sA is not None + tCrA = thr_mma.make_fragment_B(thr_mma.partition_B(sA)) + return acc, tCrA, tCrB diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sort/__init__.py b/sonic-moe/torch-ext/sonicmoe/quack/sort/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sort/__init__.py @@ -0,0 +1 @@ + diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sort/bitonic_sort.py b/sonic-moe/torch-ext/sonicmoe/quack/sort/bitonic_sort.py new file mode 100644 index 00000000..c93463ea --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sort/bitonic_sort.py @@ -0,0 +1,129 @@ +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. + +import math +from typing import Optional + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, const_expr + +from .. import utils +from .utils import compare_and_swap +from .sorting_networks import optimal_sort + + +@cute.jit +def bitonic_merge( + arr: cute.Tensor, + n: Optional[cutlass.Constexpr[int]] = None, + start: cutlass.Constexpr[int] = 0, + ascending: cutlass.Constexpr[bool] = True, +) -> None: + """Merge a bitonic sequence into a sorted sequence using iterative approach.""" + if const_expr(n is None): + n = cute.size(arr.shape) + if const_expr(n > 1): + num_levels = int(math.log2(n)) + assert n == 2**num_levels, "n must be a power of 2" + # This one must be range_constexpr otherwise it's very slow for n = 128 + for level in cutlass.range_constexpr(num_levels): + length = n >> level # n // (2^level) + step = length // 2 + for i in cutlass.range(n // length, unroll_full=True): + start_i = start + i * length + for j in cutlass.range(step, unroll_full=True): + compare_and_swap(arr, start_i + j, start_i + j + step, ascending) + + +@cute.jit +def bitonic_sort( + arr: cute.Tensor, + n: Optional[cutlass.Constexpr[int]] = None, + start: cutlass.Constexpr[int] = 0, + ascending: cutlass.Constexpr[bool] = True, +) -> None: + """ + Bitonic sort for small arrays of size N (power of 2, N <= 128). + + Args: + arr: Array to sort + n: Size of array (must be power of 2 and <= 128) + start: Starting index (default 0) + ascending: Sort in ascending order (default True) + """ + if const_expr(n is None): + n = cute.size(arr.shape) + assert n <= 128 + if const_expr(n > 1): + if const_expr(n in [2, 4, 8, 16, 32, 64]): + optimal_sort(arr, n, start, ascending) + else: # Fall back to bitonic sort + assert n % 2 == 0 + # Sort first half in ascending order + bitonic_sort(arr, n // 2, start, True) + # Sort second half in descending order + bitonic_sort(arr, n // 2, start + n // 2, False) + # Merge the whole sequence + bitonic_merge(arr, n, start, ascending) + + +@cute.jit +def bitonic_topk_merge( + arr0: cute.Tensor, + arr1: cute.Tensor, + k: Optional[cutlass.Constexpr[int]] = None, + start0: cutlass.Constexpr[int] = 0, + start1: cutlass.Constexpr[int] = 0, + ascending: cutlass.Constexpr[bool] = False, +) -> None: + if const_expr(k is None): + k = cute.size(arr0.shape) + if const_expr(arr0.element_type == Float32): + minmax_fn = utils.fmin if ascending else cute.arch.fmax + else: + minmax_fn = min if ascending else max + # Write the top k elements to the first half of the array + for i in cutlass.range(k, unroll_full=True): + arr0[start0 + i] = minmax_fn(arr0[start0 + i], arr1[start1 + k - 1 - i]) + # Now the 1st half is bitonic, we just need to merge it + bitonic_merge(arr0, k, start0, ascending) + + +@cute.jit +def bitonic_topk( + arr: cute.Tensor, + k: cutlass.Constexpr[int], + ascending: cutlass.Constexpr[bool] = False, + warp_width: cutlass.Constexpr[int] = cute.arch.WARP_SIZE, +) -> cute.Tensor: + """ + Bitonic top-k for small arrays of size N (power of 2, N <= 128). + + Args: + arr: Array to sort + k: must be power of 2 and <= 128 + ascending: Sort in ascending order (default False) + """ + assert arr.element_type in [Float32, Int32] + n = cute.size(arr.shape) + assert k == 1 << int(math.log2(k)), "k must be a power of 2" + assert n % k == 0, "n must be divisible by k" + topk_vals = cute.make_fragment(k, arr.element_type) + for v in cutlass.range(k, unroll_full=True): + topk_vals[v] = arr[v] + bitonic_sort(topk_vals, ascending=ascending) + for i in cutlass.range(1, n // k, unroll_full=True): + other_vals = cute.make_fragment(k, arr.element_type) + for v in cutlass.range(k, unroll_full=True): + other_vals[v] = arr[i * k + v] + bitonic_sort(other_vals, ascending=ascending) + # Merge 2 sorted top-k sequences to get a new top-k sequence + bitonic_topk_merge(topk_vals, other_vals, ascending=ascending) + # TODO: this is not efficient for large k (e.g. >= 16) since threads in the same warps + # do duplicate work. + for i in cutlass.range(int(math.log2(warp_width)), unroll_full=True): + other_vals = cute.make_fragment(k, arr.element_type) + for v in cutlass.range(k, unroll_full=True): + other_vals[v] = cute.arch.shuffle_sync_bfly(topk_vals[v], offset=1 << i) + bitonic_topk_merge(topk_vals, other_vals, ascending=ascending) + return topk_vals diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sort/generate_sorting_networks.py b/sonic-moe/torch-ext/sonicmoe/quack/sort/generate_sorting_networks.py new file mode 100644 index 00000000..25d10151 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sort/generate_sorting_networks.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python3 +""" +Generate optimized sorting network code from the optimal sorting network data. +Based on data from: https://bertdobbelaere.github.io/sorting_networks.html + +This script generates CUTE DSL functions for optimal sorting networks of various sizes. +""" + +import argparse +import os +import re +from typing import List, Tuple, Dict + +# Network strings from bertdobbelaere.github.io/sorting_networks.html +# Copy-paste network strings here, then run initialize_networks() to parse them +NETWORK_STRINGS = { + # Size 2: 1 CE, depth 1 + 2: """ +[(0,1)] + """, + # Size 4: 5 CEs, depth 3 + 4: """ +[(0,2),(1,3)] +[(0,1),(2,3)] +[(1,2)] + """, + # Size 8: 19 CEs, depth 6 + 8: """ +[(0,2),(1,3),(4,6),(5,7)] +[(0,4),(1,5),(2,6),(3,7)] +[(0,1),(2,3),(4,5),(6,7)] +[(2,4),(3,5)] +[(1,4),(3,6)] +[(1,2),(3,4),(5,6)] + """, + # Size 16: 60 CEs, depth 10 + 16: """ +[(0,13),(1,12),(2,15),(3,14),(4,8),(5,6),(7,11),(9,10)] +[(0,5),(1,7),(2,9),(3,4),(6,13),(8,14),(10,15),(11,12)] +[(0,1),(2,3),(4,5),(6,8),(7,9),(10,11),(12,13),(14,15)] +[(0,2),(1,3),(4,10),(5,11),(6,7),(8,9),(12,14),(13,15)] +[(1,2),(3,12),(4,6),(5,7),(8,10),(9,11),(13,14)] +[(1,4),(2,6),(5,8),(7,10),(9,13),(11,14)] +[(2,4),(3,6),(9,12),(11,13)] +[(3,5),(6,8),(7,9),(10,12)] +[(3,4),(5,6),(7,8),(9,10),(11,12)] +[(6,7),(8,9)] + """, + # Size 32: 185 CEs, depth 14 + 32: """ +[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31)] +[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31)] +[(0,4),(1,5),(2,6),(3,7),(8,12),(9,13),(10,14),(11,15),(16,20),(17,21),(18,22),(19,23),(24,28),(25,29),(26,30),(27,31)] +[(0,8),(1,9),(2,10),(3,11),(4,12),(5,13),(6,14),(7,15),(16,24),(17,25),(18,26),(19,27),(20,28),(21,29),(22,30),(23,31)] +[(0,16),(1,8),(2,4),(3,12),(5,10),(6,9),(7,14),(11,13),(15,31),(17,24),(18,20),(19,28),(21,26),(22,25),(23,30),(27,29)] +[(1,2),(3,5),(4,8),(6,22),(7,11),(9,25),(10,12),(13,14),(17,18),(19,21),(20,24),(23,27),(26,28),(29,30)] +[(1,17),(2,18),(3,19),(4,20),(5,10),(7,23),(8,24),(11,27),(12,28),(13,29),(14,30),(21,26)] +[(3,17),(4,16),(5,21),(6,18),(7,9),(8,20),(10,26),(11,23),(13,25),(14,28),(15,27),(22,24)] +[(1,4),(3,8),(5,16),(7,17),(9,21),(10,22),(11,19),(12,20),(14,24),(15,26),(23,28),(27,30)] +[(2,5),(7,8),(9,18),(11,17),(12,16),(13,22),(14,20),(15,19),(23,24),(26,29)] +[(2,4),(6,12),(9,16),(10,11),(13,17),(14,18),(15,22),(19,25),(20,21),(27,29)] +[(5,6),(8,12),(9,10),(11,13),(14,16),(15,17),(18,20),(19,23),(21,22),(25,26)] +[(3,5),(6,7),(8,9),(10,12),(11,14),(13,16),(15,18),(17,20),(19,21),(22,23),(24,25),(26,28)] +[(3,4),(5,6),(7,8),(9,10),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28)] + """, + # Size 64: 512 CEs, depth 21 + 64: """ +[(0,2),(1,3),(4,6),(5,7),(8,10),(9,11),(12,14),(13,15),(16,18),(17,19),(20,22),(21,23),(24,26),(25,27),(28,30),(29,31),(32,34),(33,35),(36,38),(37,39),(40,42),(41,43),(44,46),(45,47),(48,50),(49,51),(52,54),(53,55),(56,58),(57,59),(60,62),(61,63)] +[(0,1),(2,3),(4,5),(6,7),(8,9),(10,11),(12,13),(14,15),(16,17),(18,19),(20,21),(22,23),(24,25),(26,27),(28,29),(30,31),(32,33),(34,35),(36,37),(38,39),(40,41),(42,43),(44,45),(46,47),(48,49),(50,51),(52,53),(54,55),(56,57),(58,59),(60,61),(62,63)] +[(0,52),(1,2),(3,55),(4,48),(5,6),(7,51),(8,60),(9,10),(11,63),(12,56),(13,14),(15,59),(16,32),(17,18),(19,35),(20,24),(21,22),(23,27),(25,26),(28,44),(29,30),(31,47),(33,34),(36,40),(37,38),(39,43),(41,42),(45,46),(49,50),(53,54),(57,58),(61,62)] +[(0,20),(1,53),(2,54),(3,23),(4,28),(5,49),(6,50),(7,31),(8,36),(9,61),(10,62),(11,39),(12,16),(13,57),(14,58),(15,19),(17,33),(18,34),(21,25),(22,26),(24,52),(27,55),(29,45),(30,46),(32,56),(35,59),(37,41),(38,42),(40,60),(43,63),(44,48),(47,51)] +[(0,4),(1,21),(2,22),(3,7),(5,29),(6,30),(8,12),(9,37),(10,38),(11,15),(13,17),(14,18),(16,20),(19,23),(24,32),(25,53),(26,54),(27,35),(28,36),(31,39),(33,57),(34,58),(40,44),(41,61),(42,62),(43,47),(45,49),(46,50),(48,52),(51,55),(56,60),(59,63)] +[(0,8),(1,5),(2,6),(3,11),(4,12),(7,15),(9,13),(10,14),(16,40),(17,21),(18,22),(19,43),(20,44),(23,47),(24,28),(25,33),(26,34),(27,31),(29,37),(30,38),(32,36),(35,39),(41,45),(42,46),(48,56),(49,53),(50,54),(51,59),(52,60),(55,63),(57,61),(58,62)] +[(1,9),(2,10),(4,8),(5,13),(6,14),(7,11),(12,48),(15,51),(16,24),(17,41),(18,42),(19,27),(20,28),(21,45),(22,46),(23,31),(25,29),(26,30),(32,40),(33,37),(34,38),(35,43),(36,44),(39,47),(49,57),(50,58),(52,56),(53,61),(54,62),(55,59)] +[(4,16),(5,9),(6,10),(7,19),(8,24),(11,27),(13,49),(14,50),(17,25),(18,26),(20,32),(21,29),(22,30),(23,35),(28,40),(31,43),(33,41),(34,42),(36,52),(37,45),(38,46),(39,55),(44,56),(47,59),(53,57),(54,58)] +[(1,4),(5,17),(6,18),(8,16),(9,25),(10,26),(11,19),(12,24),(15,27),(21,33),(22,34),(29,41),(30,42),(36,48),(37,53),(38,54),(39,51),(44,52),(45,57),(46,58),(47,55),(59,62)] +[(2,8),(9,17),(10,18),(12,20),(13,25),(14,26),(15,23),(24,32),(27,35),(28,36),(31,39),(37,49),(38,50),(40,48),(43,51),(45,53),(46,54),(55,61)] +[(2,4),(12,16),(13,21),(14,22),(15,19),(20,24),(23,27),(25,33),(26,34),(28,32),(29,37),(30,38),(31,35),(36,40),(39,43),(41,49),(42,50),(44,48),(47,51),(59,61)] +[(4,16),(5,20),(10,40),(13,17),(14,18),(21,25),(22,26),(23,53),(24,28),(27,31),(29,33),(30,34),(32,36),(35,39),(37,41),(38,42),(43,58),(45,49),(46,50),(47,59)] +[(3,17),(6,36),(7,21),(8,32),(9,24),(11,41),(13,28),(14,44),(15,45),(18,48),(19,49),(22,52),(25,29),(26,30),(27,57),(31,55),(33,37),(34,38),(35,50),(39,54),(42,56),(46,60)] +[(6,20),(8,16),(10,24),(11,25),(14,28),(15,29),(17,33),(18,32),(21,37),(22,36),(26,42),(27,41),(30,46),(31,45),(34,48),(35,49),(38,52),(39,53),(43,57),(47,55)] +[(3,18),(5,8),(6,12),(7,22),(15,21),(17,32),(19,33),(23,37),(26,40),(30,44),(31,46),(41,56),(42,48),(45,60),(51,57),(55,58)] +[(3,16),(7,20),(11,26),(18,24),(19,25),(22,28),(23,29),(27,33),(30,36),(34,40),(35,41),(37,52),(38,44),(39,45),(43,56),(47,60)] +[(3,9),(7,13),(10,16),(11,17),(14,20),(15,30),(19,34),(21,36),(23,38),(25,40),(26,32),(27,42),(29,44),(31,37),(33,48),(43,49),(46,52),(47,53),(50,56),(54,60)] +[(3,8),(7,10),(9,12),(11,18),(13,14),(15,24),(17,22),(19,28),(21,26),(23,25),(27,34),(29,36),(30,32),(31,33),(35,44),(37,42),(38,40),(39,48),(41,46),(45,52),(49,50),(51,54),(53,56),(55,60)] +[(3,6),(7,12),(11,16),(15,17),(18,20),(19,24),(21,22),(23,30),(25,32),(26,28),(27,29),(31,38),(33,40),(34,36),(35,37),(39,44),(41,42),(43,45),(46,48),(47,52),(51,56),(57,60)] +[(3,5),(6,8),(7,9),(10,12),(11,13),(14,16),(15,18),(17,20),(19,21),(22,24),(23,26),(25,28),(27,30),(29,32),(31,34),(33,36),(35,38),(37,40),(39,41),(42,44),(43,46),(45,48),(47,49),(50,52),(51,53),(54,56),(55,57),(58,60)] +[(3,4),(7,8),(11,12),(13,14),(15,16),(17,18),(19,20),(21,22),(23,24),(25,26),(27,28),(29,30),(31,32),(33,34),(35,36),(37,38),(39,40),(41,42),(43,44),(45,46),(47,48),(49,50),(51,52),(55,56),(59,60)] + """, +} + +# This will be populated by initialize_networks() +OPTIMAL_NETWORKS: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] = {} + + +def parse_network_string(network_str: str) -> List[List[Tuple[int, int]]]: + """ + Parse a sorting network string from bertdobbelaere.github.io format. + + Examples: + Input: "[(0,2),(1,3)], [(0,1),(2,3)], [(1,2)]" + Output: [[(0, 2), (1, 3)], [(0, 1), (2, 3)], [(1, 2)]] + + Input: "[(0,1)], [(1,2)], [(0,1)]" + Output: [[(0, 1)], [(1, 2)], [(0, 1)]] + """ + # Remove whitespace and split by '], [' + network_str = network_str.strip() + if not network_str: + return [] + + # Split into layer strings + layer_pattern = r"\[((?:\(\d+,\d+\)(?:,\(\d+,\d+\))*)?)\]" + layers = [] + + for match in re.finditer(layer_pattern, network_str): + layer_str = match.group(1) + if not layer_str.strip(): + layers.append([]) + continue + + # Parse comparisons in this layer: (i,j), (k,l), ... + comparisons = [] + comp_pattern = r"\((\d+),(\d+)\)" + + for comp_match in re.finditer(comp_pattern, layer_str): + i, j = int(comp_match.group(1)), int(comp_match.group(2)) + comparisons.append((i, j)) + + layers.append(comparisons) + + return layers + + +def calculate_network_stats(layers: List[List[Tuple[int, int]]]) -> Tuple[int, int, int]: + """Calculate depth, total comparisons, and max index from network layers.""" + depth = len(layers) + total_comparisons = sum(len(layer) for layer in layers) + + # Find maximum index to determine network size + max_index = 0 + for layer in layers: + for i, j in layer: + max_index = max(max_index, i, j) + + network_size = max_index + 1 # Since indices are 0-based + return depth, total_comparisons, network_size + + +def add_network_from_string(size: int, network_str: str, description: str = ""): + """ + Add a network from a string representation to the OPTIMAL_NETWORKS dictionary. + + Args: + size: Size of the network (number of elements) + network_str: Network string in bertdobbelaere.github.io format + description: Optional description for debugging + """ + try: + layers = parse_network_string(network_str) + depth, comparisons, detected_size = calculate_network_stats(layers) + + if detected_size != size: + print(f"Warning: Network size mismatch! Expected {size}, detected {detected_size}") + print(f"Network string: {network_str[:100]}...") + return False + + OPTIMAL_NETWORKS[size] = (depth, comparisons, layers) + + if description: + print(f"Added network for size {size}: {description}") + print(f" Depth: {depth}, Comparisons: {comparisons}") + return True + + except Exception as e: + print(f"Error parsing network for size {size}: {e}") + print(f"Network string: {network_str[:100]}...") + return False + + +def generate_networks_dict( + networks_data: Dict[int, Tuple[int, int, List[List[Tuple[int, int]]]]] +) -> str: + """Generate the global networks dictionary.""" + lines = ["networks = {"] + + for size, (depth, num_comparisons, layers) in sorted(networks_data.items()): + # Format the network with proper indentation and newlines + network_lines = [] + for i, layer in enumerate(layers): + if i == 0: + network_lines.append(f" {layer}") + else: + network_lines.append(f",\n {layer}") + + if len(layers) == 1: + network_str = f"[{network_lines[0].strip()}]" + else: + network_str = "[\n" + "".join(network_lines) + "\n ]" + + lines.append(f" # Size {size}: {num_comparisons} CEs, depth {depth}") + lines.append(f" {size}: {network_str},") + lines.append("") + + lines.append("}") + return "\n".join(lines) + + +def generate_optimal_sort_function() -> str: + """Generate the single optimal_sort function that looks up networks by size.""" + return """@cute.jit +def optimal_sort( + arr: cute.Tensor, + n: cutlass.Constexpr[int], + start: cutlass.Constexpr[int] = 0, + ascending: cutlass.Constexpr[bool] = True +) -> None: + \"\"\" + Optimal sorting network dispatcher. + + Args: + arr: Array to sort + n: Size of array (must be power of 2 and available in networks) + start: Starting index (default 0) + ascending: Sort in ascending order (default True) + + Source: https://bertdobbelaere.github.io/sorting_networks.html + \"\"\" + assert n in networks + for level in networks[n]: + for i, j in level: + compare_and_swap(arr, start + i, start + j, ascending) +""" + + +def generate_sorting_networks_file(max_size: int = 64): + """Generate a complete sorting networks file with optimal networks up to max_size.""" + + output_file = os.path.join(os.path.dirname(__file__), "sorting_networks.py") + + # Header + header = '''# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. +""" +Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html + +This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly. +""" + +# fmt: off +# ruff: noqa +# isort: skip_file + +import cutlass +import cutlass.cute as cute + +from .utils import compare_and_swap + + +''' + + # Generate networks dictionary and optimal_sort function + sizes = [n for n in range(2, max_size + 1) if n in OPTIMAL_NETWORKS] + networks_dict = generate_networks_dict(OPTIMAL_NETWORKS) + optimal_sort_func = generate_optimal_sort_function() + + # Combine everything + content = header + networks_dict + "\n\n\n" + optimal_sort_func + + with open(output_file, "w") as f: + f.write(content) + + print(f"Generated optimal sorting networks for sizes {sizes}") + print(f"Output written to: {output_file}") + return sizes + + +def initialize_networks(): + """Initialize the OPTIMAL_NETWORKS dictionary by parsing NETWORK_STRINGS.""" + global OPTIMAL_NETWORKS + OPTIMAL_NETWORKS.clear() + + for size, network_str in NETWORK_STRINGS.items(): + success = add_network_from_string(size, network_str, f"Size {size} optimal network") + if not success: + print(f"Warning: Failed to parse network for size {size}") + + +def main(): + parser = argparse.ArgumentParser( + description="Generate optimal sorting network code from bertdobbelaere.github.io data" + ) + parser.add_argument( + "--max-size", + "-m", + type=int, + default=64, + help="Maximum sorting network size to generate (default: 32)", + ) + parser.add_argument( + "--stats", "-s", action="store_true", help="Print statistics about the optimal networks" + ) + + args = parser.parse_args() + + # Initialize networks from strings + initialize_networks() + + if args.stats: + print("Optimal Sorting Network Statistics:") + print("Size\tDepth\tComparisons\tLayers") + print("-" * 35) + for n in sorted(OPTIMAL_NETWORKS.keys()): + if n <= args.max_size: + depth, comparisons, layers = OPTIMAL_NETWORKS[n] + print(f"{n}\t{depth}\t{comparisons}\t\t{len(layers)}") + + # Generate the sorting networks file + sizes = generate_sorting_networks_file(args.max_size) + + print(f"\nGenerated optimal sorting networks for {len(sizes)} sizes") + print(f"Total networks: {len(sizes)}") + print(f"Max network size: {max(sizes)}") + + +if __name__ == "__main__": + main() diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sort/sorting_networks.py b/sonic-moe/torch-ext/sonicmoe/quack/sort/sorting_networks.py new file mode 100644 index 00000000..5390f4d7 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sort/sorting_networks.py @@ -0,0 +1,120 @@ +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. +""" +Optimal sorting networks generated from: https://bertdobbelaere.github.io/sorting_networks.html + +This file was auto-generated by quack/sort/generate_sorting_networks.py. Do not edit it directly. +""" + +# fmt: off +# ruff: noqa +# isort: skip_file + +import cutlass +import cutlass.cute as cute + +from .utils import compare_and_swap + + +networks = { + # Size 2: 1 CEs, depth 1 + 2: [[(0, 1)]], + + # Size 4: 5 CEs, depth 3 + 4: [ + [(0, 2), (1, 3)], + [(0, 1), (2, 3)], + [(1, 2)] + ], + + # Size 8: 19 CEs, depth 6 + 8: [ + [(0, 2), (1, 3), (4, 6), (5, 7)], + [(0, 4), (1, 5), (2, 6), (3, 7)], + [(0, 1), (2, 3), (4, 5), (6, 7)], + [(2, 4), (3, 5)], + [(1, 4), (3, 6)], + [(1, 2), (3, 4), (5, 6)] + ], + + # Size 16: 60 CEs, depth 10 + 16: [ + [(0, 13), (1, 12), (2, 15), (3, 14), (4, 8), (5, 6), (7, 11), (9, 10)], + [(0, 5), (1, 7), (2, 9), (3, 4), (6, 13), (8, 14), (10, 15), (11, 12)], + [(0, 1), (2, 3), (4, 5), (6, 8), (7, 9), (10, 11), (12, 13), (14, 15)], + [(0, 2), (1, 3), (4, 10), (5, 11), (6, 7), (8, 9), (12, 14), (13, 15)], + [(1, 2), (3, 12), (4, 6), (5, 7), (8, 10), (9, 11), (13, 14)], + [(1, 4), (2, 6), (5, 8), (7, 10), (9, 13), (11, 14)], + [(2, 4), (3, 6), (9, 12), (11, 13)], + [(3, 5), (6, 8), (7, 9), (10, 12)], + [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12)], + [(6, 7), (8, 9)] + ], + + # Size 32: 185 CEs, depth 14 + 32: [ + [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31)], + [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31)], + [(0, 4), (1, 5), (2, 6), (3, 7), (8, 12), (9, 13), (10, 14), (11, 15), (16, 20), (17, 21), (18, 22), (19, 23), (24, 28), (25, 29), (26, 30), (27, 31)], + [(0, 8), (1, 9), (2, 10), (3, 11), (4, 12), (5, 13), (6, 14), (7, 15), (16, 24), (17, 25), (18, 26), (19, 27), (20, 28), (21, 29), (22, 30), (23, 31)], + [(0, 16), (1, 8), (2, 4), (3, 12), (5, 10), (6, 9), (7, 14), (11, 13), (15, 31), (17, 24), (18, 20), (19, 28), (21, 26), (22, 25), (23, 30), (27, 29)], + [(1, 2), (3, 5), (4, 8), (6, 22), (7, 11), (9, 25), (10, 12), (13, 14), (17, 18), (19, 21), (20, 24), (23, 27), (26, 28), (29, 30)], + [(1, 17), (2, 18), (3, 19), (4, 20), (5, 10), (7, 23), (8, 24), (11, 27), (12, 28), (13, 29), (14, 30), (21, 26)], + [(3, 17), (4, 16), (5, 21), (6, 18), (7, 9), (8, 20), (10, 26), (11, 23), (13, 25), (14, 28), (15, 27), (22, 24)], + [(1, 4), (3, 8), (5, 16), (7, 17), (9, 21), (10, 22), (11, 19), (12, 20), (14, 24), (15, 26), (23, 28), (27, 30)], + [(2, 5), (7, 8), (9, 18), (11, 17), (12, 16), (13, 22), (14, 20), (15, 19), (23, 24), (26, 29)], + [(2, 4), (6, 12), (9, 16), (10, 11), (13, 17), (14, 18), (15, 22), (19, 25), (20, 21), (27, 29)], + [(5, 6), (8, 12), (9, 10), (11, 13), (14, 16), (15, 17), (18, 20), (19, 23), (21, 22), (25, 26)], + [(3, 5), (6, 7), (8, 9), (10, 12), (11, 14), (13, 16), (15, 18), (17, 20), (19, 21), (22, 23), (24, 25), (26, 28)], + [(3, 4), (5, 6), (7, 8), (9, 10), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28)] + ], + + # Size 64: 521 CEs, depth 21 + 64: [ + [(0, 2), (1, 3), (4, 6), (5, 7), (8, 10), (9, 11), (12, 14), (13, 15), (16, 18), (17, 19), (20, 22), (21, 23), (24, 26), (25, 27), (28, 30), (29, 31), (32, 34), (33, 35), (36, 38), (37, 39), (40, 42), (41, 43), (44, 46), (45, 47), (48, 50), (49, 51), (52, 54), (53, 55), (56, 58), (57, 59), (60, 62), (61, 63)], + [(0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11), (12, 13), (14, 15), (16, 17), (18, 19), (20, 21), (22, 23), (24, 25), (26, 27), (28, 29), (30, 31), (32, 33), (34, 35), (36, 37), (38, 39), (40, 41), (42, 43), (44, 45), (46, 47), (48, 49), (50, 51), (52, 53), (54, 55), (56, 57), (58, 59), (60, 61), (62, 63)], + [(0, 52), (1, 2), (3, 55), (4, 48), (5, 6), (7, 51), (8, 60), (9, 10), (11, 63), (12, 56), (13, 14), (15, 59), (16, 32), (17, 18), (19, 35), (20, 24), (21, 22), (23, 27), (25, 26), (28, 44), (29, 30), (31, 47), (33, 34), (36, 40), (37, 38), (39, 43), (41, 42), (45, 46), (49, 50), (53, 54), (57, 58), (61, 62)], + [(0, 20), (1, 53), (2, 54), (3, 23), (4, 28), (5, 49), (6, 50), (7, 31), (8, 36), (9, 61), (10, 62), (11, 39), (12, 16), (13, 57), (14, 58), (15, 19), (17, 33), (18, 34), (21, 25), (22, 26), (24, 52), (27, 55), (29, 45), (30, 46), (32, 56), (35, 59), (37, 41), (38, 42), (40, 60), (43, 63), (44, 48), (47, 51)], + [(0, 4), (1, 21), (2, 22), (3, 7), (5, 29), (6, 30), (8, 12), (9, 37), (10, 38), (11, 15), (13, 17), (14, 18), (16, 20), (19, 23), (24, 32), (25, 53), (26, 54), (27, 35), (28, 36), (31, 39), (33, 57), (34, 58), (40, 44), (41, 61), (42, 62), (43, 47), (45, 49), (46, 50), (48, 52), (51, 55), (56, 60), (59, 63)], + [(0, 8), (1, 5), (2, 6), (3, 11), (4, 12), (7, 15), (9, 13), (10, 14), (16, 40), (17, 21), (18, 22), (19, 43), (20, 44), (23, 47), (24, 28), (25, 33), (26, 34), (27, 31), (29, 37), (30, 38), (32, 36), (35, 39), (41, 45), (42, 46), (48, 56), (49, 53), (50, 54), (51, 59), (52, 60), (55, 63), (57, 61), (58, 62)], + [(1, 9), (2, 10), (4, 8), (5, 13), (6, 14), (7, 11), (12, 48), (15, 51), (16, 24), (17, 41), (18, 42), (19, 27), (20, 28), (21, 45), (22, 46), (23, 31), (25, 29), (26, 30), (32, 40), (33, 37), (34, 38), (35, 43), (36, 44), (39, 47), (49, 57), (50, 58), (52, 56), (53, 61), (54, 62), (55, 59)], + [(4, 16), (5, 9), (6, 10), (7, 19), (8, 24), (11, 27), (13, 49), (14, 50), (17, 25), (18, 26), (20, 32), (21, 29), (22, 30), (23, 35), (28, 40), (31, 43), (33, 41), (34, 42), (36, 52), (37, 45), (38, 46), (39, 55), (44, 56), (47, 59), (53, 57), (54, 58)], + [(1, 4), (5, 17), (6, 18), (8, 16), (9, 25), (10, 26), (11, 19), (12, 24), (15, 27), (21, 33), (22, 34), (29, 41), (30, 42), (36, 48), (37, 53), (38, 54), (39, 51), (44, 52), (45, 57), (46, 58), (47, 55), (59, 62)], + [(2, 8), (9, 17), (10, 18), (12, 20), (13, 25), (14, 26), (15, 23), (24, 32), (27, 35), (28, 36), (31, 39), (37, 49), (38, 50), (40, 48), (43, 51), (45, 53), (46, 54), (55, 61)], + [(2, 4), (12, 16), (13, 21), (14, 22), (15, 19), (20, 24), (23, 27), (25, 33), (26, 34), (28, 32), (29, 37), (30, 38), (31, 35), (36, 40), (39, 43), (41, 49), (42, 50), (44, 48), (47, 51), (59, 61)], + [(4, 16), (5, 20), (10, 40), (13, 17), (14, 18), (21, 25), (22, 26), (23, 53), (24, 28), (27, 31), (29, 33), (30, 34), (32, 36), (35, 39), (37, 41), (38, 42), (43, 58), (45, 49), (46, 50), (47, 59)], + [(3, 17), (6, 36), (7, 21), (8, 32), (9, 24), (11, 41), (13, 28), (14, 44), (15, 45), (18, 48), (19, 49), (22, 52), (25, 29), (26, 30), (27, 57), (31, 55), (33, 37), (34, 38), (35, 50), (39, 54), (42, 56), (46, 60)], + [(6, 20), (8, 16), (10, 24), (11, 25), (14, 28), (15, 29), (17, 33), (18, 32), (21, 37), (22, 36), (26, 42), (27, 41), (30, 46), (31, 45), (34, 48), (35, 49), (38, 52), (39, 53), (43, 57), (47, 55)], + [(3, 18), (5, 8), (6, 12), (7, 22), (15, 21), (17, 32), (19, 33), (23, 37), (26, 40), (30, 44), (31, 46), (41, 56), (42, 48), (45, 60), (51, 57), (55, 58)], + [(3, 16), (7, 20), (11, 26), (18, 24), (19, 25), (22, 28), (23, 29), (27, 33), (30, 36), (34, 40), (35, 41), (37, 52), (38, 44), (39, 45), (43, 56), (47, 60)], + [(3, 9), (7, 13), (10, 16), (11, 17), (14, 20), (15, 30), (19, 34), (21, 36), (23, 38), (25, 40), (26, 32), (27, 42), (29, 44), (31, 37), (33, 48), (43, 49), (46, 52), (47, 53), (50, 56), (54, 60)], + [(3, 8), (7, 10), (9, 12), (11, 18), (13, 14), (15, 24), (17, 22), (19, 28), (21, 26), (23, 25), (27, 34), (29, 36), (30, 32), (31, 33), (35, 44), (37, 42), (38, 40), (39, 48), (41, 46), (45, 52), (49, 50), (51, 54), (53, 56), (55, 60)], + [(3, 6), (7, 12), (11, 16), (15, 17), (18, 20), (19, 24), (21, 22), (23, 30), (25, 32), (26, 28), (27, 29), (31, 38), (33, 40), (34, 36), (35, 37), (39, 44), (41, 42), (43, 45), (46, 48), (47, 52), (51, 56), (57, 60)], + [(3, 5), (6, 8), (7, 9), (10, 12), (11, 13), (14, 16), (15, 18), (17, 20), (19, 21), (22, 24), (23, 26), (25, 28), (27, 30), (29, 32), (31, 34), (33, 36), (35, 38), (37, 40), (39, 41), (42, 44), (43, 46), (45, 48), (47, 49), (50, 52), (51, 53), (54, 56), (55, 57), (58, 60)], + [(3, 4), (7, 8), (11, 12), (13, 14), (15, 16), (17, 18), (19, 20), (21, 22), (23, 24), (25, 26), (27, 28), (29, 30), (31, 32), (33, 34), (35, 36), (37, 38), (39, 40), (41, 42), (43, 44), (45, 46), (47, 48), (49, 50), (51, 52), (55, 56), (59, 60)] + ], + +} + + +@cute.jit +def optimal_sort( + arr: cute.Tensor, + n: cutlass.Constexpr[int], + start: cutlass.Constexpr[int] = 0, + ascending: cutlass.Constexpr[bool] = True +) -> None: + """ + Optimal sorting network dispatcher. + + Args: + arr: Array to sort + n: Size of array (must be power of 2 and available in networks) + start: Starting index (default 0) + ascending: Sort in ascending order (default True) + + Source: https://bertdobbelaere.github.io/sorting_networks.html + """ + assert n in networks + for level in networks[n]: + for i, j in level: + compare_and_swap(arr, start + i, start + j, ascending) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/sort/utils.py b/sonic-moe/torch-ext/sonicmoe/quack/sort/utils.py new file mode 100644 index 00000000..0237e88a --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/sort/utils.py @@ -0,0 +1,31 @@ +import cutlass.cute as cute +from cutlass import Float32, const_expr + +from .. import utils + + +@cute.jit +def compare_and_swap( + arr: cute.Tensor, i: int, j: int, ascending: bool = True, use_selection: bool = False +) -> None: + """Compare and swap elements at indices i and j in ascending or descending order.""" + if const_expr(use_selection): + a, b = arr[i], arr[j] + if (a > b) ^ (not ascending): + arr[i] = b + arr[j] = a + # if const_expr(ascending): + # if a > b: + # arr[i] = b + # arr[j] = a + # else: + # if a < b: + # arr[i] = b + # arr[j] = a + else: + min_fn = min if const_expr(arr.element_type != Float32) else utils.fmin + max_fn = max if const_expr(arr.element_type != Float32) else cute.arch.fmax + if const_expr(ascending): + arr[i], arr[j] = min_fn(arr[i], arr[j]), max_fn(arr[i], arr[j]) + else: + arr[i], arr[j] = max_fn(arr[i], arr[j]), min_fn(arr[i], arr[j]) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/tensormap_manager.py b/sonic-moe/torch-ext/sonicmoe/quack/tensormap_manager.py new file mode 100644 index 00000000..9241f8cd --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/tensormap_manager.py @@ -0,0 +1,115 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass.cutlass_dsl import Boolean, const_expr, Int32 +from cutlass.utils import TensorMapUpdateMode, TensorMapManager +from cutlass._mlir.dialects import llvm + + +@dataclass(frozen=True) +class TensorMapManagerSm90(TensorMapManager): + """ + We have to subclass cutlass.utils.TensorMapManager bc it takes in warp_id and only + perform the operation if warp_id matches the current warp. + But for Hopper pingpong gemm we want to call it with warp_id 0 and 4. + So we take in a boolean `is_manager_warp` to determine whether to perform the operation or not. + """ + + @cute.jit + def init_tensormap_from_atom( + self, copy_atom: cute.CopyAtom, dst_ptr: cute.Pointer, is_manager_warp: Boolean + ) -> None: + if is_manager_warp: + with cute.arch.elect_one(): + cute.nvgpu.cpasync.copy_tensormap(copy_atom, dst_ptr) + cute.arch.sync_warp() + return + + @cute.jit + def update_tensormap( + self, + tensor_gmem: Tuple[cute.Tensor, ...], + tma_copy_atom: Tuple[cute.CopyAtom, ...], + tensormap_gmem_ptr: Tuple[cute.Pointer, ...], + is_manager_warp: Boolean, + tensormap_smem_ptr: Tuple[cute.Pointer, ...], + ) -> None: + # updates before touching tensormap in global memory + if is_manager_warp: + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for copy_atom, tensor, smem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_smem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, smem_ptr) + # wait until it's safe to update tensormap in global memory + with cute.arch.elect_one(): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_warp() + # updates to tensormap in global memory + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): + cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) + else: + for copy_atom, tensor, gmem_ptr in zip( + tma_copy_atom, tensor_gmem, tensormap_gmem_ptr + ): + cute.nvgpu.cpasync.update_tma_descriptor(copy_atom, tensor, gmem_ptr) + cute.arch.sync_warp() + cute.nvgpu.cpasync.fence_tma_desc_release() + + @cute.jit + def update_tensormap_shape( + self, + tensormap_gmem_ptr: Tuple[cute.Pointer, ...], + is_manager_warp: Boolean, + tensormap_smem_ptr: Tuple[cute.Pointer, ...], + shapes: Tuple[Int32, ...], + orders: cutlass.Constexpr[Tuple[int, ...]], + ) -> None: + # updates before touching tensormap in global memory + if is_manager_warp: + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for smem_ptr, shape, order in zip(tensormap_smem_ptr, shapes, orders): + smem_ptr_i32 = smem_ptr.toint().ir_value() + llvm.inline_asm( + None, + [smem_ptr_i32, Int32(shape).ir_value(), Int32(order).ir_value()], + "{\n\t" + ".reg .b64 smem_ptr_i64;\n\t" + "cvt.u64.u32 smem_ptr_i64, $0;\n\t" + f"tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [smem_ptr_i64], {order}, $1;\n\t" + "}\n", + "r,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + # wait until it's safe to update tensormap in global memory + with cute.arch.elect_one(): + cute.arch.cp_async_bulk_commit_group() + cute.arch.cp_async_bulk_wait_group(0, read=True) + cute.arch.sync_warp() + # updates to tensormap in global memory + if const_expr(self.tensormap_update_mode == TensorMapUpdateMode.SMEM): + for gmem_ptr, smem_ptr in zip(tensormap_gmem_ptr, tensormap_smem_ptr): + cute.nvgpu.cpasync.cp_fence_tma_desc_release(gmem_ptr, smem_ptr) + else: + assert len(shapes) == len(orders) == len(tensormap_gmem_ptr) + for gmem_ptr, shape, order in zip(tensormap_gmem_ptr, shapes, orders): + gmem_ptr_i64 = gmem_ptr.toint().ir_value() + llvm.inline_asm( + None, + [gmem_ptr_i64, Int32(shape).ir_value(), Int32(order).ir_value()], + f"tensormap.replace.tile.global_dim.global.b1024.b32 [$0], {order}, $1;", + "l,r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + cute.arch.sync_warp() + cute.nvgpu.cpasync.fence_tma_desc_release() diff --git a/sonic-moe/torch-ext/sonicmoe/quack/tile_scheduler.py b/sonic-moe/torch-ext/sonicmoe/quack/tile_scheduler.py new file mode 100644 index 00000000..e14a146f --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/tile_scheduler.py @@ -0,0 +1,932 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Tuple, Optional +from dataclasses import dataclass +from enum import IntEnum + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, Boolean, const_expr + +from . import utils as utils +from .fast_math import FastDivmod +from .pipeline import PipelineStateWAdvance +from .cute_dsl_utils import ArgumentsBase, ParamsBase + + +class RasterOrderOption(IntEnum): + AlongM = 0 + AlongN = 1 + Heuristic = 2 # Pick AlongM if tiles_n > tiles_m, else AlongN + + +class RasterOrder(IntEnum): + AlongM = 0 + AlongN = 1 + + +@cute.jit +def get_raster_order_from_option( + raster_order_option: RasterOrderOption, problem_shape_ncluster_mn: cute.Shape, group_size: Int32 +) -> RasterOrder: + raster_order = ( + RasterOrder.AlongM + if raster_order_option == RasterOrderOption.AlongM + else RasterOrder.AlongN + ) + if raster_order_option == RasterOrderOption.Heuristic: + problem_blocks_m = cute.round_up(problem_shape_ncluster_mn[0], group_size) + problem_blocks_n = cute.round_up(problem_shape_ncluster_mn[1], group_size) + raster_order = ( + RasterOrder.AlongM if problem_blocks_n > problem_blocks_m else RasterOrder.AlongN + ) + return raster_order + + +# Grouping arguments together that should be passed to __call__ +@dataclass +class TileSchedulerOptions(ArgumentsBase): + max_active_clusters: Int32 + raster_order: cutlass.Constexpr[RasterOrderOption] = RasterOrderOption.Heuristic + max_swizzle_size: Int32 = Int32(8) + tile_count_semaphore: Optional[cute.Pointer] = None + batch_idx_permute: Optional[cute.Tensor] = None + + +@dataclass +class TileSchedulerArguments(ArgumentsBase): + problem_shape_ntile_mnl: cute.Shape + raster_order: cutlass.Constexpr[RasterOrderOption] + group_size: Int32 + cluster_shape_mnk: cutlass.Constexpr[cute.Shape] + tile_count_semaphore: Optional[cute.Pointer] = None + batch_idx_permute: Optional[cute.Tensor] = None + is_persistent: cutlass.Constexpr[bool] = False + + +class TileScheduler: + @dataclass + class Params(ParamsBase): + problem_shape_ncluster_mnl: cute.Shape + raster_order: RasterOrder + num_clusters_per_problem_divmod: FastDivmod + num_groups_regular: Int32 + group_size_divmod: FastDivmod + group_size_tail_divmod: FastDivmod + num_clusters_in_group_divmod: FastDivmod + tile_count_semaphore: Optional[cute.Pointer] + batch_idx_permute: Optional[cute.Tensor] + cluster_shape_mn: cutlass.Constexpr[cute.Shape] + is_persistent: cutlass.Constexpr[bool] + + @staticmethod + @cute.jit + def create(args: TileSchedulerArguments, *, loc=None, ip=None) -> "TileScheduler.Params": + assert args.cluster_shape_mnk[2] == 1 + cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1])) + problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1]) + problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn) + problem_shape_ncluster_mnl = problem_shape_ncluster_mn + ( + args.problem_shape_ntile_mnl[2], + ) + num_clusters_per_problem = cute.size(problem_shape_ncluster_mn) + raster_order = get_raster_order_from_option( + args.raster_order, problem_shape_ncluster_mn, args.group_size + ) + ncluster_fast = ( + problem_shape_ncluster_mn[0] + if raster_order == RasterOrder.AlongM + else problem_shape_ncluster_mn[1] + ) + ncluster_slow = ( + problem_shape_ncluster_mn[1] + if raster_order == RasterOrder.AlongM + else problem_shape_ncluster_mn[0] + ) + group_size = min(args.group_size, ncluster_fast) + group_size_tail = ncluster_fast % group_size + num_groups_regular = ncluster_fast // group_size + num_clusters_in_group = group_size * ncluster_slow + return TileScheduler.Params( + problem_shape_ncluster_mnl, + raster_order, + FastDivmod.create(num_clusters_per_problem), + num_groups_regular, + FastDivmod.create(group_size), + # Don't divide by 0 + FastDivmod.create(group_size_tail if group_size_tail > 0 else 1), + FastDivmod.create(num_clusters_in_group), + args.tile_count_semaphore if const_expr(args.is_persistent) else None, + args.batch_idx_permute, + cluster_shape_mn, + args.is_persistent, + ) + + def __init__( + self, + current_work_linear_idx: Int32, + num_tiles_executed: Int32, + tile_count: Optional[cute.Tensor], + scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync], + pipeline_state: PipelineStateWAdvance, + params: Params, + *, + loc=None, + ip=None, + ): + self._current_work_linear_idx = current_work_linear_idx + self.num_tiles_executed = num_tiles_executed + self._tile_count = tile_count + self._scheduler_pipeline = scheduler_pipeline + self._pipeline_state = pipeline_state + self.params = params + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return TileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + tile_count: Optional[cute.Tensor] = None, + scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, + is_scheduler_warp: bool | Boolean = False, + *, + loc=None, + ip=None, + ) -> "TileScheduler": + """is_scheduler_warp should only be true for one warp in the whole cluster""" + stages = 0 + if const_expr(not params.is_persistent): + cidx, cidy, _ = cute.arch.cluster_idx() + cdimx, _, _ = cute.arch.cluster_dim() + cluster_id = cidx + cidy * cdimx + current_work_linear_idx = Int32(cluster_id) + else: + _, _, bidz = cute.arch.block_idx() + current_work_linear_idx = Int32(bidz) + if const_expr(params.tile_count_semaphore is not None): + assert tile_count is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(tile_count)) + return TileScheduler( + current_work_linear_idx, + Int32(0), # num_tiles_executed + tile_count, + scheduler_pipeline, + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + params, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + num_ctas_mnl = tuple( + x * y for x, y in zip(params.problem_shape_ncluster_mnl, params.cluster_shape_mn) + ) + (params.problem_shape_ncluster_mnl[2],) + if const_expr(not params.is_persistent): + return num_ctas_mnl + else: + num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip) + # Total ctas that can run in one wave + num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster + num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave) + num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster + return (*params.cluster_shape_mn, num_persistent_clusters) + + @cute.jit + def _swizzle_cta( + self, cluster_id_in_problem: Int32, *, loc=None, ip=None + ) -> Tuple[Int32, Int32]: + # CTA Swizzle to promote L2 data reuse + params = self.params + group_id, id_in_group = params.num_clusters_in_group_divmod.divmod(cluster_id_in_problem) + cid_fast_in_group, cid_slow = Int32(0), Int32(0) + if group_id < params.num_groups_regular: + cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group) + else: # tail part + cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group) + if group_id % 2 == 1: # serpentine order + ncluster_slow = ( + params.problem_shape_ncluster_mnl[1] + if params.raster_order == RasterOrder.AlongM + else params.problem_shape_ncluster_mnl[0] + ) + cid_slow = ncluster_slow - 1 - cid_slow + cid_fast = group_id * params.group_size_divmod.divisor + cid_fast_in_group + cid_m, cid_n = cid_fast, cid_slow + if params.raster_order == RasterOrder.AlongN: + cid_m, cid_n = cid_slow, cid_fast + return cid_m, cid_n + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + if const_expr(not params.is_persistent): + cluster_id_in_problem = self._current_work_linear_idx + _, _, bidz = cute.arch.block_idx() + else: + bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod( + self._current_work_linear_idx + ) + cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, loc=loc, ip=ip) + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] + batch_idx = ( + bidz if const_expr(params.batch_idx_permute is None) else params.batch_idx_permute[bidz] + ) + tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) + if const_expr(not params.is_persistent): + is_valid = self.num_tiles_executed == 0 + else: + is_valid = self._current_work_linear_idx < cute.size(params.problem_shape_ncluster_mnl) + return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + + def initial_work_tile_info(self, *, loc=None, ip=None): + return self.get_current_work(loc=loc, ip=ip) + + @cute.jit + def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None): + """is_scheduler_warp should only be true for one warp in the whole cluster""" + params = self.params + if const_expr(params.is_persistent and params.tile_count_semaphore is not None): + current_work_linear_idx = self._current_work_linear_idx + if is_scheduler_warp: + if cute.arch.lane_idx() == 0: + num_persistent_clusters = cute.arch.grid_dim()[2] + current_work_linear_idx = num_persistent_clusters + utils.atomic_inc_i32( + cute.size(params.problem_shape_ncluster_mnl) - 1, + params.tile_count_semaphore, + ) + # lane 0 already has the right tile_idx, just need to broadcast + current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0) + self._current_work_linear_idx = current_work_linear_idx + + @cute.jit + def advance_to_next_work( + self, + is_scheduler_warp: bool | Boolean = False, + *, + advance_count: int = 1, + loc=None, + ip=None, + ): + tidx = cute.arch.thread_idx()[0] + bidx = cute.arch.block_idx()[0] + bidz = cute.arch.block_idx()[2] + params = self.params + if const_expr(params.is_persistent): + num_persistent_clusters = cute.arch.grid_dim()[2] + if const_expr(params.tile_count_semaphore is None): # Static persistent + self._current_work_linear_idx += advance_count * Int32(num_persistent_clusters) + else: # Dynamic persistent + if const_expr(advance_count > 1): + self._pipeline_state.advance_iters(advance_count - 1) + current_work_linear_idx = self._current_work_linear_idx + if is_scheduler_warp: + self._scheduler_pipeline.producer_acquire(self._pipeline_state) + lane_idx = cute.arch.lane_idx() + if lane_idx < cute.size(params.cluster_shape_mn): + # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after empty wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) + if const_expr(cute.size(params.cluster_shape_mn) == 1): + self._tile_count[self._pipeline_state.index] = current_work_linear_idx + self._scheduler_pipeline.producer_commit(self._pipeline_state) + else: + peer_cta_rank_in_cluster = lane_idx + mbar_ptr = self._scheduler_pipeline.producer_get_barrier( + self._pipeline_state + ) + cute.arch.mbarrier_arrive_and_expect_tx( + mbar_ptr, 4, peer_cta_rank_in_cluster + ) + utils.store_shared_remote( + val=current_work_linear_idx, + smem_ptr=self._tile_count.iterator + self._pipeline_state.index, + mbar_ptr=mbar_ptr, + peer_cta_rank_in_cluster=peer_cta_rank_in_cluster, + ) + # cute.printf("Producer bidx = {}, bidz = {}, tidx = {}, after full arrive", bidx, bidz, tidx) + else: + # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) + self._scheduler_pipeline.consumer_wait(self._pipeline_state) + # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after full wait, idx = {}", bidx, bidz, tidx, current_work_linear_idx) + current_work_linear_idx = self._tile_count[self._pipeline_state.index] + # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, after smem read, idx = {}", bidx, bidz, tidx, current_work_linear_idx) + # Need this fence since the STAS from the producer is using the async proxy. + # Without this, we get race condition / deadlock. + if const_expr(cute.size(params.cluster_shape_mn) > 1): + cute.arch.fence_proxy( + cute.arch.ProxyKind.async_shared, space=cute.arch.SharedSpace.shared_cta + ) + cute.arch.sync_warp() + with cute.arch.elect_one(): + # if tidx % 32 == 0: cute.printf("bidx = {}, bidz = {}, tidx = {}, before empty arrive", bidx, bidz, tidx) + self._scheduler_pipeline.consumer_release(self._pipeline_state) + # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx) + # if tidx == 320: cute.printf("bidx = {}, bidz = {}, tidx = {}, idx = {}, after empty arrive", bidx, bidz, tidx, current_work_linear_idx) + self._current_work_linear_idx = current_work_linear_idx + self._pipeline_state.advance() + self.num_tiles_executed += Int32(advance_count) + + def producer_tail(self): + if const_expr(self.params.is_persistent and self.params.tile_count_semaphore is not None): + self._scheduler_pipeline.producer_tail(self._pipeline_state) + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self._current_work_linear_idx, + self.num_tiles_executed, + self._tile_count, + self._scheduler_pipeline, + self._pipeline_state, + self.params, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self._current_work_linear_idx, + self.num_tiles_executed, + self._tile_count, + self._scheduler_pipeline, + self._pipeline_state, + self.params, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) + + +@cute.jit +def triangular_idx_to_coord(idx: Int32) -> Tuple[Int32, Int32]: + """ + Convert a triangular index to 2D coordinates. + This is used to convert the linear index to 2D coordinates for triangular matrices. + """ + row = utils.ceil((utils.sqrt(2 * idx + 2.25) - 0.5)) - 1 + col = idx - (row * (row + 1)) // 2 + return row, col + + +class TriangularTileScheduler(TileScheduler): + """We assume the tile size per cluster is square (e.g., 128 x 256 per CTA, with cluster 2 x 1)""" + + @dataclass + class Params(ParamsBase): + problem_shape_ncluster_mnl: cute.Shape + num_clusters_per_problem_divmod: FastDivmod + group_size_inv_f32: Float32 + num_groups_regular: Int32 + group_size_divmod: FastDivmod + group_size_tail_divmod: FastDivmod + group_size_mul_group_size_divmod: FastDivmod + group_size_tail_mul_group_size_divmod: FastDivmod + tile_count_semaphore: Optional[cute.Pointer] + cluster_shape_mn: cutlass.Constexpr[cute.Shape] + is_persistent: cutlass.Constexpr[bool] + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "TriangularTileScheduler.Params": + assert args.cluster_shape_mnk[2] == 1 + cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1])) + problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1]) + problem_shape_ncluster_mn = cute.ceil_div(problem_shape_ntile_mn, cluster_shape_mn) + problem_shape_ncluster_mnl = problem_shape_ncluster_mn + ( + args.problem_shape_ntile_mnl[2], + ) + cluster_m = problem_shape_ncluster_mn[0] + # Assume that each cluster is responsible for a square tile + num_clusters_per_problem = cluster_m * (cluster_m + 1) // 2 + group_size = min(args.group_size, cluster_m) + group_size_tail = cluster_m % group_size + num_groups_regular = cluster_m // group_size + return TriangularTileScheduler.Params( + problem_shape_ncluster_mnl, + FastDivmod.create(num_clusters_per_problem), + Float32(1.0 / group_size), + num_groups_regular, + FastDivmod.create(group_size), + # Don't divide by 0 + FastDivmod.create(group_size_tail if group_size_tail > 0 else 1), + FastDivmod.create(group_size * group_size), + FastDivmod.create((group_size_tail if group_size_tail > 0 else 1) * group_size), + args.tile_count_semaphore if const_expr(args.is_persistent) else None, + cluster_shape_mn, + args.is_persistent, + ) + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return TriangularTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + tile_count: Optional[cute.Tensor] = None, + scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, + is_scheduler_warp: bool | Boolean = False, + *, + loc=None, + ip=None, + ) -> "TriangularTileScheduler": + stages = 0 + if const_expr(not params.is_persistent): + cluster_id, _, _ = cute.arch.cluster_idx() + current_work_linear_idx = Int32(cluster_id) + else: + _, _, bidz = cute.arch.block_idx() + current_work_linear_idx = Int32(bidz) + if const_expr(params.tile_count_semaphore is not None): + assert tile_count is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(tile_count)) + return TriangularTileScheduler( + current_work_linear_idx, + Int32(0), # num_tiles_executed + tile_count, + scheduler_pipeline, + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + params, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + clusters = ( + params.num_clusters_per_problem_divmod.divisor, + 1, + params.problem_shape_ncluster_mnl[2], + ) + num_ctas_mnl = tuple(x * y for x, y in zip(clusters, params.cluster_shape_mn)) + ( + params.problem_shape_ncluster_mnl[2], + ) + if const_expr(not params.is_persistent): + return num_ctas_mnl + else: + num_ctas_in_problem = cute.size(num_ctas_mnl, loc=loc, ip=ip) + num_ctas_per_cluster = cute.size(params.cluster_shape_mn, loc=loc, ip=ip) + # Total ctas that can run in one wave + num_ctas_per_wave = max_active_clusters * num_ctas_per_cluster + num_persistent_ctas = cutlass.min(num_ctas_in_problem, num_ctas_per_wave) + num_persistent_clusters = num_persistent_ctas // num_ctas_per_cluster + return (*params.cluster_shape_mn, num_persistent_clusters) + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + if const_expr(not params.is_persistent): + cluster_id_in_problem = self._current_work_linear_idx + _, _, bidz = cute.arch.block_idx() + else: + bidz, cluster_id_in_problem = params.num_clusters_per_problem_divmod.divmod( + self._current_work_linear_idx + ) + # CTA Swizzle to promote L2 data reuse + group_size = params.group_size_divmod.divisor + group_id = ( + utils.ceil( + (utils.sqrt(2 * cluster_id_in_problem + 2.25) - 0.5) * params.group_size_inv_f32 + ) + - 1 + ) + cid_m_start = group_id * group_size + id_in_group = cluster_id_in_problem - (cid_m_start * (cid_m_start + 1)) // 2 + group_size_actual = ( + group_size + if group_id < params.num_groups_regular + else params.group_size_tail_divmod.divisor + ) + group_col, group_remainder = Int32(0), Int32(0) + if group_id < params.num_groups_regular: + group_col, group_remainder = params.group_size_mul_group_size_divmod.divmod(id_in_group) + else: # tail part + group_col, group_remainder = params.group_size_tail_mul_group_size_divmod.divmod( + id_in_group + ) + cid_m_in_group, cid_n_in_group = Int32(0), Int32(0) + if id_in_group >= group_size_actual * group_size * group_id: # triangular tail + cid_m_in_group, cid_n_in_group = triangular_idx_to_coord(group_remainder) + else: + if group_id < params.num_groups_regular: + cid_n_in_group, cid_m_in_group = params.group_size_divmod.divmod(group_remainder) + else: + cid_n_in_group, cid_m_in_group = params.group_size_tail_divmod.divmod( + group_remainder + ) + cid_m = cid_m_start + cid_m_in_group + cid_n = group_col * group_size + cid_n_in_group + + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] + tile_coord_mnkl = (pid_m, pid_n, None, bidz) + if const_expr(not params.is_persistent): + is_valid = self.num_tiles_executed == 0 + else: + is_valid = ( + self._current_work_linear_idx + < params.num_clusters_per_problem_divmod.divisor + * params.problem_shape_ncluster_mnl[2] + ) + # bidx, bidy, bidz = cute.arch.block_idx() + # tidx, _, _ = cute.arch.thread_idx() + # if tidx == 0: + # cute.printf("bidx = {}, bidy = {}, group_id = {}, id_in_group = {}, group_size_actual = {}, group_col = {}, group_remainder = {}, cid_n_in_group = {}, cid_m_in_group = {}, cid_m = {}, cid_n = {}, is_valid = {}", + # bidx, bidy, group_id, id_in_group, group_size_actual, group_col, group_remainder, cid_n_in_group, cid_m_in_group, cid_m, cid_n, is_valid) + return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + + +@dataclass +class VarlenMTileSchedulerArguments(ParamsBase): + problem_shape_ntile_mnl: cute.Shape + total_m: Int32 + cu_seqlens_m: cute.Tensor + raster_order: cutlass.Constexpr[RasterOrderOption] + group_size: Int32 + tile_shape_mn: cutlass.Constexpr[cute.Shape] + cluster_shape_mnk: cutlass.Constexpr[cute.Shape] + tile_count_semaphore: Optional[cute.Pointer] = None + is_persistent: cutlass.Constexpr[bool] = False + + +class VarlenMTileScheduler(TileScheduler): + @dataclass + class Params(ParamsBase): + problem_shape_ncluster_mnl: cute.Shape + total_m: Int32 + cu_seqlens_m: cute.Tensor + raster_order: cutlass.Constexpr[RasterOrder] + group_size: Int32 + group_size_divmod: Optional[FastDivmod] + group_size_tail_divmod: Optional[FastDivmod] + num_clusters_in_group_divmod: FastDivmod + tile_shape_mn: cutlass.Constexpr[cute.Shape] + tile_count_semaphore: Optional[cute.Pointer] + cluster_shape_mn: cutlass.Constexpr[cute.Shape] + is_persistent: cutlass.Constexpr[bool] + + @staticmethod + @cute.jit + def create( + args: TileSchedulerArguments, *, loc=None, ip=None + ) -> "VarlenMTileScheduler.Params": + assert args.cluster_shape_mnk[2] == 1 + cluster_shape_mn = const_expr(cute.select(args.cluster_shape_mnk, mode=[0, 1])) + # problem_shape_ntile_mnl[0] will be None for VarlenM + problem_shape_ntile_mn = cute.select(args.problem_shape_ntile_mnl, mode=[0, 1]) + problem_shape_ncluster_mn = ( + None, + cute.ceil_div(problem_shape_ntile_mn[1], cluster_shape_mn[1]), + ) + problem_shape_ncluster_mnl = problem_shape_ncluster_mn + ( + args.problem_shape_ntile_mnl[2], + ) + raster_order = const_expr( + RasterOrder.AlongM + if args.raster_order == RasterOrderOption.AlongM + else RasterOrder.AlongN # For Heuristic we also use AlongN + ) + ncluster_fast = ( + problem_shape_ncluster_mn[0] + if raster_order == RasterOrder.AlongM + else problem_shape_ncluster_mn[1] + ) + ncluster_slow = ( + problem_shape_ncluster_mn[1] + if raster_order == RasterOrder.AlongM + else problem_shape_ncluster_mn[0] + ) + if const_expr(ncluster_fast is not None): + group_size = min(args.group_size, ncluster_fast) + group_size_tail = ncluster_fast % group_size + else: + group_size, group_size_tail = args.group_size, None + if const_expr(ncluster_slow is not None): + num_clusters_in_group = group_size * ncluster_slow + else: + num_clusters_in_group = None + return VarlenMTileScheduler.Params( + problem_shape_ncluster_mnl, + args.total_m, + args.cu_seqlens_m, + raster_order, + group_size, + FastDivmod.create(group_size) if ncluster_fast is not None else None, + # Don't divide by 0 + FastDivmod.create(group_size_tail if group_size_tail > 0 else 1) + if group_size_tail is not None + else None, + FastDivmod.create(num_clusters_in_group) + if num_clusters_in_group is not None + else None, + args.tile_shape_mn, + args.tile_count_semaphore if const_expr(args.is_persistent) else None, + cluster_shape_mn, + args.is_persistent, + ) + + def __init__( + self, + current_work_linear_idx: Int32, + num_tiles_executed: Int32, + current_batch_idx: Int32, + num_work_idx_before_cur_batch: Int32, + tile_count: Optional[cute.Tensor], + scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync], + pipeline_state: PipelineStateWAdvance, + params: Params, + *, + loc=None, + ip=None, + ): + self._current_work_linear_idx = current_work_linear_idx + self.num_tiles_executed = num_tiles_executed + self._current_batch_idx = current_batch_idx + self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch + self._tile_count = tile_count + self._scheduler_pipeline = scheduler_pipeline + self._pipeline_state = pipeline_state + self.params = params + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: TileSchedulerArguments, *, loc=None, ip=None) -> Params: + return VarlenMTileScheduler.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + tile_count: Optional[cute.Tensor] = None, + scheduler_pipeline: Optional[cutlass.pipeline.PipelineAsync] = None, + is_scheduler_warp: bool | Boolean = False, + *, + loc=None, + ip=None, + ) -> "VarlenMTileScheduler": + stages = 0 + _, _, bidz = cute.arch.block_idx() + current_work_linear_idx = Int32(bidz) + if const_expr(params.tile_count_semaphore is not None): + assert tile_count is not None + assert scheduler_pipeline is not None + stages = const_expr(cute.size(tile_count)) + return VarlenMTileScheduler( + current_work_linear_idx, + Int32(0), # num_tiles_executed + Int32(0), # current_batch_idx + Int32(0), # num_work_idx_before_cur_batch + tile_count, + scheduler_pipeline, + PipelineStateWAdvance(stages, Int32(0), Int32(0), Int32(1 if is_scheduler_warp else 0)), + params, + loc=loc, + ip=ip, + ) + + # called by host + @staticmethod + def get_grid_shape( + params: Params, + max_active_clusters: Int32, + *, + loc=None, + ip=None, + ) -> Tuple[Int32, Int32, Int32]: + block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0] + num_batch = params.problem_shape_ncluster_mnl[2] + total_clusters_m_max = (params.total_m + num_batch * (block_size - 1)) // block_size + total_clusters_max = total_clusters_m_max * params.problem_shape_ncluster_mnl[1] + if const_expr(not params.is_persistent): + return (*params.cluster_shape_mn, total_clusters_max) + else: + num_persistent_clusters = cutlass.min(max_active_clusters, total_clusters_max) + return (*params.cluster_shape_mn, num_persistent_clusters) + + @cute.jit + def _get_num_m_blocks( + self, lane: Int32, bidb_start: Int32, block_size: cutlass.Constexpr[int] + ) -> Int32: + num_batch = self.params.problem_shape_ncluster_mnl[2] + batch_idx = lane + bidb_start + cur_cu_seqlen = Int32(0) + if batch_idx <= num_batch: + cur_cu_seqlen = self.params.cu_seqlens_m[batch_idx] + next_cu_seqlen = cute.arch.shuffle_sync_down(cur_cu_seqlen, offset=1) + seqlen = next_cu_seqlen - cur_cu_seqlen + return ( + cute.ceil_div(seqlen, block_size) + if batch_idx < num_batch and lane < cute.arch.WARP_SIZE - 1 + else Int32(0) + ) + + @cute.jit + def _swizzle_cta( + self, cluster_id_in_problem: Int32, num_clusters_m: Int32, *, loc=None, ip=None + ) -> Tuple[Int32, Int32]: + params = self.params + # CTA Swizzle to promote L2 data reuse + if const_expr(params.num_clusters_in_group_divmod is not None): + group_id, id_in_group = params.num_clusters_in_group_divmod.divmod( + cluster_id_in_problem + ) + num_clusters_in_group = params.num_clusters_in_group_divmod.divisor + else: + assert params.raster_order == RasterOrder.AlongN + num_clusters_in_group = params.group_size * num_clusters_m + group_id = cluster_id_in_problem // num_clusters_in_group + id_in_group = cluster_id_in_problem - group_id * num_clusters_in_group + cid_fast_in_group, cid_slow = Int32(0), Int32(0) + if const_expr( + params.group_size_divmod is not None and params.group_size_tail_divmod is not None + ): + num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] + if (group_id + 1) * num_clusters_in_group <= num_clusters: + cid_slow, cid_fast_in_group = params.group_size_divmod.divmod(id_in_group) + else: # tail part + cid_slow, cid_fast_in_group = params.group_size_tail_divmod.divmod(id_in_group) + else: + assert params.raster_order == RasterOrder.AlongM + group_size_actual = cutlass.min( + params.group_size, num_clusters_m - group_id * params.group_size + ) + cid_slow = id_in_group // group_size_actual + cid_fast_in_group = id_in_group - cid_slow * group_size_actual + if group_id % 2 == 1: # serpentine order + ncluster_slow = ( + params.problem_shape_ncluster_mnl[1] + if params.raster_order == RasterOrder.AlongM + else num_clusters_m + ) + cid_slow = ncluster_slow - 1 - cid_slow + cid_fast = group_id * params.group_size + cid_fast_in_group + cid_m, cid_n = cid_fast, cid_slow + if params.raster_order == RasterOrder.AlongN: + cid_m, cid_n = cid_slow, cid_fast + return cid_m, cid_n + + @cute.jit + def get_current_work(self, *, loc=None, ip=None) -> cutlass.utils.WorkTileInfo: + params = self.params + lane_idx = cute.arch.lane_idx() + num_batch = self.params.problem_shape_ncluster_mnl[2] + block_size = params.tile_shape_mn[0] * params.cluster_shape_mn[0] + batch_idx = self._current_batch_idx + num_clusters_m = self._get_num_m_blocks( + lane_idx, bidb_start=batch_idx, block_size=block_size + ) + num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] + num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx) + # Total number of blocks for the next 31 problems, same for all lanes + clusters_in_problems = cute.arch.shuffle_sync( + num_clusters_cumulative, cute.arch.WARP_SIZE - 1 + ) + problems_end_tile = self._num_work_idx_before_cur_batch + clusters_in_problems + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, num_clusters_cumulative = %d, problems_end_tile = %d", self._tile_idx, problems_end_tile, num_clusters_m, num_clusters_cumulative, problems_end_tile) + cid_m, cid_n = Int32(0), Int32(0) + next_tile_idx = self._current_work_linear_idx + while problems_end_tile <= next_tile_idx: + batch_idx += cute.arch.WARP_SIZE - 1 + if batch_idx >= num_batch: + batch_idx = Int32(num_batch) + problems_end_tile = next_tile_idx + 1 + else: + num_clusters_m = self._get_num_m_blocks( + lane_idx, bidb_start=batch_idx, block_size=block_size + ) + num_clusters = num_clusters_m * params.problem_shape_ncluster_mnl[1] + num_clusters_cumulative = utils.warp_prefix_sum(num_clusters, lane_idx) + clusters_in_problems = cute.arch.shuffle_sync( + num_clusters_cumulative, cute.arch.WARP_SIZE - 1 + ) + problems_end_tile += clusters_in_problems + # Just a placeholer value in case batch_idx >= num_batch + num_work_idx_before_cur_batch = problems_end_tile - clusters_in_problems + if batch_idx >= num_batch: + cid_m, cid_n, batch_idx = Int32(0), Int32(0), Int32(num_batch) + else: + problems_start_tile = problems_end_tile - clusters_in_problems + # if cute.arch.thread_idx()[0] == 128 + 31: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, problems_end_tile = %d, num_clusters_m=%d, batch_idx = %d", self._tile_idx, problems_end_tile, num_clusters_m, batch_idx) + # The next problem to process is the first one that does not have ending tile position + # that is greater than or equal to tile index. + batch_idx_in_problems = cute.arch.popc( + cute.arch.vote_ballot_sync( + problems_start_tile + num_clusters_cumulative <= next_tile_idx + ) + ) + batch_idx += batch_idx_in_problems + num_clusters_prev_lane = ( + 0 + if batch_idx_in_problems == 0 + else cute.arch.shuffle_sync(num_clusters_cumulative, batch_idx_in_problems - 1) + ) + num_clusters_m = cute.arch.shuffle_sync(num_clusters_m, batch_idx_in_problems) + num_work_idx_before_cur_batch = problems_start_tile + num_clusters_prev_lane + cluster_id_in_problem = next_tile_idx - num_work_idx_before_cur_batch + # cid_n = cluster_id_in_problem // num_clusters_m + # cid_m = cluster_id_in_problem - cid_n * num_clusters_m + # if cute.arch.thread_idx()[0] == 128: cute.printf("SingleTileVarlenScheduler: tile_idx=%d, batch_idx=%d, cid_n=%d, cid_m=%d, is_valid = %d", self._tile_idx, batch_idx, cid_n, cid_m, is_valid) + cid_m, cid_n = self._swizzle_cta(cluster_id_in_problem, num_clusters_m, loc=loc, ip=ip) + self._current_batch_idx = batch_idx + self._num_work_idx_before_cur_batch = num_work_idx_before_cur_batch + + # Get the pid from cluster id + bidx_in_cluster = cute.arch.block_in_cluster_idx() + pid_m = cid_m * params.cluster_shape_mn[0] + bidx_in_cluster[0] + pid_n = cid_n * params.cluster_shape_mn[1] + bidx_in_cluster[1] + tile_coord_mnkl = (pid_m, pid_n, None, batch_idx) + if const_expr(not params.is_persistent): + is_valid = self.num_tiles_executed == 0 and batch_idx < num_batch + else: + is_valid = batch_idx < num_batch + return cutlass.utils.WorkTileInfo(tile_coord_mnkl, is_valid) + + @cute.jit + def fetch_next_work(self, is_scheduler_warp: bool | Boolean = False, *, loc=None, ip=None): + """is_scheduler_warp should only be true for one warp in the whole cluster""" + if const_expr(self.params.tile_count_semaphore is not None): + params = self.params + current_work_linear_idx = self._current_work_linear_idx + if is_scheduler_warp: + if cute.arch.lane_idx() == 0: + # cute.printf("before atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx) + num_persistent_clusters = cute.arch.grid_dim()[2] + current_work_linear_idx = num_persistent_clusters + utils.atomic_add_i32( + 1, params.tile_count_semaphore + ) + # cute.printf("after atomicadd, tidx = {}, bidz = {}, idx = {}", cute.arch.thread_idx()[0], cute.arch.block_idx()[2], current_work_linear_idx) + # lane 0 already has the right tile_idx, just need to broadcast + current_work_linear_idx = cute.arch.shuffle_sync(current_work_linear_idx, 0) + self._current_work_linear_idx = current_work_linear_idx + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self._current_work_linear_idx, + self.num_tiles_executed, + self._current_batch_idx, + self._num_work_idx_before_cur_batch, + self._tile_count, + self._scheduler_pipeline, + self._pipeline_state, + self.params, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self._current_work_linear_idx, + self.num_tiles_executed, + self._current_batch_idx, + self._num_work_idx_before_cur_batch, + self._tile_count, + self._scheduler_pipeline, + self._pipeline_state, + self.params, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/topk.py b/sonic-moe/torch-ext/sonicmoe/quack/topk.py new file mode 100644 index 00000000..dba25180 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/topk.py @@ -0,0 +1,552 @@ +from ._ops_compat import add_quack_op_namespace_prefix +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Tri Dao. + +import math +from functools import partial +from typing import Type, Optional + +import torch + +import cuda.bindings.driver as cuda + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Float32, const_expr + +from . import utils as utils +from . import copy_utils as copy_utils +from .compile_utils import make_fake_tensor as fake_tensor +from .reduction_base import ReductionBase +from .reduce import row_reduce +from .cute_dsl_utils import torch2cute_dtype_map +from .sort.bitonic_sort import bitonic_topk + + +class TopK: + def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False): + self.dtype = dtype + self.N = N + self.vecsize = 128 // dtype.width + self.k = k + self.softmax = softmax + assert N == 2 ** int(math.log2(N)), "N must be a power of 2" + assert k == 2 ** int(math.log2(k)), "N must be a power of 2" + assert k <= 128 + assert N <= 4096 + + def _threads_per_row(self): + # we want num_elems_per_thread >= self.k + # and each thread can handle at most 64 elements + N = self.N + num_threads_per_row = max(min(N // self.k, 32, N // 64), 1) + return num_threads_per_row + + def _get_tiled_copy(self): + N = self.N + vecsize = self.vecsize + num_threads = 128 if N <= 16384 else 256 + threads_per_row = self._threads_per_row() + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(min(N, 16384) // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tiled_copy = copy_utils.tiled_copy_2d( + self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize + ) + return tiled_copy, tiler_mn, threads_per_row + + @cute.jit + def __call__( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + stream: cuda.CUstream, + ): + assert mX.element_type == self.dtype + assert mValues.element_type == self.dtype + assert mIndices.element_type == Int32 + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy() + num_threads = tiled_copy.size + self.kernel(mX, mValues, mIndices, tiler_mn, tiled_copy, threads_per_row).launch( + grid=[cute.ceil_div(mX.shape[0], tiler_mn[0]), 1, 1], + block=[num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mX: cute.Tensor, + mValues: cute.Tensor, + mIndices: cute.Tensor, + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + tv_layout = tiled_copy.layout_tv_tiled + + shape = mX.shape + idX = cute.make_identity_tensor(shape) + # slice for CTAs + gX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mX, idX)] + + thr_copy = tiled_copy.get_slice(tidx) + + tXgX = thr_copy.partition_S(gX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + tXrX = cute.make_fragment_like(tXgX) + + is_even_N = const_expr(shape[1] == tiler_mn[1]) + tXpX = ( + None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + copy = partial(copy_utils.copy, pred=tXpX) + + if tXcX[0][0] < shape[0]: + copy(tXgX, tXrX) + tXrX_f32 = cute.make_fragment(tXrX.shape, Float32) + tXrX_f32.store(tXrX.load().to(Float32)) + + # Encode the indices into the bottom bits of values. + log_N = int(math.log2(self.N)) + idx_mask = (1 << log_N) - 1 + vecsize = const_expr(cute.size(tv_layout.shape[1])) + tXrX_i32 = cute.recast_tensor(tXrX_f32, Int32) + # Encode indices into the last log_N bits of tXrX_i32 + for i in cutlass.range(cute.size(tXrX_i32), unroll_full=True): + # tXcX only keeps track of the indices for every @vecsize elements + col_idx = Int32(tXcX[i // vecsize][1] + i % vecsize) + # If positive, invert the bits of the index, so that if there's a tie, + # indices coming from a earlier column will win. + encoded_idx = ~col_idx if tXrX_f32[i] >= 0 else col_idx + # Mask to keep only the last log_N bits of the encoded index + encoded_idx = encoded_idx & idx_mask + # Clear the last log_N bits and set them to our encoded index + tXrX_i32[i] = (tXrX_i32[i] & ~idx_mask) | encoded_idx + + # Fill OOB values with -inf for top-k + if const_expr(not is_even_N): + utils.fill_oob(tXrX_f32, tXpX, -tXrX_f32.element_type.inf) + + topk_vals = bitonic_topk(tXrX_f32, self.k, warp_width=threads_per_row) + + # Thread 0 in each row contains all the top-k values, so we split those into multiple threads + vecsize_out = const_expr(min(self.k, vecsize, 128 // mIndices.element_type.width)) + assert self.k % vecsize_out == 0 + nvec_per_thread = const_expr(cute.ceil_div(self.k, vecsize_out * threads_per_row)) + # 1 -> 0b11111, 2 -> 0b11110, 4 -> 0b11100, 8 -> 0b11000, 16 -> 0b10000, 32 -> 0b00000 + mask = cute.arch.WARP_SIZE - threads_per_row + mask_and_clamp = mask << 8 | (cute.arch.WARP_SIZE - 1) + topk_vals_split = cute.make_fragment((vecsize_out, nvec_per_thread), Float32) + for i in cutlass.range(cute.ceil_div(self.k, vecsize_out), unroll_full=True): + should_receive = tidx % threads_per_row == i % threads_per_row + for v in cutlass.range(vecsize_out, unroll_full=True): + if const_expr(threads_per_row > 1): + if i * vecsize_out + v < self.k: + val = cute.arch.shuffle_sync( + topk_vals[i * vecsize_out + v], offset=0, mask_and_clamp=mask_and_clamp + ) + if should_receive: + topk_vals_split[v, i // threads_per_row] = val + else: + topk_vals_split[v, i // threads_per_row] = topk_vals[i * vecsize_out + v] + + # Extract indices and clean values + topk_vals_i32 = cute.recast_tensor(topk_vals_split, Int32) + topk_indices = cute.make_fragment(topk_vals_i32.shape, Int32) + for i in cutlass.range(cute.size(topk_vals_i32), unroll_full=True): + # Extract the encoded index from the last log_N bits + encoded_idx = topk_vals_i32[i] & idx_mask + # Check if original value was positive by looking at the cleaned value + topk_vals_i32[i] = topk_vals_i32[i] & ~idx_mask # Clear last log_N bits + # If positive, we need to invert the bits back to get original index + col_idx = ~encoded_idx if topk_vals[i] >= 0 else encoded_idx + topk_indices[i] = Int32(col_idx & idx_mask) + + # Compute softmax if requested + if const_expr(self.softmax): + # Need masking as some elements may be OOB + for i in cutlass.range(cute.size(topk_vals_split, mode=[1]), unroll_full=True): + col = i * threads_per_row + tidx % threads_per_row + if col >= self.k // vecsize_out: + for v in cutlass.range(vecsize_out, unroll_full=True): + topk_vals_split[v, i] = -Float32.inf + # Get max from thread 0 (topk_vals[0] is the max since sorted descending) + max_val = cute.arch.shuffle_sync(topk_vals[0], offset=0, mask_and_clamp=mask_and_clamp) + log2_e = math.log2(math.e) + exp_x = cute.math.exp2( + topk_vals_split.load() * log2_e - (max_val * log2_e), fastmath=True + ) + denom = cute.arch.warp_reduction_sum( + exp_x.reduce(cute.ReductionOp.ADD, init_val=0.0, reduction_profile=0), + threads_in_group=threads_per_row, + ) + topk_vals_split.store(exp_x * cute.arch.rcp_approx(denom)) + + # Convert cleaned values to output type + topk_vals_out = cute.make_fragment_like(topk_vals_split, mValues.element_type) + topk_vals_out.store(topk_vals_split.load().to(mValues.element_type)) + + row = tXcX[0][0] + # # Only the 1st thread in this row writes the top-k values and indices + # if row < shape[0] and tXcX[0][1] == 0: + # # for i in cutlass.range(self.k): + # # mValues[row, i] = topk_vals_out[i] + # # mIndices[row, i] = topk_indices[i] + # # Vectorized write + # elems_per_store = const_expr(math.gcd(vecsize, self.k)) + # mValues_store = cute.tiled_divide(mValues[row, None], (elems_per_store,)) + # mIndices_store = cute.tiled_divide(mIndices[row, None], (elems_per_store,)) + # topk_vals_out_store = cute.tiled_divide(topk_vals_out, (elems_per_store,)) + # topk_indices_store = cute.tiled_divide(topk_indices, (elems_per_store,)) + # for i in cutlass.range(cute.size(topk_vals_out_store.shape, [1]), unroll_full=True): + # cute.autovec_copy(topk_vals_out_store[None, i], mValues_store[None, i]) + # cute.autovec_copy(topk_indices_store[None, i], mIndices_store[None, i]) + if tiler_mn[0] == 0 or row < shape[0]: + # Vectorized write + mValues_store = cute.tiled_divide(mValues[row, None], (vecsize_out,)) + mIndices_store = cute.tiled_divide(mIndices[row, None], (vecsize_out,)) + for i in cutlass.range(cute.size(topk_vals_out.shape, [1]), unroll_full=True): + col = i * threads_per_row + tidx % threads_per_row + if col < self.k // vecsize_out: + cute.autovec_copy(topk_vals_out[None, i], mValues_store[None, col]) + cute.autovec_copy(topk_indices[None, i], mIndices_store[None, col]) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("topk_fwd"), mutates_args={"values", "indices"}) +def _topk_fwd( + x: torch.Tensor, k: int, softmax: bool, values: torch.Tensor, indices: torch.Tensor +) -> None: + """Top-k forward pass. + Args: + x: Input tensor of shape (M, N) + k: Number of top elements to return + softmax: Whether to apply softmax to the top-k values + Returns: + Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k)) + """ + assert x.dim() == 2, "Input must be 2D" + assert x.is_cuda, "Tensor must be on CUDA device" + assert x.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + assert k > 0 and k <= x.shape[1], "k must be positive and <= N" + + N = x.size(1) + dtype = torch2cute_dtype_map[x.dtype] + compile_key = (dtype, N, k, softmax) + if compile_key not in _topk_fwd.compile_cache: + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + x_cute = fake_tensor(dtype, (batch_sym, N), div) + values_cute = fake_tensor(dtype, (batch_sym, k), div) + indices_cute = fake_tensor(Int32, (batch_sym, k), div) + topk_op = TopK(dtype, N, k, softmax=softmax) + _topk_fwd.compile_cache[compile_key] = cute.compile( + topk_op, + x_cute, + values_cute, + indices_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + _topk_fwd.compile_cache[compile_key](x, values, indices) + + +_topk_fwd.compile_cache = {} + + +def topk_fwd(x: torch.Tensor, k: int, softmax: bool = False): + """Top-k operation. + + Args: + x: Input tensor of shape (M, N) + k: Number of top elements to return + softmax: Whether to apply softmax to the top-k values + + Returns: + Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k)) + """ + M = x.size(0) + values = torch.empty((M, k), dtype=x.dtype, device=x.device) + indices = torch.empty((M, k), dtype=torch.int32, device=x.device) + _topk_fwd(x, k, softmax, values, indices) + return values, indices + + +class TopKBackward(ReductionBase): + def __init__(self, dtype: Type[cutlass.Numeric], N: int, k: int, softmax: bool = False): + super().__init__(dtype, N, stage=1, reduction_dtype=Float32) + self.dtype = dtype + self.N = N + self.k = k + self.softmax = softmax + assert k <= N + assert k <= 32768 + + def _num_threads(self): + return 128 if self.N <= 16384 else 256 + + def _get_tiled_copy(self, N: int, vecsize: Optional[int] = None): + if vecsize is None: + vecsize = min(N, 128 // self.dtype.width) + assert N % vecsize == 0, f"Input N {N} is not divisible by vector size {vecsize}" + num_threads = self._num_threads() + threads_per_row = min(N // vecsize, num_threads) + cols_per_block = num_threads // threads_per_row + num_blocks_N = cute.ceil_div(N // vecsize, threads_per_row) + tiler_mn = (cols_per_block, vecsize * num_blocks_N * threads_per_row) + tiled_copy = copy_utils.tiled_copy_2d( + self.dtype, threads_per_row, num_threads, num_copy_elems=vecsize + ) + return tiled_copy, tiler_mn, threads_per_row + + @cute.jit + def __call__( + self, + mdValues: cute.Tensor, # (M, k) + mValues: Optional[cute.Tensor], # (M, k) + mIndices: cute.Tensor, # (M, k) + mdX: cute.Tensor, # (M, N) + stream: cuda.CUstream, + ): + assert mdValues.element_type == self.dtype + if const_expr(mValues is not None): + assert mValues.element_type == self.dtype + assert mIndices.element_type == Int32 + self._set_cluster_n() + largest_dtype_width = const_expr( + max( + *(t.element_type.width for t in [mdValues, mValues, mIndices, mdX] if t is not None) + ) + ) + vecsize = math.gcd(self.N, 128 // largest_dtype_width) + tiled_copy, tiler_mn, threads_per_row = self._get_tiled_copy(self.N, vecsize=vecsize) + num_threads = tiled_copy.size + self.kernel( + mdValues, + mValues, + mIndices, + mdX, + tiler_mn, + tiled_copy, + threads_per_row, + ).launch( + grid=[cute.ceil_div(mdX.shape[0], tiler_mn[0]), 1, 1], + block=[num_threads, 1, 1], + stream=stream, + ) + + @cute.kernel + def kernel( + self, + mdValues: cute.Tensor, # (M, k) + mValues: Optional[cute.Tensor], # (M, k) + mIndices: cute.Tensor, # (M, k) + mdX: cute.Tensor, # (M, N) + tiler_mn: cute.Shape, + tiled_copy: cute.TiledCopy, + threads_per_row: cutlass.Constexpr[int], + ): + tidx, _, _ = cute.arch.thread_idx() + bidx, _, _ = cute.arch.block_idx() + + tv_layout = tiled_copy.layout_tv_tiled + shape = mdX.shape + idX = cute.make_identity_tensor(shape) + idTopK = cute.make_identity_tensor(mdValues.shape) + # slice for CTAs + gdX, cX = [cute.local_tile(mT, tiler_mn, (bidx, 0)) for mT in (mdX, idX)] + gdVals, gVals, gIdx, cTopK = [ + cute.local_tile(mT, tiler_mn, (bidx, 0)) if mT is not None else None + for mT in (mdValues, mValues, mIndices, idTopK) + ] + + # Allocate smem for output gradients + smem = cutlass.utils.SmemAllocator() + sdX = smem.allocate_tensor( + mdX.element_type, cute.make_ordered_layout(tiler_mn, order=(1, 0)), byte_alignment=16 + ) + reduction_buffer, mbar_ptr = self._allocate_reduction_buffer_and_mbar(smem, tv_layout) + + thr_copy = tiled_copy.get_slice(tidx) + + tXgdV = thr_copy.partition_S(gdVals) + tXgV = thr_copy.partition_S(gVals) if const_expr(gVals is not None) else None + tXgI = thr_copy.partition_S(gIdx) + tXrdV = cute.make_fragment_like(tXgdV) + tXrV = cute.make_fragment_like(tXgV) if const_expr(tXgV is not None) else None + tXrI = cute.make_fragment_like(tXgI) + tXrdV.fill(tXrdV.element_type.zero) + if const_expr(mValues is not None): + tXrV.fill(tXrV.element_type.zero) + tXrI.fill(0) + + tXsdX = thr_copy.partition_D(sdX) + tXgdX = thr_copy.partition_D(gdX) + tXcX = thr_copy.partition_S(cX)[(0, None), None, None] + tXrdX = cute.make_fragment_like(tXgdX) + + is_even_N = const_expr(shape[1] == tiler_mn[1]) + tXpV = copy_utils.predicate_k(thr_copy.partition_S(cTopK), limit=mdValues.shape[1]) + tXpX = ( + None if is_even_N else copy_utils.predicate_k(thr_copy.partition_S(cX), limit=shape[1]) + ) + copy_k = partial(copy_utils.copy, pred=tXpV) + copy_dx = partial(copy_utils.copy, pred=tXpX) + + row = tXcX[0][0] + tile_row_start = Int32(cute.arch.block_idx()[0] * tiler_mn[0]) + + # Zero out smem + utils.fill_oob(tXsdX, None, fill_value=mdX.element_type.zero) + + if row < shape[0]: + copy_k(tXgdV, tXrdV) + if const_expr(mValues is not None): + copy_k(tXgV, tXrV) + copy_k(tXgI, tXrI) + + cute.arch.barrier() + + dvals_f32 = tXrdV.load().to(Float32) + if const_expr(self.softmax): + vals_f32 = tXrV.load().to(Float32) + dot = row_reduce( + dvals_f32 * vals_f32, + cute.ReductionOp.ADD, + threads_per_row, + reduction_buffer[None, None, 0], + ) + grads = vals_f32 * (dvals_f32 - dot) + else: + grads = dvals_f32 + grad_cvt = cute.make_fragment(tXrdV.shape, mdX.element_type) + grad_cvt.store(grads.to(mdX.element_type)) + + # Scatter values to smem + if row < shape[0]: + for rest_v in cutlass.range(tXrdV.shape[0][1], unroll_full=True): + for n in cutlass.range(tXrdV.shape[2], unroll_full=True): + if tXpV[rest_v, 0, n]: + for v in cutlass.range(tXrdV.shape[0][0], unroll_full=True): + sdX[row - tile_row_start, tXrI[(v, rest_v), 0, n]] = grad_cvt[ + (v, rest_v), 0, n + ] + cute.arch.barrier() + + # Read from smem to rmem, then write to gmem + cute.autovec_copy(tXsdX, tXrdX) + if row < shape[0]: + copy_dx(tXrdX, tXgdX) + + +@torch.library.custom_op(add_quack_op_namespace_prefix("topk_bwd"), mutates_args={"dx"}) +def _topk_bwd( + dvalues: torch.Tensor, + values: Optional[torch.Tensor], + indices: torch.Tensor, + k: int, + softmax: bool, + dx: torch.Tensor, +) -> None: + """Top-k backward pass. + Args: + dvalues: Upstream gradients tensor of shape (M, k) + values: Forward top-k values tensor of shape (M, k) + indices: Indices tensor of shape (M, k) from forward pass + k: Number of top elements + softmax: Whether softmax was applied in forward + dx: Output gradient tensor of shape (M, N) + """ + assert dvalues.dim() == 2, "dvalues must be 2D" + if values is not None: + assert values.dim() == 2, "values must be 2D" + assert indices.dim() == 2, "indices must be 2D" + assert dvalues.is_cuda and indices.is_cuda, "Tensors must be on CUDA device" + assert dvalues.dtype in [torch.float16, torch.bfloat16, torch.float32], "Unsupported dtype" + + N = dx.size(1) + dtype = torch2cute_dtype_map[dvalues.dtype] + val_dtype = torch2cute_dtype_map[values.dtype] if values is not None else dtype + dx_dtype = torch2cute_dtype_map[dx.dtype] + compile_key = (dtype, val_dtype, dx_dtype, N, k, softmax) + if compile_key not in _topk_bwd.compile_cache: + batch_sym = cute.sym_int() + div = math.gcd(128 // dtype.width, N) + dvalues_cute = fake_tensor(dtype, (batch_sym, k), div) + values_cute = fake_tensor(val_dtype, (batch_sym, k), div) if values is not None else None + indices_cute = fake_tensor(Int32, (batch_sym, k), div) + dx_cute = fake_tensor(dx_dtype, (batch_sym, N), div) + topk_bwd_op = TopKBackward(dtype, N, k, softmax=softmax) + _topk_bwd.compile_cache[compile_key] = cute.compile( + topk_bwd_op, + dvalues_cute, + values_cute, + indices_cute, + dx_cute, + cute.runtime.make_fake_stream(use_tvm_ffi_env_stream=True), + options="--enable-tvm-ffi", + ) + _topk_bwd.compile_cache[compile_key](dvalues, values, indices, dx) + + +_topk_bwd.compile_cache = {} + + +def topk_bwd( + dvalues: torch.Tensor, + values: Optional[torch.Tensor], + indices: torch.Tensor, + N: int, + softmax: bool = False, +) -> torch.Tensor: + """Top-k backward pass. + + Args: + dvalues: Upstream gradients tensor of shape (M, k) + values: Forward top-k values tensor of shape (M, k), required if softmax=True + indices: Indices tensor of shape (M, k) from forward pass + N: Size of the original input dimension + softmax: Whether softmax was applied in forward + + Returns: + Input gradients tensor of shape (M, N) + """ + M, k = dvalues.shape + dx = torch.zeros((M, N), dtype=dvalues.dtype, device=dvalues.device) + _topk_bwd(dvalues, values, indices, k, softmax, dx) + return dx + + +class TopKFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, x: torch.Tensor, k: int, softmax: bool = False): + values, indices = topk_fwd(x, k, softmax=softmax) + ctx.save_for_backward(values if softmax else None, indices) + ctx.k = k + ctx.N = x.shape[1] + ctx.softmax = softmax + ctx.mark_non_differentiable(indices) + ctx.set_materialize_grads(False) + return values, indices + + @staticmethod + def backward(ctx, dvalues: torch.Tensor, dindices_: Optional[torch.Tensor] = None): + values, indices = ctx.saved_tensors + dx = topk_bwd(dvalues, values, indices, N=ctx.N, softmax=ctx.softmax) + return dx, None, None + + +def topk(x: torch.Tensor, k: int, softmax: bool = False): + """Top-k operation. + + Args: + x: Input tensor of shape (M, N) + k: Number of top elements to return + softmax: Whether to apply softmax to the top-k values + + Returns: + Tuple of (values tensor of shape (M, k), indices tensor of shape (M, k)) + """ + return TopKFunction.apply(x, k, softmax) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/utils.py b/sonic-moe/torch-ext/sonicmoe/quack/utils.py new file mode 100644 index 00000000..a7b110ea --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/utils.py @@ -0,0 +1,223 @@ +# Copyright (c) 2025, Wentao Guo, Ted Zadouri, Tri Dao. + +import math +from functools import partial +from typing import Optional, Tuple, Union + +import cutlass +import cutlass.cute as cute + +from cutlass import Float32, Int32, const_expr +from cutlass.cutlass_dsl import T, dsl_user_op +from cutlass._mlir.dialects import llvm, nvvm, vector + + +# cute.arch.{fma,mul,add}_packed_f32x2 uses RZ rounding mode by default +fma_packed_f32x2 = partial(cute.arch.fma_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +mul_packed_f32x2 = partial(cute.arch.mul_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +add_packed_f32x2 = partial(cute.arch.add_packed_f32x2, rnd=nvvm.RoundingModeKind.RN) +sub_packed_f32x2 = partial( + cute.arch.calc_packed_f32x2_op, + src_c=None, + calc_func=nvvm.sub_packed_f32x2, + rnd=nvvm.RoundingModeKind.RN, +) + + +@dsl_user_op +def elem_pointer(x: cute.Tensor, coord: cute.Coord, *, loc=None, ip=None) -> cute.Pointer: + return x.iterator + cute.crd2idx(coord, x.layout, loc=loc, ip=ip) + + +@cute.jit +def load_scalar_or_pointer(x: Float32 | cute.Pointer) -> Float32: + if const_expr(isinstance(x, cute.Pointer)): + return Float32(cute.make_tensor(x, cute.make_layout(1))[0]) + else: + assert isinstance(x, Float32) + return x + + +@dsl_user_op +def set_block_rank( + smem_ptr: cute.Pointer, peer_cta_rank_in_cluster: Int32, *, loc=None, ip=None +) -> Int32: + """Map the given smem pointer to the address at another CTA rank in the cluster.""" + smem_ptr_i32 = smem_ptr.toint(loc=loc, ip=ip).ir_value() + return Int32( + llvm.inline_asm( + T.i32(), + [smem_ptr_i32, peer_cta_rank_in_cluster.ir_value()], + "mapa.shared::cluster.u32 $0, $1, $2;", + "=r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def store_shared_remote( + val: float | Float32 | Int32 | cutlass.Int64, + smem_ptr: cute.Pointer, + mbar_ptr: cute.Pointer, + peer_cta_rank_in_cluster: cute.typing.Int, + *, + loc=None, + ip=None, +) -> None: + remote_smem_ptr_i32 = set_block_rank( + smem_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + remote_mbar_ptr_i32 = set_block_rank( + mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip + ).ir_value() + if const_expr(isinstance(val, float)): + val = Float32(val) + assert isinstance(val, (Float32, Int32, cutlass.Int64)), "val must be Float32, Int32, or Int64" + suffix = {Float32: "f32", Int32: "s32", cutlass.Int64: "s64"}[type(val)] + constraint = {Float32: "f", Int32: "r", cutlass.Int64: "l"}[type(val)] + llvm.inline_asm( + None, + [remote_smem_ptr_i32, val.ir_value(loc=loc, ip=ip), remote_mbar_ptr_i32], + f"st.async.shared::cluster.mbarrier::complete_tx::bytes.{suffix} [$0], $1, [$2];", + f"r,{constraint},r", + has_side_effects=True, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + + +@dsl_user_op +def fmin(a: Union[float, Float32], b: Union[float, Float32], *, loc=None, ip=None) -> Float32: + return Float32( + nvvm.fmin( + T.f32(), + Float32(a).ir_value(loc=loc, ip=ip), + Float32(b).ir_value(loc=loc, ip=ip), + loc=loc, + ip=ip, + ) + ) + + +@dsl_user_op +def sqrt(a: float | Float32, *, loc=None, ip=None) -> Float32: + return Float32( + llvm.inline_asm( + T.f32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "sqrt.approx.f32 $0, $1;", + "=f,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def ceil(a: float | Float32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [Float32(a).ir_value(loc=loc, ip=ip)], + "cvt.rpi.ftz.s32.f32 $0, $1;", + "=r,f", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@dsl_user_op +def prmt(a: int | Int32, b: int | Int32, c: int | Int32, *, loc=None, ip=None) -> Int32: + return Int32( + llvm.inline_asm( + T.i32(), + [ + Int32(a).ir_value(loc=loc, ip=ip), + Int32(b).ir_value(loc=loc, ip=ip), + Int32(c).ir_value(loc=loc, ip=ip), + ], + "prmt.b32 $0, $1, $2, $3;", + "=r,r,r,r", + has_side_effects=False, + is_align_stack=False, + asm_dialect=llvm.AsmDialect.AD_ATT, + ) + ) + + +@cute.jit +def fill_oob(tXsX: cute.Tensor, tXpX: Optional[cute.Tensor], fill_value: cute.Numeric) -> None: + """Fill out-of-bounds values in shared memory tensor. + + Args: + tXsX: Shared memory tensor to fill + tXpX: Predicate tensor indicating valid elements + fill_value: Value to fill OOB locations with + """ + tXrX_fill = cute.make_fragment_like(tXsX[(None, 0), None, 0]) + tXrX_fill.fill(fill_value) + for rest_v in cutlass.range_constexpr(tXsX.shape[0][1]): + for rest_k in cutlass.range_constexpr(tXsX.shape[2]): + if const_expr(tXpX is not None): + if not tXpX[rest_v, 0, rest_k]: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + else: + cute.autovec_copy(tXrX_fill, tXsX[(None, rest_v), None, rest_k]) + + +@dsl_user_op +def f32x2_to_i64(a: Float32, b: Float32, *, loc=None, ip=None) -> cutlass.Int64: + vec_f32x2 = vector.from_elements( + T.vector(2, T.f32()), (a.ir_value(), b.ir_value()), loc=loc, ip=ip + ) + vec_i64x1 = vector.bitcast(T.vector(1, T.i64()), vec_f32x2) + res = cutlass.Int64( + vector.extract(vec_i64x1, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + return res + + +@dsl_user_op +def i64_to_f32x2(c: cutlass.Int64, *, loc=None, ip=None) -> Tuple[Float32, Float32]: + vec_i64x1 = vector.from_elements(T.vector(1, T.i64()), (c.ir_value(),), loc=loc, ip=ip) + vec_f32x2 = vector.bitcast(T.vector(2, T.f32()), vec_i64x1) + res0 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[0], loc=loc, ip=ip) + ) + res1 = Float32( + vector.extract(vec_f32x2, dynamic_position=[], static_position=[1], loc=loc, ip=ip) + ) + return res0, res1 + + +@cute.jit +def warp_prefix_sum(val: Int32, lane: Optional[Int32] = None) -> Int32: + if const_expr(lane is None): + lane = cute.arch.lane_idx() + for i in cutlass.range_constexpr(int(math.log2(cute.arch.WARP_SIZE))): + offset = 1 << i + # Very important that we set mask_and_clamp to 0 + partial_sum = cute.arch.shuffle_sync_up(val, offset=offset, mask_and_clamp=0) + if lane >= offset: + val += partial_sum + return val + + +@dsl_user_op +def atomic_add_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.ADD, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) + + +@dsl_user_op +def atomic_inc_i32(a: int | Int32, gmem_ptr: cute.Pointer, *, loc=None, ip=None) -> Int32: + return nvvm.atomicrmw( + res=T.i32(), op=nvvm.AtomicOpKind.INC, ptr=gmem_ptr.llvm_ptr, a=Int32(a).ir_value() + ) diff --git a/sonic-moe/torch-ext/sonicmoe/quack/varlen_utils.py b/sonic-moe/torch-ext/sonicmoe/quack/varlen_utils.py new file mode 100644 index 00000000..b265cfbc --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack/varlen_utils.py @@ -0,0 +1,386 @@ +# Copyright (c) 2025, Tri Dao. + +from typing import Optional +from dataclasses import dataclass + +import cutlass +import cutlass.cute as cute +from cutlass import Int32, Boolean, const_expr +from cutlass.utils import LayoutEnum + +from .cute_dsl_utils import ArgumentsBase, ParamsBase +from .tensormap_manager import TensorMapManagerSm90 + + +# Grouping arguments together that should be passed to __call__ +@dataclass +class VarlenArguments(ArgumentsBase): + mCuSeqlensM: Optional[cute.Tensor] = None + mCuSeqlensK: Optional[cute.Tensor] = None + mTensormaps: Optional[cute.Tensor] = None + mAIdx: Optional[cute.Tensor] = None + + +class VarlenManager: + bytes_per_tensormap = 128 + + @dataclass + class Params(ParamsBase): + cu_seqlens_m: Optional[cute.Tensor] = None + cu_seqlens_k: Optional[cute.Tensor] = None + tensormaps: Optional[cute.Tensor] = None + mAIdx: Optional[cute.Tensor] = None + + @staticmethod + @cute.jit + def create(args: VarlenArguments, *, loc=None, ip=None) -> "VarlenManager.Params": + return VarlenManager.Params( + cu_seqlens_m=args.mCuSeqlensM, + cu_seqlens_k=args.mCuSeqlensK, + tensormaps=args.mTensormaps, + mAIdx=args.mAIdx, + ) + + def __init__( + self, + params: Params, + tensormap_manager: Optional[cutlass.utils.TensorMapManager], + tensormap_a_ptr: Optional[cute.Pointer], + tensormap_b_ptr: Optional[cute.Pointer], + tensormap_d_ptr: Optional[cute.Pointer], + tensormap_epi_ptrs: list[Optional[cute.Pointer]], + len_m_static: Int32, + len_k_static: Int32, + last_batch_idx: Int32 = Int32(-1), + is_group_changed: Boolean = Boolean(True), + *, + loc=None, + ip=None, + ): + self.params = params + self.tensormap_manager = tensormap_manager + self._tensormap_a_ptr = tensormap_a_ptr + self._tensormap_b_ptr = tensormap_b_ptr + self._tensormap_d_ptr = tensormap_d_ptr + self._tensormap_epi_ptrs = tensormap_epi_ptrs + self._len_m_static = len_m_static + self._len_k_static = len_k_static + self._last_batch_idx = last_batch_idx + self._is_group_changed = is_group_changed + self.varlen_m = const_expr(params.cu_seqlens_m is not None) + self.varlen_k = const_expr(params.cu_seqlens_k is not None) + self.gather_A = const_expr(params.mAIdx is not None) + self._loc = loc + self._ip = ip + + @staticmethod + def to_underlying_arguments(args: VarlenArguments, *, loc=None, ip=None) -> Params: + assert not (args.mCuSeqlensM is not None and args.mCuSeqlensK is not None), ( + "Only support either varlen_m or varlen_k" + ) + return VarlenManager.Params.create(args, loc=loc, ip=ip) + + @staticmethod + @cute.jit + def create( + params: Params, + has_D: bool, + num_epi_tensormaps: int, + len_m_static: Int32, + len_k_static: Int32, + pingpong: bool = False, + warp_idx: int | Int32 = 0, + *, + loc=None, + ip=None, + ) -> "VarlenManager": + tensormap_manager = None + tensormap_a_ptr, tensormap_b_ptr, tensormap_d_ptr = None, None, None + tensormap_epi_ptrs = [None] * num_epi_tensormaps + varlen_m = const_expr(params.cu_seqlens_m is not None) + varlen_k = const_expr(params.cu_seqlens_k is not None) + if const_expr(varlen_m or varlen_k): + tensormap_manager = TensorMapManagerSm90( + cutlass.utils.TensorMapUpdateMode.GMEM, VarlenManager.bytes_per_tensormap + ) + # equivalent to bidx + bidy * gridDim.x + bidxz * gridDim.x * gridDim.y + tensormap_workspace_idx = cute.make_layout(cute.arch.grid_dim())(cute.arch.block_idx()) + if const_expr(varlen_m): + tensormap_d_idx = warp_idx // 4 if const_expr(pingpong) else 0 + tensormap_epi_offset = tensormap_d_idx + if const_expr(has_D): + tensormap_d_ptr = tensormap_manager.get_tensormap_ptr( + params.tensormaps[tensormap_workspace_idx, tensormap_d_idx, None].iterator + ) + tensormap_epi_offset += 1 if not pingpong else 2 + tensormap_epi_ptrs = [ + tensormap_manager.get_tensormap_ptr( + params.tensormaps[ + tensormap_workspace_idx, + tensormap_epi_offset + i * (1 if not pingpong else 2), + None, + ].iterator + ) + for i in range(num_epi_tensormaps) + ] + else: + assert varlen_k + gather_A = const_expr(params.mAIdx is not None) + if const_expr(not gather_A): + tensormap_a_ptr = tensormap_manager.get_tensormap_ptr( + params.tensormaps[tensormap_workspace_idx, 0, None].iterator + ) + tensormap_b_ptr = tensormap_manager.get_tensormap_ptr( + params.tensormaps[ + tensormap_workspace_idx, 1 if not gather_A else 0, None + ].iterator + ) + return VarlenManager( + params, + tensormap_manager, + tensormap_a_ptr, + tensormap_b_ptr, + tensormap_d_ptr, + tensormap_epi_ptrs, + len_m_static=len_m_static, + len_k_static=len_k_static, + ) + + def len_m(self, batch_idx: Int32) -> Int32: + if const_expr(self.varlen_m): + return self.params.cu_seqlens_m[batch_idx + 1] - self.params.cu_seqlens_m[batch_idx] + else: + return self._len_m_static + + def len_k(self, batch_idx: Int32) -> Int32: + if const_expr(self.varlen_k): + return self.params.cu_seqlens_k[batch_idx + 1] - self.params.cu_seqlens_k[batch_idx] + else: + return self._len_k_static + + def offset_batch_A(self, mA_mkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: + params = self.params + if const_expr(self.varlen_m): + mA_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mA_mkl) + elif const_expr(self.varlen_k): + mA_mk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mA_mkl) + else: + mA_mk = mA_mkl[None, None, batch_idx] + return mA_mk + + def offset_batch_AIdx(self, batch_idx: Int32) -> cute.Tensor: + params = self.params + if const_expr(self.varlen_m): + mAIdx_mk = cute.domain_offset((params.cu_seqlens_m[batch_idx],), params.mAIdx) + elif const_expr(self.varlen_k): + mAIdx_mk = cute.domain_offset((params.cu_seqlens_k[batch_idx],), params.mAIdx) + else: + mAIdx_mk = params.mAIdx[None, batch_idx] + return mAIdx_mk + + def offset_batch_B(self, mB_nkl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: + params = self.params + if const_expr(self.varlen_k): + mB_nk = cute.domain_offset((0, params.cu_seqlens_k[batch_idx]), mB_nkl) + else: + mB_nk = mB_nkl[None, None, batch_idx] + return mB_nk + + def offset_batch_epi(self, mD_mnl: cute.Tensor, batch_idx: Int32) -> cute.Tensor: + params = self.params + if const_expr(self.varlen_m): + mD_mn = cute.domain_offset((params.cu_seqlens_m[batch_idx], 0), mD_mnl) + else: + mD_mn = mD_mnl[None, None, batch_idx] + return mD_mn + + def init_tensormap_AB( + self, + tma_atom_a: Optional[cute.CopyAtom], + tma_atom_b: cute.CopyAtom, + is_manager_warp: bool | Boolean = True, + ) -> None: + if const_expr(self.varlen_k): + if const_expr(not self.gather_A): + self.tensormap_manager.init_tensormap_from_atom( + tma_atom_a, self._tensormap_a_ptr, is_manager_warp + ) + self.tensormap_manager.init_tensormap_from_atom( + tma_atom_b, self._tensormap_b_ptr, is_manager_warp + ) + + def init_tensormap_epi( + self, + tma_atom_d: Optional[cute.CopyAtom], + tma_atoms_epi: list[cute.CopyAtom], + is_manager_warp: bool | Boolean = True, + ) -> None: + if const_expr(self.varlen_m): + if const_expr(self._tensormap_d_ptr is not None): + self.tensormap_manager.init_tensormap_from_atom( + tma_atom_d, self._tensormap_d_ptr, is_manager_warp + ) + for tma_atom, tensormap_epi_ptr in zip(tma_atoms_epi, self._tensormap_epi_ptrs): + self.tensormap_manager.init_tensormap_from_atom( + tma_atom, tensormap_epi_ptr, is_manager_warp + ) + + def fence_tensormap_init(self) -> None: + self.tensormap_manager.fence_tensormap_initialization() + + @cute.jit + def update_tensormap_AB( + self, + batch_idx: Int32, + a_layout: LayoutEnum, + b_layout: LayoutEnum, + is_manager_warp: bool | Boolean = True, + ) -> None: + if const_expr(self.varlen_k): + self._is_group_changed = Boolean(batch_idx != self._last_batch_idx) + self._last_batch_idx = batch_idx + if self._is_group_changed: + # construct tensor A/B based on real address, shape and stride information + cu_seqlens_k = self.params.cu_seqlens_k + tensormap_ptrs = [self._tensormap_b_ptr] + shapes = [cu_seqlens_k[batch_idx + 1]] + orders = [0 if const_expr(b_layout == LayoutEnum.ROW_MAJOR) else 1] + if const_expr(not self.gather_A): + tensormap_ptrs.insert(0, self._tensormap_a_ptr) + shapes.insert(0, cu_seqlens_k[batch_idx + 1]) + orders.insert(0, 0 if const_expr(a_layout == LayoutEnum.ROW_MAJOR) else 1) + self.tensormap_manager.update_tensormap_shape( + tensormap_ptrs, + is_manager_warp=is_manager_warp, + shapes=shapes, + orders=orders, + tensormap_smem_ptr=None, + ) + + @cute.jit + def update_tensormap_epi( + self, + batch_idx: Int32, + d_layout: LayoutEnum, + epi_shapes: list[Int32], + epi_orders: list[int], + is_manager_warp: bool | Boolean = True, + ) -> None: + if const_expr(self.varlen_m): + self._is_group_changed = Boolean(batch_idx != self._last_batch_idx) + self._last_batch_idx = batch_idx + # Cute-DSL doesn't like this under if statement + order_d = ( + (0 if const_expr(d_layout.is_m_major_c()) else 1) if d_layout is not None else None + ) + if self._is_group_changed: + # construct tensor A/B based on real address, shape and stride information + cu_seqlens_m = self.params.cu_seqlens_m + # construct tensor D based on real address, shape and stride information + tensormap_ptrs, shapes, orders = [], [], [] + if const_expr(self._tensormap_d_ptr is not None): + tensormap_ptrs.append(self._tensormap_d_ptr) + shapes.append(cu_seqlens_m[batch_idx + 1]) + orders.append(order_d) + tensormap_ptrs.extend(self._tensormap_epi_ptrs) + shapes.extend(epi_shapes) + orders.extend(epi_orders) + self.tensormap_manager.update_tensormap_shape( + tensormap_ptrs, + is_manager_warp=is_manager_warp, + shapes=shapes, + orders=orders, + tensormap_smem_ptr=None, + ) + + @cute.jit + def fence_tensormap_update_AB(self, is_manager_warp: bool | Boolean = True) -> None: + if const_expr(self.varlen_k): + if self._is_group_changed and is_manager_warp: + if const_expr(not self.gather_A): + self.tensormap_manager.fence_tensormap_update(self._tensormap_a_ptr) + self.tensormap_manager.fence_tensormap_update(self._tensormap_b_ptr) + + @cute.jit + def fence_tensormap_update_epi(self, is_manager_warp: bool | Boolean = True) -> None: + if const_expr(self.varlen_m): + if self._is_group_changed and is_manager_warp: + if const_expr(self._tensormap_d_ptr is not None): + self.tensormap_manager.fence_tensormap_update(self._tensormap_d_ptr) + for tensormap_epi_ptr in self._tensormap_epi_ptrs: + if const_expr(tensormap_epi_ptr is not None): + self.tensormap_manager.fence_tensormap_update(tensormap_epi_ptr) + + def get_tma_desc_a_ptr(self) -> Optional[cute.Pointer]: + tma_desc_a_ptr = None + if const_expr(self.varlen_k and self._tensormap_a_ptr is not None): + tma_desc_a_ptr = self.tensormap_manager.get_tensormap_ptr( + self._tensormap_a_ptr, cute.AddressSpace.generic + ) + return tma_desc_a_ptr + + def get_tma_desc_b_ptr(self) -> Optional[cute.Pointer]: + tma_desc_b_ptr = None + if const_expr(self.varlen_k): + tma_desc_b_ptr = self.tensormap_manager.get_tensormap_ptr( + self._tensormap_b_ptr, cute.AddressSpace.generic + ) + return tma_desc_b_ptr + + def get_tma_desc_d_ptr(self) -> Optional[cute.Pointer]: + tma_desc_d_ptr = None + if const_expr(self.varlen_m and self._tensormap_d_ptr is not None): + tma_desc_d_ptr = self.tensormap_manager.get_tensormap_ptr( + self._tensormap_d_ptr, cute.AddressSpace.generic + ) + return tma_desc_d_ptr + + def get_tma_desc_epi_ptrs(self) -> list[Optional[cute.Pointer]]: + tma_desc_epi_ptrs = [None] * len(self._tensormap_epi_ptrs) + if const_expr(self.varlen_m): + for i, tensormap_epi_ptr in enumerate(self._tensormap_epi_ptrs): + if const_expr(tensormap_epi_ptr is not None): + tma_desc_epi_ptrs[i] = self.tensormap_manager.get_tensormap_ptr( + tensormap_epi_ptr, cute.AddressSpace.generic + ) + return tma_desc_epi_ptrs + + def __extract_mlir_values__(self): + values, self._values_pos = [], [] + for obj in [ + self.params, + self.tensormap_manager, + self._tensormap_a_ptr, + self._tensormap_b_ptr, + self._tensormap_d_ptr, + self._tensormap_epi_ptrs, + self._len_m_static, + self._len_k_static, + self._last_batch_idx, + self._is_group_changed, + ]: + obj_values = cutlass.extract_mlir_values(obj) + values += obj_values + self._values_pos.append(len(obj_values)) + return values + + def __new_from_mlir_values__(self, values): + obj_list = [] + for obj, n_items in zip( + [ + self.params, + self.tensormap_manager, + self._tensormap_a_ptr, + self._tensormap_b_ptr, + self._tensormap_d_ptr, + self._tensormap_epi_ptrs, + self._len_m_static, + self._len_k_static, + self._last_batch_idx, + self._is_group_changed, + ], + self._values_pos, + ): + obj_list.append(cutlass.new_from_mlir_values(obj, values[:n_items])) + values = values[n_items:] + return self.__class__(*(tuple(obj_list)), loc=self._loc) diff --git a/sonic-moe/torch-ext/sonicmoe/quack_utils/__init__.py b/sonic-moe/torch-ext/sonicmoe/quack_utils/__init__.py new file mode 100644 index 00000000..de3a2b20 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack_utils/__init__.py @@ -0,0 +1,5 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from .gemm_interface import gemm_dgated, gemm_gated diff --git a/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_dgated.py b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_dgated.py new file mode 100644 index 00000000..d581c9b4 --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_dgated.py @@ -0,0 +1,501 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +import operator +from dataclasses import dataclass +from functools import partial +from typing import Callable, Optional, Tuple, Type + +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from ..quack import activation +from ..quack import layout_utils +from ..quack import sm90_utils +from ..quack import utils +import torch +from cutlass import Float32, Int32, const_expr +from cutlass.cute.runtime import from_dlpack +from ..quack.cute_dsl_utils import ( + ArgumentsBase, + ParamsBase, + get_device_capacity, + get_max_active_clusters, + torch2cute_dtype_map, +) +from ..quack.gemm_act import GemmActMixin +from ..quack.gemm_default_epi import GemmDefaultEpiMixin +from ..quack.gemm_sm90 import GemmSm90 +from ..quack.gemm_sm100 import GemmSm100 +from ..quack.gemm_wrapper_utils import GemmWrapperBase +from ..quack.sm90_utils import partition_for_epilogue +from ..quack.varlen_utils import VarlenManager +from torch import Tensor + + +class GemmDGatedMixin(GemmActMixin): + # Different from GemmActMixin, here act_bwd_fn must take in 3 arguments (x, y, dout) + # and return 3 arguments (dx, dy, out) + @dataclass + class EpilogueArguments(ArgumentsBase): + mPostAct: cute.Tensor + act_bwd_fn: cutlass.Constexpr[Callable] + implicit_dtype: Type[cutlass.Numeric] = cute.BFloat16 + # We don't use alpha, beta, mRowVecBroadcast for now + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + mColVecReduce: Optional[cute.Tensor] = None + + @dataclass + class EpilogueParams(ParamsBase): + tma_atom_postact: cute.CopyAtom + mPostAct_mnl: cute.Tensor + epi_postact_smem_layout_staged: cute.ComposedLayout + epi_tile_postact: cute.Tile + act_bwd_fn: cutlass.Constexpr[Callable] + implicit_dtype: Type[cutlass.Numeric] + alpha: Optional[Float32 | cute.Tensor] = None + beta: Optional[Float32 | cute.Tensor] = None + mRowVecBroadcast: Optional[cute.Tensor] = None + mColVecBroadcast: Optional[cute.Tensor] = None + mColVecReduce: Optional[cute.Tensor] = None + + def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None) -> EpilogueParams: + self.postact_dtype = args.mPostAct.element_type + self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) + # C and D are implicitly 2 16-bit elements packed into 32 bits, simply for the purpose + # for reusing the existing load/store code. + assert args.implicit_dtype.width == 16, "GemmDGated only supports 16bit for now" + assert self.d_dtype.width == 32, "D storage type must be 32 bit" + assert self.c_dtype.width == 32, "C storage type must be 32 bit" + + self.cta_tile_shape_postact_mn = self.cta_tile_shape_mnk[:2] + epi_tile_postact = self.epi_tile + utils_cls = sm100_utils if self.arch == 100 else sm90_utils + epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( + self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage + ) + tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( + args.mPostAct, + epi_postact_smem_layout_staged, + epi_tile_postact, + op_type="store", + ) + # Assume all strides are divisible by 32 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s for s in t.stride + ) + mRowVecBroadcast, mColVecBroadcast, mColVecReduce = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None + for t in (args.mRowVecBroadcast, args.mColVecBroadcast, args.mColVecReduce) + ] + return self.EpilogueParams( + tma_atom_postact, + tma_tensor_postact, + epi_postact_smem_layout_staged, + epi_tile_postact, + args.act_bwd_fn, + args.implicit_dtype, + alpha=args.alpha, + beta=args.beta, + mRowVecBroadcast=mRowVecBroadcast, + mColVecBroadcast=mColVecBroadcast, + mColVecReduce=mColVecReduce, + ) + + @cute.jit + def epi_begin( + self, + params: EpilogueParams, + epi_smem_tensors: Tuple[cute.Tensor, ...], + epi_tile: cute.Tile, + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + epilogue_barrier: cutlass.pipeline.NamedBarrier, + tidx: Int32, + ) -> Tuple[cute.Tensor, ...]: + epi_tensors = GemmDefaultEpiMixin.epi_begin( + self, + params, + epi_smem_tensors, + epi_tile, + tiled_copy_t2r, + tiled_copy_r2s, + tile_coord_mnkl, + varlen_manager, + epilogue_barrier, + tidx, + ) + partition_for_epilogue_fn = partial( + partition_for_epilogue, + epi_tile=epi_tile, + tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, + tidx=tidx, + reference_src=tiled_copy_t2r is None, + ) + tDrColVecReduce = None + if const_expr(params.mColVecReduce is not None): + colvec_mma_layout = cute.make_layout(self.cta_tile_shape_mnk[:2], stride=(1, 0)) + tDrColVec_layout = partition_for_epilogue_fn(cute.make_rmem_tensor(colvec_mma_layout, Float32)).layout + tDrColVecReduce = cute.make_rmem_tensor(tDrColVec_layout, Float32) + cute.filter_zeros(tDrColVecReduce).fill(0.0) + return (*epi_tensors, tDrColVecReduce) + + def epi_begin_loop(self, params: EpilogueParams, epi_tensors, epi_coord: cute.Coord): + epi_tensors, tDrColVecReduce = epi_tensors[:-1], epi_tensors[-1] + epi_loop_tensors = super().epi_begin_loop(params, epi_tensors, epi_coord) + tDrColVecReduce_cur = None + if const_expr(tDrColVecReduce is not None): + tDrColVecReduce_cur = cute.group_modes(tDrColVecReduce, 3, cute.rank(tDrColVecReduce))[ + None, None, None, epi_coord + ] + return (*epi_loop_tensors, tDrColVecReduce_cur) + + @cute.jit + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + alpha, beta, tDrRowVec, tDrColVec, tDrColVecReduce = epi_loop_tensors + assert alpha is None and beta is None and tDrRowVec is None # We don't use these for now + assert tRS_rC is not None + implicit_dtype = params.implicit_dtype + assert implicit_dtype.width == 16, "GemmDGatedMixin only supports 16bit for now" + tRS_rXY_f16x2 = cute.recast_tensor(tRS_rC, implicit_dtype) + tRS_rXY_f32x2 = cute.make_rmem_tensor(tRS_rXY_f16x2.layout, Float32) + tRS_rXY_f32x2.store(tRS_rXY_f16x2.load().to(Float32)) + tRS_rdXY_f32x2 = cute.make_rmem_tensor_like(tRS_rXY_f32x2, Float32) + tRS_rOut = cute.make_rmem_tensor_like(tRS_rD, Float32) + tRS_rD_scaled = cute.make_rmem_tensor_like(tRS_rD) + if const_expr(tDrColVec is not None): # Scale D by colvec + if const_expr(self.arch < 100): + tRS_rD_scaled.store(tRS_rD.load() * tDrColVec.load().to(tRS_rD.element_type)) + else: + tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) + tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVec.layout) + tRS_rD_scaled_mn = layout_utils.convert_layout_zero_stride(tRS_rD_scaled, tDrColVec.layout) + for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): + for n in cutlass.range(cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True): + ( + tRS_rD_scaled_mn[m, 2 * n], + tRS_rD_scaled_mn[m, 2 * n + 1], + ) = cute.arch.mul_packed_f32x2( + (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), + (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), + ) + else: + tRS_rD_scaled.store(tRS_rD.load()) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rD)): + ( + tRS_rdXY_f32x2[2 * i], + tRS_rdXY_f32x2[2 * i + 1], + tRS_rOut[i], + ) = params.act_bwd_fn(tRS_rXY_f32x2[2 * i], tRS_rXY_f32x2[2 * i + 1], tRS_rD_scaled[i]) + else: + for i in cutlass.range(cute.size(tRS_rD) // 2): + ( + (tRS_rdXY_f32x2[4 * i], tRS_rdXY_f32x2[4 * i + 2]), + (tRS_rdXY_f32x2[4 * i + 1], tRS_rdXY_f32x2[4 * i + 3]), + (tRS_rOut[2 * i], tRS_rOut[2 * i + 1]), + ) = params.act_bwd_fn( + (tRS_rXY_f32x2[4 * i], tRS_rXY_f32x2[4 * i + 2]), + (tRS_rXY_f32x2[4 * i + 1], tRS_rXY_f32x2[4 * i + 3]), + (tRS_rD_scaled[2 * i], tRS_rD_scaled[2 * i + 1]), + ) + if const_expr(tDrColVecReduce is not None): + # Need to multiply before D is scaled by colvec_scale + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tDrColVecReduce), unroll_full=True): + tDrColVecReduce[i] += tRS_rOut[i] * tRS_rD[i] + else: + tDrColVecReduce_mn = layout_utils.convert_layout_zero_stride(tDrColVecReduce, tDrColVecReduce.layout) + tRS_rD_mn = layout_utils.convert_layout_zero_stride(tRS_rD, tDrColVecReduce.layout) + tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVecReduce.layout) + for m in cutlass.range(cute.size(tDrColVecReduce_mn, mode=[0]), unroll_full=True): + row_sum = cute.arch.mul_packed_f32x2( + (tRS_rD_mn[m, 0], tRS_rD_mn[m, 1]), (tRS_rOut_mn[m, 0], tRS_rOut_mn[m, 1]) + ) + for n in cutlass.range(1, cute.size(tDrColVecReduce_mn, mode=[1]) // 2, unroll_full=True): + row_sum = utils.fma_packed_f32x2( + (tRS_rD_mn[m, 2 * n], tRS_rD_mn[m, 2 * n + 1]), + (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]), + row_sum, + ) + tDrColVecReduce_mn[m, 0] += row_sum[0] + row_sum[1] + + if const_expr(tDrColVec is not None): # Scale Out by colvec + if const_expr(self.arch < 100): + tRS_rOut.store(tRS_rOut.load() * tDrColVec.load().to(tRS_rD.element_type)) + else: + tDrColVec_mn = layout_utils.convert_layout_zero_stride(tDrColVec, tDrColVec.layout) + tRS_rOut_mn = layout_utils.convert_layout_zero_stride(tRS_rOut, tDrColVec.layout) + for m in cutlass.range(cute.size(tDrColVec_mn, mode=[0]), unroll_full=True): + for n in cutlass.range(cute.size(tDrColVec_mn, mode=[1]) // 2, unroll_full=True): + tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1] = cute.arch.mul_packed_f32x2( + (tRS_rOut_mn[m, 2 * n], tRS_rOut_mn[m, 2 * n + 1]), + (tDrColVec_mn[m, 0], tDrColVec_mn[m, 0]), + ) + # Type conversion + tRS_rdXY_f16x2 = cute.make_rmem_tensor(tRS_rdXY_f32x2.layout, implicit_dtype) + tRS_rdXY_f16x2.store(tRS_rdXY_f32x2.load().to(implicit_dtype)) + tRS_rD.store(cute.recast_tensor(tRS_rdXY_f16x2, Float32).load()) + tRS_rOut_cvt = cute.make_rmem_tensor_like(tRS_rOut, self.postact_dtype) + tRS_rOut_cvt.store(tRS_rOut.load().to(self.postact_dtype)) + return tRS_rOut_cvt + + @cute.jit + def epi_end( + self, + params: EpilogueParams, + epi_tensors: Tuple[cute.Tensor, ...], + epi_tile: cute.Tile, + tiled_copy_t2r: Optional[cute.TiledCopy], + tiled_copy_r2s: cute.TiledCopy, + tile_coord_mnkl: cute.Coord, + varlen_manager: VarlenManager, + tidx: Int32, + ) -> None: + partition_for_epilogue_fn = partial( + partition_for_epilogue, + epi_tile=epi_tile, + tiled_copy=tiled_copy_t2r if tiled_copy_t2r is not None else tiled_copy_r2s, + tidx=tidx, + reference_src=tiled_copy_t2r is None, + ) + tDrColVecReduce = epi_tensors[-1] + tile_M, tile_N = self.cta_tile_shape_mnk[:2] + if const_expr(params.mColVecReduce is not None): + tDrCVR_flt = cute.filter_zeros(tDrColVecReduce) + if const_expr(self.arch != 100): + for i in cutlass.range(cute.size(tDrCVR_flt), unroll_full=True): + tDrCVR_flt[i] = cute.arch.warp_reduction(tDrCVR_flt[i], operator.add, threads_in_group=4) + else: + # Don't need warp_reduce since we load from tmem with one thread per row + assert self.d_layout.is_n_major_c(), "GemmDGated only supports n-major output for now" + batch_idx = tile_coord_mnkl[3] + limit_n = params.mColVecReduce.shape[2] if not varlen_manager.varlen_m else params.mColVecReduce.shape[1] + if tile_coord_mnkl[1] < limit_n: + if const_expr(not varlen_manager.varlen_m): + mColVec = params.mColVecReduce[batch_idx, None, tile_coord_mnkl[1]] + else: + mColVec = cute.domain_offset( + (varlen_manager.params.cu_seqlens_m[batch_idx],), + params.mColVecReduce[None, tile_coord_mnkl[1]], + ) + gColVec = cute.local_tile(mColVec, (tile_M,), (tile_coord_mnkl[0],)) + limit_m = min(varlen_manager.len_m(batch_idx) - tile_coord_mnkl[0] * tile_M, tile_M) + tDcCV = partition_for_epilogue_fn(cute.make_identity_tensor((tile_M, tile_N))) + tDrColVecReduce_m = layout_utils.convert_layout_zero_stride(tDrColVecReduce, tDrColVecReduce.layout)[ + None, 0 + ] + tDcCV_m = layout_utils.convert_layout_zero_stride(tDcCV, tDrColVecReduce.layout)[None, 0] + if tDcCV_m[0][1] == 0: + for m in cutlass.range(cute.size(tDcCV_m, mode=[0])): + row_idx = tDcCV_m[m][0] + if row_idx < limit_m: + gColVec[row_idx] = tDrColVecReduce_m[m] + + +class GemmDGatedSm90(GemmDGatedMixin, GemmSm90): + pass + + +class GemmDGatedSm100(GemmDGatedMixin, GemmSm100): + pass + + +dgate_fn_map = { + "swiglu": quack.activation.dswiglu, + "swiglu_oai": quack.activation.dswiglu_oai, + "reglu": quack.activation.dreglu, + "geglu": quack.activation.dgeglu, + "glu": quack.activation.dglu, +} + + +def gemm_dgated( + A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m + B: Tensor, # (l, n, k) + Out: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m + PreAct: Tensor, # (l, m, 2*n) if n_major or (l, 2*m, n) if m_major, or (total_m, 2*n) if varlen_m + PostAct: Tensor, # (l, m, n) or (total_m, n) if varlen_m + tile_count_semaphore: Optional[Tensor], # (1,) + activation: Optional[str], + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = True, + persistent: bool = True, + max_swizzle_size: int = 8, + colvec_scale: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m + # (l, m, ceildiv(n, tile_n)), or (total_m, ceildiv(n, tile_n)) if varlen_m + colvec_reduce: Optional[Tensor] = None, + cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length + A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m +) -> None: + """If tile_count_semaphore is provided, it must already be zero'ed out.""" + if cu_seqlens_m is not None: + assert persistent, "varlen_m requires persistent=True" + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + assert Out.stride(-1) == 1, "varlen_m requires Out to be n-major" + assert PreAct.stride(-1) == 1, "varlen_m requires PreAct to be n-major" + assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" + gather_A = A_idx is not None + if gather_A: + assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cluster_N == 1, "gather_A requires cluster_N=1" + assert activation in dgate_fn_map, f"Unsupported activation {activation}" + + # Special handling for Out and PreAct + AB_swapped = not Out.stride(-1) == 1 + assert Out.dtype == PreAct.dtype + implicit_dtype = torch2cute_dtype_map[Out.dtype] + assert Out.element_size() == 2, "Out dtype must be fp16 or bf16" + assert PreAct.element_size() == 2, "Preact dtype must be fp16 or bf16" + # We pretend that Out is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16. + # Similarly we pretend that PreAct is (M, N, L) of type fp32 instead of (M, 2N, L) of type f16 + if cu_seqlens_m is not None or not AB_swapped: + # varlen_m (always AB_swapped=False) or normal case with AB_swapped=False + Out = Out.view(torch.float32) + PreAct = PreAct.view(torch.float32) + else: + # Normal case with AB_swapped=True + Out = Out.mT.view(torch.float32).mT + PreAct = PreAct.mT.view(torch.float32).mT + + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, + B, + Out, + PreAct, + additional_tensors={"PostAct": PostAct}, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + "PostAct": ("m", "n", "l"), + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmDGatedSm100 if device_capacity[0] > 9 else GemmDGatedSm90 + + acc_dtype = Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + act_fn = dgate_fn_map[activation] + epi_args = GemmCls.EpilogueArguments( + tensor_infos["PostAct"].cute_tensor, + act_fn, + implicit_dtype=implicit_dtype, + mColVecBroadcast=( + from_dlpack(colvec_scale.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 if cu_seqlens_m is None else 0 + ) + if colvec_scale is not None + else None + ), + mColVecReduce=( + from_dlpack(colvec_reduce.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=2 if cu_seqlens_m is None else 1 + ) + if colvec_reduce is not None + else None + ), + ) + scheduler_args = GemmWrapperBase.create_scheduler_args(max_active_clusters, tile_count_semaphore) + + # Create varlen arguments if needed (assumes persistent=True when varlen_m) + varlen_args = GemmWrapperBase.create_varlen_args( + cu_seqlens_m, + None, # cu_seqlens_k + A_idx, + max_active_clusters, + cluster_shape_mnk, + tensor_infos, + GemmCls.num_epi_tensormaps, + pingpong, + ) + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + activation, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + max_swizzle_size, + colvec_scale.dtype if colvec_scale is not None else None, + colvec_reduce.dtype if colvec_reduce is not None else None, + cu_seqlens_m is not None, + A_idx is not None, + key_tensor_names=("A", "B", "D", "PostAct", "C"), + ) + cache = gemm_dgated.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm_obj = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + ) + cache[compile_key] = cute.compile( + gemm_obj, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, # Out + tensor_infos["C"].cute_tensor, # PreAct + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, # Out + tensor_infos["C"].cute_tensor, # PreAct + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm_dgated.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_gated.py b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_gated.py new file mode 100644 index 00000000..171707fa --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_gated.py @@ -0,0 +1,304 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from functools import partial +from typing import Optional, Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.torch as cutlass_torch +import cutlass.utils.blackwell_helpers as sm100_utils +from ..quack import activation +from ..quack import sm90_utils +from cutlass import const_expr +from cutlass.cute.runtime import from_dlpack +from ..quack.cute_dsl_utils import get_device_capacity, get_max_active_clusters +from ..quack.gemm_act import GemmActMixin +from ..quack.gemm_default_epi import GemmDefaultEpiMixin +from ..quack.gemm_sm90 import GemmSm90 +from ..quack.gemm_sm100 import GemmSm100 +from ..quack.gemm_wrapper_utils import GemmTensorInfo, GemmWrapperBase +from ..quack.layout_utils import permute_gated_Cregs_b16 +from torch import Tensor + + +class GemmGatedMixin(GemmActMixin): + EpilogueArguments = GemmActMixin.EpilogueArguments + EpilogueParams = GemmActMixin.EpilogueParams + + def epi_to_underlying_arguments(self, args: EpilogueArguments, *, loc=None, ip=None) -> EpilogueParams: + self.postact_dtype = args.mPostAct.element_type + self.postact_layout = cutlass.utils.LayoutEnum.from_tensor(args.mPostAct) + assert self.postact_dtype.width == 16, "GemmGated only supports 16bit postact for now" + assert self.d_layout is None or self.d_layout.is_n_major_c() + assert self.postact_layout.is_n_major_c() + if self.arch == 90: + assert self.cta_tile_shape_mnk[1] % 32 == 0, "GemmGatedSm90 requires tileN to be divisible by 32" + + self.cta_tile_shape_postact_mn = ( + self.cta_tile_shape_mnk[0], + self.cta_tile_shape_mnk[1] // 2, + ) + if isinstance(self.epi_tile[1], cute.Layout): + epi_tile_postact_1 = cute.recast_layout(2, 1, self.epi_tile[1]) + else: + epi_tile_postact_1 = self.epi_tile[1] // 2 + epi_tile_postact = (self.epi_tile[0], epi_tile_postact_1) + utils_cls = sm100_utils if self.arch == 100 else sm90_utils + epi_postact_smem_layout_staged = utils_cls.make_smem_layout_epi( + self.postact_dtype, self.postact_layout, epi_tile_postact, self.epi_stage + ) + tma_atom_postact, tma_tensor_postact = self._make_tma_epi_atoms_and_tensors( + args.mPostAct, + epi_postact_smem_layout_staged, + epi_tile_postact, + op_type="store", + ) + # Assume all strides are divisible by 32 bits except the last stride + new_stride = lambda t: tuple( + cute.assume(s, divby=32 // t.element_type.width) if not cute.is_static(s) else s for s in t.stride + ) + mRowVecBroadcast, mColVecBroadcast = [ + cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=new_stride(t))) if t is not None else None + for t in (args.mRowVecBroadcast, args.mColVecBroadcast) + ] + return self.EpilogueParams( + tma_atom_postact, + tma_tensor_postact, + epi_postact_smem_layout_staged, + epi_tile_postact, + args.act_fn, + alpha=args.alpha, + beta=args.beta, + mRowVecBroadcast=mRowVecBroadcast, + mColVecBroadcast=mColVecBroadcast, + ) + + @staticmethod + def epi_smem_bytes_per_stage( + args: EpilogueArguments, cta_tile_shape_mnk: Tuple[int, int, int], epi_tile: cute.Tile + ) -> int: + postact_dtype = args.mPostAct.element_type + postact_bytes_per_stage = (cute.size(cute.shape(epi_tile)) // 2) * (postact_dtype.width // 8) + rowvec_colvec_bytes = GemmDefaultEpiMixin.epi_smem_bytes_per_stage(args, cta_tile_shape_mnk, epi_tile) + return postact_bytes_per_stage + rowvec_colvec_bytes + + @cute.jit + def epi_visit_subtile( + self, + params: EpilogueParams, + epi_loop_tensors: Tuple[cute.Tensor, ...], + tRS_rD: cute.Tensor, + tRS_rC: Optional[cute.Tensor] = None, + ) -> Optional[cute.Tensor]: + GemmDefaultEpiMixin.epi_visit_subtile(self, params, epi_loop_tensors, tRS_rD, tRS_rC) + tRS_rPostAct_layout = cute.recast_layout(2, 1, tRS_rD.layout) + # If we don't have .shape here, the compiler generates local stores and loads + tRS_rPostAct = cute.make_rmem_tensor(tRS_rPostAct_layout.shape, self.acc_dtype) + if const_expr(self.arch < 100): + for i in cutlass.range(cute.size(tRS_rPostAct), unroll_full=True): + tRS_rPostAct[i] = params.act_fn(tRS_rD[2 * i], tRS_rD[2 * i + 1]) + else: + for i in cutlass.range(cute.size(tRS_rPostAct) // 2, unroll_full=True): + tRS_rPostAct[2 * i], tRS_rPostAct[2 * i + 1] = params.act_fn( + (tRS_rD[4 * i], tRS_rD[4 * i + 2]), (tRS_rD[4 * i + 1], tRS_rD[4 * i + 3]) + ) + # Type conversion + tRS_rPostAct_out = cute.make_rmem_tensor_like(tRS_rPostAct, self.postact_dtype) + tRS_rPostAct_out.store(tRS_rPostAct.load().to(self.postact_dtype)) + if const_expr(self.arch == 90): + # Only need this if we're using STSM + permute_gated_Cregs_b16(tRS_rPostAct_out) + return tRS_rPostAct_out + + +class GemmGatedSm90(GemmGatedMixin, GemmSm90): + pass + + +class GemmGatedSm100(GemmGatedMixin, GemmSm100): + pass + + +gate_fn_map = { + "swiglu": quack.activation.swiglu, + "swiglu_oai": quack.activation.swiglu_oai, + "reglu": quack.activation.reglu, + "geglu": quack.activation.geglu, + "glu": quack.activation.glu, +} + + +def gemm_gated( + A: Tensor, # (l, m, k) or (total_m, k) if varlen_m or (whatever, k) if gather_A with varlen_m + B: Tensor, # (l, n, k) + D: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + C: Optional[Tensor], # (l, m, n) or (total_m, n) if varlen_m + PostAct: Tensor, # (l, m, n//2) or (total_m, n//2) if varlen_m + tile_count_semaphore: Optional[Tensor], # (1,) + activation: Optional[str], + tile_M: int, + tile_N: int, + cluster_M: int, + cluster_N: int, + pingpong: bool = False, + persistent: bool = True, + max_swizzle_size: int = 8, + rowvec_bias: Optional[Tensor] = None, # (l, n) + colvec_bias: Optional[Tensor] = None, # (l, m), or (total_m,) if varlen_m + cu_seqlens_m: Optional[Tensor] = None, # (l+1,) cumulative sum of m values for variable length + A_idx: Optional[Tensor] = None, # (total_m,) if gather_A with varlen_m +) -> None: + if cu_seqlens_m is not None: + assert persistent, "varlen_m requires persistent=True" + assert A.stride(-1) == 1, "varlen_m requires A to be k-major" + if D is not None: + assert D.stride(-1) == 1, "varlen_m requires D to be n-major" + assert PostAct.stride(-1) == 1, "varlen_m requires PostAct to be n-major" + gather_A = A_idx is not None + if gather_A: + assert cu_seqlens_m is not None, "gather_A requires varlen (cu_seqlens_m must be specified)" + assert cluster_N == 1, "gather_A requires cluster_N=1" + assert activation in gate_fn_map, f"Unsupported activation {activation}" + + # Special validation for PostAct shape + L, M, K, N, tensor_infos = GemmWrapperBase.validate_and_prepare_tensors( + A, B, D, C, cu_seqlens_m=cu_seqlens_m, A_idx=A_idx + ) + + # PostAct shape validation depends on varlen_m + if cu_seqlens_m is not None: + # varlen_m case: PostAct is 2D (total_m, n//2) + assert PostAct.dim() == 2 and PostAct.is_cuda, "PostAct must be a 2D CUDA tensor for varlen_m" + assert PostAct.shape == ( + M, + N // 2, + ), f"PostAct must have shape {(M, N // 2)}, got {PostAct.shape}" + else: + # Normal case: PostAct is 3D (l, m, n//2) + assert PostAct.dim() == 3 and PostAct.is_cuda, "PostAct must be a 3D CUDA tensor" + assert PostAct.shape == ( + L, + M, + N // 2, + ), f"PostAct must have shape {(L, M, N // 2)}, got {PostAct.shape}" + + tensor_infos["PostAct"] = GemmTensorInfo(PostAct) + GemmWrapperBase.permute_tensors(tensor_infos, varlen_m=cu_seqlens_m is not None) + GemmWrapperBase.extract_dtypes(tensor_infos) + major_configs = { + "A": ("m", "k", "l"), + "B": ("n", "k", "l"), + "D": ("m", "n", "l"), + "C": ("m", "n", "l"), + "PostAct": ("m", "n", "l"), # PostAct has shape (m, n//2, l) after permute + } + GemmWrapperBase.determine_major_orders(tensor_infos, major_configs) + + device_capacity = get_device_capacity(A.device) + assert device_capacity[0] in [9, 10], "Only SM90 and SM100 are supported" + GemmCls = GemmGatedSm100 if device_capacity[0] > 9 else GemmGatedSm90 + + acc_dtype = cutlass.Float32 + tile_shape_mn = (tile_M, tile_N) + cluster_shape_mnk = (cluster_M, cluster_N, 1) + if not GemmCls.is_valid_dtypes( + tensor_infos["A"].dtype, + tensor_infos["B"].dtype, + acc_dtype, + tensor_infos["D"].dtype, + tensor_infos["A"].major, + tensor_infos["B"].major, + ): + raise TypeError("Skipping due to unsupported combination of types and majors") + + max_active_clusters = get_max_active_clusters(cluster_M * cluster_N) if persistent else 0 + GemmWrapperBase.create_cute_tensors(tensor_infos, major_configs) + act_fn = gate_fn_map[activation] + epi_args = GemmCls.EpilogueArguments( + tensor_infos["PostAct"].cute_tensor, + act_fn, + mRowVecBroadcast=( + from_dlpack(rowvec_bias.detach(), assumed_align=4).mark_layout_dynamic(leading_dim=1) + if rowvec_bias is not None + else None + ), + mColVecBroadcast=( + from_dlpack(colvec_bias.detach(), assumed_align=4).mark_layout_dynamic( + leading_dim=1 if cu_seqlens_m is None else 0 + ) + if colvec_bias is not None + else None + ), + ) + scheduler_args = GemmWrapperBase.create_scheduler_args( + max_active_clusters, + tile_count_semaphore, + max_swizzle_size=max_swizzle_size, + ) + + # Create varlen arguments if needed (assumes persistent=True when varlen_m) + varlen_args = GemmWrapperBase.create_varlen_args( + cu_seqlens_m, + None, # cu_seqlens_k + A_idx, + max_active_clusters, + cluster_shape_mnk, + tensor_infos, + GemmCls.num_epi_tensormaps, + pingpong, + ) + + current_stream = cutlass_torch.current_stream() + compile_key = GemmWrapperBase.get_compile_key( + tensor_infos, + activation, + tile_shape_mn, + cluster_shape_mnk, + pingpong, + persistent, + tile_count_semaphore is not None, + device_capacity, + max_swizzle_size, + rowvec_bias.dtype if rowvec_bias is not None else None, + colvec_bias.dtype if colvec_bias is not None else None, + cu_seqlens_m is not None, + A_idx is not None, + key_tensor_names=("A", "B", "D", "PostAct", "C"), + ) + cache = gemm_gated.compile_cache + if compile_key not in cache: + if device_capacity[0] == 9: + GemmCls = partial(GemmCls, pingpong=pingpong, is_persistent=persistent) + gemm_obj = GemmCls( + acc_dtype, + tensor_infos["A"].dtype, + tile_shape_mn, + cluster_shape_mnk, + gather_A=gather_A, + ) + cache[compile_key] = cute.compile( + gemm_obj, + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + cache[compile_key]( + tensor_infos["A"].cute_tensor, + tensor_infos["B"].cute_tensor, + tensor_infos["D"].cute_tensor, + tensor_infos["C"].cute_tensor, + epi_args, + scheduler_args, + varlen_args, + current_stream, + ) + + +gemm_gated.compile_cache = {} diff --git a/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_interface.py b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_interface.py new file mode 100644 index 00000000..8a1d7b8f --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/quack_utils/gemm_interface.py @@ -0,0 +1,385 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from functools import partial +from typing import Literal, Optional, Tuple + +import torch +from ..quack.autotuner import AutotuneConfig, autotune +from ..quack.cute_dsl_utils import get_device_capacity +from ..quack.gemm_config import GemmConfig, get_all_configs +from ..quack._ops_compat import add_quack_op_namespace_prefix +from ..quack.gemm_interface import default_config, prune_invalid_gemm_configs +from torch import Tensor + +from .gemm_dgated import gemm_dgated as gemm_dgated_sm90_sm100 +from .gemm_gated import gemm_gated as gemm_gated_sm90_sm100 + + +class _LazyDeviceCapacity: + """Defer torch.cuda.get_device_capability until first access so the + module can be imported in environments without a GPU (e.g. nix build).""" + _value = None + def __getitem__(self, idx): + if self._value is None: + if not torch.cuda.is_available(): + self._value = (9, 0) + else: + cap = get_device_capacity(torch.device("cuda")) + self._value = cap if cap[0] in (9, 10) else (9, 0) + return self._value[idx] + + +default_device_capacity = _LazyDeviceCapacity() + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "gated")], + key=["activation", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_configs}, +) +def gemm_gated_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + # (M, N) or (L, M, N) or (total_M, N) if varlen_m - None if not storing preact + preact_out: Optional[Tensor], + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + config: Optional[GemmConfig] = None, +) -> None: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if C is not None and C.ndim == 2 and not varlen_m: + C = C.unsqueeze(0) # (1, M, N) + if preact_out is not None and preact_out.ndim == 2 and not varlen_m: + D = preact_out.unsqueeze(0) + else: + D = preact_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if bias is not None and bias.ndim == 1: + bias = bias.unsqueeze(0) # (L, N) + tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + gemm_gated_sm90_sm100( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + (D if not config.swap_ab else D.mT) if D is not None else None, + (C if not config.swap_ab else C.mT) if C is not None else None, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + max_swizzle_size=config.max_swizzle_size, + rowvec_bias=bias if not config.swap_ab else None, + colvec_bias=bias if config.swap_ab else None, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + + +def prune_invalid_gemm_dgated_configs(configs, named_args: dict, **kwargs): + kwargs = named_args | kwargs + # if there's colvec_scale or colvec_reduce, don't swap_AB + if kwargs.get("colvec_scale", None) is not None or kwargs.get("colvec_reduce", False): + configs = [conf for conf in configs if not conf.kwargs["config"].swap_ab] + return prune_invalid_gemm_configs(configs, named_args, **kwargs) + + +@autotune( + configs=[AutotuneConfig(config=c) for c in get_all_configs(default_device_capacity[0], "dgated")], + key=["activation", "colvec_reduce", "dynamic_scheduler"], + prune_configs_by={"early_config_prune": prune_invalid_gemm_dgated_configs}, +) +def gemm_dgated_tuned( + # (M, K) or or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + A: Tensor, + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + # whether to do colvec reduction, returning (M,) or (L, M) or (total_M) if varlen_m + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, # (L+1), int32 + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + config: Optional[GemmConfig] = None, +) -> Optional[Tensor]: + if config is None: + config = default_config(A.device) + varlen_m = cu_seqlens_m is not None + if varlen_m: + assert not config.swap_ab, "Variable-length sequences not supported with swap_ab" + og_ndim_2 = A.ndim == 2 and not varlen_m + if A.ndim == 2 and not varlen_m: + A = A.unsqueeze(0) # (1, M, K) + B = B.mT # (N, K) or (L, N, K) + if B.ndim == 2: + B = B.unsqueeze(0) # (1, N, K) + if PreAct.ndim == 2 and not varlen_m: + PreAct = PreAct.unsqueeze(0) # (1, M, 2*N) + if dx_out.ndim == 2 and not varlen_m: + D = dx_out.unsqueeze(0) + else: + D = dx_out + if postact_out.ndim == 2 and not varlen_m: + PostAct = postact_out.unsqueeze(0) + else: + PostAct = postact_out + if colvec_scale is not None and colvec_scale.ndim == 1 and not varlen_m: + colvec_scale = colvec_scale.unsqueeze(0) # (L, N) + if colvec_scale is not None: + assert not config.swap_ab, "colvec_scale not supported with swap_ab" + if colvec_reduce: + tile_n = config.tile_n + shape_n = (B.shape[-2] + tile_n - 1) // tile_n + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + colvec_shape = (total_m, shape_n) + else: + colvec_shape = (A.shape[0], A.shape[-2], shape_n) + colvec_reduce_partial = torch.empty(colvec_shape, dtype=torch.float32, device=A.device) + else: + colvec_reduce_partial = None + tile_count_semaphore = torch.zeros(1, dtype=torch.int32, device=A.device) if dynamic_scheduler else None + gemm_dgated_sm90_sm100( + A if not config.swap_ab else B, + B if not config.swap_ab else A, + D if not config.swap_ab else D.mT, + PreAct if not config.swap_ab else PreAct.mT, + PostAct if not config.swap_ab else PostAct.mT, + tile_count_semaphore, + activation, + config.tile_m, + config.tile_n, + config.cluster_m, + config.cluster_n, + config.pingpong, + persistent=True, + max_swizzle_size=config.max_swizzle_size, + colvec_scale=colvec_scale, + colvec_reduce=colvec_reduce_partial, + cu_seqlens_m=cu_seqlens_m, + A_idx=A_idx, + ) + if colvec_reduce: + colvec_reduce_final = colvec_reduce_partial.sum(dim=-1) + if og_ndim_2: + colvec_reduce_final = colvec_reduce_final.squeeze(0) + else: + colvec_reduce_final = None + return colvec_reduce_final + + +def gemm_gated( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + preact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Optional[Tensor] = None, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + store_preact: bool = True, + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> Tuple[Optional[Tensor], Tensor]: + """GEMM with gated activation and optional output tensors.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = A.dtype if postact_dtype is None else postact_dtype + varlen_m = cu_seqlens_m is not None + # Determine output shape based on gather_A + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1]) + elif A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1]) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1]) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) + if preact_out is None and store_preact: + preact_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + gemm_gated_out( + A, + B, + preact_out, + postact_out, + C, + bias, + activation, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + return preact_out, postact_out + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_gated_out"), + mutates_args=("preact_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor(a2!)? preact_out, Tensor(a3!) postact_out, Tensor? C=None, Tensor? bias=None, str activation='swiglu', Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=False, bool tuned=True) -> ()", +) +def gemm_gated_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + preact_out: Optional[Tensor], # (M, N) or (L, M, N) or (total_M, N) if varlen_m + postact_out: Tensor, # (M, N//2) or (L, M, N//2) or (total_M, N//2) if varlen_m + C: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + bias: Optional[Tensor] = None, # (N,) or (L, N) + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = False, + tuned: bool = True, +) -> None: + """GEMM with gated activation and pre-allocated output tensors.""" + fn = gemm_gated_tuned if tuned else partial(gemm_gated_tuned.fn, config=None) + fn(A, B, preact_out, postact_out, C, bias, activation, cu_seqlens_m, A_idx, dynamic_scheduler) + + +def gemm_dgated( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + dx_out: Optional[Tensor] = None, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Optional[Tensor] = None, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + out_dtype: Optional[torch.dtype] = None, + postact_dtype: Optional[torch.dtype] = None, + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Tuple[Tensor, Tensor]: + """GEMM with gated activation gradient and optional output tensors.""" + out_dtype = A.dtype if out_dtype is None else out_dtype + postact_dtype = PreAct.dtype if postact_dtype is None else postact_dtype + varlen_m = cu_seqlens_m is not None + # Determine output shape based on gather_A + if varlen_m: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m, B.shape[-1] * 2) + elif A.ndim == 2: + out_shape = (A.shape[0], B.shape[-1] * 2) + else: + out_shape = (A.shape[0], A.shape[-2], B.shape[-1] * 2) + postact_shape = (*out_shape[:-1], out_shape[-1] // 2) + if dx_out is None: + dx_out = torch.empty(out_shape, dtype=out_dtype, device=A.device) + if postact_out is None: + postact_out = torch.empty(postact_shape, dtype=postact_dtype, device=A.device) + colvec_reduce_final = gemm_dgated_out( + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale, + activation, + colvec_reduce, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + tuned, + ) + if not colvec_reduce: + return dx_out, postact_out + else: + return dx_out, postact_out, colvec_reduce_final + + +@torch.library.custom_op( + add_quack_op_namespace_prefix("gemm_dgated_out"), + mutates_args=("dx_out", "postact_out"), + device_types="cuda", + schema="(Tensor A, Tensor B, Tensor PreAct, Tensor(a3!) dx_out, Tensor(a4!) postact_out, Tensor? colvec_scale=None, str activation='swiglu', bool colvec_reduce=False, Tensor? cu_seqlens_m=None, Tensor? A_idx=None, bool dynamic_scheduler=True, bool tuned=True) -> Tensor?", +) +def gemm_dgated_out( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Optional[Tensor]: + """GEMM with gated activation gradient and pre-allocated output tensors.""" + fn = gemm_dgated_tuned if tuned else partial(gemm_dgated_tuned.fn, config=None) + return fn( + A, + B, + PreAct, + dx_out, + postact_out, + colvec_scale, + activation, + colvec_reduce, + cu_seqlens_m, + A_idx, + dynamic_scheduler, + ) + + +@torch.library.register_fake(add_quack_op_namespace_prefix("gemm_dgated_out")) +def gemm_dgated_out_fake( + A: Tensor, # (M, K) or (L, M, K) or (total_M, K) if varlen_m or (whatever, K) if gather_A with varlen_m + B: Tensor, # (K, N) or (L, K, N) + PreAct: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + dx_out: Tensor, # (M, 2*N) or (L, M, 2*N) or (total_M, 2*N) if varlen_m + postact_out: Tensor, # (M, N) or (L, M, N) or (total_M, N) if varlen_m + colvec_scale: Optional[Tensor] = None, # (M,) or (L, M) or (total_M,) if varlen_m + activation: Literal["swiglu", "swiglu_oai", "reglu", "geglu", "glu"] = "swiglu", + colvec_reduce: bool = False, + cu_seqlens_m: Optional[Tensor] = None, + A_idx: Optional[Tensor] = None, # (total_M,) if gather_A with varlen_m + dynamic_scheduler: bool = True, + tuned: bool = True, +) -> Optional[Tensor]: + if not colvec_reduce: + return None + else: + if cu_seqlens_m is not None: + total_m = A_idx.shape[0] if A_idx is not None else A.shape[0] + out_shape = (total_m,) + elif A.ndim == 2: + out_shape = (A.shape[0],) + else: + out_shape = (A.shape[0], A.shape[-2]) + return torch.empty(out_shape, dtype=torch.float32, device=A.device) diff --git a/sonic-moe/torch-ext/sonicmoe/utils.py b/sonic-moe/torch-ext/sonicmoe/utils.py new file mode 100644 index 00000000..c5e3060d --- /dev/null +++ b/sonic-moe/torch-ext/sonicmoe/utils.py @@ -0,0 +1,123 @@ +# ******************************************************************************** +# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao +# ******************************************************************************** + +from typing import Any, Callable + +import cutlass +import cutlass.cute as cute +import torch +from cutlass.cute.runtime import from_dlpack +from cutlass.cutlass_dsl import dsl_user_op +from torch.utils._pytree import tree_map + + +def make_contiguous(x: Any) -> Any: + return x.contiguous() if isinstance(x, torch.Tensor) else x + + +def ensure_contiguous(func: Callable) -> Callable: + def inner(*args, **kwargs): + args = tree_map(make_contiguous, args) + kwargs = tree_map(make_contiguous, kwargs) + return func(*args, **kwargs) + + return inner + + +def ceil_divide(x: int, y: int) -> int: + return (x + y - 1) // y + + +def check_power_of_2(n: int) -> bool: + return n & (n - 1) == 0 and n != 0 + + +def get_powers_of_2(start: int, end: int) -> list[int]: + assert check_power_of_2(start), "start is not a power of 2" + assert check_power_of_2(end), "end is not a power of 2" + + output = [] + n = start + while n <= end: + output.append(n) + n = n << 1 + + return output + + +@dsl_user_op +def domain_offset_i64(coord: cute.Coord, tensor: cute.Tensor, *, loc=None, ip=None) -> cute.Tensor: + flat_coord_i64 = tuple(cutlass.Int64(c) for c in cute.flatten(coord)) + flat_stride = cute.flatten_to_tuple(tensor.stride) + assert len(flat_coord_i64) == len(flat_stride), "Coordinate and stride must have the same length" + offset = sum(c * s for c, s in zip(flat_coord_i64, flat_stride)) + assert isinstance(tensor.iterator, cute.Pointer) + # HACK: we assume that applying the offset does not change the pointer alignment + new_ptr = cute.make_ptr( + tensor.element_type, + tensor.iterator.toint() + offset * tensor.element_type.width // 8, + tensor.memspace, + assumed_align=tensor.iterator.max_alignment, + ) + return cute.make_tensor(new_ptr, tensor.layout) + + +def divide_if_divisible(dividend: int, divisor: int, msg: str = "") -> int: + assert dividend % divisor == 0, msg + return dividend // divisor + + +def get_next_power_of_2(x: int) -> int: + x -= 1 + x |= x >> 1 + x |= x >> 2 + x |= x >> 4 + x |= x >> 8 + x |= x >> 16 + x |= x >> 32 + x += 1 + return x + + +class _TensorWithStream: + """Wrapper to pass stream parameter to __dlpack__() for CUDA graph compatibility. + + This wrapper allows us to pass a stream parameter to the tensor's __dlpack__() method + when cutlass's from_dlpack() calls it, preventing cross-stream synchronization during + CUDA graph capture. + """ + + def __init__(self, tensor: torch.Tensor, stream: int): + self._tensor = tensor + # Convert CUDA stream pointer to PyTorch's __dlpack__ convention: + # - stream=0 (null/default stream) -> use -1 to disable synchronization + # - stream=non-zero -> use the raw pointer value + # This prevents "unsupported stream on CUDA: 0" error + self._stream = -1 if stream == 0 else stream + + def __dlpack__(self, stream=None): # noqa: ARG002 + # Use the wrapped stream to prevent cross-stream synchronization + # The stream parameter is required by the DLPack protocol but ignored here + return self._tensor.__dlpack__(stream=self._stream) + + def __dlpack_device__(self): + return self._tensor.__dlpack_device__() + + +def convert_torch_tensor_to_cute_tensor( + x: torch.Tensor, + stride_order, + leading_dim: int, + alignment: int, + divisibility: int, + stream: int | None = None, +): + # Wrap tensor with stream if provided to prevent cross-stream synchronization during CUDA graph capture + tensor_input = _TensorWithStream(x, stream) if stream is not None else x + + return ( + from_dlpack(tensor_input, assumed_align=alignment) + .mark_layout_dynamic(leading_dim=leading_dim) + .mark_compact_shape_dynamic(mode=leading_dim, stride_order=stride_order, divisibility=divisibility) + )