Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions scripts/check_kernel_freshness.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
"rwkv": "https://github.com/BlinkDL/RWKV-LM",
"scattermoe": "https://github.com/shawntan/scattermoe",
"sgl-flash-attn3": "https://github.com/sgl-project/sgl-flash-attn",
"sonic-moe": "https://github.com/Dao-AILab/sonic-moe",
"tinygrad-rms": "https://github.com/tinygrad/tinygrad",
"trimul-gpumode": "https://github.com/davidberard98/gpumode-trimul",
"triton-kernels": "https://github.com/triton-lang/triton.git",
Expand Down
57 changes: 57 additions & 0 deletions sonic-moe/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
---
tags:
- kernels
- moe
- cuda
---

# SonicMoE

Accelerating Mixture-of-Experts with IO and Tile-aware Optimizations.

**SonicMoE** is a blazing-fast MoE implementation optimized for NVIDIA Hopper and Blackwell GPUs.
It leverages CuTe-DSL and Triton to deliver state-of-the-art performance through IO-aware optimizations.

- Paper: [arXiv:2512.14080](https://arxiv.org/abs/2512.14080)
- Source: [Dao-AILab/sonic-moe](https://github.com/Dao-AILab/sonic-moe)

## Requirements

- NVIDIA Hopper GPUs (H100, H200) or Blackwell GPUs (GB200, B200)
- PyTorch >= 2.7
- CUDA 12.9+
- Python 3.12+

## Usage

```python
import torch
from kernels import get_kernel

sonicmoe = get_kernel("kernels-community/sonic-moe")

from sonicmoe import MoE, KernelBackendMoE
from sonicmoe.enums import ActivationType

moe = MoE(
num_experts=128,
num_experts_per_tok=8,
hidden_size=4096,
intermediate_size=1536,
activation_function=ActivationType.SWIGLU,
add_bias=False,
std=0.02,
).to(device="cuda", dtype=torch.bfloat16)

x = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16)
output, aux_loss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe)
```

## Vendored Dependencies

This kernel vendors [QuACK](https://github.com/Dao-AILab/quack) (quack-kernels) for CuTe-DSL
GEMM infrastructure. The vendored copy is located at `torch-ext/sonicmoe/quack/`.

## License

Apache-2.0 (SonicMoE and QuACK are both Apache-2.0 licensed)
14 changes: 14 additions & 0 deletions sonic-moe/build.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
[general]
name = "sonicmoe"
license = "Apache-2.0"
backends = ["cuda"]
version = 1

[general.hub]
repo-id = "kernels-community/sonic-moe"

[general.cuda]
minver = "12.8"
python-depends = ["nvidia-cutlass-dsl"]

[kernel]
117 changes: 117 additions & 0 deletions sonic-moe/flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions sonic-moe/flake.nix
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
description = "Flake for sonic-moe kernels";

inputs = {
kernel-builder.url = "github:huggingface/kernels";
};

outputs =
{
self,
kernel-builder,
}:
kernel-builder.lib.genKernelFlakeOutputs {
inherit self;
path = ./.;
};
}
120 changes: 120 additions & 0 deletions sonic-moe/tests/test_moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# ********************************************************************************
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************

import pytest
import random

import numpy as np
import torch
from torch.testing import assert_close

if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9:
pytest.skip("SonicMoE requires Hopper (SM90) or newer GPU", allow_module_level=True)

try:
from sonicmoe import KernelBackendMoE, MoE, enable_quack_gemm
from sonicmoe.enums import ActivationType
except ImportError as e:
pytest.skip(f"sonicmoe dependencies not available: {e}", allow_module_level=True)

_SEED = 42


def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)


PROBLEM_SHAPES = [
(8192, 768, 256, 128, 8),
(8192, 768, 512, 64, 4),
(8192, 4096, 512, 128, 8),
(8192, 4096, 1024, 64, 4),
]


@pytest.mark.parametrize("problem_shape", PROBLEM_SHAPES)
@pytest.mark.parametrize("add_bias", [False, True])
def test_moe_forward_backward(problem_shape, add_bias):
device = torch.device("cuda")
dtype = torch.bfloat16

set_seed(_SEED)

T, H, I, E, K = problem_shape
with torch.device(device):
moe = MoE(
num_experts=E,
num_experts_per_tok=K,
hidden_size=H,
intermediate_size=I,
activation_function=ActivationType.SWIGLU,
add_bias=add_bias,
std=0.02,
).to(dtype=dtype)

if add_bias:
torch.nn.init.normal_(moe.c_fc.bias, 0, 0.01)
torch.nn.init.normal_(moe.c_proj.bias, 0, 0.01)

torch.cuda.empty_cache()
x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True)
x_kernel = x_torch.clone().detach().requires_grad_()

