-
Notifications
You must be signed in to change notification settings - Fork 26
sonic-moe: Add sonic-moe kernels #531
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
adarshxs
wants to merge
12
commits into
huggingface:main
Choose a base branch
from
adarshxs:kernel/sonic-moe
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
ef4c0a7
add sonic-moe
adarshxs 3fb97d4
fix
adarshxs 79a08f3
fix
adarshxs c81a44d
fix
adarshxs 7a0b1c6
Update check_kernel_freshness.py
adarshxs e421be4
upd nix.lock
357d14c
Update build.toml
adarshxs 336b5d5
Update __init__.py
adarshxs 52b9869
Update build.toml
adarshxs 38e6c2d
sonic-moe: handle non-Hopper GPUs in test
adarshxs 04756b6
fixes
adarshxs 7f62799
fixes
adarshxs File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,57 @@ | ||
| --- | ||
| tags: | ||
| - kernels | ||
| - moe | ||
| - cuda | ||
| --- | ||
|
|
||
| # SonicMoE | ||
|
|
||
| Accelerating Mixture-of-Experts with IO and Tile-aware Optimizations. | ||
|
|
||
| **SonicMoE** is a blazing-fast MoE implementation optimized for NVIDIA Hopper and Blackwell GPUs. | ||
| It leverages CuTe-DSL and Triton to deliver state-of-the-art performance through IO-aware optimizations. | ||
|
|
||
| - Paper: [arXiv:2512.14080](https://arxiv.org/abs/2512.14080) | ||
| - Source: [Dao-AILab/sonic-moe](https://github.com/Dao-AILab/sonic-moe) | ||
|
|
||
| ## Requirements | ||
|
|
||
| - NVIDIA Hopper GPUs (H100, H200) or Blackwell GPUs (GB200, B200) | ||
| - PyTorch >= 2.7 | ||
| - CUDA 12.9+ | ||
| - Python 3.12+ | ||
|
|
||
| ## Usage | ||
|
|
||
| ```python | ||
| import torch | ||
| from kernels import get_kernel | ||
|
|
||
| sonicmoe = get_kernel("kernels-community/sonic-moe") | ||
|
|
||
| from sonicmoe import MoE, KernelBackendMoE | ||
| from sonicmoe.enums import ActivationType | ||
|
|
||
| moe = MoE( | ||
| num_experts=128, | ||
| num_experts_per_tok=8, | ||
| hidden_size=4096, | ||
| intermediate_size=1536, | ||
| activation_function=ActivationType.SWIGLU, | ||
| add_bias=False, | ||
| std=0.02, | ||
| ).to(device="cuda", dtype=torch.bfloat16) | ||
|
|
||
| x = torch.randn(32768, 4096, device="cuda", dtype=torch.bfloat16) | ||
| output, aux_loss = moe(x, kernel_backend_moe=KernelBackendMoE.sonicmoe) | ||
| ``` | ||
|
|
||
| ## Vendored Dependencies | ||
|
|
||
| This kernel vendors [QuACK](https://github.com/Dao-AILab/quack) (quack-kernels) for CuTe-DSL | ||
| GEMM infrastructure. The vendored copy is located at `torch-ext/sonicmoe/_vendor/quack/`. | ||
|
|
||
| ## License | ||
|
|
||
| Apache-2.0 (SonicMoE and QuACK are both Apache-2.0 licensed) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,14 @@ | ||
| [general] | ||
| name = "sonicmoe" | ||
| license = "Apache-2.0" | ||
| backends = ["cuda"] | ||
| version = 1 | ||
|
|
||
| [general.hub] | ||
| repo-id = "kernels-community/sonic-moe" | ||
|
|
||
| [general.cuda] | ||
| minver = "12.8" | ||
| python-depends = ["nvidia-cutlass-dsl"] | ||
|
|
||
| [kernel] |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| description = "Flake for sonic-moe kernels"; | ||
|
|
||
| inputs = { | ||
| kernel-builder.url = "github:huggingface/kernels"; | ||
| }; | ||
|
|
||
| outputs = | ||
| { | ||
| self, | ||
| kernel-builder, | ||
| }: | ||
| kernel-builder.lib.genKernelFlakeOutputs { | ||
| inherit self; | ||
| path = ./.; | ||
| }; | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| # ******************************************************************************** | ||
| # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao | ||
| # ******************************************************************************** | ||
|
|
||
| import pytest | ||
| import random | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch.testing import assert_close | ||
|
|
||
| if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] < 9: | ||
| pytest.skip("SonicMoE requires Hopper (SM90) or newer GPU", allow_module_level=True) | ||
|
|
||
| try: | ||
| from sonicmoe import KernelBackendMoE, MoE, enable_quack_gemm | ||
| from sonicmoe.enums import ActivationType | ||
| except ImportError as e: | ||
| pytest.skip(f"sonicmoe dependencies not available: {e}", allow_module_level=True) | ||
|
|
||
| _SEED = 42 | ||
|
|
||
|
|
||
| def set_seed(seed: int) -> None: | ||
| random.seed(seed) | ||
| np.random.seed(seed) | ||
| torch.manual_seed(seed) | ||
| torch.cuda.manual_seed_all(seed) | ||
|
|
||
|
|
||
| PROBLEM_SHAPES = [ | ||
| (8192, 768, 256, 128, 8), | ||
| (8192, 768, 512, 64, 4), | ||
| (8192, 4096, 512, 128, 8), | ||
| (8192, 4096, 1024, 64, 4), | ||
| ] | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("problem_shape", PROBLEM_SHAPES) | ||
| @pytest.mark.parametrize("add_bias", [False, True]) | ||
| def test_moe_forward_backward(problem_shape, add_bias): | ||
| device = torch.device("cuda") | ||
| dtype = torch.bfloat16 | ||
|
|
||
| set_seed(_SEED) | ||
|
|
||
| T, H, I, E, K = problem_shape | ||
| with torch.device(device): | ||
| moe = MoE( | ||
| num_experts=E, | ||
| num_experts_per_tok=K, | ||
| hidden_size=H, | ||
| intermediate_size=I, | ||
| activation_function=ActivationType.SWIGLU, | ||
| add_bias=add_bias, | ||
| std=0.02, | ||
| ).to(dtype=dtype) | ||
|
|
||
| if add_bias: | ||
| torch.nn.init.normal_(moe.c_fc.bias, 0, 0.01) | ||
| torch.nn.init.normal_(moe.c_proj.bias, 0, 0.01) | ||
|
|
||
| torch.cuda.empty_cache() | ||
| x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True) | ||
| x_kernel = x_torch.clone().detach().requires_grad_() | ||
|
|
||
| with torch.autocast(device.type, torch.float32): | ||
| y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0] | ||
| y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0] | ||
|
|
||
| assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2) | ||
|
|
||
| dy = 0.02 * torch.randn(T, H, device=device, dtype=dtype) | ||
| W = list(moe.parameters()) | ||
|
|
||
| with torch.autocast(device.type, torch.float32): | ||
| kernel_grads = torch.autograd.grad(y_kernel, [x_kernel] + W, grad_outputs=dy, retain_graph=True) | ||
| torch_grads = torch.autograd.grad(y_torch, [x_torch] + W, grad_outputs=dy, retain_graph=True) | ||
|
|
||
| for tg, kg in zip(torch_grads, kernel_grads): | ||
| assert_close(kg.float(), tg.float(), atol=2e-2, rtol=2e-2) | ||
|
|
||
| torch.cuda.empty_cache() | ||
|
|
||
|
|
||
| @pytest.mark.parametrize( | ||
| "problem_shape", | ||
| [(8192, 4096, 512, 128, 8)], | ||
| ) | ||
| def test_moe_quack_gemm(problem_shape): | ||
| device = torch.device("cuda") | ||
| dtype = torch.bfloat16 | ||
|
|
||
| set_seed(_SEED) | ||
|
|
||
| T, H, I, E, K = problem_shape | ||
| with torch.device(device): | ||
| moe = MoE( | ||
| num_experts=E, | ||
| num_experts_per_tok=K, | ||
| hidden_size=H, | ||
| intermediate_size=I, | ||
| activation_function=ActivationType.SWIGLU, | ||
| add_bias=False, | ||
| std=0.02, | ||
| ).to(dtype=dtype) | ||
|
|
||
| torch.cuda.empty_cache() | ||
| x_torch = 0.02 * torch.randn(T, H, device=device, dtype=dtype, requires_grad=True) | ||
| x_kernel = x_torch.clone().detach().requires_grad_() | ||
|
|
||
| with torch.autocast(device.type, torch.float32): | ||
| with enable_quack_gemm(True): | ||
| y_kernel = moe(x_kernel, kernel_backend_moe=KernelBackendMoE.sonicmoe)[0] | ||
|
|
||
| y_torch = moe(x_torch, kernel_backend_moe=KernelBackendMoE.torch)[0] | ||
|
|
||
| assert_close(y_kernel.float(), y_torch.float(), atol=1.4e-2, rtol=2e-2) | ||
|
|
||
| torch.cuda.empty_cache() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,37 @@ | ||
| # ******************************************************************************** | ||
| # Copyright (c) 2025, Wentao Guo, Mayank Mishra, Xinle Cheng, Ion Stoica, Tri Dao | ||
| # ******************************************************************************** | ||
|
|
||
| import os as _os | ||
| import sys as _sys | ||
|
|
||
| # Inject vendored quack-kernels into the module search path so that | ||
| # `import quack` resolves to the bundled copy when no system install exists. | ||
| _vendor_dir = _os.path.join(_os.path.dirname(__file__), "_vendor") | ||
| if _vendor_dir not in _sys.path: | ||
| _sys.path.insert(0, _vendor_dir) | ||
|
|
||
| __version__ = "0.1.1" | ||
|
|
||
| # Lazy imports: defer heavy dependencies (cutlass, cuda, triton) so that | ||
| # `import sonicmoe` succeeds in environments without GPU libraries | ||
| # (e.g. the nix build sandbox get-kernel-check). | ||
adarshxs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from .enums import KernelBackendMoE | ||
|
|
||
| _LAZY_IMPORTS = { | ||
| "MoE": ".moe", | ||
| "enable_quack_gemm": ".functional", | ||
| "moe_general_routing_inputs": ".functional", | ||
| "moe_TC_softmax_topk_layer": ".functional", | ||
| } | ||
|
|
||
|
|
||
| def __getattr__(name): | ||
| if name in _LAZY_IMPORTS: | ||
| module_path = _LAZY_IMPORTS[name] | ||
| import importlib | ||
| mod = importlib.import_module(module_path, __name__) | ||
| val = getattr(mod, name) | ||
| globals()[name] = val | ||
adarshxs marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return val | ||
| raise AttributeError(f"module {__name__!r} has no attribute {name!r}") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,8 @@ | ||
| __version__ = "0.2.5" | ||
|
|
||
| import os | ||
|
|
||
| if os.environ.get("CUTE_DSL_PTXAS_PATH", None) is not None: | ||
| import quack.cute_dsl_ptxas # noqa: F401 | ||
|
|
||
| quack.cute_dsl_ptxas.patch() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.