Skip to content

Fix Triton MoE GEMM shared memory exhaustion by reducing stage count#2723

Open
nidal567 wants to merge 2 commits intoROCm:mainfrom
nidal567:golden-tot-testing
Open

Fix Triton MoE GEMM shared memory exhaustion by reducing stage count#2723
nidal567 wants to merge 2 commits intoROCm:mainfrom
nidal567:golden-tot-testing

Conversation

@nidal567
Copy link
Copy Markdown
Contributor

  • Reduce num_stages in kernel configs
  • Lower LDS usage to avoid shared memory OOR
  • Fix triton.runtime.errors.OutOfResources errors in MoE GEMM kernels

Motivation

Several Triton kernels were failing with:

  • triton.runtime.errors.OutOfResources: shared memory
  • See here for my up-to-date github issue showing the results of the Triton kernel test suite

This was observed across multiple test cases:

  • test_ff_a16w16_fused.py
  • test_fused_gemm_afp4wfp4_a16w16.py
  • test_moe_gemm_a4w4.py #current PR
  • test_moe_gemm_a8w4.py #current PR
  • test_moe_gemm_a8w8.py #current PR

The issue became prominent in testing after async copy was enabled, which increased LDS (shared memory) usage.

This PR is about the MoE GEMMs labeled with #current PR as the other two kernels are in progress.

Technical Details

The root cause is increased LDS pressure due to:

  • Double buffering from num_stages
  • Increased tile residency in shared memory
  • Kernel configurations originally tuned for register staging exceeding LDS limits

To address this:

  • Reduced num_stages from 2 -> 1 in affected Triton kernel configs:
    • moe_op_gemm_a4w4.py
    • moe_op_gemm_a8w4.py
    • moe_op_gemm_a8w8.py
  • This lowers shared memory (LDS) usage and avoids exceeding hardware limits

No other changes were introduced beyond parameter adjustments in the config files.

Test Plan

  • Ran the following affected test cases:
    • test_moe_gemm_a4w4.py
    • test_moe_gemm_a8w4.py
    • test_moe_gemm_a8w8.py
  • Executed tests multiple times to ensure stability and correctness.

Test Result

  • The following previously failing tests now pass:
    • test_moe_gemm_a4w4.py
    • test_moe_gemm_a8w4.py
    • test_moe_gemm_a8w8.py
  • Results are consistent across multiple runs

Submission Checklist

- Reduce num_stages in kernel configs
- Lowered LDS usage to avoid shared memory OOR
- Fix triton.runtime.errors.OutOfResources errors in MoE GEMM kernels
@nidal567 nidal567 self-assigned this Apr 13, 2026
@nidal567 nidal567 requested a review from a team April 13, 2026 17:54
@github-actions
Copy link
Copy Markdown
Contributor

🏷️ CI Guide