with torch.autocast(device.type, torch.float32):
y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0]
y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0]

assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2)

dy = 0.02 * torch.randn(T, H, device=device, dtype=dtype)
W = list(moe.parameters())

with torch.autocast(device.type, torch.float32):
kernel_grads = torch.autograd.grad(y_kernel, [x_kernel] + W, grad_outputs=dy, retain_graph=True)
torch_grads = torch.autograd.grad(y_torch, [x_torch] + W, grad_outputs=dy, retain_graph=True)

for tg, kg in zip(torch_grads, kernel_grads):
assert_close(kg.float(), tg.float(), atol=2e-2, rtol=2e-2)

torch.cuda.empty_cache()


@pytest.mark.parametrize(
"problem_shape",
[(8192, 4096, 512, 128, 8)],
)
def test_moe_quack_gemm(problem_shape):
device = torch.device("cuda")
dtype = torch.bfloat16

set_seed(_SEED)

T, H, I, E, K = problem_shape
with torch.device(device):
moe = MoE(
num_experts=E,
num_experts_per_tok=K,
hidden_size=H,
intermediate_size=I,
activation_function=ActivationType.SWIGLU,
add_bias=False,
std=0.02,
).to(dtype=dtype)

torch.cuda.empty_cache()
x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True)
x_kernel = x_torch.clone().detach().requires_grad_()

with torch.autocast(device.type, torch.float32):
with enable_quack_gemm(True):
y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0]

y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0]

assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2)

torch.cuda.empty_cache()
36 changes: 36 additions & 0 deletions sonic-moe/torch-ext/sonicmoe/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# ********************************************************************************
# Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao
# ********************************************************************************

from functools import lru_cache

__version__ = "0.1.1"

from .enums import KernelBackendMoE

_LAZY_IMPORTS = {
"MoE": ".moe",
"enable_quack_gemm": ".functional",
"moe_general_routing_inputs": ".functional",
"moe_TC_softmax_topk_layer": ".functional",
}

@lru_cache(maxsize=None)
def _load_attr(name: str):
import importlib
module_path = _LAZY_IMPORTS[name]
mod = importlib.import_module(module_path, __name__)
return getattr(mod, name)

def __getattr__(name):
if name in _LAZY_IMPORTS:
return _load_attr(name)
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

__all__ = [
"KernelBackendMoE",
"MoE",
"enable_quack_gemm",
"moe_general_routing_inputs",
"moe_TC_softmax_topk_layer",
]
10 changes: 10 additions & 0 deletions sonic-moe/torch-ext/sonicmoe/_ops_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
"""Compatibility helpers for op namespacing in source and built layouts."""

try:
from ._ops import add_op_namespace_prefix as _generated_add_op_namespace_prefix
except ImportError:
def _generated_add_op_namespace_prefix(name: str) -> str:
return name if "::" in name else f"sonicmoe::{name}"

def add_op_namespace_prefix(name: str) -> str:
return _generated_add_op_namespace_prefix(name)
Loading
Loading