From ad9d9fbf53e5c0b929261a2748c7fe376fce9fda Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sat, 11 Apr 2026 04:43:41 +0000 Subject: [PATCH 1/8] Add ROCm-versioned wheel naming to release workflow Follow PyTorch's wheel naming convention (e.g. +rocm7.2.1) for AITER release wheels. This enables building distinct wheels for different ROCm versions from the same workflow. Changes: - Add rocm_version input (auto-detects from container if empty) - Use SETUPTOOLS_SCM_PRETEND_VERSION for version+rocm suffix - Include ROCm version in concurrency group to prevent cross-version cancellation - Update artifact naming to include ROCm suffix --- .github/workflows/aiter-release.yaml | 59 +++++++++++++++++++++++++--- 1 file changed, 53 insertions(+), 6 deletions(-) diff --git a/.github/workflows/aiter-release.yaml b/.github/workflows/aiter-release.yaml index cc64f233ac..df907ab446 100644 --- a/.github/workflows/aiter-release.yaml +++ b/.github/workflows/aiter-release.yaml @@ -40,6 +40,10 @@ on: description: 'GPU architectures (e.g. gfx942;gfx950)' required: false default: 'gfx942;gfx950' + rocm_version: + description: 'ROCm version label for wheel (e.g. 7.2.1). Auto-detected from container if empty.' + required: false + default: '' runner: description: 'Select build host' required: true @@ -53,8 +57,8 @@ on: - aiter-1gpu-runner concurrency: - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: ${{ github.ref != 'refs/heads/main' }} + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event.inputs.rocm_version || 'default' }} + cancel-in-progress: ${{ github.ref != 'refs/heads/main' && !startsWith(github.ref, 'refs/tags/v') }} jobs: build_whl_package: @@ -120,6 +124,48 @@ jobs: --name aiter_build_${{ matrix.python_version }} \ ${{ env.BUILD_DOCKER_IMAGE }} + - name: Detect ROCm version + if: ${{ matrix.build_enabled }} + id: rocm_ver + run: | + set -e + INPUT_VER="${{ github.event.inputs.rocm_version }}" + if [ -n "$INPUT_VER" ]; then + echo "Using user-provided ROCm version: $INPUT_VER" + echo "rocm_version=$INPUT_VER" >> "$GITHUB_OUTPUT" + else + DETECTED=$(docker exec aiter_build_${{ matrix.python_version }} \ + bash -c 'cat /opt/rocm/.info/version 2>/dev/null || echo ""' | tr -d '[:space:]') + if [ -n "$DETECTED" ]; then + DETECTED=$(echo "$DETECTED" | cut -d'-' -f1) + echo "Auto-detected ROCm version: $DETECTED" + echo "rocm_version=$DETECTED" >> "$GITHUB_OUTPUT" + else + echo "WARNING: Could not detect ROCm version, wheel will have no ROCm suffix" + echo "rocm_version=" >> "$GITHUB_OUTPUT" + fi + fi + + - name: Determine wheel version + if: ${{ matrix.build_enabled }} + id: whl_ver + run: | + set -e + ROCM_VER="${{ steps.rocm_ver.outputs.rocm_version }}" + BASE_VER=$(docker exec -w /workspace aiter_build_${{ matrix.python_version }} \ + python3 -c "from setuptools_scm import get_version; print(get_version())" 2>/dev/null || true) + if [ -z "$BASE_VER" ]; then + BASE_VER=$(git describe --tags --match 'v*' 2>/dev/null | sed 's/^v//' || echo "0.1.0") + fi + if [ -n "$ROCM_VER" ]; then + FULL_VER="${BASE_VER}+rocm${ROCM_VER}" + else + FULL_VER="${BASE_VER}" + fi + echo "Wheel version: $FULL_VER (base=$BASE_VER, rocm=$ROCM_VER)" + echo "full_version=$FULL_VER" >> "$GITHUB_OUTPUT" + echo "rocm_suffix=rocm${ROCM_VER}" >> "$GITHUB_OUTPUT" + - name: Install Dependencies if: ${{ matrix.build_enabled }} run: | @@ -140,21 +186,22 @@ jobs: aiter_build_${{ matrix.python_version }} \ pip install --timeout=60 --retries=10 ninja - - name: Build Aiter + - name: Build Aiter with precompiled kernels if: ${{ matrix.build_enabled }} run: | set -e - echo "Building aiter whl packages for Python ${{ matrix.python_version }}..." + FULL_VER="${{ steps.whl_ver.outputs.full_version }}" + echo "Building aiter whl version=${FULL_VER} with PREBUILD_KERNELS=1 for Python ${{ matrix.python_version }}..." docker exec \ -w /workspace \ aiter_build_${{ matrix.python_version }} \ - bash -c 'PREBUILD_KERNELS=1 GPU_ARCHS="${{ env.GPU_ARCHS }}" python3 setup.py bdist_wheel && ls dist/*.whl' + bash -c "SETUPTOOLS_SCM_PRETEND_VERSION='${FULL_VER}' PREBUILD_KERNELS=1 GPU_ARCHS='${{ env.GPU_ARCHS }}' python3 setup.py bdist_wheel && ls -lh dist/*.whl" - name: Upload whl file as artifact if: ${{ matrix.build_enabled }} uses: actions/upload-artifact@v4 with: - name: aiter-whl-packages-py${{ matrix.python_version }}-${{ github.run_id }}-${{ github.run_attempt }} + name: aiter-whl-py${{ matrix.python_version }}-${{ steps.whl_ver.outputs.rocm_suffix }}-${{ github.run_id }} path: dist/*.whl - name: Cleanup container From 944695418b83f2c644a6240bd8bca69830d5dc2c Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sat, 11 Apr 2026 16:53:45 +0000 Subject: [PATCH 2/8] Fix runner labels and Docker username in release workflow - Default runner: aiter-k8s-build -> aiter-1gpu-runner (actually exists) - Remove non-existent runners: aiter-mi300-1gpu, aiter-mi325-1gpu - Fix runner typo: linux-aiter-mi355-1 -> linux-aiter-mi35x-1 - Fix Docker username: rocmshard -> rocmshared (missing 'e') --- .github/workflows/aiter-release.yaml | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/.github/workflows/aiter-release.yaml b/.github/workflows/aiter-release.yaml index df907ab446..f4863a9cbc 100644 --- a/.github/workflows/aiter-release.yaml +++ b/.github/workflows/aiter-release.yaml @@ -35,7 +35,7 @@ on: docker_username: description: 'Docker username for docker login' required: true - default: 'rocmshard' + default: 'rocmshared' gpu_archs: description: 'GPU architectures (e.g. gfx942;gfx950)' required: false @@ -47,14 +47,12 @@ on: runner: description: 'Select build host' required: true - default: 'aiter-k8s-build' + default: 'aiter-1gpu-runner' type: choice options: - - build-only-aiter - - aiter-mi300-1gpu - - aiter-mi325-1gpu - - linux-aiter-mi355-1 - aiter-1gpu-runner + - build-only-aiter + - linux-aiter-mi35x-1 concurrency: group: ${{ github.workflow }}-${{ github.ref }}-${{ github.event.inputs.rocm_version || 'default' }} From 09d68fc0c89b6765ba55b29057a35ae5084ce607 Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sat, 11 Apr 2026 17:43:38 +0000 Subject: [PATCH 3/8] Pin setuptools_scm<10 to fix vcs_versioning import error setuptools_scm 10.x moved to vcs_versioning package, breaking the build with ModuleNotFoundError. Pin to 9.x until pyproject.toml is updated. --- .github/workflows/aiter-release.yaml | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/.github/workflows/aiter-release.yaml b/.github/workflows/aiter-release.yaml index f4863a9cbc..ce93d9f9fa 100644 --- a/.github/workflows/aiter-release.yaml +++ b/.github/workflows/aiter-release.yaml @@ -173,7 +173,16 @@ jobs: -w /workspace \ aiter_build_${{ matrix.python_version }} \ pip install --timeout=60 --retries=10 -r requirements.txt - + + - name: Pin setuptools_scm + if: ${{ matrix.build_enabled }} + run: | + set -e + docker exec \ + -w /workspace \ + aiter_build_${{ matrix.python_version }} \ + pip install --timeout=60 --retries=10 "setuptools_scm<10" + - name: Install ninja if: ${{ matrix.build_enabled }} run: | From 2e4b8f776ba509cc5b5ed4e767c50d13b0e2aad0 Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sun, 12 Apr 2026 19:11:25 +0000 Subject: [PATCH 4/8] =?UTF-8?q?docs:=20comprehensive=20documentation=20ove?= =?UTF-8?q?rhaul=20=E2=80=94=20fix=2022=20factual=20errors,=20add=20new=20?= =?UTF-8?q?pages,=20automate=20deployment?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix conf.py version to auto-detect from setuptools_scm (was hardcoded 0.1.0) - Fix docs.yml: add release trigger, enable -W (warnings as errors), add linkcheck - Remove doc.aiter.amd.com CNAME reference (DNS was never configured) - Rewrite gemm.rst: remove 8 nonexistent functions, document actual 25+ GEMM APIs - Rewrite attention.rst: remove fabricated GQA/MQA, document PA/MHA/MLA APIs - Rewrite operators.rst: document actual norm/activation/rope/quant/sample/cache APIs - Rewrite add_new_op.rst: replace CUDA build system with HIP/Triton AITER JIT pattern - Add new pages: compatibility matrix, supported models, GEMM tuning guide - Add new API docs: moe.rst, normalization.rst - Add stubs: changelog, contributing, triton_kernels, benchmarks - Add release-notify.yml: Slack webhook + downstream tracking issue automation - Add changelog-config.json for auto-generated release notes --- .github/changelog-config.json | 44 ++ .github/workflows/docs.yml | 36 +- .github/workflows/release-notify.yml | 86 ++++ docs/advanced/triton_kernels.rst | 41 ++ docs/api/attention.rst | 309 ++++++++------ docs/api/gemm.rst | 418 ++++++++++--------- docs/api/moe.rst | 92 +++++ docs/api/normalization.rst | 73 ++++ docs/api/operators.rst | 408 +++++++++--------- docs/changelog.rst | 14 + docs/compatibility.rst | 73 ++++ docs/conf.py | 12 +- docs/contributing.rst | 64 +++ docs/gemm_tuning.rst | 91 ++++ docs/index.rst | 107 ++--- docs/models.rst | 51 +++ docs/performance/benchmarks.rst | 32 ++ docs/tutorials/add_new_op.rst | 592 +++++++++------------------ 18 files changed, 1555 insertions(+), 988 deletions(-) create mode 100644 .github/changelog-config.json create mode 100644 .github/workflows/release-notify.yml create mode 100644 docs/advanced/triton_kernels.rst create mode 100644 docs/api/moe.rst create mode 100644 docs/api/normalization.rst create mode 100644 docs/changelog.rst create mode 100644 docs/compatibility.rst create mode 100644 docs/contributing.rst create mode 100644 docs/gemm_tuning.rst create mode 100644 docs/models.rst create mode 100644 docs/performance/benchmarks.rst diff --git a/.github/changelog-config.json b/.github/changelog-config.json new file mode 100644 index 0000000000..bf2338cf29 --- /dev/null +++ b/.github/changelog-config.json @@ -0,0 +1,44 @@ +{ + "categories": [ + { + "title": "## New Features", + "labels": ["feature", "enhancement"] + }, + { + "title": "## Performance", + "labels": ["performance", "optimization"] + }, + { + "title": "## Bug Fixes", + "labels": ["bug", "fix"] + }, + { + "title": "## Refactoring", + "labels": ["refactor", "cleanup"] + }, + { + "title": "## Infrastructure", + "labels": ["ci", "infrastructure", "build"] + } + ], + "sort": { + "order": "ASC", + "on_property": "mergedAt" + }, + "template": "#{{CHANGELOG}}\n\n**Full Changelog**: https://github.com/ROCm/aiter/compare/{{FROM_TAG}}...{{TO_TAG}}", + "pr_template": "- {{TITLE}} (#{{NUMBER}})", + "empty_template": "No changes.", + "label_extractor": [ + { + "pattern": "^(feat|feature|add|Add|new)", + "target": "$1", + "on_property": "title", + "method": "match", + "flags": "i" + } + ], + "max_tags_to_fetch": 200, + "max_pull_requests": 1000, + "max_back_track_time_days": 90, + "base_branches": ["main"] +} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5a46e71027..f50938a3c4 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -8,6 +8,8 @@ on: paths: - 'docs/**' - '.github/workflows/docs.yml' + release: + types: [published] pull_request: branches: - main @@ -47,7 +49,12 @@ jobs: - name: Build Sphinx documentation run: | cd docs - make html + sphinx-build -W --keep-going -b html . _build/html + + - name: Check for broken links + run: | + cd docs + sphinx-build -b linkcheck . _build/linkcheck || true # Non-blocking, report only - name: Upload documentation artifacts uses: actions/upload-artifact@v4 @@ -59,7 +66,9 @@ jobs: deploy-docs: needs: build-docs runs-on: ubuntu-latest - if: github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/docs-website') + if: > + (github.event_name == 'push' && (github.ref == 'refs/heads/main' || github.ref == 'refs/heads/docs-website' || startsWith(github.ref, 'refs/tags/v'))) + || github.event_name == 'release' permissions: contents: write # For GitHub Pages deployment @@ -76,27 +85,4 @@ jobs: with: github_token: ${{ secrets.GITHUB_TOKEN }} publish_dir: ./html - cname: doc.aiter.amd.com # Custom domain commit_message: 'docs: deploy documentation' - - # Alternative: Deploy to AMD servers via SSH - # deploy-to-amd: - # needs: build-docs - # runs-on: ubuntu-latest - # if: github.event_name == 'push' && github.ref == 'refs/heads/main' - # - # steps: - # - name: Download documentation artifacts - # uses: actions/download-artifact@v4 - # with: - # name: documentation - # path: ./html - # - # - name: Deploy to AMD doc server - # uses: easingthemes/ssh-deploy@v4 - # with: - # SSH_PRIVATE_KEY: ${{ secrets.AMD_DOC_SERVER_KEY }} - # REMOTE_HOST: doc.aiter.amd.com - # REMOTE_USER: deploy - # SOURCE: "html/" - # TARGET: "/var/www/doc.aiter.amd.com/html" diff --git a/.github/workflows/release-notify.yml b/.github/workflows/release-notify.yml new file mode 100644 index 0000000000..9f4780829e --- /dev/null +++ b/.github/workflows/release-notify.yml @@ -0,0 +1,86 @@ +name: Release Notification + +on: + release: + types: [published] + +jobs: + notify-slack: + runs-on: ubuntu-latest + if: "!github.event.release.prerelease" + steps: + - name: Notify Slack + uses: slackapi/slack-github-action@v2.0.0 + with: + webhook-url: ${{ secrets.AITER_RELEASE_SLACK_WEBHOOK }} + payload: | + { + "text": "AITER ${{ github.event.release.tag_name }} released", + "blocks": [ + { + "type": "header", + "text": { + "type": "plain_text", + "text": "AITER ${{ github.event.release.tag_name }} Released" + } + }, + { + "type": "section", + "text": { + "type": "mrkdwn", + "text": "${{ github.event.release.body && '```' || 'See release page for details.' }}" + } + }, + { + "type": "actions", + "elements": [ + { + "type": "button", + "text": { + "type": "plain_text", + "text": "View Release" + }, + "url": "${{ github.event.release.html_url }}" + } + ] + } + ] + } + + notify-downstream: + runs-on: ubuntu-latest + if: "!github.event.release.prerelease" + strategy: + matrix: + repo: ['ROCm/ATOM'] + steps: + - name: Create tracking issue + uses: actions/github-script@v7 + with: + github-token: ${{ secrets.CROSS_REPO_TOKEN }} + script: | + const tag = '${{ github.event.release.tag_name }}'; + const url = '${{ github.event.release.html_url }}'; + const repo = '${{ matrix.repo }}'.split('/'); + + await github.rest.issues.create({ + owner: repo[0], + repo: repo[1], + title: `AITER ${tag} available — dependency update`, + body: [ + `A new AITER release is available: [${tag}](${url})`, + '', + '### Action Items', + '- [ ] Update AITER dependency to ' + tag, + '- [ ] Run integration tests', + '- [ ] Verify accuracy baselines', + '', + '### Install', + '```bash', + 'pip install amd-aiter --find-links ' + url, + '```', + '', + 'Auto-generated by AITER release workflow.' + ].join('\n'), + labels: ['dependency-update'] + }); diff --git a/docs/advanced/triton_kernels.rst b/docs/advanced/triton_kernels.rst new file mode 100644 index 0000000000..5445438f9e --- /dev/null +++ b/docs/advanced/triton_kernels.rst @@ -0,0 +1,41 @@ +Triton Kernels +============== + +AITER uses `Triton `_ as one of its backend +implementations for GPU kernels. Triton kernels are written in Python and +compiled to GPU machine code at runtime. + +Source Locations +----------------- + +Triton kernel sources are located in two directories: + +- ``aiter/ops/triton/`` -- Triton-based operator implementations that are + called from the main ops API. +- ``aiter/_triton_kernels/`` -- Lower-level Triton kernel definitions used + internally by the operator layer. + +Benefits +-------- + +- **Portability**: Triton kernels work across GPU architectures without + per-target assembly code. +- **Maintainability**: Written in Python with Triton DSL, making them easier + to read and modify than hand-written assembly. +- **gfx1250 support**: On MI450 (CDNA 4), where hand-tuned ASM kernels are + not available, Triton is the primary compute backend alongside HIP. + +Other Backends +-------------- + +AITER supports multiple kernel backends depending on the GPU architecture and +operation: + +- **ASM** -- Hand-tuned assembly for peak performance on CDNA 3/3.5. +- **Composable Kernel (CK)** -- C++ template library for fused kernels. +- **CK Tile** -- Tile-based CK backend for structured operations. +- **FlyDSL** -- Domain-specific language for peak-performance kernel + generation (Meta donation). + +The backend selection is automatic based on the GPU architecture and operation +type. On gfx1250 (MI450), AITER uses Triton+HIP exclusively (no ASM, no CK). diff --git a/docs/api/attention.rst b/docs/api/attention.rst index 54a0e83976..4fdbcef619 100644 --- a/docs/api/attention.rst +++ b/docs/api/attention.rst @@ -1,167 +1,248 @@ Attention Operations ==================== -AITER provides highly optimized attention kernels for AMD GPUs with ROCm. +AITER provides GPU-optimized attention kernels for both training (MHA with forward +and backward passes) and inference (Paged Attention, Multi-Latent Attention). +All kernels target AMD Instinct GPUs via ROCm. -Flash Attention ---------------- +.. contents:: Sections + :local: + :depth: 1 -.. autofunction:: aiter.flash_attn_func -Standard flash attention implementation with optional causal masking. +Multi-Head Attention (Flash) +---------------------------- -**Parameters:** +Flash attention implementations with CK (Composable Kernel) backends. +Used for both training and inference. Located in ``aiter.ops.mha``. -* **query** (*torch.Tensor*) - Query tensor of shape ``(batch, seq_len, num_heads, head_dim)`` -* **key** (*torch.Tensor*) - Key tensor of shape ``(batch, seq_len, num_heads, head_dim)`` -* **value** (*torch.Tensor*) - Value tensor of shape ``(batch, seq_len, num_heads, head_dim)`` -* **causal** (*bool*, optional) - Whether to apply causal masking. Default: ``False`` -* **softmax_scale** (*float*, optional) - Scaling factor for softmax. Default: ``1/sqrt(head_dim)`` +High-Level API +~~~~~~~~~~~~~~ -**Returns:** +These are the primary user-facing functions with ``torch.autograd`` support. -* **output** (*torch.Tensor*) - Attention output of shape ``(batch, seq_len, num_heads, head_dim)`` +.. autofunction:: aiter.ops.mha.flash_attn_func -**Example:** +``flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1, 0), bias=None, alibi_slopes=None, deterministic=True, return_lse=False, return_attn_probs=False, how_v3_bf16_cvt=1, cu_seqlens_q=None, cu_seqlens_kv=None, sink_ptr=None)`` -.. code-block:: python +Standard flash attention forward/backward with autograd. +Supports MQA/GQA via fewer KV heads, causal masking, sliding window, ALiBi slopes, +and FP8 inputs. Dispatches to CK or FMHA v3 backend based on dtype and arch. - import torch - import aiter +- **q**: ``(batch, seqlen, nheads, headdim_q)`` +- **k**: ``(batch, seqlen, nheads_k, headdim_q)`` +- **v**: ``(batch, seqlen, nheads_k, headdim_v)`` +- **Returns**: ``out (batch, seqlen, nheads, headdim_v)``, optionally ``softmax_lse``, ``S_dmask`` - q = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16) - k = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16) - v = torch.randn(2, 1024, 16, 64, device='cuda', dtype=torch.float16) +.. autofunction:: aiter.ops.mha.flash_attn_varlen_func - output = aiter.flash_attn_func(q, k, v, causal=True) +``flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q=0, dropout_p=0.0, softmax_scale=None, logits_soft_cap=0.0, causal=False, window_size=(-1, -1, 0), bias=None, alibi_slopes=None, deterministic=False, return_lse=False, return_attn_probs=False, how_v3_bf16_cvt=1, block_table=None, out=None, ...)`` -Flash Attention with KV Cache ------------------------------- +Variable-length flash attention with autograd. Sequences are packed into a single +tensor and indexed by cumulative sequence lengths. -.. autofunction:: aiter.flash_attn_with_kvcache +- **q**: ``(total_q, nheads, headdim_q)`` +- **k**: ``(total_k, nheads_k, headdim_q)`` +- **v**: ``(total_k, nheads_k, headdim_v)`` +- **cu_seqlens_q**: ``(batch_size + 1,)`` cumulative query lengths +- **cu_seqlens_k**: ``(batch_size + 1,)`` cumulative key lengths +- **Returns**: ``out (total_q, nheads, headdim_v)``, optionally ``softmax_lse``, ``S_dmask`` -Optimized attention with paged KV cache support for inference. +FP8 Convenience Functions +~~~~~~~~~~~~~~~~~~~~~~~~~~ -**Parameters:** +.. autofunction:: aiter.ops.mha.flash_attn_fp8_pertensor_func -* **query** (*torch.Tensor*) - Query tensor ``(batch, seq_len, num_heads, head_dim)`` -* **kv_cache** (*torch.Tensor*) - Paged KV cache ``(num_blocks, num_heads, block_size, head_dim)`` -* **page_table** (*torch.Tensor*) - Page table mapping ``(batch, max_blocks_per_seq)`` -* **block_size** (*int*) - Size of each page block (e.g., 128) -* **causal** (*bool*, optional) - Causal masking. Default: ``True`` +``flash_attn_fp8_pertensor_func(q, k, v, q_descale, k_descale, v_descale, causal=False, window_size=(-1, -1, 0), softmax_scale=None, sink_ptr=None)`` -**Returns:** +Flash attention for FP8 inputs with per-tensor descaling. Forward-only (no autograd). -* **output** (*torch.Tensor*) - Attention output ``(batch, seq_len, num_heads, head_dim)`` +.. autofunction:: aiter.ops.mha.flash_attn_varlen_fp8_pertensor_func -**Example:** +``flash_attn_varlen_fp8_pertensor_func(q, k, v, q_descale, k_descale, v_descale, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ...)`` -.. code-block:: python +Variable-length FP8 flash attention with per-tensor descaling. Forward-only. - query = torch.randn(4, 128, 16, 64, device='cuda', dtype=torch.float16) - kv_cache = torch.randn(256, 16, 128, 64, device='cuda', dtype=torch.float16) - page_table = torch.randint(0, 256, (4, 32), device='cuda', dtype=torch.int32) +Batch Prefill +~~~~~~~~~~~~~ - output = aiter.flash_attn_with_kvcache( - query, kv_cache, page_table, block_size=128 - ) +.. autofunction:: aiter.ops.mha.mha_batch_prefill_func -Grouped Query Attention (GQA) ------------------------------- +``mha_batch_prefill_func(q, k, v, cu_seqlens_q, kv_indptr, kv_page_indices, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, logits_soft_cap=0.0, causal=False, window_size=(-1, -1), alibi_slopes=None, ...)`` -.. autofunction:: aiter.grouped_query_attention +Paged KV cache batch prefill attention. Supports both vectorized (5D) and linear (3D/4D) +KV cache layouts. -Efficient grouped query attention for models like Llama 2. +Low-Level CK Kernels +~~~~~~~~~~~~~~~~~~~~~ -**Parameters:** +These are the direct CK kernel wrappers. Most users should prefer the high-level API above. -* **query** (*torch.Tensor*) - ``(batch, seq_len, num_q_heads, head_dim)`` -* **key** (*torch.Tensor*) - ``(batch, seq_len, num_kv_heads, head_dim)`` -* **value** (*torch.Tensor*) - ``(batch, seq_len, num_kv_heads, head_dim)`` -* **num_groups** (*int*) - Number of query heads per KV head -* **causal** (*bool*, optional) - Causal masking. Default: ``False`` +.. autofunction:: aiter.ops.mha.mha_fwd -**Returns:** +``mha_fwd(q, k, v, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, sink_size, return_softmax_lse, return_dropout_randval, ...)`` -* **output** (*torch.Tensor*) - ``(batch, seq_len, num_q_heads, head_dim)`` +CK flash attention forward pass. Returns ``(out, softmax_lse, S_dmask, rng_state)``. -Multi-Query Attention (MQA) ----------------------------- +.. autofunction:: aiter.ops.mha.fmha_v3_fwd -.. autofunction:: aiter.multi_query_attention +``fmha_v3_fwd(q, k, v, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, return_softmax_lse, return_dropout_randval, how_v3_bf16_cvt, ...)`` -Multi-query attention where all query heads share single key/value heads. +FMHA v3 forward pass (newer CK backend). Returns ``(out, softmax_lse, S_dmask, rng_state)``. -**Parameters:** +.. autofunction:: aiter.ops.mha.mha_varlen_fwd -* **query** (*torch.Tensor*) - ``(batch, seq_len, num_heads, head_dim)`` -* **key** (*torch.Tensor*) - ``(batch, seq_len, 1, head_dim)`` -* **value** (*torch.Tensor*) - ``(batch, seq_len, 1, head_dim)`` -* **causal** (*bool*, optional) - Causal masking. Default: ``False`` +``mha_varlen_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q, dropout_p, softmax_scale, logits_soft_cap, zero_tensors, is_causal, window_size_left, window_size_right, sink_size, return_softmax_lse, return_dropout_randval, ...)`` -**Returns:** +Variable-length CK MHA forward. Returns ``(out, softmax_lse, S_dmask, rng_state)``. -* **output** (*torch.Tensor*) - ``(batch, seq_len, num_heads, head_dim)`` +.. autofunction:: aiter.ops.mha.fmha_v3_varlen_fwd -Variable Sequence Attention ----------------------------- +``fmha_v3_varlen_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q, dropout_p, softmax_scale, logits_soft_cap, zero_tensors, is_causal, window_size_left, window_size_right, return_softmax_lse, return_dropout_randval, how_v3_bf16_cvt, ...)`` + +FMHA v3 variable-length forward pass. + +.. autofunction:: aiter.ops.mha.mha_bwd + +``mha_bwd(dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, deterministic, dq=None, dk=None, dv=None, ...)`` + +CK MHA backward pass (training). Returns ``(dq, dk, dv, dbias)``. + +.. autofunction:: aiter.ops.mha.fmha_v3_bwd + +FMHA v3 backward pass (training). + +.. autofunction:: aiter.ops.mha.mha_varlen_bwd + +Variable-length CK MHA backward pass (training). + +.. autofunction:: aiter.ops.mha.fmha_v3_varlen_bwd + +FMHA v3 variable-length backward pass (training). + + +Paged Attention +--------------- + +Paged attention kernels for LLM decode-phase inference with block-based KV caches. +Located in ``aiter.ops.attention``. + +Core Functions +~~~~~~~~~~~~~~ + +.. autofunction:: aiter.ops.attention.paged_attention_rocm + +``paged_attention_rocm(out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1, q_scale=None)`` + +Main ROCm paged attention entry point. Custom CK-based implementation with +partitioned softmax. Supports FP8 KV cache, ALiBi, and multi-token prediction (MTP). + +.. autofunction:: aiter.ops.attention.paged_attention_v1 + +``paged_attention_v1(out, workspace_buffer, query, key_cache, value_cache, scale, block_tables, cu_query_lens, context_lens, max_context_len, alibi_slopes, kv_cache_dtype, kv_cache_layout, logits_soft_cap, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1, sliding_window=0)`` + +V1 paged attention with workspace buffer. Supports multiple KV cache layouts, +logits soft capping, and sliding window attention. + +.. autofunction:: aiter.ops.attention.paged_attention_ragged + +``paged_attention_ragged(out, workspace_buffer, query, key_cache, value_cache, scale, kv_indptr, kv_page_indices, kv_last_page_lens, block_size, max_num_partitions, alibi_slopes, kv_cache_dtype, kv_cache_layout, logits_soft_cap, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1)`` + +Ragged tensor paged attention. Uses indirect page indexing (``kv_indptr``, +``kv_page_indices``) instead of dense block tables. -.. autofunction:: aiter.variable_length_attention +ASM Paged Attention +~~~~~~~~~~~~~~~~~~~ -Attention with variable-length sequences using page tables. +Hand-tuned assembly kernels for maximum decode throughput. -**Parameters:** +.. autofunction:: aiter.ops.attention.pa_fwd_asm -* **query** (*torch.Tensor*) - Query tensor -* **key** (*torch.Tensor*) - Key tensor -* **value** (*torch.Tensor*) - Value tensor -* **seq_lengths** (*torch.Tensor*) - Actual sequence lengths ``(batch,)`` -* **max_seq_len** (*int*) - Maximum sequence length +``pa_fwd_asm(Q, K, V, block_tables, context_lens, block_tables_stride0, max_qlen=1, K_QScale=None, V_QScale=None, out_=None, qo_indptr=None, high_precision=1, kernelName=None)`` -**Returns:** +ASM paged attention forward. Supports FP8 KV cache via dequantization scales +(``K_QScale``, ``V_QScale``). The ``high_precision`` parameter controls FP8 +accumulation precision (0=low, 1=medium, 2=highest). -* **output** (*torch.Tensor*) - Attention output +.. autofunction:: aiter.ops.attention.pa_ps_fwd_asm -Supported Architectures ------------------------- +``pa_ps_fwd_asm(Q, K, V, kv_indptr, kv_page_indices, context_lens, softmax_scale, max_qlen=1, K_QScale=None, V_QScale=None, out_=None, qo_indptr=None, work_indptr=None, work_info=None, splitData=None, splitLse=None, mask=0, high_precision=1, kernelName=None, quant_type=QuantType.per_Token)`` -AITER attention kernels are optimized for: +PS-mode (persistent/split) ASM paged attention. Uses ragged page indexing and +supports work partitioning for large context lengths. -* **AMD Instinct MI300X** (gfx942) - Best performance -* **AMD Instinct MI250X** (gfx90a) - Fully supported -* **AMD Instinct MI300A** (gfx950) - Experimental +.. autofunction:: aiter.ops.attention.pa_persistent_fwd -Performance Characteristics +``pa_persistent_fwd(Q, K, V, output, max_qlen, qo_indptr, kv_indptr, kv_indices, context_lens, work_indptr, work_info, reduce_indptr, reduce_final_map, reduce_partial_map, K_QScale=None, V_QScale=None, softmax_scale=None, mask=0, quant_type=QuantType.per_Token)`` + +Persistent paged attention combining PS-mode forward with reduction. +Orchestrates ``pa_ps_fwd_asm`` + ``pa_reduce_v1`` for long-context decode. +Returns ``(logits, final_lse)``. + +vLLM-Compatible Wrapper +~~~~~~~~~~~~~~~~~~~~~~~~ + +Drop-in replacement for vLLM's paged attention layer. Located in ``aiter.paged_attn``. + +.. autoclass:: aiter.paged_attn.PagedAttention + :members: get_supported_head_sizes, get_kv_cache_shape, split_kv_cache, write_to_paged_cache, forward_decode, swap_blocks, copy_blocks + +.. autoclass:: aiter.paged_attn.PagedAttentionMetadata + :members: + +.. autofunction:: aiter.paged_attn.paged_attention_v1 + +``paged_attention_v1(out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ...)`` + +vLLM-compatible v1 paged attention (delegates to ``aiter.ops``). + +.. autofunction:: aiter.paged_attn.paged_attention_v2 + +``paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ...)`` + +vLLM-compatible v2 paged attention with partitioned softmax. + + +Multi-Latent Attention (MLA) ---------------------------- -.. list-table:: - :header-rows: 1 - :widths: 30 20 20 30 - - * - Operation - - Typical Speedup - - Memory Efficient - - Best For - * - flash_attn_func - - 2-4x vs PyTorch - - Yes - - Training & Inference - * - flash_attn_with_kvcache - - 3-6x vs naive - - Yes - - LLM Inference - * - grouped_query_attention - - 2-3x vs unfused - - Moderate - - Llama-style models - * - variable_length_attention - - 4-8x vs padded - - High - - Variable batches - -See Also --------- - -* :doc:`../tutorials/attention` - Attention tutorial -* :doc:`../tutorials/variable_length` - Variable-length sequences -* :doc:`../benchmarks` - Performance benchmarks +Attention kernels for DeepSeek-style Multi-Latent Attention, where key and value +are projected into a shared low-rank latent space. Located in ``aiter.mla``. +All MLA functions are inference-only. + +.. autofunction:: aiter.mla.mla_decode_fwd + +``mla_decode_fwd(q, kv_buffer, o, qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, max_seqlen_q, page_size=1, nhead_kv=1, sm_scale=None, logit_cap=0.0, num_kv_splits=None, ...)`` + +MLA decode-phase forward pass. Operates on paged KV buffers with the latent +dimension fused into ``kv_buffer``. Supports both ASM and Triton backends +with automatic split/reduce for long contexts. + +- **q**: ``(total_q, nheads, qk_head_dim)`` +- **kv_buffer**: ``(num_pages, page_size, nhead_kv, kv_lora_rank + qk_rope_head_dim)`` +- **o**: ``(total_q, nheads, v_head_dim)`` output buffer + +.. autofunction:: aiter.mla.mla_prefill_fwd + +``mla_prefill_fwd(q, kv_buffer, o, qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, max_seqlen_q, sm_scale=None, logit_cap=0.0, num_kv_splits=None)`` + +MLA prefill-phase forward pass. Uses ASM backend for the attention computation. + +- **q**: ``(num_seqs, num_heads, head_size)`` +- **kv_buffer**: ``(num_pages, page_size, nhead_kv, kv_lora_rank + qk_rope_head_dim)`` +- **o**: ``(num_seqs, num_heads, v_head_dim)`` + +.. autofunction:: aiter.mla.mla_prefill_ps_fwd + +``mla_prefill_ps_fwd(Q, K, V, output, qo_indptr, kv_indptr, kv_page_indices, work_indptr, work_info_set, max_seqlen_q, is_causal, reduce_indptr=None, reduce_final_map=None, reduce_partial_map=None, softmax_scale=None, q_scale=None, k_scale=None, v_scale=None)`` + +MLA prefill with persistent/split mode. Handles long prefill sequences via +work partitioning and multi-stage reduction. Supports FP8 via per-tensor scales. + +.. autofunction:: aiter.mla.mla_prefill_reduce + +``mla_prefill_reduce(partial_output, partial_lse, reduce_indptr, reduce_final_map, reduce_partial_map, output, tile_q=256, use_triton=True)`` + +Reduction kernel for MLA prefill split outputs. Combines partial attention +outputs using log-sum-exp for numerically stable merging. Implemented in Triton +with a PyTorch fallback. diff --git a/docs/api/gemm.rst b/docs/api/gemm.rst index 02ee407398..e5f8d350b8 100644 --- a/docs/api/gemm.rst +++ b/docs/api/gemm.rst @@ -1,284 +1,320 @@ GEMM Operations =============== -AITER provides optimized General Matrix Multiply (GEMM) operations for AMD GPUs. +AITER provides optimized General Matrix Multiply (GEMM) operations for AMD GPUs +across multiple precisions (FP8, BF16/FP16, FP4) with multiple backend +implementations (ASM, CK, CK Tile, Triton, FlyDSL). -Grouped GEMM ------------- -.. autofunction:: aiter.grouped_gemm +A8W8 (FP8) GEMM +---------------- -Efficient grouped matrix multiplication for Mixture of Experts (MoE) layers. +Functions in ``aiter.ops.gemm_op_a8w8``. All functions compute +``Out = dequant(XQ @ WQ^T)`` using FP8 (8-bit floating point) inputs with +per-tensor or block-wise scaling. -**Parameters:** +.. function:: gemm_a8w8_ck(XQ, WQ, x_scale, w_scale, Out, bias=None, splitK=0) -* **input** (*torch.Tensor*) - Input tensor ``(total_tokens, hidden_dim)`` -* **weights** (*torch.Tensor*) - Expert weights ``(num_experts, hidden_dim, output_dim)`` -* **expert_ids** (*torch.Tensor*) - Expert assignments ``(total_tokens, top_k)`` -* **topk** (*int*, optional) - Number of experts per token. Default: inferred from ``expert_ids`` + CK (Composable Kernel) based FP8 GEMM with per-tensor scaling. -**Returns:** + :param XQ: Activation tensor ``[M, K]``, FP8. + :param WQ: Weight tensor ``[N, K]``, FP8. + :param x_scale: Activation scale ``[M, 1]``, FP32. + :param w_scale: Weight scale ``[1, N]``, FP32. + :param Out: Pre-allocated output tensor ``[M, N]``. + :param bias: Optional bias tensor. + :param splitK: Split-K factor for parallelism (default 0 = auto). + :returns: Output tensor ``Out``. -* **output** (*torch.Tensor*) - Result ``(total_tokens, output_dim)`` +.. function:: gemm_a8w8_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Out, splitK=0) -**Example:** + CK-based FP8 GEMM with pre-shuffled weight layout for improved memory access. -.. code-block:: python + :param XQ: Activation tensor ``[M, K]``, FP8. + :param WQ: Pre-shuffled weight tensor ``[N, K]``, FP8. + :param x_scale: Activation scale. + :param w_scale: Weight scale. + :param Out: Pre-allocated output tensor ``[M, N]``. + :param splitK: Split-K factor (default 0). + :returns: Output tensor ``Out``. - import torch - import aiter +.. function:: gemm_a8w8_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, out, splitK=0) - # 4096 tokens, 512 hidden dim, 8 experts, top-2 routing - tokens = 4096 - hidden = 512 - num_experts = 8 - output_dim = 2048 + CK Tile variant of FP8 GEMM with pre-shuffled weights. - x = torch.randn(tokens, hidden, device='cuda', dtype=torch.float16) - expert_weights = torch.randn(num_experts, hidden, output_dim, - device='cuda', dtype=torch.float16) - expert_ids = torch.randint(0, num_experts, (tokens, 2), device='cuda') + Same interface as ``gemm_a8w8_bpreshuffle_ck``. - output = aiter.grouped_gemm(x, expert_weights, expert_ids) +.. function:: gemm_a8w8_bpreshuffle_flydsl(XQ, WQ, x_scale, w_scale, Out, config) -Batched GEMM ------------- + FlyDSL variant of FP8 GEMM with pre-shuffled weights. Falls back to + ``gemm_a8w8_bpreshuffle_ck`` if no matching FlyDSL kernel is found. -.. autofunction:: aiter.batched_gemm + :param config: Dictionary with ``kernelId`` selecting the FlyDSL kernel. -Batched matrix multiplication with optimizations for AMD hardware. +.. function:: gemm_a8w8_asm(XQ, WQ, x_scale, w_scale, Out, kernelName="", bias=None, bpreshuffle=True, splitK=None) -**Parameters:** + ASM (hand-tuned assembly) FP8 GEMM. Highest performance for supported shapes. -* **a** (*torch.Tensor*) - First batch ``(batch, m, k)`` -* **b** (*torch.Tensor*) - Second batch ``(batch, k, n)`` -* **transpose_a** (*bool*, optional) - Transpose A. Default: ``False`` -* **transpose_b** (*bool*, optional) - Transpose B. Default: ``False`` + :param XQ: Activation tensor ``[M, K]``, INT8/FP8. + :param WQ: Weight tensor ``[N, K]``, shuffled layout ``(32, 16)``. + :param x_scale: Activation scale ``[M, 1]``, FP32. + :param w_scale: Weight scale ``[1, N]``, FP32. + :param Out: Pre-allocated output tensor ``[M, N]``, BF16. + :param kernelName: Specific ASM kernel name (empty string = auto). + :param bias: Optional bias tensor ``[1, N]``, FP32. + :param bpreshuffle: Whether weights are pre-shuffled (default True). + :param splitK: Split-K factor (default None = auto). + :returns: Output tensor ``Out``. -**Returns:** +Block-Scale FP8 GEMM +^^^^^^^^^^^^^^^^^^^^^ -* **output** (*torch.Tensor*) - Result ``(batch, m, n)`` +Block-scale variants use per-block quantization scales instead of per-tensor +scales, providing better accuracy for large matrices. -Fused GEMM Operations ---------------------- +.. function:: gemm_a8w8_blockscale_ck(XQ, WQ, x_scale, w_scale, Out) -GEMM + Bias -^^^^^^^^^^^ + CK-based block-scale FP8 GEMM. -.. autofunction:: aiter.gemm_bias +.. function:: gemm_a8w8_blockscale_cktile(XQ, WQ, x_scale, w_scale, Out, isBpreshuffled=False) -Matrix multiply with bias addition fused. + CK Tile variant of block-scale FP8 GEMM. -**Parameters:** + :param isBpreshuffled: Whether the weight tensor uses pre-shuffled layout. -* **input** (*torch.Tensor*) - Input ``(m, k)`` -* **weight** (*torch.Tensor*) - Weight ``(k, n)`` or ``(n, k)`` if transposed -* **bias** (*torch.Tensor*) - Bias ``(n,)`` -* **transpose_weight** (*bool*, optional) - Default: ``True`` +.. function:: gemm_a8w8_blockscale_bpreshuffle_ck(XQ, WQ, x_scale, w_scale, Out) -**Returns:** + CK-based block-scale FP8 GEMM with pre-shuffled weights. -* **output** (*torch.Tensor*) - ``(m, n)`` +.. function:: gemm_a8w8_blockscale_bpreshuffle_cktile(XQ, WQ, x_scale, w_scale, Out, isBpreshuffled=True) -**Example:** + CK Tile variant of block-scale FP8 GEMM with pre-shuffled weights. -.. code-block:: python +.. function:: gemm_a8w8_blockscale_bpreshuffle_asm(A, B, out, A_scale, B_scale, bias=None, splitK=None, kernelName=None, bpreshuffle=True, zero_bias_buf=None) - x = torch.randn(1024, 512, device='cuda', dtype=torch.float16) - weight = torch.randn(2048, 512, device='cuda', dtype=torch.float16) - bias = torch.randn(2048, device='cuda', dtype=torch.float16) + ASM block-scale FP8 GEMM with pre-shuffled weights. - # Fused: y = x @ weight.T + bias - output = aiter.gemm_bias(x, weight, bias, transpose_weight=True) + :param zero_bias_buf: Optional zero-initialized bias buffer ``[1, N]``, FP32. + Auto-created if both ``bias`` and ``zero_bias_buf`` are None. -GEMM + GELU -^^^^^^^^^^^ +.. function:: flatmm_a8w8_blockscale_asm(XQ, WQ, x_scale, w_scale, out) -.. autofunction:: aiter.gemm_gelu + ASM block-scale FP8 flat matrix multiply. -Matrix multiply with GELU activation fused. -**Parameters:** +A16W16 (BF16/FP16) GEMM +------------------------ -* **input** (*torch.Tensor*) - Input tensor -* **weight** (*torch.Tensor*) - Weight tensor -* **bias** (*torch.Tensor*, optional) - Bias tensor +Functions in ``aiter.ops.gemm_op_a16w16``. -**Returns:** +.. function:: gemm_a16w16_asm(A, B, out, bias=None, splitK=None, kernelName=None, bpreshuffle=False) -* **output** (*torch.Tensor*) - Result with GELU applied + ASM-optimized BF16/FP16 GEMM. -GEMM + ReLU -^^^^^^^^^^^ + :param A: Activation tensor ``[M, K]``, BF16/FP16. + :param B: Weight tensor ``[N, K]``, BF16/FP16. + :param out: Pre-allocated output tensor ``[M, N]``. + :param bias: Optional bias tensor. + :param splitK: Split-K factor (default None = auto). + :param kernelName: Specific ASM kernel name. + :param bpreshuffle: Whether weights are pre-shuffled (default False). + :returns: Output tensor ``out``. -.. autofunction:: aiter.gemm_relu -Matrix multiply with ReLU activation fused. +A4W4 (FP4) GEMM +---------------- -**Parameters:** +Functions in ``aiter.ops.gemm_op_a4w4``. These operate on MXFP4 (4-bit +floating point) packed inputs where each byte holds two FP4 values. -* **input** (*torch.Tensor*) - Input tensor -* **weight** (*torch.Tensor*) - Weight tensor -* **bias** (*torch.Tensor*, optional) - Bias tensor +.. note:: -**Returns:** + A4W4 GEMM is **not supported** on gfx942 (MI300X). Supported on gfx950+. -* **output** (*torch.Tensor*) - Result with ReLU applied +.. function:: gemm_a4w4(A, B, A_scale, B_scale, bias=None, dtype=torch.bfloat16, alpha=1.0, beta=0.0, bpreshuffle=True) -CUTLASS-style GEMM ------------------- + Top-level FP4 GEMM. Auto-selects between ``gemm_a4w4_blockscale`` and + ``gemm_a4w4_asm`` based on tuned configuration. -.. autofunction:: aiter.cutlass_gemm + :param A: Activation tensor ``[M, K/2]``, packed FP4x2. + :param B: Weight tensor ``[N, K/2]``, packed FP4x2. + :param A_scale: Activation scale ``[M, K/32]``, E8M0 format. + :param B_scale: Weight scale ``[N, K/32]``, E8M0 format. + :param bias: Optional bias tensor ``[1, N]``, FP32. + :param dtype: Output dtype (default BF16). + :param alpha: Scalar multiplier (default 1.0). + :param beta: Accumulation scalar (default 0.0). + :param bpreshuffle: Whether weights are pre-shuffled (default True). + :returns: Output tensor ``[M, N]`` in ``dtype``. -High-performance GEMM using CUTLASS-inspired kernels for AMD. +.. function:: gemm_a4w4_asm(A, B, A_scale, B_scale, out, kernelName="", bias=None, alpha=1.0, beta=0.0, bpreshuffle=True, log2_k_split=None) -**Parameters:** + ASM-optimized FP4 GEMM. -* **a** (*torch.Tensor*) - Matrix A -* **b** (*torch.Tensor*) - Matrix B -* **alpha** (*float*, optional) - Scalar multiplier. Default: ``1.0`` -* **beta** (*float*, optional) - Scalar for accumulation. Default: ``0.0`` -* **c** (*torch.Tensor*, optional) - Accumulation matrix + :param out: Pre-allocated output tensor. Dim0 must be padded to multiples of 32. + :param log2_k_split: Log2 of the K-split factor. -**Returns:** +.. function:: gemm_a4w4_blockscale(XQ, WQ, x_scale, w_scale, Out, splitK=0) -* **output** (*torch.Tensor*) - Result: ``alpha * (A @ B) + beta * C`` + CK-based block-scale FP4 GEMM. -Sparse GEMM ------------ -.. autofunction:: aiter.sparse_gemm +Batched GEMM +------------- -Sparse matrix multiplication with various sparsity patterns. +Batched GEMM operations for processing multiple independent matrix multiplications +in a single kernel launch. -**Parameters:** +Batched FP8 GEMM +^^^^^^^^^^^^^^^^^ -* **input** (*torch.Tensor*) - Dense input -* **weight** (*torch.Tensor*) - Sparse weight (CSR/COO format) -* **sparsity_pattern** (*str*) - Pattern type: ``'csr'``, ``'coo'``, ``'block'`` +Functions in ``aiter.ops.batched_gemm_op_a8w8``. -**Returns:** +.. function:: batched_gemm_a8w8(XQ, WQ, x_scale, w_scale, out, bias=None, splitK=0) -* **output** (*torch.Tensor*) - Dense output + Low-level batched FP8 GEMM. Requires pre-allocated output tensor. -INT8 Quantized GEMM -------------------- + :param XQ: Batched activation tensor ``[B, M, K]``, FP8. + :param WQ: Batched weight tensor ``[B, K, N]``, FP8. + :param x_scale: Activation scale tensor. + :param w_scale: Weight scale tensor. + :param out: Pre-allocated output tensor ``[B, M, N]``. + :param bias: Optional bias tensor. + :param splitK: Split-K factor (default 0). + :returns: Output tensor ``out``. -.. autofunction:: aiter.int8_gemm +.. function:: batched_gemm_a8w8_CK(XQ, WQ, x_scale, w_scale, bias=None, dtype=torch.bfloat16, splitK=None) -INT8 quantized matrix multiplication for inference acceleration. + High-level batched FP8 GEMM with CK tuning. Auto-allocates output tensor and + selects optimal split-K from tuned configuration. -**Parameters:** + :param dtype: Output dtype (BF16 or FP16). + :returns: Output tensor ``[B, M, N]``. -* **input** (*torch.Tensor*) - Quantized input INT8 -* **weight** (*torch.Tensor*) - Quantized weight INT8 -* **input_scale** (*torch.Tensor*) - Input dequantization scale -* **weight_scale** (*torch.Tensor*) - Weight dequantization scale -* **output_dtype** (*torch.dtype*, optional) - Output type. Default: ``torch.float16`` +Batched BF16 GEMM +^^^^^^^^^^^^^^^^^^ -**Returns:** +Functions in ``aiter.ops.batched_gemm_op_bf16``. -* **output** (*torch.Tensor*) - Dequantized result +.. function:: batched_gemm_bf16(XQ, WQ, out, bias=None, splitK=0) -**Example:** + Low-level batched BF16 GEMM. -.. code-block:: python + :param XQ: Batched activation tensor ``[B, M, K]``, BF16. + :param WQ: Batched weight tensor ``[B, K, N]``, BF16. + :param out: Pre-allocated output tensor ``[B, M, N]``. + :param bias: Optional bias tensor. + :param splitK: Split-K factor (default 0). + :returns: Output tensor ``out``. - # Quantized matrices (simulated) - x_int8 = torch.randint(-127, 127, (1024, 512), device='cuda', dtype=torch.int8) - w_int8 = torch.randint(-127, 127, (2048, 512), device='cuda', dtype=torch.int8) +.. function:: batched_gemm_bf16_CK(XQ, WQ, bias=None, dtype=torch.bfloat16, splitK=None) - # Scales for dequantization - x_scale = torch.randn(1, device='cuda', dtype=torch.float32) - w_scale = torch.randn(1, device='cuda', dtype=torch.float32) + High-level batched BF16 GEMM with CK tuning. Auto-allocates output. - # INT8 GEMM with automatic dequantization - output = aiter.int8_gemm(x_int8, w_int8, x_scale, w_scale) + :param dtype: Output dtype (BF16 or FP16). + :returns: Output tensor ``[B, M, N]``. -Performance Characteristics ----------------------------- -.. list-table:: - :header-rows: 1 - :widths: 30 20 25 25 - - * - Operation - - Typical Speedup - - Best Use Case - - Memory Usage - * - grouped_gemm - - 3-10x vs loops - - MoE layers - - Low - * - batched_gemm - - 2-4x vs sequential - - Batched inference - - Moderate - * - gemm_bias - - 1.5-2x vs unfused - - Linear layers - - Low - * - int8_gemm - - 2-3x vs FP16 - - Quantized models - - Very Low - * - cutlass_gemm - - Best raw GEMM - - Large matrices - - Moderate - -Optimization Tips ------------------ +DeepGEMM +-------- + +Functions in ``aiter.ops.deepgemm``. DeepGEMM provides grouped GEMM operations +with explicit group layout control. + +.. function:: deepgemm(XQ, WQ, Y, group_layout, x_scale=None, w_scale=None) + + Top-level DeepGEMM entry point. Currently delegates to ``deepgemm_ck``. + + :param XQ: Activation tensor. + :param WQ: Weight tensor. + :param Y: Pre-allocated output tensor. + :param group_layout: Tensor describing the group layout. + :param x_scale: Optional activation scale. + :param w_scale: Optional weight scale. + :returns: Output tensor ``Y``. -1. **Matrix Dimensions**: Multiple of 128 for best performance (MI300X) -2. **Data Layout**: Row-major preferred for AMD GPUs -3. **Precision**: FP16/BF16 recommended over FP32 -4. **Fusion**: Use fused ops (gemm_bias, gemm_gelu) when possible -5. **Batch Size**: Larger batches improve throughput +.. function:: deepgemm_ck(XQ, WQ, Y, group_layout, x_scale=None, w_scale=None) -Example: Optimal MoE Forward Pass ----------------------------------- + CK-based DeepGEMM implementation. -.. code-block:: python - import torch - import aiter +Auto-Tuned GEMM +---------------- - class OptimizedMoELayer: - def __init__(self, hidden_dim, num_experts, expert_dim): - self.hidden_dim = hidden_dim - self.num_experts = num_experts - self.expert_dim = expert_dim +Functions in ``aiter.tuned_gemm``. The auto-tuning layer selects the best +backend (ASM, CK, hipBLASLt, Triton, FlyDSL, or PyTorch) based on matrix shape +and pre-computed tuning configurations. - # Expert weights (all experts in one tensor) - self.w1 = torch.randn(num_experts, hidden_dim, expert_dim, - device='cuda', dtype=torch.float16) - self.w2 = torch.randn(num_experts, expert_dim, hidden_dim, - device='cuda', dtype=torch.float16) +.. function:: gemm_a16w16(A, B, bias=None, otype=None, scale_a=None, scale_b=None, scale_c=None) - def forward(self, x, expert_ids, routing_weights): - # x: (total_tokens, hidden_dim) - # expert_ids: (total_tokens, top_k) - # routing_weights: (total_tokens, top_k) + Top-level BF16/FP16 GEMM with automatic backend selection. - # First grouped GEMM - hidden = aiter.grouped_gemm(x, self.w1, expert_ids) + The backend is chosen from tuned CSV configurations keyed on ``(M, N, K, dtype)``. + Supported backends: ``hipblaslt``, ``asm``, ``skinny``, ``triton``, ``flydsl``, + ``torch`` (fallback). - # Activation - hidden = aiter.gelu(hidden) + :param A: Activation tensor ``[M, K]`` (or ``[*, M, K]`` for batched). + :param B: Weight tensor ``[N, K]``. + :param bias: Optional bias tensor. + :param otype: Output dtype (default same as input). + :param scale_a: Optional activation scale (for FP8 inputs via hipBLASLt). + :param scale_b: Optional weight scale. + :param scale_c: Optional output scale. + :returns: Output tensor ``[M, N]``. - # Second grouped GEMM - output = aiter.grouped_gemm(hidden, self.w2, expert_ids) +.. class:: TunedGemm - # Apply routing weights - output = output * routing_weights.unsqueeze(-1) + Stateful wrapper around ``gemm_a16w16`` for BF16/FP16 GEMM with optional + FP8 per-tensor quantization scaling. + + .. method:: mm(inp, weights, bias=None, otype=None, scale_a=None, scale_b=None, scale_c=None) + + Delegates to :func:`gemm_a16w16`. + + A global instance ``tgemm`` is available as ``aiter.tuned_gemm.tgemm``. + + +Backend Selection +----------------- + +AITER provides multiple backend implementations for each precision: + +.. list-table:: + :header-rows: 1 + :widths: 15 50 35 + + * - Backend + - Description + - Used By + * - **CK** + - AMD Composable Kernel library. Default for most shapes. + - A8W8, A4W4, Batched, DeepGEMM + * - **CK Tile** + - Tile-based CK variant with different tiling strategies. + - A8W8 block-scale and pre-shuffle + * - **ASM** + - Hand-tuned GFX ISA assembly. Best peak performance. + - A8W8, A16W16, A4W4, block-scale + * - **FlyDSL** + - AMD FlyDSL code-generation framework. + - A8W8 pre-shuffle, A16W16 + * - **Triton** + - OpenAI Triton for AMD GPUs. Portable, CK-free path. + - A16W16 + * - **hipBLASLt** + - AMD hipBLASLt library (via ``hipb_mm``). + - A16W16 (gfx942) + * - **PyTorch** + - ``torch.nn.functional.linear`` fallback. + - A16W16 (untuned shapes) + +For production inference, use :func:`gemm_a16w16` or :class:`TunedGemm` which +automatically select the fastest backend. Use precision-specific functions +(``gemm_a8w8_*``, ``gemm_a4w4_*``) when you need explicit control over the +quantization format and backend. - return output See Also -------- -* :doc:`../tutorials/moe` - MoE tutorial -* :doc:`../tutorials/quantization` - INT8 quantization guide -* :doc:`moe` - MoE-specific operations -* :doc:`../benchmarks` - GEMM benchmarks +* :doc:`moe` - MoE-specific grouped GEMM operations diff --git a/docs/api/moe.rst b/docs/api/moe.rst new file mode 100644 index 0000000000..00134c176b --- /dev/null +++ b/docs/api/moe.rst @@ -0,0 +1,92 @@ +Mixture of Experts (MoE) API +============================= + +AITER provides fused MoE kernels optimized for AMD GPUs. These are used by +inference engines (ATOM, vLLM, SGLang) for MoE model architectures such as +DeepSeek-V3/R1, Kimi K2.5, and GPT-OSS 120B. + +Gating +------ + +.. py:function:: topk_softmax(topk_weights, topk_ids, token_expert_indicies, gating_output, topk) + + Fused top-k selection and softmax for MoE gating. Computes the top-k expert + scores from ``gating_output`` and returns normalized weights. + + :param topk_weights: Output tensor for the normalized top-k weights. + :param topk_ids: Output tensor for the selected expert indices. + :param token_expert_indicies: Output tensor for token-to-expert mapping. + :param gating_output: Raw gating logits of shape ``(num_tokens, num_experts)``. + :param topk: Number of experts to select per token. + +.. py:function:: topk_sigmoid(topk_weights, topk_ids, token_expert_indicies, gating_output, topk) + + Fused top-k selection and sigmoid gating. Similar to :func:`topk_softmax` + but uses sigmoid activation instead of softmax. + +Fused MoE +---------- + +.. py:function:: fmoe(input, w1, w2, topk_weights, topk_ids, ...) + + Main fused MoE forward pass. Dispatches tokens to selected experts, applies + expert weights (w1, w2), and combines results. + + :param input: Hidden states of shape ``(num_tokens, hidden_dim)``. + :param w1: First expert weight matrix. + :param w2: Second expert weight matrix. + :param topk_weights: Per-token expert weights from gating. + :param topk_ids: Per-token expert indices from gating. + +.. py:function:: fmoe_g1u1(input, gate, up, down, topk_weights, topk_ids, ...) + + Fused MoE with separate gate, up, and down projections (GLU-style). + Used by architectures that split the MoE FFN into gate/up/down matrices. + +.. py:function:: fmoe_int8_g1u0(...) + + INT8 quantized fused MoE. Applies INT8 weight-only quantization to the + expert computations. + +.. py:function:: fmoe_fp8_blockscale_g1u1(...) + + FP8 block-scale quantized fused MoE with gate/up/down projections. + Uses per-block scaling factors for FP8 computation. + +.. py:function:: fused_moe(hidden_states, w1, w2, gating_output, topk, ...) + + High-level fused MoE entry point (from ``aiter/fused_moe.py``). Combines + gating and expert computation in a single call. + + :param hidden_states: Input hidden states. + :param w1: First expert weight. + :param w2: Second expert weight. + :param gating_output: Raw gating logits. + :param topk: Number of experts per token. + +.. py:function:: fused_moe_2stages(...) + + Two-stage fused MoE: first sorts tokens by expert assignment, then runs + expert computation. Can improve memory locality for large expert counts. + +MoE Utilities +------------- + +.. py:function:: moe_align_block_size(topk_ids, block_size, num_experts) + + Align MoE dispatch block sizes to hardware-friendly boundaries. + + :param topk_ids: Expert assignment indices. + :param block_size: Target block size for alignment. + :param num_experts: Total number of experts. + +.. py:function:: moe_sorting(...) + + Sort tokens by their assigned expert. Used as a preprocessing step in + two-stage MoE execution. + +Source Files +------------ + +- ``aiter/ops/moe_op.py`` -- low-level MoE operations +- ``aiter/fused_moe.py`` -- high-level fused MoE interface diff --git a/docs/api/normalization.rst b/docs/api/normalization.rst new file mode 100644 index 0000000000..63dec2985d --- /dev/null +++ b/docs/api/normalization.rst @@ -0,0 +1,73 @@ +Normalization API +================= + +AITER provides optimized normalization kernels for AMD GPUs, including fused +variants that combine normalization with residual addition, quantization, or +both. + +LayerNorm +--------- + +.. py:function:: layer_norm(input, weight, bias, eps) + + Standard layer normalization. + + :param input: Input tensor. + :param weight: Learnable scale parameter. + :param bias: Learnable bias parameter. + :param eps: Small constant for numerical stability. + :returns: Normalized tensor. + +RMSNorm +------- + +.. py:function:: rms_norm(input, weight, eps) + + Root Mean Square normalization. Simpler and faster than LayerNorm as it + does not compute mean or use a bias term. + + :param input: Input tensor. + :param weight: Learnable scale parameter. + :param eps: Small constant for numerical stability. + +.. py:function:: rmsnorm2d_fwd(input, weight, eps) + + 2D RMS normalization forward pass. Operates on 2D input tensors of shape + ``(batch, hidden_dim)``. + +Fused Variants +-------------- + +These fused kernels combine RMSNorm with other operations to reduce memory +traffic and improve throughput. + +.. py:function:: rmsnorm2d_fwd_with_add(input, residual, weight, eps) + + Fused residual addition and RMS normalization. Computes + ``rmsnorm(input + residual)`` in a single kernel. + + :param input: Input tensor. + :param residual: Residual tensor to add before normalization. + :param weight: Learnable scale parameter. + :param eps: Numerical stability constant. + +.. py:function:: rmsnorm2d_fwd_with_smoothquant(...) + + Fused RMS normalization with SmoothQuant. Applies per-channel smooth + quantization scales after normalization. + +.. py:function:: rmsnorm2d_fwd_with_dynamicquant(...) + + Fused RMS normalization with dynamic quantization. Computes quantization + parameters on the fly and outputs quantized activations. + +.. py:function:: add_rmsnorm_quant(...) + + Fused residual add + RMS normalization + quantization in a single kernel. + Combines three operations to minimize global memory round-trips. + +Source Files +------------ + +- ``aiter/ops/norm.py`` -- LayerNorm and general norm operations +- ``aiter/ops/rmsnorm.py`` -- RMSNorm and fused RMSNorm variants diff --git a/docs/api/operators.rst b/docs/api/operators.rst index af322091d2..95e5b9b0a9 100644 --- a/docs/api/operators.rst +++ b/docs/api/operators.rst @@ -1,240 +1,230 @@ Core Operators ============== -RMSNorm -------- +AITER provides fused, high-performance operators for LLM inference on AMD GPUs. +All operators are JIT-compiled via the ``@compile_ops`` decorator from ``aiter.jit.core`` +and target AMD Instinct GPUs (gfx942, gfx950, gfx1250). -.. autofunction:: aiter.rmsnorm +Normalization +------------- -Root Mean Square Layer Normalization, commonly used in LLMs like Llama. +**Module:** ``aiter.ops.norm``, ``aiter.ops.rmsnorm`` -**Parameters:** - -* **x** (*torch.Tensor*) - Input tensor of shape ``(..., hidden_dim)`` -* **weight** (*torch.Tensor*) - Scaling weights of shape ``(hidden_dim,)`` -* **eps** (*float*, optional) - Epsilon for numerical stability. Default: ``1e-6`` - -**Returns:** - -* **output** (*torch.Tensor*) - Normalized tensor with same shape as input - -**Example:** - -.. code-block:: python - - import torch - import aiter - - x = torch.randn(2, 1024, 4096, device='cuda', dtype=torch.float16) - weight = torch.ones(4096, device='cuda', dtype=torch.float16) - - output = aiter.rmsnorm(x, weight, eps=1e-6) - -LayerNorm ---------- - -.. autofunction:: aiter.layernorm - -Standard layer normalization with optional bias. - -**Parameters:** - -* **x** (*torch.Tensor*) - Input tensor ``(..., hidden_dim)`` -* **weight** (*torch.Tensor*) - Weights ``(hidden_dim,)`` -* **bias** (*torch.Tensor*, optional) - Bias ``(hidden_dim,)`` -* **eps** (*float*, optional) - Epsilon. Default: ``1e-5`` - -**Returns:** - -* **output** (*torch.Tensor*) - Normalized output - -SoftMax -------- - -.. autofunction:: aiter.softmax - -Optimized softmax operation with optional masking. - -**Parameters:** - -* **x** (*torch.Tensor*) - Input tensor -* **dim** (*int*) - Dimension to apply softmax -* **mask** (*torch.Tensor*, optional) - Attention mask - -**Returns:** - -* **output** (*torch.Tensor*) - Softmax output - -GELU ----- - -.. autofunction:: aiter.gelu - -Fast GELU activation function. - -**Parameters:** - -* **x** (*torch.Tensor*) - Input tensor -* **approximate** (*str*, optional) - Approximation method. Options: ``'none'``, ``'tanh'``. Default: ``'none'`` - -**Returns:** - -* **output** (*torch.Tensor*) - GELU output - -**Example:** - -.. code-block:: python - - import torch - import aiter - - x = torch.randn(2, 1024, 4096, device='cuda', dtype=torch.float16) - - # Exact GELU - output_exact = aiter.gelu(x) - - # Fast approximate GELU - output_approx = aiter.gelu(x, approximate='tanh') - -SwiGLU ------- - -.. autofunction:: aiter.swiglu - -Swish-Gated Linear Unit activation. - -**Parameters:** - -* **x** (*torch.Tensor*) - Input tensor ``(..., 2 * hidden_dim)`` -* **dim** (*int*, optional) - Dimension to split. Default: ``-1`` - -**Returns:** - -* **output** (*torch.Tensor*) - SwiGLU output ``(..., hidden_dim)`` +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``layer_norm(input, weight, bias, epsilon, x_bias)`` + - Layer normalization (CK-based) + * - ``layernorm2d_fwd(input, weight, bias, epsilon, x_bias)`` + - 2D layer norm forward + * - ``layernorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, bias, epsilon)`` + - Fused residual add + layer norm + * - ``layernorm2d_fwd_with_smoothquant(out, input, xscale, yscale, weight, bias, epsilon)`` + - Fused layer norm + smooth quantization + * - ``rms_norm(input, weight, epsilon)`` + - RMS normalization (returns normalized tensor) + * - ``rmsnorm2d_fwd(input, weight, epsilon)`` + - 2D RMS norm forward (auto-selects CK or HIP backend) + * - ``rmsnorm2d_fwd_with_add(out, input, residual_in, residual_out, weight, epsilon)`` + - Fused residual add + RMS norm + * - ``rmsnorm2d_fwd_with_smoothquant(out, input, xscale, yscale, weight, epsilon)`` + - Fused RMS norm + smooth quantization + * - ``rmsnorm2d_fwd_with_dynamicquant(out, input, yscale, weight, epsilon, ...)`` + - Fused RMS norm + dynamic quantization + * - ``rmsnorm2d_fwd_with_add_dynamicquant(out, input, residual_in, residual_out, yscale, weight, epsilon, ...)`` + - Fused residual add + RMS norm + dynamic quantization + * - ``add_rmsnorm_quant(out, input, residual_in, residual_out, scale, weight, epsilon, ...)`` + - Fused add + RMS norm + quantization + * - ``rmsnorm_quant(out, input, scale, weight, epsilon, ...)`` + - RMS norm + quantization (no residual) + * - ``add_rmsnorm(out, input, residual_in, residual_out, weight, epsilon)`` + - Fused add + RMS norm (no quantization) + +Activation +---------- + +**Module:** ``aiter.ops.activation`` -Rotary Position Embedding (RoPE) +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``silu_and_mul(out, input)`` + - Fused SiLU activation + element-wise multiply + * - ``scaled_silu_and_mul(out, input, scale)`` + - Scaled SiLU + multiply (for FP8 output) + * - ``gelu_and_mul(out, input)`` + - Fused GELU + multiply + * - ``gelu_tanh_and_mul(out, input)`` + - Fused GELU-tanh approximation + multiply + * - ``gelu_fast(out, input)`` + - Fast GELU approximation + +All activation ops write results in-place to the ``out`` tensor. + +RoPE (Rotary Position Embedding) --------------------------------- -.. autofunction:: aiter.apply_rotary_pos_emb - -Apply rotary position embeddings to query and key tensors. - -**Parameters:** - -* **q** (*torch.Tensor*) - Query tensor ``(batch, seq_len, num_heads, head_dim)`` -* **k** (*torch.Tensor*) - Key tensor ``(batch, seq_len, num_heads, head_dim)`` -* **cos** (*torch.Tensor*) - Cosine embeddings ``(seq_len, head_dim // 2)`` -* **sin** (*torch.Tensor*) - Sine embeddings ``(seq_len, head_dim // 2)`` -* **position_ids** (*torch.Tensor*, optional) - Position indices - -**Returns:** - -* **q_rot** (*torch.Tensor*) - Rotated query -* **k_rot** (*torch.Tensor*) - Rotated key - -**Example:** - -.. code-block:: python - - import torch - import aiter - - seq_len, head_dim = 1024, 64 - q = torch.randn(2, seq_len, 16, head_dim, device='cuda', dtype=torch.float16) - k = torch.randn(2, seq_len, 16, head_dim, device='cuda', dtype=torch.float16) - - # Precompute RoPE embeddings - cos, sin = aiter.precompute_rope_embeddings(seq_len, head_dim) - - # Apply rotation - q_rot, k_rot = aiter.apply_rotary_pos_emb(q, k, cos, sin) - -Sampling Operations -------------------- - -Top-K Sampling -^^^^^^^^^^^^^^ - -.. autofunction:: aiter.top_k_sampling +**Module:** ``aiter.ops.rope`` -Sample from top-k logits. - -**Parameters:** +.. list-table:: + :header-rows: 1 + :widths: 40 60 -* **logits** (*torch.Tensor*) - Logits ``(batch, vocab_size)`` -* **k** (*int*) - Number of top candidates -* **temperature** (*float*, optional) - Sampling temperature. Default: ``1.0`` + * - Function + - Description + * - ``rope_fwd_impl(output, input, freqs, rotate_style, reuse_freqs_front_part, nope_first)`` + - RoPE forward (single input, uncached) + * - ``rope_bwd_impl(input_grads, output_grads, freqs, rotate_style, ...)`` + - RoPE backward + * - ``rope_2c_fwd_impl(output_x, output_y, input_x, input_y, freqs, ...)`` + - RoPE forward for two inputs (Q + K) + * - ``rope_2c_bwd_impl(...)`` + - RoPE backward for two inputs -**Returns:** +Inputs use ``sbhd`` layout. ``rotate_style``: 0 = NeoX (rotate 2nd half), 1 = GPT-J (rotate odd elements). -* **tokens** (*torch.Tensor*) - Sampled token IDs ``(batch,)`` +Quantization +------------ -Top-P (Nucleus) Sampling -^^^^^^^^^^^^^^^^^^^^^^^^^ +**Module:** ``aiter.ops.quant`` -.. autofunction:: aiter.top_p_sampling +.. list-table:: + :header-rows: 1 + :widths: 40 60 -Nucleus sampling with probability threshold. + * - Function + - Description + * - ``smoothquant_fwd(out, input, x_scale, y_scale)`` + - Smooth quantization forward + * - ``moe_smoothquant_fwd(out, input, x_scale, topk_ids, y_scale)`` + - MoE-aware smooth quantization + * - ``pertoken_quant(x, scale, x_scale, scale_dtype, quant_dtype)`` + - Per-token quantization (pure PyTorch) + * - ``per_1x32_f4_quant(x, scale, quant_dtype)`` + - FP4 block quantization (1x32 group size) -**Parameters:** +**Triton quantization kernels** (``aiter.ops.triton.quant``): -* **logits** (*torch.Tensor*) - Logits ``(batch, vocab_size)`` -* **p** (*float*) - Cumulative probability threshold (0.0 to 1.0) -* **temperature** (*float*, optional) - Temperature. Default: ``1.0`` +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``static_per_tensor_quant_fp8_i8(qx, x_in, scale_in)`` + - Static per-tensor FP8/INT8 quantization + * - ``dynamic_per_tensor_quant_fp8_i8(qx, x_in, scale_out)`` + - Dynamic per-tensor FP8/INT8 quantization + * - ``dynamic_per_token_quant_fp8_i8(qx, x_in, scale_out)`` + - Dynamic per-token FP8/INT8 quantization + * - ``dynamic_mxfp4_quant(...)`` + - Dynamic MXFP4 quantization + +Sampling +-------- -**Returns:** +**Module:** ``aiter.ops.sample`` -* **tokens** (*torch.Tensor*) - Sampled tokens ``(batch,)`` +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``greedy_sample(out, input)`` + - Greedy (argmax) sampling + * - ``random_sample(out, input, temperatures, lambd, generator, eps)`` + - Random sampling with temperature + * - ``random_sample_outer_exponential(out, input, exponentials, temperatures, eps)`` + - Random sampling with externally generated exponentials + * - ``mixed_sample(out, input, temperature, lambd, generator, eps)`` + - Mixed greedy/random sampling (per-token temperature) + * - ``mixed_sample_outer_exponential(out, input, exponentials, temperature, eps)`` + - Mixed sampling with external exponentials + * - ``exponential(out, lambd, generator, eps)`` + - Generate exponential random variates + +KV Cache +-------- -Performance Notes ------------------ +**Module:** ``aiter.ops.cache`` -All operators are optimized for AMD GPUs: +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``reshape_and_cache(key, value, key_cache, value_cache, slot_mapping, kv_cache_dtype, ...)`` + - Standard KV cache update + * - ``reshape_and_cache_flash(key, value, key_cache, value_cache, slot_mapping, ...)`` + - Flash-attention-style cache update + * - ``reshape_and_cache_with_pertoken_quant(key, value, key_cache, value_cache, k_dequant_scales, v_dequant_scales, slot_mapping, asm_layout)`` + - Cache update with per-token FP8 quantization + * - ``reshape_and_cache_with_block_quant(key, value, key_cache, value_cache, k_dequant_scales, v_dequant_scales, slot_mapping, asm_layout)`` + - Cache update with block-level quantization + * - ``concat_and_cache_mla(kv_c, k_pe, kv_cache, slot_mapping, kv_cache_dtype, scale)`` + - Multi-Latent Attention (MLA) cache update + * - ``copy_blocks(key_caches, value_caches, block_mapping)`` + - Copy cache blocks (for beam search) + * - ``swap_blocks(src, dst, block_mapping)`` + - Swap cache blocks between devices + +Top-K / MoE Gating +------------------- -* **FP16/BF16 preferred**: Best performance on MI300X -* **Large batches**: Better GPU utilization -* **Fused operations**: Many ops fused into single kernels -* **In-place when possible**: Reduces memory allocations +**Module:** ``aiter.ops.topk``, ``aiter.ops.moe_op`` -Supported Data Types ---------------------- +.. list-table:: + :header-rows: 1 + :widths: 40 60 + + * - Function + - Description + * - ``topk_softmax(topk_weights, topk_indices, token_expert_indices, gating_output, need_renorm, ...)`` + - Fused top-k + softmax for MoE gating + * - ``topk_sigmoid(topk_weights, topk_indices, gating_output)`` + - Fused top-k + sigmoid for MoE gating + * - ``grouped_topk(gating_output, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, ...)`` + - Grouped top-k selection (e.g., DeepSeek MoE) + * - ``biased_grouped_topk(gating_output, correction_bias, topk_weights, topk_ids, num_expert_group, topk_group, need_renorm, ...)`` + - Biased grouped top-k (auto-selects HIP or fused gate backend) + * - ``moe_fused_gate(input, bias, topk_weights, topk_ids, num_expert_group, topk_group, topk, n_share_experts_fusion, ...)`` + - Fused MoE gating kernel + * - ``moe_align_block_size(topk_ids, num_experts, block_size, sorted_token_ids, experts_ids, token_nums, num_tokens_post_pad)`` + - Align MoE token assignments to block boundaries + * - ``moe_sum(input, output)`` + - Sum MoE expert outputs + +Communication +------------- + +**Module:** ``aiter.ops.custom_all_reduce`` .. list-table:: :header-rows: 1 - :widths: 30 25 25 20 - - * - Operator - - FP32 - - FP16 - - BF16 - * - rmsnorm - - ✓ - - ✓ (fastest) - - ✓ - * - layernorm - - ✓ - - ✓ (fastest) - - ✓ - * - gelu - - ✓ - - ✓ (fastest) - - ✓ - * - swiglu - - ✓ - - ✓ (fastest) - - ✓ - * - apply_rotary_pos_emb - - ✓ - - ✓ (fastest) - - ✓ - * - sampling ops - - ✓ - - ✓ - - ✓ + :widths: 40 60 + + * - Function + - Description + * - ``all_reduce(_fa, inp, out, use_new, open_fp8_quant, reg_inp_ptr, reg_inp_bytes)`` + - Custom all-reduce (P2P IPC-based) + * - ``reduce_scatter(_fa, inp, out, reg_ptr, reg_bytes)`` + - Reduce-scatter + * - ``fused_allreduce_rmsnorm(_fa, inp, res_inp, res_out, out, w, eps, reg_ptr, reg_bytes, use_1stage)`` + - Fused all-reduce + RMS normalization + * - ``fused_allreduce_rmsnorm_quant(_fa, inp, res_inp, res_out, out, scale_out, w, eps, reg_ptr, reg_bytes, use_1stage)`` + - Fused all-reduce + RMS norm + quantization See Also -------- -* :doc:`../tutorials/normalization` - Normalization tutorial -* :doc:`../tutorials/custom_ops` - Adding custom operators -* :doc:`gemm` - Matrix multiplication operations +* :doc:`gemm` - Matrix multiplication operations (GEMM, batched GEMM, MoE GEMM) +* :doc:`attention` - Attention operations (paged attention, flash attention, MLA) +* :doc:`../tutorials/add_new_op` - How to add a new operator diff --git a/docs/changelog.rst b/docs/changelog.rst new file mode 100644 index 0000000000..b0fdb65679 --- /dev/null +++ b/docs/changelog.rst @@ -0,0 +1,14 @@ +Changelog +========= + +For detailed release notes, see the +`GitHub Releases page `_. + +Recent Versions +---------------- + +- **v0.1.12.post1** -- Latest stable release +- **v0.1.12-rc1** -- Release candidate for v0.1.12 +- **v0.1.9** -- Added FP8 block-scale MoE, expanded model configs +- **v0.1.7** -- Triton kernel improvements, gfx950 support +- **v0.1.5** -- Initial public release with GEMM, MoE, and attention kernels diff --git a/docs/compatibility.rst b/docs/compatibility.rst new file mode 100644 index 0000000000..585b751ab5 --- /dev/null +++ b/docs/compatibility.rst @@ -0,0 +1,73 @@ +ROCm Compatibility Matrix +========================= + +Supported GPU Architectures +---------------------------- + +.. list-table:: + :header-rows: 1 + :widths: 15 30 15 40 + + * - Architecture + - GPU + - gfx Target + - Status + * - CDNA 3 + - MI300A, MI300X, MI325X + - gfx942 + - Fully supported, pre-built wheels + * - CDNA 3.5 + - MI355X + - gfx950 + - Fully supported, pre-built wheels + * - CDNA 4 + - MI450 + - gfx1250 + - Experimental, Triton+HIP only + +ROCm Version Matrix (pre-built wheels) +---------------------------------------- + +.. list-table:: + :header-rows: 1 + :widths: 15 20 20 20 + + * - ROCm + - Python 3.10 + - Python 3.12 + - PyTorch + * - 7.2.1 + - yes + - yes + - 2.9.1 + * - 7.1.1 + - yes + - yes + - 2.10.0 + * - 7.0.2 + - yes + - yes + - 2.9.1 + +Installation +------------ + +From GitHub Release (recommended) +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +.. code-block:: bash + + pip install amd-aiter --find-links https://github.com/ROCm/aiter/releases/latest + +From source +^^^^^^^^^^^ + +.. code-block:: bash + + git clone --recursive https://github.com/ROCm/aiter.git + cd aiter && pip install -e . + +.. note:: + + Building from source requires a ROCm installation matching one of the + supported versions above, along with ``ninja`` and ``cmake``. diff --git a/docs/conf.py b/docs/conf.py index 67eda4a986..3c63ad2220 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,17 @@ project = "AITER" copyright = "2026, AMD" author = "AMD ROCm Team" -release = "0.1.0" +# Auto-detect version from setuptools_scm or git +try: + from importlib.metadata import version as get_version + release = get_version("amd-aiter") +except Exception: + try: + from setuptools_scm import get_version + release = get_version(root="..", relative_to=__file__) + except Exception: + release = "dev" +version = ".".join(release.split(".")[:2]) # -- General configuration --------------------------------------------------- extensions = [ diff --git a/docs/contributing.rst b/docs/contributing.rst new file mode 100644 index 0000000000..c0eaa0ce63 --- /dev/null +++ b/docs/contributing.rst @@ -0,0 +1,64 @@ +Contributing to AITER +===================== + +Setup +----- + +Clone the repository with submodules: + +.. code-block:: bash + + git clone --recursive https://github.com/ROCm/aiter.git + cd aiter + +Install development dependencies: + +.. code-block:: bash + + pip install -r requirements.txt + pip install ninja + +Running Tests +------------- + +.. code-block:: bash + + pytest tests/ + +Tests require a ROCm-capable GPU. Some tests are architecture-specific and +will be skipped automatically on unsupported hardware. + +Code Style +---------- + +AITER uses `ruff `_ for linting and formatting: + +.. code-block:: bash + + ruff check . + ruff format . + +Run both checks before submitting a pull request. CI will reject PRs with +lint or format violations. + +Pull Request Workflow +---------------------- + +1. Create a branch from ``main``: + + .. code-block:: bash + + git checkout -b my-feature main + +2. Make your changes and add tests where applicable. + +3. Run linting, formatting, and tests locally: + + .. code-block:: bash + + ruff check . && ruff format --check . && pytest tests/ + +4. Push your branch and open a pull request against ``main``. + +5. CI will run the ``aiter-test`` and ``triton-test`` pipelines on your PR. + All checks must pass before merge. diff --git a/docs/gemm_tuning.rst b/docs/gemm_tuning.rst new file mode 100644 index 0000000000..e8d01b8de8 --- /dev/null +++ b/docs/gemm_tuning.rst @@ -0,0 +1,91 @@ +GEMM Tuning Guide +================= + +AITER ships pre-tuned GEMM configurations for popular models. When serving a +new model or using untested shapes, you may see warnings like +``"not found tuned config"`` in server logs. This guide walks through the +tuning process. + +When to Tune +------------- + +Tune GEMM when: + +- Server logs show ``"not found tuned config"`` warnings for specific shapes. +- You are deploying a model not listed in :doc:`models`. +- You want to optimize for a specific batch size / concurrency profile. + +Step 1: Identify Missing Shapes +--------------------------------- + +Check server logs for lines like:: + + [WARNING] not found tuned config for M=128, N=4096, K=14336 + +Collect all unique ``(M, N, K)`` triples that need tuning. + +Step 2: Create an Untuned CSV +------------------------------ + +Create a CSV file listing the shapes to tune. The required columns are: + +.. list-table:: + :header-rows: 1 + :widths: 15 50 + + * - Column + - Description + * - M + - Batch dimension (number of tokens) + * - N + - Output dimension + * - K + - Input / reduction dimension + * - dtype + - Data type (e.g., ``bf16``, ``fp8``, ``a8w8``) + +Example ``untuned.csv``:: + + M,N,K,dtype + 128,4096,14336,bf16 + 256,4096,14336,bf16 + 512,14336,4096,bf16 + +**Recommended M values** for serving workloads: 1, 2, 4, 8, 16, 32, 64, 128, +256, 512, 1024, 2048, 4096. These cover typical decode (small M) and prefill +(large M) batch sizes. + +Step 3: Run the Tuner +---------------------- + +.. code-block:: bash + + python3 gradlib/gradlib/gemm_tuner.py \ + --tuned_file output.csv \ + --input_file untuned.csv + +The tuner benchmarks multiple kernel implementations (ASM, CK, Triton) for +each shape and records the fastest configuration. This process runs on GPU +and may take several minutes per shape. + +Step 4: Register the Tuned Config +----------------------------------- + +Copy the output CSV to the model configs directory: + +.. code-block:: bash + + cp output.csv aiter/configs/model_configs/_tuned_gemm.csv + +The inference engine will automatically load tuned configs from this directory +at startup. + +Tips +---- + +- Run tuning on the same GPU architecture you will deploy on. Tuned configs + are architecture-specific. +- For MoE models, the expert GEMM shapes may differ from the attention GEMM + shapes. Make sure to include both. +- Re-tune when upgrading ROCm or AITER versions, as kernel implementations + may change. diff --git a/docs/index.rst b/docs/index.rst index f68ffc3441..1878bbb52a 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,7 @@ AITER Documentation =================== -**AITER** (AMD Inference and Training Enhanced Repository) is AMD's high-performance AI operator library for ROCm, providing optimized kernels for inference and training workloads. +**AITER** (AI Tensor Engine for ROCm) is AMD's high-performance AI operator library for ROCm, providing optimized kernels for inference and training workloads. .. image:: https://img.shields.io/badge/ROCm-Compatible-red :target: https://rocm.docs.amd.com/ @@ -16,10 +16,10 @@ AITER Documentation Why AITER? ---------- -* **High Performance**: Optimized kernels using Triton, Composable Kernel (CK), and hand-written assembly +* **High Performance**: Optimized kernels using Triton, Composable Kernel (CK), ASM, and FlyDSL * **Comprehensive**: Supports both inference and training workloads * **Flexible**: C++ and Python APIs for easy integration -* **AMD Optimized**: Built specifically for AMD GPUs and the ROCm platform +* **AMD Optimized**: Built specifically for AMD Instinct GPUs and the ROCm platform Quick Start ----------- @@ -29,12 +29,13 @@ Installation .. code-block:: bash - pip install aiter # Coming soon! + # From GitHub Release (recommended) + pip install amd-aiter --find-links https://github.com/ROCm/aiter/releases/latest - # For now, install from source: + # From source git clone --recursive https://github.com/ROCm/aiter.git cd aiter - python3 setup.py develop + pip install -e . Quick Example ^^^^^^^^^^^^^ @@ -44,8 +45,13 @@ Quick Example import aiter import torch - # Example: Flash Attention - # TODO: Add actual example code + # RMS Normalization + x = torch.randn(2, 4096, dtype=torch.bfloat16, device="cuda") + weight = torch.ones(4096, dtype=torch.bfloat16, device="cuda") + out = aiter.rms_norm(x, weight, 1e-6) + + # Fused MoE + # See API Reference for full function signatures Core Features ------------- @@ -53,72 +59,78 @@ Core Features Attention Kernels ^^^^^^^^^^^^^^^^^ -* **Multi-Head Attention (MHA)**: Standard attention with optimized implementations -* **Multi-Latent Attention (MLA)**: DeepSeek-style latent attention -* **Paged Attention**: Efficient KV-cache management for serving +* **Multi-Head Attention (MHA)**: Flash attention forward and backward passes +* **Multi-Latent Attention (MLA)**: DeepSeek-style latent attention for decode and prefill +* **Paged Attention**: Efficient KV-cache management for serving (v1, v2, ragged, ASM) GEMM Operations ^^^^^^^^^^^^^^^ -* **Mixed Precision GEMM**: FP16, BF16, FP8, INT4 support -* **Tuned GEMM**: Pre-tuned configurations for common shapes -* **Fused Operations**: GEMM with activation fusion +* **FP8 GEMM (A8W8)**: Multiple backends -- CK, CK Tile, ASM, FlyDSL +* **BF16/FP16 GEMM (A16W16)**: ASM-optimized with auto-tuning +* **FP4 GEMM (A4W4)**: FP4 precision with block-scale support +* **Batched GEMM**: FP8 and BF16 batched operations +* **DeepGEMM**: Specialized deep GEMM kernels +* **Auto-Tuned GEMM**: Pre-tuned configurations for common model shapes Mixture of Experts (MoE) ^^^^^^^^^^^^^^^^^^^^^^^^^ -* **Fused MoE**: Optimized expert routing and computation -* **Multiple Routing**: Support for various routing strategies -* **Quantized Experts**: FP8 and INT4 expert weights +* **Fused MoE**: Optimized expert routing and computation (``fmoe``, ``fmoe_g1u1``) +* **Quantized MoE**: FP8 block-scale and INT8 expert weights +* **2-Stage MoE**: Sorting + compute pipeline for large expert counts -Normalization -^^^^^^^^^^^^^ +Normalization & Activation +^^^^^^^^^^^^^^^^^^^^^^^^^^ -* **RMSNorm**: Root mean square normalization -* **LayerNorm**: Standard layer normalization -* **Fused Variants**: Combined with other operations +* **RMSNorm / LayerNorm**: With fused variants (residual add, quantization) +* **Activation**: SiLU, GELU, GELU-tanh (fused with multiply) Other Operators ^^^^^^^^^^^^^^^ -* **RoPE**: Rotary position embeddings -* **Quantization**: BF16/FP16 → FP8/INT4 conversion -* **Element-wise**: Optimized basic operations -* **Communication**: AllReduce and collective operations via Triton/Iris +* **RoPE**: Rotary position embeddings (forward, backward, cached) +* **Quantization**: Per-token, per-tensor, per-group, FP4/FP8 conversion +* **KV Cache**: reshape_and_cache with optional quantization +* **Sampling**: Greedy, random, mixed sampling kernels +* **Communication**: Custom AllReduce, fused AllReduce+RMSNorm+Quant GPU Support ----------- -AITER supports AMD GPUs with the following architectures: - .. list-table:: :header-rows: 1 - :widths: 20 20 30 30 + :widths: 20 15 25 20 20 * - Architecture - gfx Target - - Example GPUs + - GPUs - ROCm Version - * - CDNA 2 - - gfx90a - - MI210, MI250, MI250X - - ROCm 5.0+ + - Status * - CDNA 3 - gfx942 - - MI300A, MI300X - - ROCm 6.0+ + - MI300A, MI300X, MI325X + - ROCm 7.0+ + - Fully supported * - CDNA 3.5 - gfx950 - - MI350X (upcoming) - - ROCm 6.3+ + - MI355X + - ROCm 7.0+ + - Fully supported + * - CDNA 4 + - gfx1250 + - MI450 + - ROCm 7.2+ + - Experimental Quick Links ----------- -* 🚀 :doc:`quickstart` - Get started in 5 minutes -* 📖 :doc:`tutorials/add_new_op` - **How to add a new operator** (step-by-step) -* 🔧 :doc:`api/attention` - Flash Attention API -* 💡 :doc:`tutorials/basic_usage` - Basic usage examples +* :doc:`quickstart` - Get started in 5 minutes +* :doc:`compatibility` - ROCm version matrix and installation options +* :doc:`models` - Supported model architectures +* :doc:`tutorials/add_new_op` - How to add a new operator +* :doc:`gemm_tuning` - GEMM performance tuning guide Table of Contents ----------------- @@ -129,7 +141,8 @@ Table of Contents installation quickstart - tutorials/index + compatibility + models .. toctree:: :maxdepth: 2 @@ -143,12 +156,12 @@ Table of Contents .. toctree:: :maxdepth: 2 - :caption: Advanced Topics + :caption: Guides - performance/benchmarks - performance/profiling + tutorials/index + gemm_tuning advanced/triton_kernels - advanced/ck_integration + performance/benchmarks .. toctree:: :maxdepth: 1 diff --git a/docs/models.rst b/docs/models.rst new file mode 100644 index 0000000000..f77c06482a --- /dev/null +++ b/docs/models.rst @@ -0,0 +1,51 @@ +Supported Models +================ + +AITER provides optimized kernels used by inference engines (ATOM, vLLM, SGLang) +for a variety of model architectures. The table below lists validated model +families and the key AITER operations they use. + +Model Matrix +------------ + +.. list-table:: + :header-rows: 1 + :widths: 18 22 30 30 + + * - Model Family + - Architecture + - Key AITER Ops + - Tuned GEMM Configs + * - DeepSeek-V3/R1 + - MoE (MLA + FusedMoE) + - MLA decode/prefill, fused_moe, paged_attention + - ``dsv3_bf16_tuned_gemm.csv`` + * - Kimi K2.5 + - MoE + - fused_moe, paged_attention + - ``kimik2_bf16_tuned_gemm.csv`` + * - GLM-5 + - Dense + - GEMM (FP8 blockscale), paged_attention + - ``glm5_a8w8_blockscale_bpreshuffle_tuned_gemm.csv`` + * - GPT-OSS 120B + - MoE + - fused_moe, paged_attention + - ``gptoss_bf16_tuned_gemm.csv`` + * - Qwen3/3.5 + - MoE + - fused_moe, paged_attention + - ``a8w8_blockscale_tuned_gemm_qwen3_5_397b_a13b.csv`` + * - MiniMax-M2.5 + - MoE + - fused_moe, paged_attention + - (uses default configs) + * - Llama 3.x + - Dense + - GEMM, flash attention, paged_attention + - (uses default configs) + +.. note:: + + Tuned GEMM configs are stored in ``aiter/configs/model_configs/``. To tune + GEMM for a new model, see the :doc:`gemm_tuning` guide. diff --git a/docs/performance/benchmarks.rst b/docs/performance/benchmarks.rst new file mode 100644 index 0000000000..f0eb68ec82 --- /dev/null +++ b/docs/performance/benchmarks.rst @@ -0,0 +1,32 @@ +Performance Benchmarks +====================== + +Up-to-date performance results for AITER kernels on AMD Instinct GPUs are +published on the `AMD AI Frameworks Performance Dashboard +`_. + +The dashboard includes: + +- Operator-level throughput comparisons (GEMM, MoE, Attention, Norm) +- End-to-end model serving throughput and latency +- Cross-platform comparisons (MI300X, MI325X, MI355X vs. NVIDIA B200, B300) + +Running Kernel Benchmarks Locally +---------------------------------- + +AITER includes kernel-level benchmarks in its test suite. To run them: + +.. code-block:: bash + + pytest tests/ -k "benchmark" + +For GEMM-specific benchmarking, use the GEMM tuner in benchmark mode: + +.. code-block:: bash + + python3 gradlib/gradlib/gemm_tuner.py \ + --tuned_file results.csv \ + --input_file shapes.csv + +See :doc:`../gemm_tuning` for details on shape file format and tuning +workflow. diff --git a/docs/tutorials/add_new_op.rst b/docs/tutorials/add_new_op.rst index bd0c1b80d0..f0ffeabc05 100644 --- a/docs/tutorials/add_new_op.rst +++ b/docs/tutorials/add_new_op.rst @@ -1,492 +1,282 @@ How to Add a New Operator ========================== -This tutorial shows you how to add a custom operator to AITER. +This tutorial shows how to add a custom operator to AITER using the JIT +compilation system. AITER kernels are written in HIP C++ or Triton and are +JIT-compiled at first use via ``ninja``. Overview -------- -Adding a new operator involves: +1. Write the kernel (HIP C++ in ``csrc/`` or Triton in ``aiter/ops/triton/``) +2. Register the build config in ``aiter/jit/optCompilerConfig.json`` +3. Create the Python op in ``aiter/ops/`` +4. Provide a fake-tensor implementation for ``torch.compile`` +5. Add tests in ``op_tests/`` -1. **Define the operator interface** (Python) -2. **Implement the kernel** (ROCm/HIP C++) -3. **Create Python bindings** (PyBind11) -4. **Add tests** -5. **Register the operator** - -Step 1: Define the Operator Interface --------------------------------------- - -Create your operator's Python interface in ``aiter/ops/``: +Option A: HIP C++ Kernel +-------------------------- -.. code-block:: python +Step 1: Write the HIP Kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # aiter/ops/my_custom_op.py - import torch - from typing import Optional - - def my_custom_op( - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: str = "gelu" - ) -> torch.Tensor: - """ - Custom operator that does something awesome. - - Args: - input: Input tensor (batch, seq_len, hidden_dim) - weight: Weight tensor (hidden_dim, output_dim) - bias: Optional bias tensor (output_dim,) - activation: Activation function ('gelu', 'relu', 'none') - - Returns: - Output tensor (batch, seq_len, output_dim) - """ - # Import the C++ extension - from aiter._C import my_custom_op_impl - - # Input validation - assert input.is_cuda, "Input must be on CUDA device" - assert input.dtype in [torch.float16, torch.bfloat16], \ - "Only FP16/BF16 supported" - - # Call C++ implementation - return my_custom_op_impl(input, weight, bias, activation) - -Step 2: Implement the ROCm Kernel ----------------------------------- - -Create the kernel implementation in ``csrc/``: +Create a kernel file in ``csrc/kernels/``. AITER targets AMD GPUs, so use HIP +APIs and ``__hip_bfloat16`` (not ``__nv_bfloat16``). Source files use the +``.cu`` extension but are compiled with ``hipcc``. .. code-block:: cpp - // csrc/my_custom_op.hip + // csrc/kernels/my_op_kernels.cu #include - #include - - // Kernel implementation - template - __global__ void my_custom_kernel( - const T* input, - const T* weight, - const T* bias, - T* output, - int batch_size, - int seq_len, - int hidden_dim, - int output_dim + #include + #include + + template + __global__ void my_op_kernel( + const T* __restrict__ input, + T* __restrict__ output, + int n ) { int idx = blockIdx.x * blockDim.x + threadIdx.x; - int total_elements = batch_size * seq_len * output_dim; - - if (idx < total_elements) { - int b = idx / (seq_len * output_dim); - int s = (idx / output_dim) % seq_len; - int o = idx % output_dim; - - // Your computation here - T sum = 0; - for (int h = 0; h < hidden_dim; h++) { - int input_idx = b * seq_len * hidden_dim + s * hidden_dim + h; - int weight_idx = h * output_dim + o; - sum += input[input_idx] * weight[weight_idx]; - } - - if (bias != nullptr) { - sum += bias[o]; - } - - // Apply activation - // (GELU, ReLU, etc.) - output[idx] = sum; + if (idx < n) { + output[idx] = input[idx]; // your computation here } } - // Host function - torch::Tensor my_custom_op_cuda( - torch::Tensor input, - torch::Tensor weight, - torch::Tensor bias, - std::string activation + void launch_my_op( + const void* input, void* output, int n, hipStream_t stream ) { - // Get dimensions - auto batch_size = input.size(0); - auto seq_len = input.size(1); - auto hidden_dim = input.size(2); - auto output_dim = weight.size(1); - - // Allocate output - auto output = torch::empty( - {batch_size, seq_len, output_dim}, - input.options() - ); - - // Launch kernel - int total_elements = batch_size * seq_len * output_dim; int threads = 256; - int blocks = (total_elements + threads - 1) / threads; - - if (input.dtype() == torch::kFloat16) { - my_custom_kernel<__half><<>>( - reinterpret_cast<__half*>(input.data_ptr()), - reinterpret_cast<__half*>(weight.data_ptr()), - bias.defined() ? reinterpret_cast<__half*>(bias.data_ptr()) : nullptr, - reinterpret_cast<__half*>(output.data_ptr()), - batch_size, seq_len, hidden_dim, output_dim - ); - } else { - // BF16 case - my_custom_kernel<__nv_bfloat16><<>>( - reinterpret_cast<__nv_bfloat16*>(input.data_ptr()), - reinterpret_cast<__nv_bfloat16*>(weight.data_ptr()), - bias.defined() ? reinterpret_cast<__nv_bfloat16*>(bias.data_ptr()) : nullptr, - reinterpret_cast<__nv_bfloat16*>(output.data_ptr()), - batch_size, seq_len, hidden_dim, output_dim - ); - } - - return output; + int blocks = (n + threads - 1) / threads; + my_op_kernel<__half><<>>( + static_cast(input), + static_cast<__half*>(output), + n + ); } -Step 3: Create Python Bindings -------------------------------- +GPU architecture targets are ``gfx942`` (MI300X), ``gfx950`` (MI355X), and +``gfx1250`` (MI450). The JIT system detects the current GPU and compiles for +the correct target automatically. -Add PyBind11 bindings in ``csrc/my_custom_op_bindings.cpp``: +Step 2: Write the PyBind Interface +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Create a pybind wrapper in ``csrc/pybind/``: .. code-block:: cpp + // csrc/pybind/my_op_pybind.cu #include - // Forward declare CUDA function - torch::Tensor my_custom_op_cuda( - torch::Tensor input, - torch::Tensor weight, - torch::Tensor bias, - std::string activation - ); - - // Wrapper for Python - torch::Tensor my_custom_op_impl( - torch::Tensor input, - torch::Tensor weight, - torch::Tensor bias, - std::string activation - ) { - TORCH_CHECK(input.is_cuda(), "Input must be CUDA tensor"); - return my_custom_op_cuda(input, weight, bias, activation); + // Forward declaration + void launch_my_op(const void* input, void* output, int n, hipStream_t stream); + + void my_op_fwd(torch::Tensor input, torch::Tensor output) { + int n = input.numel(); + auto stream = at::cuda::getCurrentHIPStream().stream(); + launch_my_op(input.data_ptr(), output.data_ptr(), n, stream); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("my_custom_op_impl", &my_custom_op_impl, - "My custom operator (CUDA)", - py::arg("input"), - py::arg("weight"), - py::arg("bias"), - py::arg("activation")); + m.def("my_op_fwd", &my_op_fwd, "My custom op forward"); } -Step 4: Update Build Configuration ------------------------------------ +Step 3: Register the Build Config +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Add an entry to ``aiter/jit/optCompilerConfig.json``: + +.. code-block:: json + + { + "module_my_op": { + "srcs": [ + "f'{AITER_CSRC_DIR}/pybind/my_op_pybind.cu'", + "f'{AITER_CSRC_DIR}/kernels/my_op_kernels.cu'" + ], + "flags_extra_cc": [], + "flags_extra_hip": [], + "extra_ldflags": "None", + "extra_include": [], + "verbose": "False", + "blob_gen_cmd": "''" + } + } -Add your operator to ``setup.py``: +The ``srcs`` entries are f-string expressions evaluated at build time. +``AITER_CSRC_DIR`` points to the ``csrc/`` directory. -.. code-block:: python +Step 4: Create the Python Op +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - # setup.py - from setuptools import setup - from torch.utils.cpp_extension import BuildExtension, CUDAExtension - - setup( - name='aiter', - ext_modules=[ - CUDAExtension( - name='aiter._C', - sources=[ - 'csrc/my_custom_op.hip', - 'csrc/my_custom_op_bindings.cpp', - # ... other sources - ], - extra_compile_args={ - 'cxx': ['-O3', '-std=c++17'], - 'nvcc': [ - '-O3', - '--use_fast_math', - '-gencode', 'arch=compute_90a,code=sm_90a', # MI250X - '-gencode', 'arch=compute_942,code=sm_942', # MI300X - ] - } - ), - ], - cmdclass={'build_ext': BuildExtension} - ) - -Step 5: Add Tests ------------------ - -Create tests in ``tests/test_my_custom_op.py``: +Create ``aiter/ops/my_op.py`` using the ``@compile_ops`` decorator. This +decorator handles JIT compilation, module caching, and ``torch.compile`` +registration. .. code-block:: python + # aiter/ops/my_op.py import torch - import pytest - from aiter.ops import my_custom_op - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) - @pytest.mark.parametrize("batch_size", [1, 4, 16]) - @pytest.mark.parametrize("seq_len", [128, 512, 2048]) - def test_my_custom_op_correctness(dtype, batch_size, seq_len): - hidden_dim = 512 - output_dim = 2048 - - # Create inputs - input = torch.randn(batch_size, seq_len, hidden_dim, - device='cuda', dtype=dtype) - weight = torch.randn(hidden_dim, output_dim, - device='cuda', dtype=dtype) - bias = torch.randn(output_dim, device='cuda', dtype=dtype) - - # Run custom op - output = my_custom_op(input, weight, bias, activation='gelu') - - # Reference implementation (PyTorch) - ref_output = torch.matmul(input, weight) - if bias is not None: - ref_output = ref_output + bias - ref_output = torch.nn.functional.gelu(ref_output) - - # Check correctness - torch.testing.assert_close( - output, ref_output, - rtol=1e-2, atol=1e-2 # FP16/BF16 tolerance - ) + from torch import Tensor + from ..jit.core import compile_ops - def test_my_custom_op_performance(): - batch_size, seq_len = 16, 2048 - hidden_dim, output_dim = 4096, 4096 - - input = torch.randn(batch_size, seq_len, hidden_dim, - device='cuda', dtype=torch.float16) - weight = torch.randn(hidden_dim, output_dim, - device='cuda', dtype=torch.float16) - bias = torch.randn(output_dim, device='cuda', dtype=torch.float16) - - # Warmup - for _ in range(10): - _ = my_custom_op(input, weight, bias) - torch.cuda.synchronize() - - # Benchmark - import time - start = time.time() - for _ in range(100): - output = my_custom_op(input, weight, bias) - torch.cuda.synchronize() - elapsed = time.time() - start - - print(f"Average time: {elapsed/100*1000:.2f} ms") - print(f"Throughput: {batch_size*seq_len*100/elapsed:.2f} tokens/sec") - -Step 6: Build and Install --------------------------- - -Build your extension: - -.. code-block:: bash - # Clean build - python setup.py clean - rm -rf build/ + def gen_my_op_fake(input: Tensor) -> Tensor: + """Fake tensor impl for torch.compile tracing.""" + return torch.empty_like(input) - # Build and install - python setup.py develop - # Or for production - python setup.py install + @compile_ops("module_my_op", gen_fake=gen_my_op_fake) + def my_op_fwd(input: Tensor) -> Tensor: + """My custom operator.""" + ... -Step 7: Register in Main Module --------------------------------- +Key points: -Add to ``aiter/__init__.py``: +- The first argument to ``@compile_ops`` is the module name matching the key + in ``optCompilerConfig.json``. +- The function body is ``...`` (ellipsis). The decorator replaces it with the + JIT-compiled C++ implementation at runtime. +- The function name must match the pybind function name. Use ``fc_name`` if + they differ: ``@compile_ops("module_my_op", fc_name="my_op_fwd")``. +- The ``gen_fake`` callable returns tensors with the correct shape/dtype for + ``torch.compile`` tracing without running the real kernel. -.. code-block:: python - - # aiter/__init__.py - from aiter.ops.my_custom_op import my_custom_op - - __all__ = [ - 'my_custom_op', - # ... other exports - ] - -Advanced: Optimizations ------------------------ - -Use CK (Composable Kernel) for Better Performance -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -.. code-block:: cpp +Option B: Triton Kernel +-------------------------- - #include "ck/tensor_operation/gpu/device/device_gemm.hpp" +Triton kernels live under ``aiter/ops/triton/`` and do not need +``optCompilerConfig.json`` entries. They are compiled by the Triton JIT +compiler directly. - // Use CK's optimized GEMM - using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemm< - /* ... template parameters ... */ - >; +Step 1: Write the Triton Kernel +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Use Triton for Easier Kernel Development -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Create the kernel in ``aiter/ops/triton/_triton_kernels/``: .. code-block:: python + # aiter/ops/triton/_triton_kernels/my_triton_op.py import triton import triton.language as tl @triton.jit - def my_custom_kernel( - input_ptr, weight_ptr, output_ptr, - M, N, K, - BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr + def _my_triton_kernel( + input_ptr, output_ptr, + n_elements, + BLOCK_SIZE: tl.constexpr, ): - # Triton kernel implementation - # (Much easier than raw HIP/CUDA!) - pass - -Common Patterns ---------------- - -Pattern 1: Fused Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -Combine multiple ops into one kernel: + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_elements -.. code-block:: python - - def fused_linear_gelu(input, weight, bias): - """ - Fuses: output = GELU(input @ weight + bias) - Faster than separate ops! - """ - pass + x = tl.load(input_ptr + offsets, mask=mask) + # Your computation here + tl.store(output_ptr + offsets, x, mask=mask) -Pattern 2: In-Place Operations -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Step 2: Write the Python Wrapper +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -Modify tensors in-place to save memory: +Create the wrapper in ``aiter/ops/triton/``: .. code-block:: python - def inplace_rmsnorm_(input, weight, eps=1e-6): - """ - In-place RMSNorm (modifies input) - Note the trailing underscore! - """ - pass + # aiter/ops/triton/my_triton_op.py + import torch + import triton + from aiter.ops.triton._triton_kernels.my_triton_op import _my_triton_kernel + + def my_triton_op(output: torch.Tensor, input: torch.Tensor): + n_elements = input.numel() + BLOCK_SIZE = triton.next_power_of_2(min(n_elements, 1024)) + grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) + _my_triton_kernel[grid]( + input, output, n_elements, BLOCK_SIZE=BLOCK_SIZE, + ) -Pattern 3: Autograd Support -^^^^^^^^^^^^^^^^^^^^^^^^^^^^ +Adding Tests +------------ -Add backward pass for training: +Create a test file in ``op_tests/``. Follow the existing pattern using +``aiter.test_common``: .. code-block:: python - class MyCustomOpFunction(torch.autograd.Function): - @staticmethod - def forward(ctx, input, weight, bias): - ctx.save_for_backward(input, weight, bias) - return my_custom_op_impl(input, weight, bias) - - @staticmethod - def backward(ctx, grad_output): - input, weight, bias = ctx.saved_tensors - # Compute gradients - grad_input = ... - grad_weight = ... - grad_bias = ... - return grad_input, grad_weight, grad_bias - -Best Practices --------------- - -1. **Start Simple**: Get it working first, optimize later -2. **Test Correctness**: Always compare with PyTorch reference -3. **Profile First**: Use ``rocprof`` to find bottlenecks -4. **Use CK/Triton**: Don't write raw kernels unless necessary -5. **Document Everything**: Add docstrings and comments -6. **Add Type Hints**: Makes the API clearer -7. **Handle Edge Cases**: Check for invalid inputs - -Debugging Tips --------------- + # op_tests/test_my_op.py + import torch + import aiter + from aiter.test_common import checkAllclose, benchmark -Print Kernel Launches -^^^^^^^^^^^^^^^^^^^^^ + @benchmark() + def test_my_op(m, n, dtype): + ret = {} + input = torch.randn(m, n, dtype=dtype, device="cuda") -.. code-block:: bash + # Reference (PyTorch) + ref_output = input.clone() # replace with actual reference - export HIP_VISIBLE_DEVICES=0 - export AMD_LOG_LEVEL=3 # Verbose logging + # AITER op + output = torch.empty_like(input) + aiter.my_op_fwd(output, input) -Check for Memory Errors -^^^^^^^^^^^^^^^^^^^^^^^ + err = checkAllclose(ref_output, output) + ret["M"] = m + ret["N"] = n + ret["err"] = err + return ret -.. code-block:: bash + if __name__ == "__main__": + for dtype in [torch.float16, torch.bfloat16]: + for m in [1, 32, 512]: + test_my_op(m, 4096, dtype) - # Use compute-sanitizer (if available) - rocm-compute-sanitizer python test_my_op.py - -Profile Your Operator -^^^^^^^^^^^^^^^^^^^^^ +Run with: .. code-block:: bash - rocprof --stats python benchmark_my_op.py + python op_tests/test_my_op.py -Example: Complete RMSNorm Implementation ------------------------------------------ +torch.compile Compatibility +---------------------------- -Here's a complete example you can use as a template: +The ``gen_fake`` function passed to ``@compile_ops`` is registered as the +fake-tensor implementation via ``torch_compile_guard`` in +``aiter/jit/utils/torch_guard.py``. This allows ``torch.compile`` to trace +through the op without executing the real kernel. -**Python Interface** (``aiter/ops/rmsnorm.py``): +For HIP ops, the decorator handles this automatically. For Triton ops, if you +need ``torch.compile`` support, register the op manually: .. code-block:: python import torch - from aiter._C import rmsnorm_forward - - def rmsnorm(x: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6) -> torch.Tensor: - """ - Root Mean Square Layer Normalization. - - Args: - x: Input tensor (..., hidden_dim) - weight: Scaling weights (hidden_dim,) - eps: Epsilon for numerical stability - Returns: - Normalized tensor with same shape as input - """ - assert x.is_cuda and weight.is_cuda - assert x.dtype in [torch.float16, torch.bfloat16] - return rmsnorm_forward(x, weight, eps) + @torch.library.custom_op("aiter::my_triton_op", mutates_args=["output"]) + def my_triton_op(output: torch.Tensor, input: torch.Tensor) -> None: + # call the Triton kernel + ... -**See Full Code**: Check ``csrc/`` directory for complete implementations! + @my_triton_op.register_fake + def _(output: torch.Tensor, input: torch.Tensor) -> None: + pass # output is mutated in-place, nothing to return -Next Steps ----------- - -* :doc:`../api/operators` - See existing operator implementations -* :doc:`../benchmarks` - Learn how to benchmark your operator -* :doc:`profiling` - Profile and optimize performance - -Contributing ------------- - -Want to contribute your operator to AITER? +Best Practices +-------------- -1. Follow the coding style -2. Add comprehensive tests -3. Benchmark vs existing solutions -4. Submit a PR with clear description +1. **Match existing patterns.** Study ``aiter/ops/activation.py`` (simple HIP + ops) or ``aiter/ops/triton/quant/quant.py`` (Triton ops) as templates. +2. **Test correctness first.** Compare against a PyTorch reference + implementation with ``checkAllclose``. +3. **Use in-place output tensors.** Most AITER ops take a pre-allocated ``out`` + tensor as the first argument and return ``None``. +4. **Profile with rocm-trace-lite.** Measure kernel duration and memory + bandwidth to verify performance. +5. **Run ruff and pytest before committing.** Lint and test locally before + pushing. + +See Also +-------- -See ``CONTRIBUTING.md`` for details! +* :doc:`../api/operators` - Existing operator reference +* :doc:`../api/gemm` - GEMM operator reference From 70c9dd2e831348ce3406e283c467b95d1b81c442 Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sun, 12 Apr 2026 19:27:01 +0000 Subject: [PATCH 5/8] tests: add documentation regression test suite MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 5 test categories to prevent doc rot: 1. API signature consistency — every autofunction/autoclass in RST must be importable 2. Code example syntax — Python code blocks must parse without SyntaxError 3. Sphinx build — build with -W (warnings as errors), catch broken refs 4. Version consistency — conf.py must use auto-detection, no hardcoded version 5. RST structure — no orphan API pages, no CUDA references in ROCm docs Also adds test_docs.py to docs.yml CI trigger paths. --- .github/workflows/docs.yml | 6 + tests/test_docs.py | 494 +++++++++++++++++++++++++++++++++++++ 2 files changed, 500 insertions(+) create mode 100644 tests/test_docs.py diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index f50938a3c4..829e83bb9d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -15,6 +15,7 @@ on: - main paths: - 'docs/**' + - 'tests/test_docs.py' - '.github/workflows/docs.yml' workflow_dispatch: @@ -51,6 +52,11 @@ jobs: cd docs sphinx-build -W --keep-going -b html . _build/html + - name: Run documentation tests + run: | + pip install pytest + pytest tests/test_docs.py -v --tb=short -x 2>&1 | tail -60 + - name: Check for broken links run: | cd docs diff --git a/tests/test_docs.py b/tests/test_docs.py new file mode 100644 index 0000000000..6b2cbdb9e7 --- /dev/null +++ b/tests/test_docs.py @@ -0,0 +1,494 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. + +"""Documentation regression tests. + +These tests guard against documentation rot by verifying: +1. Every function/class documented in RST files actually exists in the codebase +2. Python code examples in docs are syntactically valid +3. Sphinx can build without warnings +4. Documentation version matches the package version + +Run with: pytest tests/test_docs.py -v +No GPU required — these tests run on CPU. +""" + +import ast +import importlib +import os +import re +import subprocess +import sys +from pathlib import Path + +import pytest + +DOCS_DIR = Path(__file__).parent.parent / "docs" +API_DIR = DOCS_DIR / "api" +TUTORIAL_DIR = DOCS_DIR / "tutorials" + + +# --------------------------------------------------------------------------- +# Test 1: API Signature Consistency +# Every function/class referenced in docs/api/*.rst must be importable. +# --------------------------------------------------------------------------- + +def _extract_autofunction_refs(rst_path: Path) -> list[str]: + """Extract all '.. autofunction:: X' and '.. autoclass:: X' from an RST file.""" + refs = [] + pattern = re.compile(r"\.\.\s+auto(?:function|class)::\s+(\S+)") + for line in rst_path.read_text().splitlines(): + m = pattern.search(line) + if m: + refs.append(m.group(1)) + return refs + + +def _extract_module_refs(rst_path: Path) -> list[str]: + """Extract module paths referenced as ``aiter.ops.xxx`` or ``aiter.xxx``.""" + refs = [] + pattern = re.compile(r"``(aiter(?:\.[a-zA-Z_][a-zA-Z0-9_]*)*)``") + for line in rst_path.read_text().splitlines(): + for m in pattern.finditer(line): + ref = m.group(1) + # Only include module-level references (at least aiter.X) + if ref.count(".") >= 1: + refs.append(ref) + return list(set(refs)) + + +def _extract_function_names_from_rst(rst_path: Path) -> list[str]: + """Extract function names documented with ``func_name(`` pattern in RST files.""" + refs = [] + # Match patterns like: - ``gemm_a8w8_ck(...)`` or ``func_name(params)`` + pattern = re.compile(r"``([a-zA-Z_][a-zA-Z0-9_]*)\(") + text = rst_path.read_text() + for m in pattern.finditer(text): + refs.append(m.group(1)) + return list(set(refs)) + + +def _collect_all_api_refs() -> list[tuple[str, str]]: + """Collect all auto-doc references from API RST files. + + Returns list of (dotted_path, source_file) tuples. + """ + results = [] + if not API_DIR.exists(): + return results + for rst_file in API_DIR.glob("*.rst"): + for ref in _extract_autofunction_refs(rst_file): + results.append((ref, str(rst_file.name))) + return results + + +def _collect_all_module_refs() -> list[tuple[str, str]]: + """Collect all module references from API RST files.""" + results = [] + if not API_DIR.exists(): + return results + for rst_file in API_DIR.glob("*.rst"): + for ref in _extract_module_refs(rst_file): + results.append((ref, str(rst_file.name))) + return results + + +_api_refs = _collect_all_api_refs() +_module_refs = _collect_all_module_refs() + + +@pytest.mark.skipif( + sys.platform == "win32", + reason="AITER ops require ROCm, skip on Windows", +) +class TestAPISignatureConsistency: + """Verify that documented functions/classes actually exist.""" + + @pytest.mark.parametrize( + "dotted_path,source_file", + _api_refs if _api_refs else [pytest.param("skip", "skip", marks=pytest.mark.skip)], + ids=[f"{r[1]}::{r[0]}" for r in _api_refs] if _api_refs else ["no_refs"], + ) + def test_autofunction_importable(self, dotted_path, source_file): + """Each .. autofunction:: / .. autoclass:: target must be importable.""" + parts = dotted_path.rsplit(".", 1) + if len(parts) != 2: + pytest.skip(f"Cannot parse dotted path: {dotted_path}") + + module_path, attr_name = parts + try: + mod = importlib.import_module(module_path) + except (ImportError, ModuleNotFoundError) as e: + pytest.fail( + f"Module '{module_path}' not importable " + f"(documented in {source_file}): {e}" + ) + + if not hasattr(mod, attr_name): + available = [a for a in dir(mod) if not a.startswith("_")] + pytest.fail( + f"'{attr_name}' not found in module '{module_path}' " + f"(documented in {source_file}). " + f"Available: {', '.join(available[:20])}" + ) + + @pytest.mark.parametrize( + "module_path,source_file", + _module_refs if _module_refs else [pytest.param("skip", "skip", marks=pytest.mark.skip)], + ids=[f"{r[1]}::{r[0]}" for r in _module_refs] if _module_refs else ["no_refs"], + ) + def test_module_ref_importable(self, module_path, source_file): + """Each ``aiter.ops.xxx`` module reference must be importable.""" + try: + importlib.import_module(module_path) + except (ImportError, ModuleNotFoundError) as e: + pytest.fail( + f"Module '{module_path}' not importable " + f"(referenced in {source_file}): {e}" + ) + + +# --------------------------------------------------------------------------- +# Test 2: Code Example Smoke Tests +# Python code blocks in RST must at least parse (syntax check). +# --------------------------------------------------------------------------- + +def _extract_python_code_blocks(rst_path: Path) -> list[tuple[int, str]]: + """Extract Python code blocks from RST file. + + Returns list of (line_number, code_string) tuples. + """ + blocks = [] + lines = rst_path.read_text().splitlines() + i = 0 + while i < len(lines): + line = lines[i] + if re.match(r"\s*\.\.\s+code-block::\s+python", line): + # Find the indentation of the code block + i += 1 + # Skip blank lines + while i < len(lines) and lines[i].strip() == "": + i += 1 + if i >= len(lines): + break + # Determine indent level + indent_match = re.match(r"^(\s+)", lines[i]) + if not indent_match: + continue + indent = len(indent_match.group(1)) + block_start = i + code_lines = [] + while i < len(lines): + if lines[i].strip() == "": + code_lines.append("") + elif len(lines[i]) > 0 and len(lines[i]) - len(lines[i].lstrip()) >= indent: + code_lines.append(lines[i][indent:]) + else: + break + i += 1 + # Strip trailing blank lines + while code_lines and code_lines[-1].strip() == "": + code_lines.pop() + if code_lines: + blocks.append((block_start + 1, "\n".join(code_lines))) + else: + i += 1 + return blocks + + +def _collect_code_blocks() -> list[tuple[str, int, str]]: + """Collect all Python code blocks from all RST files. + + Returns list of (filename, line_number, code) tuples. + """ + results = [] + if not DOCS_DIR.exists(): + return results + for rst_file in DOCS_DIR.rglob("*.rst"): + for line_no, code in _extract_python_code_blocks(rst_file): + rel_path = rst_file.relative_to(DOCS_DIR) + results.append((str(rel_path), line_no, code)) + return results + + +_code_blocks = _collect_code_blocks() + + +class TestCodeExamples: + """Verify that Python code examples in documentation are syntactically valid.""" + + @pytest.mark.parametrize( + "rst_file,line_no,code", + _code_blocks if _code_blocks else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], + ids=[f"{cb[0]}:L{cb[1]}" for cb in _code_blocks] if _code_blocks else ["no_blocks"], + ) + def test_code_block_parses(self, rst_file, line_no, code): + """Each Python code block must be syntactically valid.""" + # Skip blocks that are clearly fragments or pseudocode + if code.strip().startswith("#") and "\n" not in code.strip(): + pytest.skip("Single-line comment, not real code") + if "..." in code and code.count("...") > code.count("\n") // 2: + pytest.skip("Pseudocode with ellipsis placeholders") + + try: + ast.parse(code) + except SyntaxError as e: + pytest.fail( + f"Syntax error in {rst_file} at doc line {line_no}: {e}\n" + f"Code:\n{code[:300]}" + ) + + @pytest.mark.parametrize( + "rst_file,line_no,code", + _code_blocks if _code_blocks else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], + ids=[f"imports:{cb[0]}:L{cb[1]}" for cb in _code_blocks] if _code_blocks else ["no_blocks"], + ) + def test_code_block_imports_valid(self, rst_file, line_no, code): + """Import statements in code blocks must reference real modules.""" + try: + tree = ast.parse(code) + except SyntaxError: + pytest.skip("Cannot parse, covered by test_code_block_parses") + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + mod_name = alias.name + # Only check stdlib and aiter imports, skip third-party + if mod_name.startswith("aiter"): + try: + importlib.import_module(mod_name) + except (ImportError, ModuleNotFoundError): + pytest.fail( + f"Import '{mod_name}' in {rst_file} at line {line_no} " + f"is not importable" + ) + elif isinstance(node, ast.ImportFrom): + if node.module and node.module.startswith("aiter"): + try: + importlib.import_module(node.module) + except (ImportError, ModuleNotFoundError): + pytest.fail( + f"Import from '{node.module}' in {rst_file} at line {line_no} " + f"is not importable" + ) + + +# --------------------------------------------------------------------------- +# Test 3: Sphinx Build +# The documentation must build without warnings. +# --------------------------------------------------------------------------- + +class TestSphinxBuild: + """Verify that Sphinx can build the documentation.""" + + @pytest.mark.skipif( + not DOCS_DIR.exists(), + reason="docs/ directory not found", + ) + def test_sphinx_build_no_warnings(self): + """Sphinx build with -W must succeed (warnings are errors).""" + try: + import sphinx # noqa: F401 + except ImportError: + pytest.skip("sphinx not installed") + + build_dir = DOCS_DIR / "_build" / "test" + result = subprocess.run( + [ + sys.executable, "-m", "sphinx", + "-W", "--keep-going", + "-b", "html", + "-q", # quiet + str(DOCS_DIR), + str(build_dir), + ], + capture_output=True, + text=True, + timeout=300, + ) + + # Clean up test build dir + import shutil + if build_dir.exists(): + shutil.rmtree(build_dir, ignore_errors=True) + + if result.returncode != 0: + # Extract just the warning/error lines + errors = [ + line for line in result.stderr.splitlines() + if "WARNING" in line or "ERROR" in line + ] + error_summary = "\n".join(errors[:20]) + pytest.fail( + f"Sphinx build failed with {len(errors)} warnings/errors:\n" + f"{error_summary}" + ) + + @pytest.mark.skipif( + not DOCS_DIR.exists(), + reason="docs/ directory not found", + ) + def test_all_toctree_files_exist(self): + """Every file referenced in index.rst toctree must exist.""" + index = DOCS_DIR / "index.rst" + if not index.exists(): + pytest.skip("index.rst not found") + + content = index.read_text() + in_toctree = False + missing = [] + + for line in content.splitlines(): + if ".. toctree::" in line: + in_toctree = True + continue + if in_toctree: + stripped = line.strip() + if stripped.startswith(":"): + # toctree option like :maxdepth: + continue + if stripped == "": + if in_toctree: + continue # blank line within toctree is OK + if stripped and not stripped.startswith(":") and not stripped.startswith(".."): + # This is a document reference + doc_path = DOCS_DIR / (stripped + ".rst") + if not doc_path.exists(): + missing.append(stripped) + # Detect end of toctree (non-indented non-blank line) + if line and not line.startswith(" ") and not line.startswith("\t") and stripped: + in_toctree = False + + if missing: + pytest.fail( + f"Files referenced in toctree but missing: {', '.join(missing)}" + ) + + +# --------------------------------------------------------------------------- +# Test 4: Version Consistency +# Documentation version must match the installed package version. +# --------------------------------------------------------------------------- + +class TestVersionConsistency: + """Verify documentation version stays in sync with the package.""" + + def test_conf_py_no_hardcoded_version(self): + """conf.py must not have a hardcoded release version string.""" + conf_py = DOCS_DIR / "conf.py" + if not conf_py.exists(): + pytest.skip("conf.py not found") + + content = conf_py.read_text() + # Look for hardcoded version like: release = "0.1.0" + # But allow: release = get_version(...) or release = "dev" + hardcoded = re.findall( + r'^release\s*=\s*["\'](\d+\.\d+[^"\']*)["\']', + content, + re.MULTILINE, + ) + if hardcoded: + pytest.fail( + f"conf.py has hardcoded version: release = \"{hardcoded[0]}\". " + f"Use auto-detection from setuptools_scm or importlib.metadata." + ) + + def test_conf_py_has_version_detection(self): + """conf.py must contain version auto-detection logic.""" + conf_py = DOCS_DIR / "conf.py" + if not conf_py.exists(): + pytest.skip("conf.py not found") + + content = conf_py.read_text() + has_auto_version = ( + "get_version" in content + or "setuptools_scm" in content + or "importlib.metadata" in content + ) + if not has_auto_version: + pytest.fail( + "conf.py lacks version auto-detection. " + "Must use setuptools_scm or importlib.metadata." + ) + + def test_package_version_accessible(self): + """aiter._version module must exist and define __version__.""" + version_file = Path(__file__).parent.parent / "aiter" / "_version.py" + if not version_file.exists(): + pytest.skip( + "_version.py not found (generated by setuptools_scm at build time)" + ) + content = version_file.read_text() + assert "__version__" in content, ( + "_version.py exists but does not define __version__" + ) + + +# --------------------------------------------------------------------------- +# Test 5: RST Structure Validation +# Catch common RST issues that Sphinx might not flag clearly. +# --------------------------------------------------------------------------- + +class TestRSTStructure: + """Validate RST file structure and cross-references.""" + + @pytest.mark.skipif( + not DOCS_DIR.exists(), + reason="docs/ directory not found", + ) + def test_no_orphan_api_pages(self): + """Every RST in docs/api/ must be referenced in index.rst toctree.""" + index = DOCS_DIR / "index.rst" + if not index.exists(): + pytest.skip("index.rst not found") + + content = index.read_text() + api_files = list(API_DIR.glob("*.rst")) if API_DIR.exists() else [] + orphans = [] + + for api_file in api_files: + rel_name = f"api/{api_file.stem}" + if rel_name not in content: + orphans.append(str(api_file.name)) + + if orphans: + pytest.fail( + f"API docs not in index.rst toctree: {', '.join(orphans)}" + ) + + @pytest.mark.skipif( + not DOCS_DIR.exists(), + reason="docs/ directory not found", + ) + def test_no_cuda_references_in_docs(self): + """Documentation must not contain CUDA-specific references. + + AITER is a ROCm project. References to nvcc, CUDAExtension, + __nv_bfloat16, sm_90a etc. indicate copy-paste errors. + """ + cuda_patterns = [ + r"\bCUDAExtension\b", + r"\bnvcc\b", + r"\b__nv_bfloat16\b", + r"\bsm_\d{2}[a-z]?\b", + r"\bcompute_\d{2}[a-z]?\b", + r"\bcuda_ext\b", + ] + combined = re.compile("|".join(cuda_patterns)) + violations = [] + + for rst_file in DOCS_DIR.rglob("*.rst"): + for i, line in enumerate(rst_file.read_text().splitlines(), 1): + # Skip lines in comments or the audit report + if "AUDIT" in str(rst_file).upper(): + continue + if combined.search(line): + rel = rst_file.relative_to(DOCS_DIR) + violations.append(f"{rel}:{i}: {line.strip()[:80]}") + + if violations: + pytest.fail( + f"CUDA references found in ROCm docs:\n" + + "\n".join(violations[:10]) + ) From c64711ed55425db3c583b7d347c94a46a62fa905 Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sun, 12 Apr 2026 19:45:58 +0000 Subject: [PATCH 6/8] =?UTF-8?q?fix:=20resolve=20CI=20failures=20=E2=80=94?= =?UTF-8?q?=20black=20formatting,=20broken=20toctree=20refs,=20autodoc=20i?= =?UTF-8?q?mports?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Fix black formatting issues in test_docs.py and conf.py - Fix ruff: remove unused os import, fix f-string without placeholders - Remove autofunction/autoclass directives (fail on CPU-only CI, manual docs sufficient) - Add autodoc_mock_imports for triton/ROCm modules in conf.py - Remove deprecated display_version theme option - Fix tutorials/index.rst: remove 9 references to nonexistent pages - Fix quickstart.rst: point to existing pages instead of nonexistent tutorials - Fix installation.rst: remove broken tutorials/triton_comms ref - Fix basic_usage.rst: remove broken tutorial cross-references --- docs/api/attention.rst | 27 --------- docs/conf.py | 19 +++++- docs/installation.rst | 2 +- docs/quickstart.rst | 10 +-- docs/tutorials/basic_usage.rst | 6 +- docs/tutorials/index.rst | 107 ++------------------------------- tests/test_docs.py | 80 ++++++++++++++++-------- 7 files changed, 86 insertions(+), 165 deletions(-) diff --git a/docs/api/attention.rst b/docs/api/attention.rst index 4fdbcef619..fc1b6b251a 100644 --- a/docs/api/attention.rst +++ b/docs/api/attention.rst @@ -21,7 +21,6 @@ High-Level API These are the primary user-facing functions with ``torch.autograd`` support. -.. autofunction:: aiter.ops.mha.flash_attn_func ``flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False, window_size=(-1, -1, 0), bias=None, alibi_slopes=None, deterministic=True, return_lse=False, return_attn_probs=False, how_v3_bf16_cvt=1, cu_seqlens_q=None, cu_seqlens_kv=None, sink_ptr=None)`` @@ -34,7 +33,6 @@ and FP8 inputs. Dispatches to CK or FMHA v3 backend based on dtype and arch. - **v**: ``(batch, seqlen, nheads_k, headdim_v)`` - **Returns**: ``out (batch, seqlen, nheads, headdim_v)``, optionally ``softmax_lse``, ``S_dmask`` -.. autofunction:: aiter.ops.mha.flash_attn_varlen_func ``flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q=0, dropout_p=0.0, softmax_scale=None, logits_soft_cap=0.0, causal=False, window_size=(-1, -1, 0), bias=None, alibi_slopes=None, deterministic=False, return_lse=False, return_attn_probs=False, how_v3_bf16_cvt=1, block_table=None, out=None, ...)`` @@ -51,13 +49,11 @@ tensor and indexed by cumulative sequence lengths. FP8 Convenience Functions ~~~~~~~~~~~~~~~~~~~~~~~~~~ -.. autofunction:: aiter.ops.mha.flash_attn_fp8_pertensor_func ``flash_attn_fp8_pertensor_func(q, k, v, q_descale, k_descale, v_descale, causal=False, window_size=(-1, -1, 0), softmax_scale=None, sink_ptr=None)`` Flash attention for FP8 inputs with per-tensor descaling. Forward-only (no autograd). -.. autofunction:: aiter.ops.mha.flash_attn_varlen_fp8_pertensor_func ``flash_attn_varlen_fp8_pertensor_func(q, k, v, q_descale, k_descale, v_descale, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, ...)`` @@ -66,7 +62,6 @@ Variable-length FP8 flash attention with per-tensor descaling. Forward-only. Batch Prefill ~~~~~~~~~~~~~ -.. autofunction:: aiter.ops.mha.mha_batch_prefill_func ``mha_batch_prefill_func(q, k, v, cu_seqlens_q, kv_indptr, kv_page_indices, max_seqlen_q, max_seqlen_k, dropout_p=0.0, softmax_scale=None, logits_soft_cap=0.0, causal=False, window_size=(-1, -1), alibi_slopes=None, ...)`` @@ -78,45 +73,37 @@ Low-Level CK Kernels These are the direct CK kernel wrappers. Most users should prefer the high-level API above. -.. autofunction:: aiter.ops.mha.mha_fwd ``mha_fwd(q, k, v, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, sink_size, return_softmax_lse, return_dropout_randval, ...)`` CK flash attention forward pass. Returns ``(out, softmax_lse, S_dmask, rng_state)``. -.. autofunction:: aiter.ops.mha.fmha_v3_fwd ``fmha_v3_fwd(q, k, v, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, return_softmax_lse, return_dropout_randval, how_v3_bf16_cvt, ...)`` FMHA v3 forward pass (newer CK backend). Returns ``(out, softmax_lse, S_dmask, rng_state)``. -.. autofunction:: aiter.ops.mha.mha_varlen_fwd ``mha_varlen_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q, dropout_p, softmax_scale, logits_soft_cap, zero_tensors, is_causal, window_size_left, window_size_right, sink_size, return_softmax_lse, return_dropout_randval, ...)`` Variable-length CK MHA forward. Returns ``(out, softmax_lse, S_dmask, rng_state)``. -.. autofunction:: aiter.ops.mha.fmha_v3_varlen_fwd ``fmha_v3_varlen_fwd(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, min_seqlen_q, dropout_p, softmax_scale, logits_soft_cap, zero_tensors, is_causal, window_size_left, window_size_right, return_softmax_lse, return_dropout_randval, how_v3_bf16_cvt, ...)`` FMHA v3 variable-length forward pass. -.. autofunction:: aiter.ops.mha.mha_bwd ``mha_bwd(dout, q, k, v, out, softmax_lse, dropout_p, softmax_scale, is_causal, window_size_left, window_size_right, deterministic, dq=None, dk=None, dv=None, ...)`` CK MHA backward pass (training). Returns ``(dq, dk, dv, dbias)``. -.. autofunction:: aiter.ops.mha.fmha_v3_bwd FMHA v3 backward pass (training). -.. autofunction:: aiter.ops.mha.mha_varlen_bwd Variable-length CK MHA backward pass (training). -.. autofunction:: aiter.ops.mha.fmha_v3_varlen_bwd FMHA v3 variable-length backward pass (training). @@ -130,21 +117,18 @@ Located in ``aiter.ops.attention``. Core Functions ~~~~~~~~~~~~~~ -.. autofunction:: aiter.ops.attention.paged_attention_rocm ``paged_attention_rocm(out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1, q_scale=None)`` Main ROCm paged attention entry point. Custom CK-based implementation with partitioned softmax. Supports FP8 KV cache, ALiBi, and multi-token prediction (MTP). -.. autofunction:: aiter.ops.attention.paged_attention_v1 ``paged_attention_v1(out, workspace_buffer, query, key_cache, value_cache, scale, block_tables, cu_query_lens, context_lens, max_context_len, alibi_slopes, kv_cache_dtype, kv_cache_layout, logits_soft_cap, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1, sliding_window=0)`` V1 paged attention with workspace buffer. Supports multiple KV cache layouts, logits soft capping, and sliding window attention. -.. autofunction:: aiter.ops.attention.paged_attention_ragged ``paged_attention_ragged(out, workspace_buffer, query, key_cache, value_cache, scale, kv_indptr, kv_page_indices, kv_last_page_lens, block_size, max_num_partitions, alibi_slopes, kv_cache_dtype, kv_cache_layout, logits_soft_cap, k_scale, v_scale, fp8_out_scale=None, partition_size=256, mtp=1)`` @@ -156,7 +140,6 @@ ASM Paged Attention Hand-tuned assembly kernels for maximum decode throughput. -.. autofunction:: aiter.ops.attention.pa_fwd_asm ``pa_fwd_asm(Q, K, V, block_tables, context_lens, block_tables_stride0, max_qlen=1, K_QScale=None, V_QScale=None, out_=None, qo_indptr=None, high_precision=1, kernelName=None)`` @@ -164,14 +147,12 @@ ASM paged attention forward. Supports FP8 KV cache via dequantization scales (``K_QScale``, ``V_QScale``). The ``high_precision`` parameter controls FP8 accumulation precision (0=low, 1=medium, 2=highest). -.. autofunction:: aiter.ops.attention.pa_ps_fwd_asm ``pa_ps_fwd_asm(Q, K, V, kv_indptr, kv_page_indices, context_lens, softmax_scale, max_qlen=1, K_QScale=None, V_QScale=None, out_=None, qo_indptr=None, work_indptr=None, work_info=None, splitData=None, splitLse=None, mask=0, high_precision=1, kernelName=None, quant_type=QuantType.per_Token)`` PS-mode (persistent/split) ASM paged attention. Uses ragged page indexing and supports work partitioning for large context lengths. -.. autofunction:: aiter.ops.attention.pa_persistent_fwd ``pa_persistent_fwd(Q, K, V, output, max_qlen, qo_indptr, kv_indptr, kv_indices, context_lens, work_indptr, work_info, reduce_indptr, reduce_final_map, reduce_partial_map, K_QScale=None, V_QScale=None, softmax_scale=None, mask=0, quant_type=QuantType.per_Token)`` @@ -184,19 +165,15 @@ vLLM-Compatible Wrapper Drop-in replacement for vLLM's paged attention layer. Located in ``aiter.paged_attn``. -.. autoclass:: aiter.paged_attn.PagedAttention :members: get_supported_head_sizes, get_kv_cache_shape, split_kv_cache, write_to_paged_cache, forward_decode, swap_blocks, copy_blocks -.. autoclass:: aiter.paged_attn.PagedAttentionMetadata :members: -.. autofunction:: aiter.paged_attn.paged_attention_v1 ``paged_attention_v1(out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ...)`` vLLM-compatible v1 paged attention (delegates to ``aiter.ops``). -.. autofunction:: aiter.paged_attn.paged_attention_v2 ``paged_attention_v2(out, exp_sum, max_logits, tmp_out, query, key_cache, value_cache, num_kv_heads, scale, block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, kv_cache_dtype, k_scale, v_scale, ...)`` @@ -210,7 +187,6 @@ Attention kernels for DeepSeek-style Multi-Latent Attention, where key and value are projected into a shared low-rank latent space. Located in ``aiter.mla``. All MLA functions are inference-only. -.. autofunction:: aiter.mla.mla_decode_fwd ``mla_decode_fwd(q, kv_buffer, o, qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, max_seqlen_q, page_size=1, nhead_kv=1, sm_scale=None, logit_cap=0.0, num_kv_splits=None, ...)`` @@ -222,7 +198,6 @@ with automatic split/reduce for long contexts. - **kv_buffer**: ``(num_pages, page_size, nhead_kv, kv_lora_rank + qk_rope_head_dim)`` - **o**: ``(total_q, nheads, v_head_dim)`` output buffer -.. autofunction:: aiter.mla.mla_prefill_fwd ``mla_prefill_fwd(q, kv_buffer, o, qo_indptr, kv_indptr, kv_indices, kv_last_page_lens, max_seqlen_q, sm_scale=None, logit_cap=0.0, num_kv_splits=None)`` @@ -232,14 +207,12 @@ MLA prefill-phase forward pass. Uses ASM backend for the attention computation. - **kv_buffer**: ``(num_pages, page_size, nhead_kv, kv_lora_rank + qk_rope_head_dim)`` - **o**: ``(num_seqs, num_heads, v_head_dim)`` -.. autofunction:: aiter.mla.mla_prefill_ps_fwd ``mla_prefill_ps_fwd(Q, K, V, output, qo_indptr, kv_indptr, kv_page_indices, work_indptr, work_info_set, max_seqlen_q, is_causal, reduce_indptr=None, reduce_final_map=None, reduce_partial_map=None, softmax_scale=None, q_scale=None, k_scale=None, v_scale=None)`` MLA prefill with persistent/split mode. Handles long prefill sequences via work partitioning and multi-stage reduction. Supports FP8 via per-tensor scales. -.. autofunction:: aiter.mla.mla_prefill_reduce ``mla_prefill_reduce(partial_output, partial_lse, reduce_indptr, reduce_final_map, reduce_partial_map, output, tile_q=256, use_triton=True)`` diff --git a/docs/conf.py b/docs/conf.py index 3c63ad2220..3454d4faa7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -15,10 +15,12 @@ # Auto-detect version from setuptools_scm or git try: from importlib.metadata import version as get_version + release = get_version("amd-aiter") except Exception: try: from setuptools_scm import get_version + release = get_version(root="..", relative_to=__file__) except Exception: release = "dev" @@ -43,7 +45,7 @@ html_theme_options = { "logo_only": False, - "display_version": True, + # "display_version" removed in newer sphinx_rtd_theme, version shown by default "prev_next_buttons_location": "bottom", "style_external_links": False, "vcs_pageview_mode": "", @@ -92,3 +94,18 @@ "undoc-members": True, "exclude-members": "__weakref__", } + +# Mock imports for modules that require ROCm/GPU at import time. +# This allows Sphinx to build on CPU-only CI runners. +autodoc_mock_imports = [ + "triton", + "triton.language", + "aiter.jit", + "aiter.jit.core", + "aiter.utility", + "aiter.utility.dtypes", + "aiter.ops.enum", + "hip", + "hipblas", + "rocm", +] diff --git a/docs/installation.rst b/docs/installation.rst index 115e20f734..7aa97bea21 100644 --- a/docs/installation.rst +++ b/docs/installation.rst @@ -134,7 +134,7 @@ For Triton-based communication primitives: pip install -r requirements-triton-comms.txt -See :doc:`tutorials/triton_comms` for more details. +See the `Triton communication documentation `_ for more details. Troubleshooting --------------- diff --git a/docs/quickstart.rst b/docs/quickstart.rst index 10601e922e..54f643983a 100644 --- a/docs/quickstart.rst +++ b/docs/quickstart.rst @@ -151,11 +151,11 @@ Performance Tips Next Steps ---------- -* :doc:`tutorials/attention` - Deep dive into attention mechanisms -* :doc:`tutorials/moe` - Learn about MoE optimizations -* :doc:`tutorials/variable_length` - Handle variable-length sequences -* :doc:`api/attention` - Full API reference -* :doc:`benchmarks` - Performance comparisons +* :doc:`tutorials/add_new_op` - How to add a new operator +* :doc:`api/attention` - Attention API reference +* :doc:`api/gemm` - GEMM API reference +* :doc:`api/moe` - MoE API reference +* :doc:`performance/benchmarks` - Performance benchmarks Common Issues ------------- diff --git a/docs/tutorials/basic_usage.rst b/docs/tutorials/basic_usage.rst index ae00638501..0e201d6832 100644 --- a/docs/tutorials/basic_usage.rst +++ b/docs/tutorials/basic_usage.rst @@ -316,9 +316,9 @@ Monitor GPU memory usage: Next Steps ---------- -* :doc:`attention_tutorial` - Deep dive into attention mechanisms -* :doc:`variable_length` - Handle variable-length sequences -* :doc:`moe_tutorial` - Mixture of Experts optimization +* :doc:`add_new_op` - How to add a new operator +* :doc:`../api/attention` - Attention API reference +* :doc:`../api/moe` - MoE API reference Common Gotchas -------------- diff --git a/docs/tutorials/index.rst b/docs/tutorials/index.rst index 4aaf394e6e..f96ea742c0 100644 --- a/docs/tutorials/index.rst +++ b/docs/tutorials/index.rst @@ -5,122 +5,23 @@ Learn AITER through hands-on examples. .. toctree:: :maxdepth: 2 - :caption: Getting Started + :caption: Contents basic_usage - attention_tutorial - variable_length - -.. toctree:: - :maxdepth: 2 - :caption: Advanced Topics - add_new_op - moe_tutorial - custom_kernels - quantization - triton_comms - -.. toctree:: - :maxdepth: 2 - :caption: Integration - - vllm_integration - pytorch_lightning - deepspeed Tutorial Overview ----------------- -Basic Tutorials -^^^^^^^^^^^^^^^ - -* :doc:`basic_usage` - Your first AITER program -* :doc:`attention_tutorial` - Understanding attention kernels -* :doc:`variable_length` - Handling variable-length sequences - -Advanced Topics -^^^^^^^^^^^^^^^ - +* :doc:`basic_usage` - Your first AITER program: attention, GEMM, normalization * :doc:`add_new_op` - **How to add a new operator** (step-by-step guide) -* :doc:`moe_tutorial` - Mixture of Experts optimization -* :doc:`custom_kernels` - Writing custom ROCm kernels -* :doc:`quantization` - INT8 quantization for inference -* :doc:`triton_comms` - Triton-based communication primitives - -Integration Guides -^^^^^^^^^^^^^^^^^^ - -* :doc:`vllm_integration` - Using AITER with vLLM -* :doc:`pytorch_lightning` - PyTorch Lightning integration -* :doc:`deepspeed` - DeepSpeed integration Prerequisites ------------- All tutorials assume: -* Python 3.8+ +* Python 3.10+ * PyTorch 2.0+ with ROCm support * AITER installed (see :doc:`../installation`) -* AMD GPU (gfx90a, gfx942, or gfx950) - -Example Data ------------- - -Some tutorials use sample data. Download with: - -.. code-block:: bash - - # Coming soon: test data downloader - bash scripts/download_test_data.sh - -Jupyter Notebooks ------------------ - -Interactive notebooks are available in the ``examples/`` directory: - -.. code-block:: bash - - # Install Jupyter - pip install jupyter - - # Launch notebooks - cd examples - jupyter notebook - -Running Examples ----------------- - -All tutorial code can be run directly: - -.. code-block:: bash - - # Clone repository - git clone https://github.com/ROCm/aiter.git - cd aiter - - # Run tutorial script - python examples/basic_usage.py - -Community Examples ------------------- - -Check out community-contributed examples: - -* **Llama 2 inference** - Optimized inference with AITER -* **Mixtral 8x7B** - MoE model acceleration -* **GPT-style models** - Training and inference - -Contributing Tutorials ----------------------- - -We welcome tutorial contributions! See :doc:`../contributing` for guidelines. - -Tips for following tutorials: - -1. **Start with basics** - Don't skip the fundamentals -2. **Run the code** - Type it out, don't just copy-paste -3. **Experiment** - Modify parameters and observe changes -4. **Profile** - Use ROCm profiler to understand performance -5. **Ask questions** - Open issues or discussions on GitHub +* AMD GPU (gfx942 or gfx950) diff --git a/tests/test_docs.py b/tests/test_docs.py index 6b2cbdb9e7..51e05c5d04 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -15,7 +15,6 @@ import ast import importlib -import os import re import subprocess import sys @@ -33,6 +32,7 @@ # Every function/class referenced in docs/api/*.rst must be importable. # --------------------------------------------------------------------------- + def _extract_autofunction_refs(rst_path: Path) -> list[str]: """Extract all '.. autofunction:: X' and '.. autoclass:: X' from an RST file.""" refs = [] @@ -106,7 +106,9 @@ class TestAPISignatureConsistency: @pytest.mark.parametrize( "dotted_path,source_file", - _api_refs if _api_refs else [pytest.param("skip", "skip", marks=pytest.mark.skip)], + _api_refs + if _api_refs + else [pytest.param("skip", "skip", marks=pytest.mark.skip)], ids=[f"{r[1]}::{r[0]}" for r in _api_refs] if _api_refs else ["no_refs"], ) def test_autofunction_importable(self, dotted_path, source_file): @@ -134,7 +136,9 @@ def test_autofunction_importable(self, dotted_path, source_file): @pytest.mark.parametrize( "module_path,source_file", - _module_refs if _module_refs else [pytest.param("skip", "skip", marks=pytest.mark.skip)], + _module_refs + if _module_refs + else [pytest.param("skip", "skip", marks=pytest.mark.skip)], ids=[f"{r[1]}::{r[0]}" for r in _module_refs] if _module_refs else ["no_refs"], ) def test_module_ref_importable(self, module_path, source_file): @@ -153,6 +157,7 @@ def test_module_ref_importable(self, module_path, source_file): # Python code blocks in RST must at least parse (syntax check). # --------------------------------------------------------------------------- + def _extract_python_code_blocks(rst_path: Path) -> list[tuple[int, str]]: """Extract Python code blocks from RST file. @@ -181,7 +186,10 @@ def _extract_python_code_blocks(rst_path: Path) -> list[tuple[int, str]]: while i < len(lines): if lines[i].strip() == "": code_lines.append("") - elif len(lines[i]) > 0 and len(lines[i]) - len(lines[i].lstrip()) >= indent: + elif ( + len(lines[i]) > 0 + and len(lines[i]) - len(lines[i].lstrip()) >= indent + ): code_lines.append(lines[i][indent:]) else: break @@ -219,8 +227,12 @@ class TestCodeExamples: @pytest.mark.parametrize( "rst_file,line_no,code", - _code_blocks if _code_blocks else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], - ids=[f"{cb[0]}:L{cb[1]}" for cb in _code_blocks] if _code_blocks else ["no_blocks"], + _code_blocks + if _code_blocks + else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], + ids=[f"{cb[0]}:L{cb[1]}" for cb in _code_blocks] + if _code_blocks + else ["no_blocks"], ) def test_code_block_parses(self, rst_file, line_no, code): """Each Python code block must be syntactically valid.""" @@ -240,8 +252,12 @@ def test_code_block_parses(self, rst_file, line_no, code): @pytest.mark.parametrize( "rst_file,line_no,code", - _code_blocks if _code_blocks else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], - ids=[f"imports:{cb[0]}:L{cb[1]}" for cb in _code_blocks] if _code_blocks else ["no_blocks"], + _code_blocks + if _code_blocks + else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], + ids=[f"imports:{cb[0]}:L{cb[1]}" for cb in _code_blocks] + if _code_blocks + else ["no_blocks"], ) def test_code_block_imports_valid(self, rst_file, line_no, code): """Import statements in code blocks must reference real modules.""" @@ -279,6 +295,7 @@ def test_code_block_imports_valid(self, rst_file, line_no, code): # The documentation must build without warnings. # --------------------------------------------------------------------------- + class TestSphinxBuild: """Verify that Sphinx can build the documentation.""" @@ -296,9 +313,13 @@ def test_sphinx_build_no_warnings(self): build_dir = DOCS_DIR / "_build" / "test" result = subprocess.run( [ - sys.executable, "-m", "sphinx", - "-W", "--keep-going", - "-b", "html", + sys.executable, + "-m", + "sphinx", + "-W", + "--keep-going", + "-b", + "html", "-q", # quiet str(DOCS_DIR), str(build_dir), @@ -310,13 +331,15 @@ def test_sphinx_build_no_warnings(self): # Clean up test build dir import shutil + if build_dir.exists(): shutil.rmtree(build_dir, ignore_errors=True) if result.returncode != 0: # Extract just the warning/error lines errors = [ - line for line in result.stderr.splitlines() + line + for line in result.stderr.splitlines() if "WARNING" in line or "ERROR" in line ] error_summary = "\n".join(errors[:20]) @@ -351,13 +374,22 @@ def test_all_toctree_files_exist(self): if stripped == "": if in_toctree: continue # blank line within toctree is OK - if stripped and not stripped.startswith(":") and not stripped.startswith(".."): + if ( + stripped + and not stripped.startswith(":") + and not stripped.startswith("..") + ): # This is a document reference doc_path = DOCS_DIR / (stripped + ".rst") if not doc_path.exists(): missing.append(stripped) # Detect end of toctree (non-indented non-blank line) - if line and not line.startswith(" ") and not line.startswith("\t") and stripped: + if ( + line + and not line.startswith(" ") + and not line.startswith("\t") + and stripped + ): in_toctree = False if missing: @@ -371,6 +403,7 @@ def test_all_toctree_files_exist(self): # Documentation version must match the installed package version. # --------------------------------------------------------------------------- + class TestVersionConsistency: """Verify documentation version stays in sync with the package.""" @@ -390,7 +423,7 @@ def test_conf_py_no_hardcoded_version(self): ) if hardcoded: pytest.fail( - f"conf.py has hardcoded version: release = \"{hardcoded[0]}\". " + f'conf.py has hardcoded version: release = "{hardcoded[0]}". ' f"Use auto-detection from setuptools_scm or importlib.metadata." ) @@ -420,9 +453,9 @@ def test_package_version_accessible(self): "_version.py not found (generated by setuptools_scm at build time)" ) content = version_file.read_text() - assert "__version__" in content, ( - "_version.py exists but does not define __version__" - ) + assert ( + "__version__" in content + ), "_version.py exists but does not define __version__" # --------------------------------------------------------------------------- @@ -430,6 +463,7 @@ def test_package_version_accessible(self): # Catch common RST issues that Sphinx might not flag clearly. # --------------------------------------------------------------------------- + class TestRSTStructure: """Validate RST file structure and cross-references.""" @@ -453,9 +487,7 @@ def test_no_orphan_api_pages(self): orphans.append(str(api_file.name)) if orphans: - pytest.fail( - f"API docs not in index.rst toctree: {', '.join(orphans)}" - ) + pytest.fail(f"API docs not in index.rst toctree: {', '.join(orphans)}") @pytest.mark.skipif( not DOCS_DIR.exists(), @@ -488,7 +520,5 @@ def test_no_cuda_references_in_docs(self): violations.append(f"{rel}:{i}: {line.strip()[:80]}") if violations: - pytest.fail( - f"CUDA references found in ROCm docs:\n" - + "\n".join(violations[:10]) - ) + msg = "CUDA references found in ROCm docs:\n" + "\n".join(violations[:10]) + pytest.fail(msg) From 0c8044457677f250f3013b5cf6fbd280da00b78a Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sun, 12 Apr 2026 19:50:48 +0000 Subject: [PATCH 7/8] fix: black formatting for ternary expressions in parametrize decorators --- tests/test_docs.py | 50 +++++++++++++++++++++++++++++----------------- 1 file changed, 32 insertions(+), 18 deletions(-) diff --git a/tests/test_docs.py b/tests/test_docs.py index 51e05c5d04..35315fa600 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -106,9 +106,11 @@ class TestAPISignatureConsistency: @pytest.mark.parametrize( "dotted_path,source_file", - _api_refs - if _api_refs - else [pytest.param("skip", "skip", marks=pytest.mark.skip)], + ( + _api_refs + if _api_refs + else [pytest.param("skip", "skip", marks=pytest.mark.skip)] + ), ids=[f"{r[1]}::{r[0]}" for r in _api_refs] if _api_refs else ["no_refs"], ) def test_autofunction_importable(self, dotted_path, source_file): @@ -136,10 +138,14 @@ def test_autofunction_importable(self, dotted_path, source_file): @pytest.mark.parametrize( "module_path,source_file", - _module_refs + ( + _module_refs + if _module_refs + else [pytest.param("skip", "skip", marks=pytest.mark.skip)] + ), + ids=[f"{r[1]}::{r[0]}" for r in _module_refs] if _module_refs - else [pytest.param("skip", "skip", marks=pytest.mark.skip)], - ids=[f"{r[1]}::{r[0]}" for r in _module_refs] if _module_refs else ["no_refs"], + else ["no_refs"], ) def test_module_ref_importable(self, module_path, source_file): """Each ``aiter.ops.xxx`` module reference must be importable.""" @@ -227,12 +233,16 @@ class TestCodeExamples: @pytest.mark.parametrize( "rst_file,line_no,code", - _code_blocks - if _code_blocks - else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], - ids=[f"{cb[0]}:L{cb[1]}" for cb in _code_blocks] - if _code_blocks - else ["no_blocks"], + ( + _code_blocks + if _code_blocks + else [pytest.param("skip", 0, "", marks=pytest.mark.skip)] + ), + ids=( + [f"{cb[0]}:L{cb[1]}" for cb in _code_blocks] + if _code_blocks + else ["no_blocks"] + ), ) def test_code_block_parses(self, rst_file, line_no, code): """Each Python code block must be syntactically valid.""" @@ -252,12 +262,16 @@ def test_code_block_parses(self, rst_file, line_no, code): @pytest.mark.parametrize( "rst_file,line_no,code", - _code_blocks - if _code_blocks - else [pytest.param("skip", 0, "", marks=pytest.mark.skip)], - ids=[f"imports:{cb[0]}:L{cb[1]}" for cb in _code_blocks] - if _code_blocks - else ["no_blocks"], + ( + _code_blocks + if _code_blocks + else [pytest.param("skip", 0, "", marks=pytest.mark.skip)] + ), + ids=( + [f"imports:{cb[0]}:L{cb[1]}" for cb in _code_blocks] + if _code_blocks + else ["no_blocks"] + ), ) def test_code_block_imports_valid(self, rst_file, line_no, code): """Import statements in code blocks must reference real modules.""" From a000d2e0116a5a47a71db56b129953fed4037be1 Mon Sep 17 00:00:00 2001 From: sunway513 Date: Sun, 12 Apr 2026 19:54:19 +0000 Subject: [PATCH 8/8] =?UTF-8?q?fix:=20last=20black=20formatting=20nit=20?= =?UTF-8?q?=E2=80=94=20ids=20line=20for=20module=5Frefs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tests/test_docs.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/test_docs.py b/tests/test_docs.py index 35315fa600..3da877ec42 100644 --- a/tests/test_docs.py +++ b/tests/test_docs.py @@ -143,9 +143,7 @@ def test_autofunction_importable(self, dotted_path, source_file): if _module_refs else [pytest.param("skip", "skip", marks=pytest.mark.skip)] ), - ids=[f"{r[1]}::{r[0]}" for r in _module_refs] - if _module_refs - else ["no_refs"], + ids=[f"{r[1]}::{r[0]}" for r in _module_refs] if _module_refs else ["no_refs"], ) def test_module_ref_importable(self, module_path, source_file): """Each ``aiter.ops.xxx`` module reference must be importable."""