Runs automatically on every PR:

  • ✅ Pre-checks (submodule verification, code formatting)
  • ✅ Aiter op tests (gfx942 + gfx950)
  • ✅ Triton tests (only when aiter/ops/triton/** or related paths are changed)

Extended tests (opt-in via labels):

Label Tests
ci:triton-355 Run Triton tests on MI355 in addition to MI325
ci:sglang SGLang integration tests
ci:atom ATOM benchmark (DeepSeek-R1 + GPT-OSS)
ci:vllm vLLM benchmark
ci:all All of the above

Add labels via the sidebar or gh pr edit 2723 --add-label <label>

brunomazzottiamd

This comment was marked as resolved.

@brunomazzottiamd

This comment was marked as resolved.

Copy link
Copy Markdown
Contributor

@brunomazzottiamd brunomazzottiamd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Let's wait for CI to pass.

@nidal567
Copy link
Copy Markdown
Contributor Author

To further support this:

Gfx950 benchmark results using aiter/op_tests/op_benchmarks/triton/model_benchmarking_tool/bench_models.py

MoE Kernel Benchmark Summary (with TRITON_HIP_USE_ASYNC_COPY=1)

With num_stages=1 forced earlier but now added in conditionally, all tested MoE kernels work except moe_op_gemm_a8w8_blockscale (used by DeepSeek models). That kernel still fails with Out of LDS errors, but I did not touch it in this PR, so we can investigate this in another issue in regards to performance.

  • DeepSeek-R1: moe_op_gemm_a4w4 (experts=256, moe_dim1=7168, topk=8)
  • GPT-OSS 120B: moe_op_gemm_a8w4 (experts=128, moe_dim1=2880, topk=4)
  • Llama4 Maverick: moe_op_gemm_a8w8 (experts=128, moe_dim1=5120, topk=1)
  • Qwen3-235B-A22B: moe_op_gemm_a8w8 (experts=128, moe_dim1=4096, topk=8)

All values below are throughput in TFLOPs

TP M val DeepSeek-R1 GPT-OSS 120B Llama4 Maverick Qwen3-235B-A22B
1 512 285.8 164.0 27.75 283.4
1 1536 838.3 448.2 71.21 780.8
1 2560 829.7 181.4 251.10 712.5
1 3584 1052.0 219.8 303.80 949.6
1 4608 971.1 177.2 382.30 904.9
1 5632 1164.0 209.7 473.50 1059.0
1 6656 1246.0 246.3 510.50 994.4
1 7680 1342.0 247.7 535.80 1077.0
1 8704 1233.0 232.0 640.90 1091.0
1 9728 1295.0 240.9 713.00 1147.0
1 10752 1390.0 256.4 785.10 1140.0
1 11776 1446.0 267.3 856.20 1178.0
1 12800 1351.0 259.2 891.50 1183.0
1 13824 1357.0 250.3 954.40 1202.0
1 14848 1467.0 266.0 918.80 1202.0
1 15872 1471.0 270.9 860.90 1224.0
2 512 264.1 156.0 28.69 277.1
2 1536 661.4 407.4 78.79 686.6
2 2560 996.3 179.4 245.20 653.2
2 3584 1173.0 205.1 299.50 833.1
2 4608 1006.0 172.7 381.40 823.9
2 5632 1219.0 206.8 442.50 934.6
2 6656 1433.0 223.6 503.30 893.7
2 7680 1452.0 242.4 502.30 943.4
2 8704 1319.0 230.2 616.60 969.8
2 9728 1407.0 239.2 685.20 1019.0
2 10752 1507.0 254.4 752.90 1027.0
2 11776 1536.0 264.9 819.10 1063.0
2 12800 1504.0 256.1 882.70 1068.0
2 13824 1482.0 256.6 891.80 1095.0
2 14848 1534.0 255.8 856.00 1082.0
2 15872 1608.0 271.4 815.20 1094.0
4 512 252.2 121.4 30.20 235.5
4 1536 542.7 151.7 80.72 506.1
4 2560 859.6 105.1 238.80 524.9
4 3584 925.2 132.1 295.50 621.7
4 4608 833.9 111.9 360.70 626.9
4 5632 1011.0 135.2 434.80 736.1
4 6656 1057.0 140.4 469.50 705.6
4 7680 1165.0 153.6 462.80 746.5
4 8704 1095.0 144.5 574.40 764.6
4 9728 1172.0 155.8 635.80 808.4
4 10752 1218.0 165.1 696.90 810.1
4 11776 1268.0 165.2 756.60 825.4
4 12800 1197.0 167.2 762.80 840.3
4 13824 1242.0 171.4 809.30 861.4
4 14848 1260.0 172.0 811.60 851.1
4 15872 1315.0 174.9 783.50 859.2
8 512 241.2 87.84 32.70 177.5
8 1536 398.7 122.20 82.01 302.5
8 2560 650.5 78.46 210.70 365.4
8 3584 632.0 84.19 257.80 419.0
8 4608 623.3 79.48 333.50 461.9
8 5632 743.8 91.44 402.80 487.7
8 6656 733.6 96.01 419.60 501.2
8 7680 816.2 103.80 453.60 509.3
8 8704 806.6 103.80 537.10 533.5
8 9728 855.2 107.90 592.90 545.2
8 10752 861.5 110.70 647.80 561.4
8 11776 891.2 114.90 614.10 564.0
8 12800 891.9 115.10 663.10 577.9
8 13824 896.1 114.50 704.50 584.9
8 14848 906.0 116.40 734.40 581.7
8 15872 931.9 120.00 699.90 575.9

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ci:triton-355 triton

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants