Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Python virtual environments
.venv/
venv/
env/

bazel
bazel-bazel-test
bazel-bin
Expand Down
163 changes: 163 additions & 0 deletions py/torch_tensorrt/annotation/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# torch_tensorrt.annotation — custom_plugin descriptors

`torch_tensorrt.annotation` (aliased as `tta`) provides descriptor types and
factory functions for defining custom TensorRT AOT QDP plugins backed by
Triton, CuTile, or CuTeDSL kernels.

```python
import torch_tensorrt.annotation as tta
```

This module is **descriptor-only**: it builds spec objects that describe how a
plugin should be compiled and registered. It does not patch `torch.export`,
add compilation hooks, or modify any torch-trt core path.

---

## Table of contents

1. [Quick start](#1-quick-start)
2. [Factory functions](#2-factory-functions)
3. [Spec types](#3-spec-types)
4. [QDP plugin flow](#4-qdp-plugin-flow)
5. [Running tests](#5-running-tests)

---

## 1. Quick start

```python
import torch_tensorrt.annotation as tta

# Triton AOT plugin
spec = tta.custom_plugin(tta.triton(my_launch_fn, configs=[{"BLOCK_SIZE": 128}]))

# CuTile plugin (Blackwell sm_100+)
spec = tta.custom_plugin(tta.cutile(my_cutile_kernel, arch=120))

# CuTeDSL plugin
spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_kernel))
```

---

## 2. Factory functions

### `tta.triton(launch_fn, configs=None)`

Wraps a Triton kernel launch function.

- **`launch_fn`** — callable that launches the Triton kernel;
signature `(input0, ..., output, stream, **config)`.
- **`configs`** — list of `dict` tactic configs; each becomes a separate
QDP tactic. Pass `None` for a single default tactic.

```python
@triton.jit
def _add_relu_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
mask = i < n
tl.store(out_ptr + i, tl.maximum(0, tl.load(x_ptr+i, mask=mask) + tl.load(y_ptr+i, mask=mask)), mask=mask)

def launch_add_relu(x, y, out, stream, BLOCK=256):
_add_relu_kernel[(triton.cdiv(x.numel(), BLOCK),)](x, y, out, x.numel(), BLOCK=BLOCK)

spec = tta.custom_plugin(tta.triton(launch_add_relu, configs=[{"BLOCK": 128}, {"BLOCK": 256}]))
```

### `tta.cutile(launch_fn, arch=None, configs=None)`

Wraps a CuTile (cuda-tile) kernel. Requires Blackwell (sm_100+) and the
`cuda-tile` package.

- **`arch`** — SM architecture integer (e.g. `120` for sm_120).
- **`configs`** — list of tactic dicts.

```python
spec = tta.custom_plugin(tta.cutile(my_cutile_fn, arch=120, configs=[{"TILE_M": 64}]))
```

### `tta.cutedsl(launch_fn, configs=None)`

Wraps a CuTeDSL kernel (`nvidia-cutlass-dsl`).

```python
spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_fn))
```

### `tta.custom_plugin(impl)`

Builds a `CustomPluginSpec` from a kernel spec (`TritonSpec`, `CuTileSpec`, or
`CuTeDSLSpec`). Computes a deterministic QDP `op_name` from the kernel
function identity and config hash.

```python
spec = tta.custom_plugin(tta.triton(launch_fn, configs=[{"BLOCK": 256}]))
# spec.op_name — deterministic string like "tta::launch_fn_a3f2c1"
```

---

## 3. Spec types

All spec types are plain frozen dataclasses — they carry no mutable state and
are safe to hash, compare, and cache.

| Type | Returned by | Description |
|------|-------------|-------------|
| `CustomPluginSpec` | `custom_plugin()` | AOT QDP plugin descriptor; holds `impl` (`TritonSpec` \| `CuTileSpec` \| `CuTeDSLSpec`) and computed `op_name` |
| `TritonSpec` | `triton()` | Triton kernel launch function + tactic configs |
| `CuTileSpec` | `cutile()` | CuTile kernel + target arch + tactic configs |
| `CuTeDSLSpec` | `cutedsl()` | CuTeDSL kernel + tactic configs |

---

## 4. QDP plugin flow

`tta.custom_plugin` produces a descriptor. When you call
`register_custom_plugin(spec, inputs)` (from `_custom_plugin._descriptor`) the
module:

1. Derives a deterministic `op_name` from the kernel function + config hash.
2. Registers `@trtp.register("tta::op_name")` — the shape/dtype descriptor
function derived symbolically from the kernel signature.
3. Registers `@trtp.aot_impl("tta::op_name")` — the AOT implementation
function that returns `(kernel_name, ptx_or_cubin, KernelLaunchParams,
SymIntExprs)`.
4. Uses a process-level lock + double-checked locking to prevent duplicate
registration in multi-threaded pytest-xdist workers.

The QDP AOT path means **no Python is needed at TRT engine runtime** — the
compiled kernel is embedded directly.

```
tta.triton(launch_fn, configs)
└─► TritonSpec
└─► tta.custom_plugin(spec)
└─► CustomPluginSpec(op_name, impl)
└─► register_custom_plugin(spec, inputs)
├─► @trtp.register (shape descriptor)
└─► @trtp.aot_impl (PTX/cubin → TRT)
```

---

## 5. Running tests

Unit tests are CPU-only (no GPU required) and live in
`tests/py/annotation/unit/`.

```bash
# From inside the dev Docker container:
python -m pytest tests/py/annotation/unit/ -n 4 --tb=short -v
```

Test files:

| File | What it covers |
|------|---------------|
| `test_specs.py` | `TritonSpec`, `CuTileSpec`, `CuTeDSLSpec` construction and hashing |
| `test_specs_custom_plugin.py` | `CustomPluginSpec` and `custom_plugin()` factory |
| `test_signature_binder.py` | TRT signature derivation and binding |
| `test_layer_metadata.py` | `AnnotationMetadata` encode/decode round-trip |
| `test_plugin_lowering.py` | QDP plugin lowering path |
52 changes: 52 additions & 0 deletions py/torch_tensorrt/annotation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Torch-TensorRT Annotation Layer (TTA) — custom_plugin API.

Provides descriptor types and factory functions for defining custom TensorRT
AOT QDP plugins backed by Triton, CuTile, or CuTeDSL kernels.

Usage::

import torch_tensorrt.annotation as tta

# Triton kernel descriptor
spec = tta.custom_plugin(
tta.triton(my_triton_kernel, configs=[{"BLOCK_SIZE": 128}]),
meta_impl=lambda x: x.new_empty(x.shape),
)

# CuTile kernel descriptor
spec = tta.custom_plugin(
tta.cutile(my_cutile_kernel),
meta_impl=lambda x: x.new_empty(x.shape),
)

# CuTeDSL kernel descriptor
spec = tta.custom_plugin(
tta.cutedsl(my_cutedsl_kernel),
meta_impl=lambda x: x.new_empty(x.shape),
)
"""

from ._specs import (
CuTeDSLSpec,
CuTileSpec,
TritonSpec,
cutedsl,
cutile,
triton,
)

from ._custom_plugin._descriptor import CustomPluginSpec, custom_plugin

__all__ = [
# Descriptor types
"TritonSpec",
"CuTileSpec",
"CuTeDSLSpec",
"CustomPluginSpec",
# Factory functions
"custom_plugin",
"triton",
"cutile",
"cutedsl",
]
79 changes: 79 additions & 0 deletions py/torch_tensorrt/annotation/_custom_plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""Custom plugin sub-package: QDP-backed plugin descriptor and lowering.
This sub-package bridges user-supplied GPU kernels (Triton, cuTILE, CuTe DSL)
to TensorRT's Quickstart Dynamic Plugin (QDP) framework, enabling annotated
boundary ops to be lowered to first-class ``IPluginV3`` layers at TRT engine
compile time.
Role in the compilation pipeline
---------------------------------
1. **Annotation** — the user calls ``tta.custom_plugin(kernel, meta_impl=...)``
(re-exported here as ``custom_plugin``), which returns a
``CustomPluginSpec`` that is stored in the boundary op's
``AnnotationMetadata``.
2. **Registration** — at compile time, ``register_custom_plugin`` calls
``@trtp.register`` / ``@trtp.aot_impl`` on the descriptor, making the
plugin visible to TRT's global plugin registry.
3. **Lowering** — ``lower_custom_plugin_descriptor`` calls ``trtp.op.<ns>.<name>``
to insert an ``IPluginV3`` layer into the ``INetworkDefinition``.
Public surface
--------------
``CustomPluginSpec``
Dataclass returned by ``custom_plugin()``. Carries the op name, kernel
specs, meta-shape implementation, and optional tactic table.
``custom_plugin``
Factory that builds a ``CustomPluginSpec`` from a kernel spec, auto-
computing a deterministic QDP op name from the kernel fingerprint.
``lower_custom_plugin_descriptor``
Converts a ``CustomPluginSpec`` into a TRT ``IPluginV3`` layer and
returns the output ``trt.ITensor`` (or tuple thereof).
``register_custom_plugin``
Registers ``@trtp.register`` / ``@trtp.aot_impl`` handlers for a
descriptor's op name. Idempotent at the process level.
``QDPRuntimeError``
Raised when TRT's QDP framework encounters a runtime error (e.g. shape
mismatch, unsupported dtype) during plugin execution.
``TTAPluginError``
Raised for TTA-level plugin configuration errors (e.g. missing meta_impl,
invalid kernel spec).
``SymbolicTensor`` / ``TensorRole``
Proxy used during AOT kernel compilation to carry symbolic shape
expressions. ``TensorRole`` distinguishes input from output tensors so
that ``analyze_launch_args`` can reconstruct the correct QDP binding
indices.
"""

# ---------------------------------------------------------------------------
# Re-exports from sub-modules
# ---------------------------------------------------------------------------

from ._descriptor import (
CustomPluginSpec,
custom_plugin,
lower_custom_plugin_descriptor,
register_custom_plugin,
)
from ._qdp_utils import QDPRuntimeError, TTAPluginError
from ._symbolic import SymbolicTensor, TensorRole

# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------

__all__ = [
"CustomPluginSpec",
"custom_plugin",
"lower_custom_plugin_descriptor",
"register_custom_plugin",
"QDPRuntimeError",
"TTAPluginError",
"SymbolicTensor",
"TensorRole",
]
16 changes: 16 additions & 0 deletions py/torch_tensorrt/annotation/_custom_plugin/_aot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""AOT kernel backend implementations for the TTA custom plugin system.
Each sub-module implements the ``aot_impl_<backend>`` function that drives a
backend-specific compile pipeline and returns the QDP AOT 4-tuple::
(kernel_name, code_bytes, KernelLaunchParams, SymIntExprs)
Backends
--------
_triton — Triton JIT kernels compiled to PTX via triton.compile()
_cutile — cuTILE programs compiled to CUBIN via cuda.tile compile_tile()
_cutedsl — CuTe DSL @cute.jit kernels compiled via cutlass.cute.compile()
All three backends share ``_qdp_utils`` helpers for sandboxing, launch-arg
analysis, artifact dumping, and the unified ``AOTMetadata`` result type.
"""
Loading
Loading