From 8c968b32ac13d64940918c8c9bc1e85e0cffed07 Mon Sep 17 00:00:00 2001 From: Bowen Fu Date: Fri, 10 Apr 2026 08:41:21 +0000 Subject: [PATCH] =?UTF-8?q?feat(annotation):=20add=20tta.custom=5Fplugin?= =?UTF-8?q?=20=E2=80=94=20Triton/CuTile/CuTeDSL=20QDP=20AOT=20integration?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduces the `torch_tensorrt.annotation` (tta) API for registering custom GPU kernels as TensorRT QDP plugins compiled ahead-of-time. Public API ---------- - `tta.triton(launch_fn, configs=..., input_formats=..., output_formats=...)` - `tta.cutile(launch_fn, configs=..., input_formats=..., output_formats=...)` - `tta.cutedsl(launch_fn, arch=..., configs=..., input_formats=..., output_formats=...)` - `tta.custom_plugin(spec_or_list, meta_impl=...)` → CustomPluginSpec Each spec is a frozen dataclass; `custom_plugin` validates inputs eagerly and derives a stable `op_name` (namespace `tta_custom::`) used for TRT plugin registration. Compilation pipeline -------------------- - Triton: sandboxed grid recording → `triton.compile` → PTX + KernelLaunchParams - CuTile: sandboxed arg recording → `cuda_tile.compile` → PTX - CuTeDSL: `@cute.jit` unwrap → sandboxed grid/block recording → `cute.compile` → PTX (temp dir via `TemporaryDirectory`; PTX string remains accessible after cleanup) Plugin generation ----------------- - `_compute_out_descs_symbolic`: ShapeEnv + lambdify for TRT shape descriptors; uses `output.ndim` (not first-input ndim) for correct multi-output support - `_format_token_for_tactic`: maps `trt.TensorFormat` to AutoTuneCombination tokens - Non-LINEAR format support (e.g. CHW32): TRT inserts reformatting nodes automatically Tests ----- - Integration suite (`tests/py/annotation/integration/`): 41 e2e tests covering Triton/CuTile/CuTeDSL backends, dynamic shapes, multi-output, cross-backend, non-LINEAR formats, BF16, 3D inputs, weights, and large shapes - Unit suite (`tests/py/annotation/unit/`): 10 tests for user-facing API validation contracts and op_name invariants not exercised by e2e - `conftest.py`: GPU routing (Blackwell vs pre-Blackwell), per-worker TRT timing cache, dynamic `-n` worker count for subprocess reruns --- .gitignore | 5 + py/torch_tensorrt/annotation/README.md | 163 ++ py/torch_tensorrt/annotation/__init__.py | 52 + .../annotation/_custom_plugin/__init__.py | 79 + .../_custom_plugin/_aot/__init__.py | 16 + .../_custom_plugin/_aot/_cutedsl.py | 266 +++ .../annotation/_custom_plugin/_aot/_cutile.py | 754 +++++++ .../annotation/_custom_plugin/_aot/_triton.py | 552 +++++ .../annotation/_custom_plugin/_descriptor.py | 951 +++++++++ .../annotation/_custom_plugin/_qdp_utils.py | 999 +++++++++ .../annotation/_custom_plugin/_symbolic.py | 266 +++ py/torch_tensorrt/annotation/_recorders.py | 207 ++ py/torch_tensorrt/annotation/_specs.py | 257 +++ .../dynamo/conversion/plugins/_custom_op.py | 143 +- .../conversion/plugins/_generate_plugin.py | 268 ++- .../plugins/_generate_plugin_converter.py | 4 +- tests/py/annotation/BUILD | 10 + tests/py/annotation/__init__.py | 1 + tests/py/annotation/conftest.py | 210 ++ tests/py/annotation/integration/__init__.py | 1 + .../test_custom_plugin_trt_plugins_e2e.py | 1797 +++++++++++++++++ tests/py/annotation/unit/__init__.py | 0 tests/py/annotation/unit/test_specs.py | 83 + 23 files changed, 7014 insertions(+), 70 deletions(-) create mode 100644 py/torch_tensorrt/annotation/README.md create mode 100644 py/torch_tensorrt/annotation/__init__.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/__init__.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_aot/__init__.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutedsl.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutile.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_aot/_triton.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_qdp_utils.py create mode 100644 py/torch_tensorrt/annotation/_custom_plugin/_symbolic.py create mode 100644 py/torch_tensorrt/annotation/_recorders.py create mode 100644 py/torch_tensorrt/annotation/_specs.py create mode 100644 tests/py/annotation/BUILD create mode 100644 tests/py/annotation/__init__.py create mode 100644 tests/py/annotation/conftest.py create mode 100644 tests/py/annotation/integration/__init__.py create mode 100644 tests/py/annotation/integration/test_custom_plugin_trt_plugins_e2e.py create mode 100644 tests/py/annotation/unit/__init__.py create mode 100644 tests/py/annotation/unit/test_specs.py diff --git a/.gitignore b/.gitignore index f08d97d448..eeca5fb284 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,8 @@ +# Python virtual environments +.venv/ +venv/ +env/ + bazel bazel-bazel-test bazel-bin diff --git a/py/torch_tensorrt/annotation/README.md b/py/torch_tensorrt/annotation/README.md new file mode 100644 index 0000000000..36c2788664 --- /dev/null +++ b/py/torch_tensorrt/annotation/README.md @@ -0,0 +1,163 @@ +# torch_tensorrt.annotation — custom_plugin descriptors + +`torch_tensorrt.annotation` (aliased as `tta`) provides descriptor types and +factory functions for defining custom TensorRT AOT QDP plugins backed by +Triton, CuTile, or CuTeDSL kernels. + +```python +import torch_tensorrt.annotation as tta +``` + +This module is **descriptor-only**: it builds spec objects that describe how a +plugin should be compiled and registered. It does not patch `torch.export`, +add compilation hooks, or modify any torch-trt core path. + +--- + +## Table of contents + +1. [Quick start](#1-quick-start) +2. [Factory functions](#2-factory-functions) +3. [Spec types](#3-spec-types) +4. [QDP plugin flow](#4-qdp-plugin-flow) +5. [Running tests](#5-running-tests) + +--- + +## 1. Quick start + +```python +import torch_tensorrt.annotation as tta + +# Triton AOT plugin +spec = tta.custom_plugin(tta.triton(my_launch_fn, configs=[{"BLOCK_SIZE": 128}])) + +# CuTile plugin (Blackwell sm_100+) +spec = tta.custom_plugin(tta.cutile(my_cutile_kernel, arch=120)) + +# CuTeDSL plugin +spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_kernel)) +``` + +--- + +## 2. Factory functions + +### `tta.triton(launch_fn, configs=None)` + +Wraps a Triton kernel launch function. + +- **`launch_fn`** — callable that launches the Triton kernel; + signature `(input0, ..., output, stream, **config)`. +- **`configs`** — list of `dict` tactic configs; each becomes a separate + QDP tactic. Pass `None` for a single default tactic. + +```python +@triton.jit +def _add_relu_kernel(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr): + i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK) + mask = i < n + tl.store(out_ptr + i, tl.maximum(0, tl.load(x_ptr+i, mask=mask) + tl.load(y_ptr+i, mask=mask)), mask=mask) + +def launch_add_relu(x, y, out, stream, BLOCK=256): + _add_relu_kernel[(triton.cdiv(x.numel(), BLOCK),)](x, y, out, x.numel(), BLOCK=BLOCK) + +spec = tta.custom_plugin(tta.triton(launch_add_relu, configs=[{"BLOCK": 128}, {"BLOCK": 256}])) +``` + +### `tta.cutile(launch_fn, arch=None, configs=None)` + +Wraps a CuTile (cuda-tile) kernel. Requires Blackwell (sm_100+) and the +`cuda-tile` package. + +- **`arch`** — SM architecture integer (e.g. `120` for sm_120). +- **`configs`** — list of tactic dicts. + +```python +spec = tta.custom_plugin(tta.cutile(my_cutile_fn, arch=120, configs=[{"TILE_M": 64}])) +``` + +### `tta.cutedsl(launch_fn, configs=None)` + +Wraps a CuTeDSL kernel (`nvidia-cutlass-dsl`). + +```python +spec = tta.custom_plugin(tta.cutedsl(my_cutedsl_fn)) +``` + +### `tta.custom_plugin(impl)` + +Builds a `CustomPluginSpec` from a kernel spec (`TritonSpec`, `CuTileSpec`, or +`CuTeDSLSpec`). Computes a deterministic QDP `op_name` from the kernel +function identity and config hash. + +```python +spec = tta.custom_plugin(tta.triton(launch_fn, configs=[{"BLOCK": 256}])) +# spec.op_name — deterministic string like "tta::launch_fn_a3f2c1" +``` + +--- + +## 3. Spec types + +All spec types are plain frozen dataclasses — they carry no mutable state and +are safe to hash, compare, and cache. + +| Type | Returned by | Description | +|------|-------------|-------------| +| `CustomPluginSpec` | `custom_plugin()` | AOT QDP plugin descriptor; holds `impl` (`TritonSpec` \| `CuTileSpec` \| `CuTeDSLSpec`) and computed `op_name` | +| `TritonSpec` | `triton()` | Triton kernel launch function + tactic configs | +| `CuTileSpec` | `cutile()` | CuTile kernel + target arch + tactic configs | +| `CuTeDSLSpec` | `cutedsl()` | CuTeDSL kernel + tactic configs | + +--- + +## 4. QDP plugin flow + +`tta.custom_plugin` produces a descriptor. When you call +`register_custom_plugin(spec, inputs)` (from `_custom_plugin._descriptor`) the +module: + +1. Derives a deterministic `op_name` from the kernel function + config hash. +2. Registers `@trtp.register("tta::op_name")` — the shape/dtype descriptor + function derived symbolically from the kernel signature. +3. Registers `@trtp.aot_impl("tta::op_name")` — the AOT implementation + function that returns `(kernel_name, ptx_or_cubin, KernelLaunchParams, + SymIntExprs)`. +4. Uses a process-level lock + double-checked locking to prevent duplicate + registration in multi-threaded pytest-xdist workers. + +The QDP AOT path means **no Python is needed at TRT engine runtime** — the +compiled kernel is embedded directly. + +``` +tta.triton(launch_fn, configs) + └─► TritonSpec + └─► tta.custom_plugin(spec) + └─► CustomPluginSpec(op_name, impl) + └─► register_custom_plugin(spec, inputs) + ├─► @trtp.register (shape descriptor) + └─► @trtp.aot_impl (PTX/cubin → TRT) +``` + +--- + +## 5. Running tests + +Unit tests are CPU-only (no GPU required) and live in +`tests/py/annotation/unit/`. + +```bash +# From inside the dev Docker container: +python -m pytest tests/py/annotation/unit/ -n 4 --tb=short -v +``` + +Test files: + +| File | What it covers | +|------|---------------| +| `test_specs.py` | `TritonSpec`, `CuTileSpec`, `CuTeDSLSpec` construction and hashing | +| `test_specs_custom_plugin.py` | `CustomPluginSpec` and `custom_plugin()` factory | +| `test_signature_binder.py` | TRT signature derivation and binding | +| `test_layer_metadata.py` | `AnnotationMetadata` encode/decode round-trip | +| `test_plugin_lowering.py` | QDP plugin lowering path | diff --git a/py/torch_tensorrt/annotation/__init__.py b/py/torch_tensorrt/annotation/__init__.py new file mode 100644 index 0000000000..a307b0074f --- /dev/null +++ b/py/torch_tensorrt/annotation/__init__.py @@ -0,0 +1,52 @@ +""" +Torch-TensorRT Annotation Layer (TTA) — custom_plugin API. + +Provides descriptor types and factory functions for defining custom TensorRT +AOT QDP plugins backed by Triton, CuTile, or CuTeDSL kernels. + +Usage:: + + import torch_tensorrt.annotation as tta + + # Triton kernel descriptor + spec = tta.custom_plugin( + tta.triton(my_triton_kernel, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ) + + # CuTile kernel descriptor + spec = tta.custom_plugin( + tta.cutile(my_cutile_kernel), + meta_impl=lambda x: x.new_empty(x.shape), + ) + + # CuTeDSL kernel descriptor + spec = tta.custom_plugin( + tta.cutedsl(my_cutedsl_kernel), + meta_impl=lambda x: x.new_empty(x.shape), + ) +""" + +from ._specs import ( + CuTeDSLSpec, + CuTileSpec, + TritonSpec, + cutedsl, + cutile, + triton, +) + +from ._custom_plugin._descriptor import CustomPluginSpec, custom_plugin + +__all__ = [ + # Descriptor types + "TritonSpec", + "CuTileSpec", + "CuTeDSLSpec", + "CustomPluginSpec", + # Factory functions + "custom_plugin", + "triton", + "cutile", + "cutedsl", +] diff --git a/py/torch_tensorrt/annotation/_custom_plugin/__init__.py b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py new file mode 100644 index 0000000000..ce67ac7737 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/__init__.py @@ -0,0 +1,79 @@ +"""Custom plugin sub-package: QDP-backed plugin descriptor and lowering. + +This sub-package bridges user-supplied GPU kernels (Triton, cuTILE, CuTe DSL) +to TensorRT's Quickstart Dynamic Plugin (QDP) framework, enabling annotated +boundary ops to be lowered to first-class ``IPluginV3`` layers at TRT engine +compile time. + +Role in the compilation pipeline +--------------------------------- +1. **Annotation** — the user calls ``tta.custom_plugin(kernel, meta_impl=...)`` + (re-exported here as ``custom_plugin``), which returns a + ``CustomPluginSpec`` that is stored in the boundary op's + ``AnnotationMetadata``. +2. **Registration** — at compile time, ``register_custom_plugin`` calls + ``@trtp.register`` / ``@trtp.aot_impl`` on the descriptor, making the + plugin visible to TRT's global plugin registry. +3. **Lowering** — ``lower_custom_plugin_descriptor`` calls ``trtp.op..`` + to insert an ``IPluginV3`` layer into the ``INetworkDefinition``. + +Public surface +-------------- +``CustomPluginSpec`` + Dataclass returned by ``custom_plugin()``. Carries the op name, kernel + specs, meta-shape implementation, and optional tactic table. + +``custom_plugin`` + Factory that builds a ``CustomPluginSpec`` from a kernel spec, auto- + computing a deterministic QDP op name from the kernel fingerprint. + +``lower_custom_plugin_descriptor`` + Converts a ``CustomPluginSpec`` into a TRT ``IPluginV3`` layer and + returns the output ``trt.ITensor`` (or tuple thereof). + +``register_custom_plugin`` + Registers ``@trtp.register`` / ``@trtp.aot_impl`` handlers for a + descriptor's op name. Idempotent at the process level. + +``QDPRuntimeError`` + Raised when TRT's QDP framework encounters a runtime error (e.g. shape + mismatch, unsupported dtype) during plugin execution. + +``TTAPluginError`` + Raised for TTA-level plugin configuration errors (e.g. missing meta_impl, + invalid kernel spec). + +``SymbolicTensor`` / ``TensorRole`` + Proxy used during AOT kernel compilation to carry symbolic shape + expressions. ``TensorRole`` distinguishes input from output tensors so + that ``analyze_launch_args`` can reconstruct the correct QDP binding + indices. +""" + +# --------------------------------------------------------------------------- +# Re-exports from sub-modules +# --------------------------------------------------------------------------- + +from ._descriptor import ( + CustomPluginSpec, + custom_plugin, + lower_custom_plugin_descriptor, + register_custom_plugin, +) +from ._qdp_utils import QDPRuntimeError, TTAPluginError +from ._symbolic import SymbolicTensor, TensorRole + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + +__all__ = [ + "CustomPluginSpec", + "custom_plugin", + "lower_custom_plugin_descriptor", + "register_custom_plugin", + "QDPRuntimeError", + "TTAPluginError", + "SymbolicTensor", + "TensorRole", +] diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_aot/__init__.py b/py/torch_tensorrt/annotation/_custom_plugin/_aot/__init__.py new file mode 100644 index 0000000000..ffc41fb0f8 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_aot/__init__.py @@ -0,0 +1,16 @@ +"""AOT kernel backend implementations for the TTA custom plugin system. + +Each sub-module implements the ``aot_impl_`` function that drives a +backend-specific compile pipeline and returns the QDP AOT 4-tuple:: + + (kernel_name, code_bytes, KernelLaunchParams, SymIntExprs) + +Backends +-------- +_triton — Triton JIT kernels compiled to PTX via triton.compile() +_cutile — cuTILE programs compiled to CUBIN via cuda.tile compile_tile() +_cutedsl — CuTe DSL @cute.jit kernels compiled via cutlass.cute.compile() + +All three backends share ``_qdp_utils`` helpers for sandboxing, launch-arg +analysis, artifact dumping, and the unified ``AOTMetadata`` result type. +""" diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutedsl.py b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutedsl.py new file mode 100644 index 0000000000..9f4d7471ad --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutedsl.py @@ -0,0 +1,266 @@ +"""CuTe DSL AOT backend: @cute.jit sandbox → cute.compile → PTX → KernelLaunchParams. + +AOT compilation pipeline for CuTe DSL kernels +=============================================== +1. **Grid sandbox** — The user's ``launch_fn`` (a ``@cute.jit``-decorated function) + is unwrapped to its raw Python body via ``_get_jit_raw_fn``. The raw body is + executed in a sandbox with ``SymbolicTensor`` proxies and ``CuTeDSLKernelRecorder`` + objects injected over ``@cute.kernel`` callables in the module globals. This + captures the symbolic grid expression (``recorded_grid``) without running on a GPU. + If the sandbox fails (``strict=False``), the grid falls back to ``(1, 1, 1)`` and + TRT launches with a single CTA. + +2. **CUDA tensor construction** — Dummy ``torch.zeros`` tensors matching the shapes + and dtypes described by the ``inp_descs`` / ``out_descs`` TensorDescs are placed + on the GPU. They are converted to CuTe tensors via ``from_dlpack`` for the + ``cute.compile`` call. + +3. **AOT compilation** — ``cute.compile(launch_fn, *cute_tensors, options=..., **cfg)`` + compiles the kernel with config constants baked in as compile-time values. The + ``--dump-dir`` / ``--keep-ptx`` options cause PTX to be written to a temp directory + *and* attached to ``compiled.artifacts.PTX``. + +4. **PTX extraction** — PTX is read from ``compiled.artifacts.PTX``. The kernel name + is taken from ``compiled.kernel_info`` (first key). + +5. **KernelLaunchParams construction** — The recorded symbolic grid and block dimensions + are stored in a ``trtp.KernelLaunchParams`` so TRT evaluates grid size at runtime. + ``extra_args`` is empty (CuTe DSL encodes shape information differently from Triton). + +The public entry-point is ``aot_impl_cutedsl``, which returns the QDP AOT 4-tuple +``(kernel_name, ptx_bytes, KernelLaunchParams, SymIntExprs)``. +``compile_cutedsl_kernel`` is a thin tactic-manager wrapper. +""" +from __future__ import annotations + +import tempfile +from typing import Any, Dict, List, Mapping, Optional, Tuple + +try: + import tensorrt as trt + import tensorrt.plugin as trtp +except ImportError as e: + raise ImportError( + "TensorRT with plugin support is required for CuTe DSL AOT compilation." + ) from e + +import torch + +from .._qdp_utils import ( + AOTMetadata, + TTAPluginError, + _assign_recorded_grid, + _launch_params_from_trt, + _safe_dim, + is_cute_kernel, + run_kernel_sandbox, + td_dtype_to_torch, +) +from ..._recorders import CuTeDSLKernelRecorder +from ..._specs import CuTeDSLSpec +from .._symbolic import SymbolicTensor, TensorRole + + +def _get_jit_raw_fn(fn: Any) -> Optional[Any]: + """Try to extract the raw Python function from a @cute.jit decorated object. + + @cute.jit may preserve the original function via __wrapped__ (functools.wraps) + or a backend-specific attribute. Check __wrapped__ BEFORE __code__ because + @cute.jit objects may have __code__ on the wrapper (calling it triggers JIT + compilation), while __wrapped__ points to the raw Python body. + """ + for attr in ("__wrapped__", "_fn", "fn", "_func", "func"): + raw = getattr(fn, attr, None) + if raw is not None and raw is not fn and callable(raw) and hasattr(raw, "__code__"): + return raw + if hasattr(fn, "__code__"): + return fn + return None + + +def aot_impl_cutedsl( + *, + qdp_symbol: str, + spec: CuTeDSLSpec, + cfg: Mapping[str, Any], + launch_fn: Any, + host_args: List[SymbolicTensor], + inp_descs: List[Any], + out_descs: List[Any], + attrs: Optional[Dict[str, Any]] = None, +) -> Tuple[str, bytes, Any, Any]: + """CuTe DSL AOT implementation: sandbox → record grid → compile → extract PTX. + + Steps: + 1. Find @cute.kernel objects in launch_fn's module. + 2. Run sandboxed launch_fn body with SymbolicTensor proxies to record the grid. + 3. Create dummy GPU tensors matching inp/out TensorDescs. + 4. Convert to CuTe tensors via from_dlpack. + 5. Call cute.compile(launch_fn, *cute_tensors) to get compiled object. + 6. Extract PTX from compiled.artifacts.PTX. + 7. Get kernel name from compiled.kernel_info. + 8. Return (kernel_name, ptx, KernelLaunchParams, SymIntExprs). + + Returns: + (kernel_name_bytes, ptx_bytes, KernelLaunchParams, SymIntExprs) + """ + backend = "cutedsl" + + try: + import cutlass.cute as cute + from cutlass.cute.runtime import from_dlpack + except ImportError as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"cutlass.cute not available: {exc}", + ) from exc + + # 1. Sandbox run: find @cute.kernel objects and record the symbolic grid. + # strict=False: if @cute.jit doesn't expose __code__ or the sandbox fails, + # we fall back to a (1, 1, 1) grid rather than raising. + recorded_grid: Optional[Tuple[Any, ...]] = None + recorded_block: Optional[Tuple[Any, ...]] = None + + raw_fn = _get_jit_raw_fn(launch_fn) + used_recorder, _ = run_kernel_sandbox( + launch_fn=launch_fn, + host_args=host_args, + is_kernel_fn=is_cute_kernel, + recorder_factory=lambda obj: CuTeDSLKernelRecorder(real_kernel=obj), + raw_fn=raw_fn, + host_kwargs=dict(cfg, **(attrs or {})) if cfg or attrs else None, + strict=False, + ) + if used_recorder is not None: + recorded_grid = used_recorder.grid + recorded_block = used_recorder.block + + # 2. Create dummy GPU tensors for cute.compile. + all_descs = list(inp_descs) + list(out_descs) + dummy_tensors = [] + for td in all_descs: + shape = [_safe_dim(d) for d in td.shape_expr] + torch_dtype = td_dtype_to_torch(td.dtype) + dummy_tensors.append(torch.zeros(shape, dtype=torch_dtype, device="cuda")) + + cute_tensors = [from_dlpack(t) for t in dummy_tensors] + + # 3. Compile. + # TemporaryDirectory cleans up the dump files on exit, even on exception. + # compiled.artifacts.PTX is an in-memory string, so it remains accessible + # after the context manager exits. + with tempfile.TemporaryDirectory(prefix="cutedsl_aot_") as dump_dir: + compile_opts = f"--dump-dir={dump_dir} --keep-cubin --keep-ptx" + # Pass cfg as kwargs so config constants (e.g. BLOCK_SIZE) are compile-time + # values in the PTX, enabling nvvm.reqntid and correct per-tactic block sizes. + compile_kwargs = dict(cfg) if cfg else {} + # cute.compile can raise RuntimeError, subprocess.CalledProcessError, or + # arbitrary Exception subclasses from NVCC/NVVM; a broad catch is necessary + # to wrap all failures in a structured diagnostic. + try: + compiled = cute.compile( + launch_fn, *cute_tensors, options=compile_opts, **compile_kwargs + ) + except Exception as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"cute.compile failed for op '{qdp_symbol}': {exc}", + ) from exc + + if not hasattr(compiled, "artifacts") or not hasattr(compiled.artifacts, "PTX"): + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=( + f"compiled object (type {type(compiled).__name__}) has no " + f"artifacts.PTX attribute" + ), + ) + + ptx_str = compiled.artifacts.PTX + if not ptx_str or ".entry" not in ptx_str: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg="compiled PTX is empty or has no .entry kernel", + ) + + kernel_names = list(compiled.kernel_info.keys()) + if not kernel_names: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg="compiled object has no kernel names in kernel_info", + ) + kernel_name_str = kernel_names[0] + + # 4. Build KernelLaunchParams using the recorded (symbolic) grid. + launch = trtp.KernelLaunchParams() + _assign_recorded_grid(launch, recorded_grid) + + if recorded_block is not None and len(recorded_block) >= 1: + launch.block_x = recorded_block[0] if isinstance(recorded_block[0], int) else 1 + launch.block_y = recorded_block[1] if len(recorded_block) > 1 and isinstance(recorded_block[1], int) else 1 + launch.block_z = recorded_block[2] if len(recorded_block) > 2 and isinstance(recorded_block[2], int) else 1 + else: + launch.block_x = 1 + launch.block_y = 1 + launch.block_z = 1 + launch.shared_mem = 0 + + extra_args = trtp.SymIntExprs(0) + + ptx_bytes = ptx_str.encode("utf-8") + return kernel_name_str, ptx_bytes, launch, extra_args + + +def compile_cutedsl_kernel( + spec: CuTeDSLSpec, config: Dict[str, Any] +) -> AOTMetadata: + """Compile a CuTeDSLSpec into unified AOTMetadata. + + This is the tactic-manager entry-point. It constructs synthetic 1-input / + 1-output TensorDesc stubs (shape [256], float32) to drive the sandbox run, + delegates to ``aot_impl_cutedsl`` for the full pipeline, and wraps the result + in ``AOTMetadata``. + + Args: + spec: CuTeDSLSpec carrying ``launch_fn`` (a ``@cute.jit`` function), + optional ``configs``, etc. + config: Single tactic configuration dict (e.g. ``{"BLOCK_SIZE": 128}``). + + Returns: + AOTMetadata with ``backend="cutedsl"``, compiled PTX bytes, and launch params. + """ + inp_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + out_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + sym_inputs = [ + SymbolicTensor(td=inp_descs[0], role=TensorRole.INPUT, index=0) + ] + sym_outputs = [ + SymbolicTensor(td=out_descs[0], role=TensorRole.OUTPUT, index=0) + ] + host_args = sym_inputs + sym_outputs + kernel_name_bytes, ptx_bytes, launch, extra_args = aot_impl_cutedsl( + qdp_symbol="cutedsl_compile", + spec=spec, + cfg=config, + launch_fn=spec.launch_fn, + host_args=host_args, + inp_descs=inp_descs, + out_descs=out_descs, + ) + kernel_name = ( + kernel_name_bytes.decode("utf-8") + if isinstance(kernel_name_bytes, bytes) + else kernel_name_bytes + ) + launch_params = _launch_params_from_trt(launch, extra_args, num_inputs=1, num_outputs=1) + return AOTMetadata(binary=ptx_bytes, kernel_name=kernel_name, launch_params=launch_params, backend="cutedsl") diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutile.py b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutile.py new file mode 100644 index 0000000000..a780a5f098 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_cutile.py @@ -0,0 +1,754 @@ +"""cuTILE AOT backend: source → sandbox recording → CUBIN/PTX → KernelLaunchParams. + +AOT compilation pipeline for cuTILE kernels +============================================ +1. **Sandbox execution** — A patched ``cuda.tile.launch`` shim is injected alongside + ``CuTileLaunchRecorder`` proxies for each cuTILE program object found in the + user's module globals. ``launch_fn`` is called with ``SymbolicTensor`` arguments + so that the recorded grid and tensor bindings are captured symbolically. + +2. **Argument analysis** — ``analyze_launch_args`` splits recorded call arguments into + *pointer binding indices* and *scalar SymInt32 expressions* that TRT evaluates at + runtime. + +3. **CUBIN compilation** — ``compile_tile(pyfunc, example_tensors, CompilerOptions())`` + from ``cuda.tile._compile`` produces a CUBIN file containing an embedded PTX + debug section (``.nv_debug_ptx_txt``). + +4. **PTX extraction and parameter reordering** — ``_extract_ptx_from_cubin`` recovers + the PTX from the CUBIN. The cuTILE kernel ABI uses per-tensor groups of + ``(ptr, extents..., strides...)`` parameters in manual (kernel-source) order. + The QDP runtime passes ``(input_ptrs, extra_args, output_ptrs)`` in a different + order. ``_reorder_cutile_ptx_for_trt`` reorders the ``.param`` declarations to + match the runtime expectation and downgrades the ``.version`` directive if the + CUDA driver is older than the PTX version emitted by ``tileiras``. + +5. **Tactic uniquification** — Identical to the Triton backend: a config-derived + suffix is appended to the kernel name so TRT registers separate PTX/cubin entries + per tactic. + +The public entry-point is ``aot_impl_cutile``, which returns the QDP AOT 4-tuple +``(kernel_name_bytes, code_bytes, KernelLaunchParams, SymIntExprs)``. +``compile_cutile_program`` is a thin tactic-manager wrapper. +""" +from __future__ import annotations + +import logging +import os +import re +from typing import Any, Dict, List, Mapping, Optional, Tuple + +logger = logging.getLogger(__name__) + +try: + import tensorrt as trt + import tensorrt.plugin as trtp +except ImportError as e: + raise ImportError( + "TensorRT with plugin support is required for cuTILE AOT compilation." + ) from e + +from .._qdp_utils import ( + AOTMetadata, + TTAPluginError, + _as_symint32, + _assign_recorded_grid, + _launch_params_from_trt, + _safe_dim, + _shape_expr_to_ints, + _sym_dim, + analyze_launch_args, + dump_code_artifact, + is_cutile_program, + run_kernel_sandbox, + td_dtype_to_torch, +) +from ..._recorders import CuTileLaunchRecorder +from ..._specs import CuTileSpec +from .._symbolic import SymbolicTensor, TensorRole + + +def _params_per_tensor_for_rank(rank: int) -> int: + """CuTile kernel ABI: ptr + extent per dim + stride per dim = 1 + 2*rank.""" + return 1 + 2 * max(1, rank) + + +def _cutile_trt_order_for_io( + num_inputs: int, + num_outputs: int, + num_scalars: int = 1, + rank: int = 1, +) -> Tuple[int, Tuple[int, ...]]: + """Compute (num_params, runtime_order) for CuTile kernel. + + Per-tensor params: ptr (1) + extent per dim (rank) + stride per dim (rank) = 1 + 2*rank. + Manual (kernel) order: for each input [ptr, ext..., str...]; for each output [ptr, ext..., str...]; scalars. + QDP runtime passes: input_ptrs, extra_args (extent/stride + scalars), output_ptrs. + Returns permutation so that reordered[physical_idx] = manual_param_index. + """ + n_in = num_inputs + n_out = num_outputs + ppt = _params_per_tensor_for_rank(rank) + num_params = (n_in + n_out) * ppt + num_scalars + manual_ptr_in = [i * ppt for i in range(n_in)] + manual_ptr_out = [n_in * ppt + j * ppt for j in range(n_out)] + extent_stride_indices = list(range(1, ppt)) + manual_extent_stride_in = [ + i * ppt + k for i in range(n_in) for k in extent_stride_indices + ] + manual_extent_stride_out = [ + n_in * ppt + j * ppt + k + for j in range(n_out) + for k in extent_stride_indices + ] + manual_scalars = [(n_in + n_out) * ppt + s for s in range(num_scalars)] + trt_order = ( + list(manual_ptr_in) + + manual_extent_stride_in + + manual_extent_stride_out + + manual_scalars + + list(manual_ptr_out) + ) + return num_params, tuple(trt_order) + + +_ELF_MIN_SIZE = 64 + +_PTX_ENTRY_PARAM_RE = re.compile( + r"(\.visible\s+\.entry\s+(\w+)\s*\()([^)]*)\)", + re.DOTALL, +) + +_PTX_VERSION_RE = re.compile(r"\.version\s+(\d+)\.(\d+)") +_PTX_REQNTID_RE = re.compile(r"\.reqntid\s+(\d+)") +_PTX_MAX_SUPPORTED_VERSION = (9, 0) + + +def _extract_ptx_from_cubin(cubin_bytes: bytes) -> Optional[str]: + """Extract embedded PTX from CUBIN (.nv_debug_ptx_txt section). + + The CuTile compiler embeds PTX as null-separated strings in this debug section. + We scan for '.version' to find the start, then track brace depth to find the + matching closing '}' of the kernel function body. A simple find(b'}') is wrong + because PTX kernels that use vector-register syntax (e.g. mov.b64 {%r1, %r2}) + contain inline '}' characters before the actual function-closing '}'. + Returns formatted PTX string or None. + """ + if len(cubin_bytes) < _ELF_MIN_SIZE or cubin_bytes[:4] != b"\x7fELF": + return None + idx = cubin_bytes.find(b".version") + if idx < 0: + return None + # Find the opening '{' of the kernel function body (the first '{' after .version). + open_brace = cubin_bytes.find(b"{", idx) + if open_brace < 0: + return None + # Track brace depth to find the matching closing '}', handling inline '{...}' + # pairs inside instructions (e.g. mov.b64 {%r1, %r2}, %rd0). + depth = 0 + end = -1 + scan_limit = min(open_brace + 500_000, len(cubin_bytes)) + for i in range(open_brace, scan_limit): + c = cubin_bytes[i] + if c == ord("{"): + depth += 1 + elif c == ord("}"): + depth -= 1 + if depth == 0: + end = i + break + if end < 0: + return None + raw = cubin_bytes[idx : end + 1] + ptx = raw.replace(b"\x00", b"\n").decode("utf-8", errors="replace") + lines = [l for l in ptx.splitlines() if l.strip()] + return "\n".join(lines) + "\n" + + +# LIMITATION (short-term workaround): PTX version downgrade is a best-effort +# text substitution. CUDA 13.1 tileiras emits .version 9.1 but the driver in +# our container (CUDA 13.0) only JIT-compiles up to PTX 9.0. Lowering the +# declared version by editing the .version line is safe only if the PTX body +# uses no ISA features introduced after the target version. This will silently +# produce incorrect PTX if a future tileiras version emits instructions that +# require a higher PTX version than the declared one after downgrade. +# Long-term fix: align the driver and CUDA toolkit versions so no downgrade is +# needed. +def _downgrade_ptx_version(ptx: str) -> str: + """Downgrade .version directive if it exceeds what the runtime driver supports. + + CUDA 13.1 tileiras emits .version 9.1 but a CUDA 13.0 driver only JIT-compiles + up to 9.0. Lowering the declared version is safe as long as the PTX body doesn't + use features introduced after the target version (CuTile-generated PTX doesn't). + """ + m = _PTX_VERSION_RE.search(ptx) + if not m: + return ptx + major, minor = int(m.group(1)), int(m.group(2)) + max_major, max_minor = _PTX_MAX_SUPPORTED_VERSION + if (major, minor) <= (max_major, max_minor): + return ptx + replacement = f".version {max_major}.{max_minor}" + return ptx[: m.start()] + replacement + ptx[m.end() :] + + +def _parse_reqntid(ptx: str) -> Optional[int]: + """Extract the .reqntid value (required threads per CTA) from PTX. + + CuTile may vectorise (e.g. f32x2) and use fewer threads than the tile size. + The kernel MUST be launched with exactly this many threads. + """ + m = _PTX_REQNTID_RE.search(ptx) + return int(m.group(1)) if m else None + + +def _count_ptx_params(ptx: str) -> int: + """Count the number of .param arguments in the PTX .entry declaration.""" + m = _PTX_ENTRY_PARAM_RE.search(ptx) + if not m: + return 0 + param_block = m.group(3) + params_raw = [p.strip().rstrip(",") for p in param_block.split(",") if p.strip()] + return len(params_raw) + + +def _reorder_ptx_params(ptx: str, runtime_order: Tuple[int, ...]) -> str: + """Reorder .param declarations in PTX entry to match runtime_order permutation. + + runtime_order[i] = manual_param_index: physical slot i holds manual param runtime_order[i]. + So the new declaration order is: for each i in 0..n-1, place the param that was originally at index runtime_order[i]. + """ + m = _PTX_ENTRY_PARAM_RE.search(ptx) + if not m: + return ptx + prefix = m.group(1) + param_block = m.group(3) + params_raw = [p.strip().rstrip(",") for p in param_block.split(",") if p.strip()] + if len(params_raw) != len(runtime_order): + return ptx + reordered = [params_raw[runtime_order[i]] for i in range(len(runtime_order))] + new_param_block = ",\n ".join(reordered) + new_entry = prefix + "\n " + new_param_block + "\n)" + return ptx[: m.start()] + new_entry + ptx[m.end() :] + + +# LIMITATION (fragile PTX rewrite): _reorder_cutile_ptx_for_trt extracts PTX +# embedded in the CuTile CUBIN via a regex on the raw ELF bytes, then rewrites +# .param declarations to match the TRT plugin runtime argument order (inputs, +# scalars, outputs). This is brittle for several reasons: +# 1. The PTX extraction scans raw ELF bytes for the PTX section header — it +# can silently return wrong bytes if the CUBIN internal layout changes. +# 2. The .param reorder uses a regex on the textual PTX entry signature; any +# change to how tileiras serialises the .param block (whitespace, line +# breaks, inline comments) will cause the reorder to silently no-op and +# return the original CUBIN, producing incorrect kernel argument binding. +# 3. If tileiras starts emitting multiple .entry kernels per CUBIN, the regex +# matches only the first one. +# The proper fix is for tileiras / TRT CuTile integration to agree on a +# canonical argument order without requiring a post-processing rewrite. +def _reorder_cutile_ptx_for_trt( + cubin_bytes: bytes, + kernel_name: str, + num_inputs: int = 1, + num_outputs: int = 1, + num_scalars: int = 1, + rank: int = 1, +) -> bytes: + """Extract PTX from CuTile CUBIN and reorder .param declarations to CuTile plugin runtime order. + + Returns reordered PTX as bytes (accepted by runtime and JIT-compiled to the + right arch), or original cubin_bytes unchanged on failure. + """ + if len(cubin_bytes) < _ELF_MIN_SIZE or cubin_bytes[:4] != b"\x7fELF": + return cubin_bytes + ptx = _extract_ptx_from_cubin(cubin_bytes) + if not ptx: + return cubin_bytes + expected_n, runtime_order = _cutile_trt_order_for_io( + num_inputs, num_outputs, num_scalars, rank=rank + ) + if expected_n < 1: + return cubin_bytes + reordered_ptx = _reorder_ptx_params(ptx, runtime_order) + if reordered_ptx == ptx: + return cubin_bytes + reordered_ptx = _downgrade_ptx_version(reordered_ptx) + return reordered_ptx.encode("utf-8") + + +def _infer_cutile_extra_args( + inp_descs: List[Any], + out_descs: List[Any], + block_size: int, + scalar_symints: Optional[List[Any]] = None, + rank: int = 1, +) -> Any: + """Build SymIntExprs: (extent per dim, stride per dim) per tensor then scalar(s). + + Uses symbolic SymInt32 expressions for dynamic dimensions so that TRT can + evaluate them at runtime with actual input shapes. + """ + num_scalars = 1 + if scalar_symints: + num_scalars = len(scalar_symints) + n_tensors = len(inp_descs) + len(out_descs) + r = max(1, rank) + n = n_tensors * 2 * r + num_scalars + extra_args = trtp.SymIntExprs(n) + idx = 0 + for td in list(inp_descs) + list(out_descs): + try: + shape = list(td.shape_expr) # list of int or SymInt32 + if r == 1: + # numel = symbolic product of all dims; wrap result back to SymInt32 + numel: Any = trtp.SymInt32(1) + for d in shape: + numel = numel * _sym_dim(d) + extra_args[idx] = _as_symint32(numel) + extra_args[idx + 1] = trtp.SymInt32(1) + else: + use_shape = shape[:r] if len(shape) >= r else (shape + [1] * (r - len(shape))) + # extents (symbolic) + for i in range(r): + d = use_shape[i] if i < len(use_shape) else 1 + extra_args[idx + i] = _as_symint32(_sym_dim(d)) + # row-major strides (symbolic products) + for i in range(r): + s: Any = trtp.SymInt32(1) + for j in range(i + 1, len(use_shape)): + s = s * _sym_dim(use_shape[j] if j < len(use_shape) else 1) + extra_args[idx + r + i] = _as_symint32(s) + except (AttributeError, TypeError, RuntimeError): + # Broad shape_expr access may fail for mock/stub TensorDescs used in + # tests; fall back to SymInt32(1) for all extent/stride slots. + for i in range(2 * r): + extra_args[idx + i] = trtp.SymInt32(1) + idx += 2 * r + for i in range(num_scalars): + if scalar_symints and i < len(scalar_symints): + extra_args[idx + i] = _as_symint32(scalar_symints[i]) + else: + extra_args[idx + i] = trtp.SymInt32(block_size) + return extra_args + + +def aot_impl_cutile( + *, + qdp_symbol: str, + spec: CuTileSpec, + cfg: Mapping[str, Any], + launch_fn: Any, + host_args: List[SymbolicTensor], + inp_descs: List[Any], + out_descs: List[Any], + attrs: Optional[Dict[str, Any]] = None, +) -> Tuple[bytes, bytes, Any, Any]: + """cuTILE AOT implementation: sandbox → record → compile → CUBIN. + + Steps: + 1. Find cuTILE program objects in launch_fn's module. + 2. Replace with CuTileLaunchRecorder proxies. + 3. Run sandboxed launch_fn(*host_args, **merged_kwargs). + 4. Analyse recorded args → param_binding_indices + scalar SymInts. + 5. Compile to CUBIN using cuTILE internal APIs. + 6. Return (kernel_name, cubin, KernelLaunchParams, SymIntExprs). + + Returns: + (kernel_name_bytes, cubin_bytes, KernelLaunchParams, SymIntExprs) + """ + import torch + + backend = "cutile" + + try: + import cuda.tile # type: ignore[import] # noqa: F401 + except ImportError as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"cuTILE module (cuda.tile) not available for op '{qdp_symbol}': {exc}", + ) from exc + + # 1-2. Locate module, find cuTILE programs, sandbox and run. + # Build a patched ct.launch so that ct.launch(stream, grid, prog, args) + # correctly sets recorder.grid on a CuTileLaunchRecorder proxy. + # strict=True: raise on module-not-found and no-program-found. + import cuda.tile as _ct # type: ignore[import] + import types as _types + + # Sentinel list: closure trick to capture the recorder after creation. + _recorder_ref: List[Any] = [] + + def _sandbox_launch(stream, grid, kernel, kernel_args): + if isinstance(kernel, CuTileLaunchRecorder): + kernel.grid = grid + kernel(*kernel_args) + else: + _ct.launch(stream, grid, kernel, kernel_args) + + _ct_sandbox = _types.ModuleType("cuda.tile.sandbox") + for _k, _v in vars(_ct).items(): + setattr(_ct_sandbox, _k, _v) + _ct_sandbox.launch = _sandbox_launch + + merged_kwargs = dict(cfg) + if attrs: + merged_kwargs.update(attrs) + + # Pass the patched ct module as extra_override; harmless if module doesn't use it. + # Broad catch is necessary: the sandbox executes arbitrary user launch_fn code + # and can raise any exception type (ImportError for missing deps, TypeError + # for shape mismatches, AttributeError from proxy gaps, etc.). + try: + used_recorder, prog_recorders = run_kernel_sandbox( + launch_fn=launch_fn, + host_args=host_args, + is_kernel_fn=is_cutile_program, + recorder_factory=lambda obj: CuTileLaunchRecorder(real_prog=obj), + host_kwargs=merged_kwargs, + extra_overrides={"ct": _ct_sandbox}, + strict=True, + op=qdp_symbol, + backend=backend, + ) + except TTAPluginError: + raise + except Exception as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"sandbox failed for op '{qdp_symbol}': {exc}", + ) from exc + + # 3. Exactly one cuTILE program must have been called. + used = [rec for rec in prog_recorders.values() if rec.args is not None] + if len(used) != 1: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"expected exactly 1 cuTILE program launch; got {len(used)}", + ) + + recorder = used[0] + args = recorder.args # guaranteed non-None + + # 4. Analyze recorded args → pointer binding indices + scalar SymInts. + num_inputs = len(inp_descs) + num_outputs = len(out_descs) + + _, scalar_symints = analyze_launch_args( + args=args, + num_inputs=num_inputs, + num_outputs=num_outputs, + op=qdp_symbol, + backend=backend, + ) + + def _symint_to_int(v: Any, default: int) -> int: + """Extract concrete int from a scalar that may be Python int or SymInt32.""" + if isinstance(v, int): + return v + # Try the same fallbacks as _shape_expr_to_ints + try: + return int(v) + except (TypeError, ValueError): + pass + if hasattr(v, "max"): + try: + return int(v.max()) + except (TypeError, ValueError, AttributeError): + pass + if hasattr(v, "constant_value") and not getattr(v, "is_fake", True): + try: + cv = v.constant_value + return int(cv() if callable(cv) else cv) + except (TypeError, ValueError, AttributeError, RuntimeError): + pass + return default + + block_size = int(cfg.get("BLOCK", 256)) + if scalar_symints: + block_size = _symint_to_int(scalar_symints[0], block_size) + + try: + from cuda.tile._compile import compile_tile # type: ignore[import] + from cuda.tile._compiler_options import CompilerOptions # type: ignore[import] + except ImportError as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"failed to import cuTILE internals (compile_tile/CompilerOptions) for op '{qdp_symbol}': {exc}", + ) from exc + + # Ensure tileiras is findable: add the Python-packaged nvidia/cu13/bin if needed. + import shutil as _shutil + import sysconfig as _sysconfig + if not _shutil.which("tileiras"): + _site = _sysconfig.get_path("platlib") + _cu13_bin = os.path.join(_site, "nvidia", "cu13", "bin") + if os.path.isdir(_cu13_bin) and _cu13_bin not in os.environ.get("PATH", ""): + os.environ["PATH"] = _cu13_bin + os.pathsep + os.environ.get("PATH", "") + + all_descs = list(inp_descs) + list(out_descs) + num_inputs = len(inp_descs) + num_outputs = len(out_descs) + + def _make_example_tensors(flatten_all: bool) -> List[Any]: + out: List[Any] = [] + for td in all_descs: + dims = [_safe_dim(d) for d in td.shape_expr] + numel = 1 + for d in dims: + numel *= d + if flatten_all or len(dims) > 2: + shape = (numel,) + else: + shape = tuple(dims) + out.append( + torch.empty( + shape, + dtype=td_dtype_to_torch(td.dtype), + device="cuda", + ) + ) + return out + + example_tensors = _make_example_tensors(flatten_all=False) + if scalar_symints: + scalar_ints = [_symint_to_int(s, block_size) for s in scalar_symints] + pyfunc_args = tuple(example_tensors) + tuple(scalar_ints) + else: + scalar_ints = [block_size] + pyfunc_args = tuple(example_tensors) + (block_size,) + real_prog = recorder.real_prog + pyfunc = getattr(real_prog, "_pyfunc", real_prog) + + try: + result = compile_tile(pyfunc, pyfunc_args, CompilerOptions()) + except Exception as exc: + exc_str = str(exc) + if "Index size" in exc_str and "array rank" in exc_str: + example_tensors = _make_example_tensors(flatten_all=True) + pyfunc_args = tuple(example_tensors) + tuple(scalar_ints) + try: + result = compile_tile(pyfunc, pyfunc_args, CompilerOptions()) + except Exception as retry_exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"cuTILE compilation failed (retry with flattened shapes): {retry_exc}", + ) from retry_exc + else: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"cuTILE compilation failed: {exc}", + ) from exc + + if hasattr(result, "fname_cubin"): + with open(result.fname_cubin, "rb") as f: + cubin_bytes = f.read() + kernel_name_str = getattr(result, "func_name", None) or getattr(real_prog, "__name__", None) or "cutile_kernel" + elif isinstance(result, bytes): + cubin_bytes = result + kernel_name_str = getattr(real_prog, "kernel_name", None) or getattr(real_prog, "__name__", None) or "cutile_kernel" + else: + with open(result, "rb") as f: + cubin_bytes = f.read() + kernel_name_str = getattr(real_prog, "kernel_name", None) or getattr(real_prog, "__name__", None) or "cutile_kernel" + + num_scalars = len(scalar_symints) if scalar_symints else 1 + rank = max(1, max(len(td.shape_expr) for td in all_descs)) + ptx_raw = _extract_ptx_from_cubin(cubin_bytes) + num_ptx_params = _count_ptx_params(ptx_raw) if ptx_raw else 0 + n_io = num_inputs + num_outputs + if n_io > 0 and num_ptx_params >= num_scalars: + remainder = (num_ptx_params - num_scalars) % n_io + if remainder == 0: + ppt = (num_ptx_params - num_scalars) // n_io + if ppt >= 1: + effective_rank = (ppt - 1) // 2 + if effective_rank < 1: + effective_rank = 1 + else: + effective_rank = 1 + else: + effective_rank = 1 if rank > 2 else rank + else: + effective_rank = 1 if rank > 2 else rank + if rank > 1 and ptx_raw: + expected_n, _ = _cutile_trt_order_for_io( + num_inputs, num_outputs, num_scalars, rank=effective_rank + ) + # Dump diagnostic PTX for rank>1 kernels and log the path. + diag_filename = f"cutile_rank_gt1_{kernel_name_str}_nparam_{num_ptx_params}_expected_{expected_n}.ptx" + dump_code_artifact("CUTILE_DUMP_DIR", diag_filename, ptx_raw, default_dir="/tmp/cutile_dump") + ptx_path = os.path.join( + os.environ.get("CUTILE_DUMP_DIR") or "/tmp/cutile_dump", diag_filename + ) + logger.info( + "cutile rank>1: kernel=%s num_ptx_params=%d expected_n=%d effective_rank=%d (num_inputs=%d num_outputs=%d num_scalars=%d) ptx_dumped=%s", + kernel_name_str, + num_ptx_params, + expected_n, + effective_rank, + num_inputs, + num_outputs, + num_scalars, + ptx_path, + ) + code_out = _reorder_cutile_ptx_for_trt( + cubin_bytes, + kernel_name_str, + num_inputs=num_inputs, + num_outputs=num_outputs, + num_scalars=num_scalars, + rank=effective_rank, + ) + + if ptx_raw is not None: + dump_code_artifact( + "CUTILE_DUMP_DIR", + f"cutile_rank{effective_rank}_{kernel_name_str}_orig.ptx", + ptx_raw, + default_dir="/tmp/cutile_dump", + ) + dump_code_artifact( + "CUTILE_DUMP_DIR", + f"cutile_rank{effective_rank}_{kernel_name_str}_orig.cubin", + cubin_bytes, + default_dir="/tmp/cutile_dump", + ) + if b".version" in code_out[:64]: + dump_code_artifact( + "CUTILE_DUMP_DIR", + f"cutile_rank{effective_rank}_{kernel_name_str}_reordered.ptx", + code_out.decode("utf-8", errors="replace"), + default_dir="/tmp/cutile_dump", + ) + + # CuTile may vectorise and use fewer threads than tile_size (.reqntid). + is_ptx = b".version" in code_out[:64] + if is_ptx: + ptx_str = code_out.decode("utf-8", errors="replace") + reqntid = _parse_reqntid(ptx_str) + actual_block = reqntid if reqntid else block_size + else: + actual_block = block_size + + # Follow the same pattern as _triton_aot.py: use recorder.grid directly. + # The grid recorded during sandbox execution is always correct — it's whatever + # the launch function computed from the input shapes (concrete ints for static + # shapes, SymInt32 expressions for dynamic shapes). + launch = trtp.KernelLaunchParams() + recorded_grid = getattr(recorder, "grid", None) + if recorded_grid is not None and len(recorded_grid) >= 1: + _assign_recorded_grid(launch, recorded_grid) + else: + # Fallback: 1D grid from output numel / block_size. + # Broad catch is necessary: _shape_expr_to_ints raises RuntimeError for + # unbounded dynamic dims and may raise AttributeError for stub TensorDescs. + try: + out_ints = _shape_expr_to_ints(out_descs[0].shape_expr) + numel = 1 + for x in out_ints: + numel *= x + launch.grid_x = trtp.SymInt32((numel + block_size - 1) // block_size) + except (RuntimeError, AttributeError, TypeError): + launch.grid_x = trtp.SymInt32(1) + launch.grid_y = trtp.SymInt32(1) + launch.grid_z = trtp.SymInt32(1) + launch.block_x = actual_block + launch.shared_mem = 0 + + extra_args = _infer_cutile_extra_args( + inp_descs, out_descs, block_size, scalar_symints, rank=effective_rank + ) + + ext = ".ptx" if is_ptx else ".cubin" + dump_code_artifact( + "CUTILE_DUMP_DIR", + kernel_name_str + ext, + code_out, + default_dir="/tmp/cutile_dump", + ) + + # Uniquify kernel name per config so TRT registers separate PTX/cubin per tactic. + # Without this, two tactics with the same kernel function but different tile sizes + # (e.g. BLOCK=128 vs BLOCK=256) share the same name and TRT applies wrong launch + # params to one of them, causing ~50% element mismatches. + if cfg: + suffix = "_".join(f"{k}{v}" for k, v in sorted(cfg.items())) + unique_name = f"{kernel_name_str}_{suffix}" + # If code_out is still a cubin but we have ptx_raw available, convert to PTX + # so we can do a safe string-replace for the kernel name. + if not is_ptx and ptx_raw is not None: + downgraded = _downgrade_ptx_version(ptx_raw) + code_out = downgraded.encode("utf-8") + is_ptx = True + if is_ptx: + ptx_str_out = code_out.decode("utf-8", errors="replace") + ptx_str_out = ptx_str_out.replace(kernel_name_str, unique_name) + code_out = ptx_str_out.encode("utf-8") + kernel_name_str = unique_name + else: + # Cubin without embedded PTX: patch null-terminated name in-place. + # Only safe when the new name fits in the original allocation. + old_b = kernel_name_str.encode("utf-8") + b"\x00" + new_b = unique_name.encode("utf-8") + b"\x00" + if len(new_b) <= len(old_b): + code_out = code_out.replace(old_b, new_b + b"\x00" * (len(old_b) - len(new_b))) + kernel_name_str = unique_name + else: + logger.warning( + "cutile aot: cannot uniquify cubin kernel name %r → %r (new name is longer); " + "tactic name collision may occur", + kernel_name_str, + unique_name, + ) + + return kernel_name_str.encode("utf-8"), code_out, launch, extra_args + + +def compile_cutile_program(spec: CuTileSpec, config: Dict[str, Any]) -> AOTMetadata: + """Compile a CuTileSpec into unified AOTMetadata. + + This is the tactic-manager entry-point. It constructs synthetic 1-input / + 1-output TensorDesc stubs (shape [256], float32) to drive the sandbox run, + delegates to ``aot_impl_cutile`` for the full pipeline, and wraps the result + in ``AOTMetadata``. + + Args: + spec: CuTileSpec carrying ``launch_fn``, optional ``configs``, etc. + config: Single tactic configuration dict (e.g. ``{"BLOCK": 256}``). + + Returns: + AOTMetadata with ``backend="cutile"``, compiled PTX/CUBIN bytes, and launch params. + """ + inp_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + out_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + sym_inputs = [ + SymbolicTensor(td=inp_descs[0], role=TensorRole.INPUT, index=0) + ] + sym_outputs = [ + SymbolicTensor(td=out_descs[0], role=TensorRole.OUTPUT, index=0) + ] + host_args = sym_inputs + sym_outputs + kernel_name_bytes, code_bytes, launch, extra_args = aot_impl_cutile( + qdp_symbol="cutile_compile", + spec=spec, + cfg=config, + launch_fn=spec.launch_fn, + host_args=host_args, + inp_descs=inp_descs, + out_descs=out_descs, + ) + program_name = kernel_name_bytes.decode("utf-8") if isinstance(kernel_name_bytes, bytes) else kernel_name_bytes + launch_params = _launch_params_from_trt(launch, extra_args, num_inputs=1, num_outputs=1) + return AOTMetadata(binary=code_bytes, kernel_name=program_name, launch_params=launch_params, backend="cutile") diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_aot/_triton.py b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_triton.py new file mode 100644 index 0000000000..9ee1dfbdca --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_aot/_triton.py @@ -0,0 +1,552 @@ +"""Triton AOT backend: source → sandbox recording → PTX → KernelLaunchParams. + +AOT compilation pipeline for Triton kernels +============================================ +1. **Sandbox execution** — The user's ``launch_fn`` is run with ``SymbolicTensor`` + proxies in place of real tensors and ``TritonLaunchRecorder`` objects injected + over the real ``@triton.jit`` kernels in the module globals. This lets us + capture the kernel call without executing on a GPU. + +2. **Argument analysis** — ``analyze_launch_args`` separates the recorded call + arguments into *pointer binding indices* (which TRT tensor buffer maps to + which kernel parameter) and *scalar SymInt32 expressions* (grid or shape + scalars that TRT evaluates at runtime). + +3. **PTX compilation** — ``triton.compile(ASTSource(fn, signature, constexprs))`` + produces both PTX and CUBIN. We extract the PTX from ``compiled.asm["ptx"]`` + because TRT's QDP runtime JIT-compiles PTX for the current GPU architecture, + matching the official ``aot_plugin`` example and avoiding cubin arch mismatch. + +4. **Parameter reordering** — Triton emits `.param` declarations in Python/kernel + order (pointers first, then scalars). The QDP runtime passes arguments in + *runtime order* (input pointers, scalars, output pointers). ``_fix_triton_ptx_for_trt`` + rewrites the `.param` block and all references in the PTX body to match. + +5. **Tactic uniquification** — When multiple tactics compile the same kernel + function with different ``constexprs`` (e.g. ``BLOCK_M=16`` vs ``BLOCK_M=32``), + a config-derived suffix is appended to the kernel name so TRT registers + separate PTX entries per tactic. + +The public entry-point for the descriptor system is ``aot_impl_triton``, which +returns the QDP AOT 4-tuple ``(kernel_name, ptx_bytes, KernelLaunchParams, SymIntExprs)``. +``compile_triton_kernel`` is a thin wrapper used by the tactic manager. +""" +from __future__ import annotations + +import logging +import re +from typing import Any, Dict, List, Mapping, Optional, Tuple + +logger = logging.getLogger(__name__) + +try: + import tensorrt as trt + import tensorrt.plugin as trtp +except ImportError as e: + raise ImportError( + "TensorRT with plugin support is required for Triton AOT compilation." + ) from e + +from .._qdp_utils import ( + AOTMetadata, + TTAPluginError, + _as_symint32, + _assign_recorded_grid, + _launch_params_from_trt, + analyze_launch_args, + dump_code_artifact, + is_triton_kernel, + run_kernel_sandbox, +) +from ..._recorders import TritonLaunchRecorder +from ..._specs import TritonSpec +from .._symbolic import SymbolicTensor, TensorRole + + +def _trt_dtype_to_triton_ptr(trt_dtype: Any, qdp_symbol: str) -> str: + """Map a TensorRT DataType to a Triton pointer element-type string.""" + if trt_dtype == trt.float16: + return "fp16" + if trt_dtype == trt.bfloat16: + return "bf16" + if trt_dtype == trt.float32: + return "fp32" + if trt_dtype == trt.int32: + return "i32" + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend="triton", + msg=f"unsupported tensor dtype {trt_dtype} for Triton kernel signature", + ) + + +# LIMITATION (fragile PTX rewrite): the three _ptx_* helpers below rewrite .param +# declarations and their references inside the PTX body using line-by-line text +# scanning. Triton compiles parameters in Python/kernel order (pointers first, +# then constexprs), while the TRT plugin runtime expects (inputs, scalars, +# outputs). The rewrite is fragile because: +# 1. It identifies param references by matching the prefix ``{kernel_name}_param_`` +# as a plain string — any change to Triton's param naming convention silently +# produces incorrect PTX. +# 2. It scans for ``.entry {kernel_name}(`` as a literal string; multi-line or +# differently-formatted entry declarations will not be recognised. +# 3. Unused trailing params (constexprs not referenced in the body) are stripped +# by counting references — this is correct only if Triton doesn't reuse param +# indices in non-obvious ways. +# The proper fix is for TRT's QDP AOT API to expose a parameter-order remapping +# mechanism so that post-compilation PTX rewriting is not needed. + + +def _ptx_downgrade_version(ptx: str) -> str: + """Downgrade the PTX ``.version`` line from 9.x to 9.0. + + Triton on CUDA 13.x emits ``.version 9.1`` (set by LLVM's NVPTX backend). + TRT's QDP PTX loader caps at 9.0 — kernels with a higher version silently + fail to load, producing a spurious ``onShapeChange`` error at runtime. + + Why not pin ``ptx_version=90`` in ``triton.compile``? Requesting 9.0 makes + LLVM emit ``.version 9.0``, which then fails Triton's ``make_cubin`` step + because the bundled ``ptxas`` (v8.7) cannot assemble PTX 9.0. Post-compilation + header patching is therefore the only viable workaround until TRT's PTX loader + is updated to accept 9.1. The ``.target`` line (e.g. ``sm_120a``) is unchanged. + """ + lines = ptx.split("\n") + result = [] + for line in lines: + if line.startswith(".version "): + line = re.sub(r"^(\.version\s+)9\.([1-9]\d*)", r"\g<1>9.0", line) + result.append(line) + return "\n".join(result) + + +def _ptx_reorder_and_strip_params( + ptx: str, + kernel_name: str, + trt_order: List[int], + num_used_params: int, +) -> str: + """Reorder ``.param`` declarations and body references; strip trailing params. + + Two rewrites combined into one pass because they both operate on the same + param-block region of the PTX: + + * **Reorder**: rearrange the ``.param`` declaration lines inside the + ``.entry`` block to match ``trt_order`` (TRT runtime order: + inputs → scalars → outputs). Body references (``{kernel}_param_N``) + are also renamed to match. + + * **Strip**: Triton appends internal params (``printf_buffer``, ``prevGrid``) + beyond the user-declared arguments. TRT passes exactly ``num_used_params`` + args; the extras are dropped and the trailing comma is fixed. + + Args: + ptx: PTX text (after version downgrade). + kernel_name: Entry-point name matching the ``.entry`` directive. + trt_order: Permutation list mapping new param index → original index. + ``trt_order[i]`` is the original position of the param + that should appear at position ``i`` in the TRT call. + num_used_params: Number of params TRT will pass (= len(positional_names)). + """ + needs_reorder = trt_order != list(range(num_used_params)) + pfx = f"{kernel_name}_param_" + lines = ptx.split("\n") + result: List[str] = [] + in_entry = False + param_lines: List[str] = [] + + for line in lines: + if f".entry {kernel_name}(" in line: + in_entry = True + param_lines = [] + result.append(line) + continue + + if in_entry and ".param" in line and pfx in line: + param_lines.append(line) + continue + + if in_entry and ")" in line and ".param" not in line: + in_entry = False + reordered = ( + [param_lines[i] for i in trt_order if i < len(param_lines)] + if needs_reorder + else param_lines[:num_used_params] + ) + for i, pline in enumerate(reordered): + pline = pline.rstrip().rstrip(",") + if i < len(reordered) - 1: + pline += "," + result.append(pline) + result.append(line) + continue + + result.append(line) + + if needs_reorder: + old_to_new = {old: new for new, old in enumerate(trt_order)} + joined = "\n".join(result) + width = len(str(num_used_params - 1)) if num_used_params > 0 else 1 + for old_idx in range(num_used_params - 1, -1, -1): + joined = joined.replace(f"{pfx}{old_idx}", f"{pfx}TEMP{old_idx:0{width}d}") + for old_idx, new_idx in old_to_new.items(): + joined = joined.replace(f"{pfx}TEMP{old_idx:0{width}d}", f"{pfx}{new_idx}") + result = joined.split("\n") + + return "\n".join(result) + + +def _fix_triton_ptx_for_trt( + ptx: str, + kernel_name: str, + num_used_params: int, + param_binding_indices: List[int], + num_inputs: int, + num_scalars: int, +) -> str: + """Apply all PTX mutations needed for TRT QDP compatibility. + + Delegates each discrete rewrite to a named helper: + + 1. :func:`_ptx_downgrade_version` — cap ``.version`` at 9.0. + 2. :func:`_ptx_reorder_and_strip_params` — reorder param declarations and + body references to TRT runtime order (inputs → scalars → outputs) and + strip Triton's internal trailing params. + """ + # Compute TRT runtime parameter order from the recorded binding indices. + num_ptrs = len(param_binding_indices) + input_params = sorted( + ((binding, orig) for orig, binding in enumerate(param_binding_indices) if binding < num_inputs) + ) + output_params = sorted( + ((binding, orig) for orig, binding in enumerate(param_binding_indices) if binding >= num_inputs) + ) + scalar_params = list(range(num_ptrs, num_ptrs + num_scalars)) + trt_order = ( + [orig for _, orig in input_params] + + scalar_params + + [orig for _, orig in output_params] + ) + + ptx = _ptx_downgrade_version(ptx) + ptx = _ptx_reorder_and_strip_params(ptx, kernel_name, trt_order, num_used_params) + return ptx + + +def _specialize_ptx_kernel_name( + ptx: str, + kernel_name: str, + cfg: Mapping[str, Any], +) -> Tuple[str, str]: + """Append a config-derived suffix to the PTX kernel name. + + When two tactics compile the same ``@triton.jit`` function with different + constexprs (e.g. ``BLOCK_M=16`` vs ``BLOCK_M=32``), both return the same + ``kernel_name``. TRT identifies kernels by name, so it uses whichever PTX + was registered last for *all* tactics sharing that name — but still applies + each tactic's launch params, causing a mismatch (wrong grid dimensions for + the baked-in tile sizes). Append a short config suffix to give every tactic + a distinct name. + + Args: + ptx: Raw PTX text from ``compiled.asm["ptx"]``. + kernel_name: Current kernel entry-point name (from ``compiled.metadata.name``). + cfg: Tactic config dict (e.g. ``{"BLOCK_M": 32}``). Empty dict + means no suffix is needed. + + Returns: + ``(new_ptx, new_kernel_name)`` — the rewritten PTX and updated name. + If *cfg* is empty both values are returned unchanged. + """ + if not cfg: + return ptx, kernel_name + suffix = "_".join(f"{k}{v}" for k, v in sorted(cfg.items())) + unique_name = f"{kernel_name}_{suffix}" + ptx = ptx.replace(kernel_name, unique_name) + return ptx, unique_name + + +def _make_triton_launch_params( + grid: Any, + num_warps: int, + shared_mem: int, + scalar_symints: List[Any], +) -> Tuple[Any, Any]: + """Build ``trtp.KernelLaunchParams`` and ``trtp.SymIntExprs`` for a Triton kernel. + + Centralises the repeated pattern of converting a recorded grid tuple into + ``KernelLaunchParams.grid_{x,y,z}`` SymInt32 values and populating + ``extra_args`` from the scalar SymInts captured during sandbox execution. + + Args: + grid: Grid value captured by :class:`TritonLaunchRecorder`. + May be a tuple of ints/SymInt32 or a single value. + num_warps: ``compiled.metadata.num_warps`` from ``triton.compile``. + shared_mem: ``compiled.metadata.shared`` from ``triton.compile``. + scalar_symints: Scalar SymInt expressions from ``analyze_launch_args``. + + Returns: + ``(launch, extra_args)`` — a ``trtp.KernelLaunchParams`` with grid/block/shared + populated and a ``trtp.SymIntExprs`` with one slot per scalar. + """ + launch = trtp.KernelLaunchParams() + _assign_recorded_grid(launch, grid) + launch.block_x = num_warps * 32 + launch.block_y = 1 + launch.block_z = 1 + launch.shared_mem = shared_mem + + extra_args = trtp.SymIntExprs(len(scalar_symints)) + for idx, val in enumerate(scalar_symints): + extra_args[idx] = _as_symint32(val) + + return launch, extra_args + + +def aot_impl_triton( + *, + qdp_symbol: str, + spec: TritonSpec, + cfg: Mapping[str, Any], + launch_fn: Any, + host_args: List[SymbolicTensor], + inp_descs: List[Any], + out_descs: List[Any], + attrs: Optional[Dict[str, Any]] = None, +) -> Tuple[str, bytes, Any, Any]: + """Triton AOT implementation: sandbox → record → compile → PTX. + + Steps: + 1. Find @triton.jit kernels in launch_fn's module. + 2. Replace with TritonLaunchRecorder proxies. + 3. Run sandboxed launch_fn(*host_args, **merged_kwargs). + 4. Analyse recorded args → param_binding_indices + scalar SymInts. + 5. Build per-tensor-dtype Triton signature, triton.compile() → PTX. + 6. Fix PTX: reorder params to runtime order, strip unused trailing params. + 7. Return (kernel_name, ptx, KernelLaunchParams, SymIntExprs). + + Returns: + (kernel_name, ptx, KernelLaunchParams, SymIntExprs) + """ + import triton + + backend = "triton" + + merged_kwargs = dict(cfg) + if attrs: + merged_kwargs.update(attrs) + + # 1-2. Locate module, find @triton.jit kernels, sandbox and run. + # strict=True: propagate module-not-found and no-kernel-found as errors. + # host_kwargs: cfg (tactic constexprs) + attrs (plugin compile-time constants). + # Broad catch is necessary: the sandbox executes arbitrary user launch_fn code + # and can raise any exception type (ImportError for missing deps, TypeError + # for shape mismatches, AttributeError from proxy gaps, etc.). + # + # WHY the sandbox recorder is needed before calling ASTSource: + # * Which kernel — launch_fn may import several @triton.jit functions; the + # recorder tells us exactly which one was launched (``used[0].real_kernel``). + # * Argument order — the recorder captures positional args in the order the + # launch_fn passes them. We use this to derive ``param_binding_indices`` + # (which arg index maps to which TRT input/output descriptor) and the + # Triton ``signature`` dict. ASTSource cannot provide this information. + # * Grid expression — ``recorder.grid`` captures the symbolic or concrete grid + # value to populate ``KernelLaunchParams.grid_{x,y,z}`` at plugin build time. + try: + used_recorder, kernel_recorders = run_kernel_sandbox( + launch_fn=launch_fn, + host_args=host_args, + is_kernel_fn=is_triton_kernel, + recorder_factory=lambda obj: TritonLaunchRecorder(real_kernel=obj), + host_kwargs=merged_kwargs, + strict=True, + op=qdp_symbol, + backend=backend, + ) + except TTAPluginError: + raise + except Exception as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"sandbox failed for op '{qdp_symbol}': {exc}", + ) from exc + + # 3. Exactly one Triton kernel must have been launched. + used = [rec for rec in kernel_recorders.values() if rec.grid is not None] + if len(used) != 1: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"expected exactly 1 Triton kernel launch; got {len(used)}", + ) + + recorder = used[0] + grid = recorder.grid + args = recorder.args + kwargs = recorder.kwargs or {} + + num_inputs = len(inp_descs) + num_outputs = len(out_descs) + + def _to_int(x: Any) -> Any: + try: + return int(x) + except (TypeError, ValueError): + return x + + constexprs = {k: _to_int(v) for k, v in cfg.items()} + if attrs: + constexprs.update(attrs) + kernel_arg_names = list(recorder.real_kernel.arg_names) + positional_names = [n for n in kernel_arg_names if n not in constexprs] + + full_args: List[Any] = [] + for i, name in enumerate(positional_names): + if i < len(args): + full_args.append(args[i]) + elif name in kwargs: + full_args.append(kwargs[name]) + else: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"missing kernel argument '{name}' (pass positionally or by keyword)", + ) + + # 4. Analyze recorded args → pointer binding indices + scalar SymInts. + param_binding_indices, scalar_symints = analyze_launch_args( + args=full_args, + num_inputs=num_inputs, + num_outputs=num_outputs, + op=qdp_symbol, + backend=backend, + ) + + if not param_binding_indices: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg="no pointer arguments recorded for Triton kernel", + ) + + # 5. Build Triton signature dict {param_name: type_str} for non-constexpr args. + all_descs = list(inp_descs) + list(out_descs) + + ptr_idx = 0 + scalar_idx = 0 + signature: Dict[str, str] = {} + for name in positional_names: + if ptr_idx < len(param_binding_indices): + b = param_binding_indices[ptr_idx] + dtype_str = _trt_dtype_to_triton_ptr(all_descs[b].dtype, qdp_symbol) + signature[name] = f"*{dtype_str}" + ptr_idx += 1 + else: + signature[name] = "i32" + scalar_idx += 1 + + expected_num_scalars = len(positional_names) - len(param_binding_indices) + if len(scalar_symints) != expected_num_scalars: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=( + f"scalar count mismatch: launch passed {len(scalar_symints)} scalars " + f"but kernel has {expected_num_scalars} scalar args (depends on descriptor ranks for this invocation)" + ), + ) + + # triton.compile raises a mix of Exception subclasses (CompilationError, + # subprocess.CalledProcessError, RuntimeError) depending on the failure mode; + # a broad catch is necessary here to give a structured diagnostic in all cases. + try: + compiled = triton.compile( + triton.compiler.ASTSource( + fn=recorder.real_kernel, + signature=signature, + constexprs=constexprs, + ) + ) + except Exception as exc: + raise TTAPluginError( + op=qdp_symbol, + stage="aot_impl", + backend=backend, + msg=f"triton.compile failed for kernel '{recorder.real_kernel.__name__}': {exc}", + ) from exc + + # 6. Build KernelLaunchParams (grid can be symbolic; QDP evaluates at runtime). + launch, extra_args = _make_triton_launch_params( + grid=grid, + num_warps=compiled.metadata.num_warps, + shared_mem=compiled.metadata.shared, + scalar_symints=scalar_symints, + ) + + # 7. Extract PTX, reorder params to runtime order, strip unused trailing params. + kernel_name_str: str = compiled.metadata.name + ptx: str = compiled.asm["ptx"] + if isinstance(ptx, bytes): + ptx = ptx.decode("utf-8") + + dump_code_artifact("TTA_DUMP_TRITON_PTX", f"{kernel_name_str}_raw.ptx", ptx) + + ptx, kernel_name_str = _specialize_ptx_kernel_name(ptx, kernel_name_str, cfg) + + num_used = len(positional_names) + ptx = _fix_triton_ptx_for_trt( + ptx=ptx, + kernel_name=kernel_name_str, + num_used_params=num_used, + param_binding_indices=param_binding_indices, + num_inputs=num_inputs, + num_scalars=len(scalar_symints), + ) + ptx_bytes = ptx.encode("utf-8") + + dump_code_artifact("TTA_DUMP_TRITON_PTX", f"{kernel_name_str}_fixed.ptx", ptx) + + return kernel_name_str, ptx_bytes, launch, extra_args + + +def compile_triton_kernel(spec: TritonSpec, config: Dict[str, Any]) -> AOTMetadata: + """Compile a TritonSpec into unified AOTMetadata. + + This is the tactic-manager entry-point. It constructs synthetic 1-input / + 1-output TensorDesc stubs (shape [256], float32) to drive the sandbox run, + delegates to ``aot_impl_triton`` for the full pipeline, and wraps the result + in ``AOTMetadata``. + + Args: + spec: TritonSpec carrying ``launch_fn``, optional ``configs``, etc. + config: Single tactic configuration dict (e.g. ``{"BLOCK_M": 32}``). + + Returns: + AOTMetadata with ``backend="triton"``, compiled PTX bytes, and launch params. + """ + inp_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + out_descs = [trtp.TensorDesc(dtype=trt.float32, shape_expr=[256])] + sym_inputs = [ + SymbolicTensor(td=inp_descs[0], role=TensorRole.INPUT, index=0) + ] + sym_outputs = [ + SymbolicTensor(td=out_descs[0], role=TensorRole.OUTPUT, index=0) + ] + host_args = sym_inputs + sym_outputs + kernel_name_str, ptx_bytes, launch, extra_args = aot_impl_triton( + qdp_symbol="triton_compile", + spec=spec, + cfg=config, + launch_fn=spec.launch_fn, + host_args=host_args, + inp_descs=inp_descs, + out_descs=out_descs, + ) + launch_params = _launch_params_from_trt(launch, extra_args, num_inputs=1, num_outputs=1) + return AOTMetadata(binary=ptx_bytes, kernel_name=kernel_name_str, launch_params=launch_params, backend="triton") diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py new file mode 100644 index 0000000000..9e82e98206 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_descriptor.py @@ -0,0 +1,951 @@ +"""tta.custom_plugin: CustomPluginSpec, factory, and QDP registration. + +This module provides: +- CustomPluginSpec: the descriptor returned by tta.custom_plugin(...) +- custom_plugin(): factory that computes a deterministic QDP op_name +- register_custom_plugin(): registers @trtp.register / @trtp.autotune / @trtp.aot_impl +- lower_custom_plugin_descriptor(): lowers a CustomPluginSpec to a + TRT plugin layer via trtp.op + +Backend-specific AOT logic lives in _triton_aot / _cutile_aot / +_cutedsl_aot. +""" +from __future__ import annotations + +import inspect +import logging +import threading +import typing + +import torch +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +logger = logging.getLogger(__name__) + +try: + import tensorrt as trt + import tensorrt.plugin as trtp + + _TRT_AVAILABLE = True +except ImportError: + _TRT_AVAILABLE = False + trt = None # type: ignore[assignment] + trtp = None # type: ignore[assignment] + +from ._qdp_utils import ( + QDPRuntimeError, + TacticEntry, + build_tactic_table, + derive_impl_id, + dtype_token, + format_token, + make_qdp_symbol, +) +from .._specs import CuTeDSLSpec, CuTileSpec, TritonSpec +from ._symbolic import SymbolicTensor, TensorRole + +KernelSpec = Union[TritonSpec, CuTileSpec, CuTeDSLSpec] + +# --------------------------------------------------------------------------- +# Process-level QDP registration registry +# --------------------------------------------------------------------------- +# TRT's internal QDP registry is process-global, so we track registered op_names +# at the process level rather than using thread-local storage. Thread-local storage +# would allow two concurrent threads (e.g. pytest-xdist workers sharing a process) +# to each attempt QDP registration for the same op_name, causing TRT to raise an +# "already registered" error on the second attempt. +# +# Threading contract for _qdp_registered_ops: +# WRITE: Always done under _qdp_registration_lock. The set is only grown, never +# shrunk, so writes are monotonic. The final add() and the return from +# register_custom_plugin() are both inside the lock. +# READ (under lock): Always safe; used inside register_custom_plugin() after +# acquiring _qdp_registration_lock for the definitive TOCTOU-safe check. +# READ (without lock, fast path): Safe because the set only grows. A thread +# that observes op_name IN the set can safely skip registration — the +# worst outcome of a race is a redundant lock acquisition on the slow +# path, which is also safe. A thread that observes op_name NOT IN the +# set proceeds to acquire the lock and re-checks inside (double-checked +# locking pattern). This avoids lock contention on the common post- +# registration path without risking double-registration. +_qdp_registered_ops: set = set() +_qdp_registration_lock = threading.Lock() + +# Thread-local cache for _aot_fn builders. Building the closure is cheap, +# but we avoid redundant work within a single thread. +_tls = threading.local() + + +def _get_aot_fn_cache() -> Dict[Tuple[str, int], Callable[..., Any]]: + """Return the thread-local cache mapping (op_name, num_inputs) to built _aot_fn. + + The cache is never cleared because TRT's QDP registry is also process-persistent, + so the two remain in sync without any explicit eviction. + """ + if not hasattr(_tls, "aot_fn_cache"): + _tls.aot_fn_cache = {} + return _tls.aot_fn_cache + + +# --------------------------------------------------------------------------- +# Public descriptor +# --------------------------------------------------------------------------- + + +@dataclass(frozen=True) +class CustomPluginSpec: + """Descriptor returned by ``tta.custom_plugin(...)``. + + Lifecycle + --------- + 1. **Creation** — ``custom_plugin()`` constructs a ``CustomPluginSpec`` + by splitting kwargs into weights/attrs, hashing the kernel specs + to derive a deterministic ``op_name``, and probing ``meta_impl`` with dummy + meta tensors to count outputs. No TRT objects are touched at this stage. + + 2. **QDP registration** — ``register_custom_plugin()`` (called lazily from + ``lower_custom_plugin_descriptor()``) registers three QDP callbacks with + TRT's process-global plugin registry under ``op_name``: + + * ``@trtp.register`` — shape/dtype descriptor (uses ``meta_impl`` or identity) + * ``@trtp.autotune`` — enumerates (dtype, format, tactic) combinations + * ``@trtp.aot_impl`` — AOT kernel dispatch (Triton / CuTile / CuTeDSL) + + Registration is idempotent: a process-level set (``_qdp_registered_ops``) + guards against double-registration across threads. + + 3. **Use in lowering** — ``lower_custom_plugin_descriptor()`` calls + ``ctx.net.add_plugin(trtp.op..(*trt_inputs), aot=True)`` to add + the plugin layer to the TRT network. Weight tensors declared in kwargs are + injected as ``trt.add_constant`` layers and appended to ``trt_inputs`` before + the plugin call so the launch_fn receives + ``(*activations, *weights_in_order, *outputs, ...)``. + + Attributes: + op_name: Unique QDP symbol, e.g. ``"tta_custom::host_kernel_a1b2c3d4"``. + Deterministically derived from the kernel specs and attrs so that + re-creating the same descriptor in a different process produces + the same name. + specs: Non-empty list of kernel specs (TritonSpec, CuTileSpec, + CuTeDSLSpec). Multiple specs provide alternative tactics; TRT's + autotune selects the fastest one at engine-build time. + meta_impl: PyTorch meta function for shape/dtype inference. Receives meta + tensors (one per plugin input) and must return a single + ``torch.Tensor`` or a tuple/list of ``torch.Tensor`` s. + num_outputs: *Deprecated internal field — do not rely on this value.* + Retained for backward compatibility only. The QDP + registration path re-infers the output count from real + ``trt.ITensor`` input ranks at lowering time. + :meth:`auto_register_torch_op` infers it independently + at schema-registration time. + attrs: Scalar kwargs baked into kernel PTX at AOT time + (e.g. ``addend=1.0``). NOT forwarded as TRT plugin fields. + weights: Tensor kwargs bound at plugin creation time. At lowering, each + weight is added to the TRT network as a constant tensor and + appended to the dynamic activation inputs. The launch_fn must + accept dynamic inputs first, then weights (in declaration order), + then outputs. Named so debug messages can identify each weight. + """ + + op_name: str + specs: List[KernelSpec] + meta_impl: Callable[..., Any] + num_outputs: int = 1 + attrs: Dict[str, Any] = field(default_factory=dict) + # Tensor weights are excluded from __hash__ and __eq__ because torch.Tensor + # is not hashable. Identity / equality of a descriptor is captured by + # op_name (which already encodes num_weights via derive_impl_id). + weights: Dict[str, "torch.Tensor"] = field( + default_factory=dict, hash=False, compare=False + ) + + def lower_to_trt( + self, + ctx: Any, + trt_inputs: List[Any], + name: str, + qdp_name: Optional[str] = None, + ) -> Any: + """Lower this spec to a TRT ``IPluginV3`` layer via :func:`lower_custom_plugin_descriptor`. + + Shared entry-point used by both the native TTA lowering pass and the + Dynamo integration path (``trt_plugins.custom_op(impl=...)``). Using + this method ensures both paths go through the same code: weight + injection, TTA layer metadata, and ``aot=True`` semantics. + + When ``qdp_name`` is supplied the plugin is looked up under that name + (the torch op name, e.g. ``"ns::my_op"``) rather than the auto-derived + TTA fingerprint in ``self.op_name``. This is required when the plugin + was registered under the torch op name rather than the TTA fingerprint + (e.g. via :func:`register_custom_plugin` with an explicit ``qdp_name``). + + Args: + ctx: Torch-TRT ``ConversionContext`` (carries ``ctx.net``). + trt_inputs: Ordered ``trt.ITensor`` activation inputs. + name: Layer name for TRT debugging/profiling. + qdp_name: Optional QDP name override (e.g. torch op name). + + Returns: + A single ``trt.ITensor`` or a tuple of ``trt.ITensor`` s. + """ + import dataclasses + desc = dataclasses.replace(self, op_name=qdp_name) if qdp_name is not None else self + return lower_custom_plugin_descriptor(ctx, desc, trt_inputs, name) + + def auto_register_torch_op(self, op_name: str) -> None: + """Auto-register ``torch.library.custom_op`` and ``register_fake`` for ``op_name``. + + Eliminates the boilerplate of writing ``@torch.library.custom_op`` and + ``@torch.library.register_fake`` by hand when a :class:`CustomPluginSpec` + already carries ``meta_impl`` and kernel specs. + + The eager implementation calls the first spec's ``launch_fn`` with the + first available config. The fake implementation delegates to + ``meta_impl`` for shape/dtype inference. + + If the op is already registered in ``torch.ops``, the call is a no-op. + + Args: + op_name: torch op name in ``"namespace::name"`` form. + + Raises: + ValueError: If ``meta_impl`` is ``None`` (required for shape inference). + """ + if self.meta_impl is None: + raise ValueError( + f"auto_register_torch_op: meta_impl is required to auto-register " + f"'{op_name}'; set meta_impl in tta.custom_plugin(..., meta_impl=...)" + ) + + namespace, name = op_name.split("::", 1) + + # Skip if the op is already registered. + ns_obj = getattr(torch.ops, namespace, None) + if ns_obj is not None and hasattr(ns_obj, name): + return + + meta_sig = inspect.signature(self.meta_impl) + param_names = list(meta_sig.parameters.keys()) + + first_spec = self.specs[0] + first_config = first_spec.configs[0] if getattr(first_spec, "configs", None) else {} + _launch = first_spec.launch_fn + _meta = self.meta_impl + + # Infer num_outputs for the torch.library schema. + from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import ( + _probe_num_outputs_from_callable, + ) + _n_args = len(inspect.signature(_meta).parameters) + _num_outputs = _probe_num_outputs_from_callable(_meta, _n_args) + + # CuTeDSL launch functions are @cute.jit decorated, which means they only + # accept cute.Tensor arguments (not torch.Tensor). The eager path receives + # plain torch.Tensor inputs/outputs from PyTorch dispatch, so we wrap the + # launch function with a DLPack bridge that converts torch.Tensor → cute.Tensor + # before forwarding to the jit-compiled kernel. + # + # Why not put this conversion inside the user's launch_fn? Because + # cute.compile() (called in the AOT path) requires its argument to be + # @cute.jit decorated. A regular Python wrapper around a @cute.jit function + # is NOT itself @cute.jit, so cute.compile() would raise: + # DSLRuntimeError: Function <...> is not decorated with jit decorator. + # Therefore TTA must own the eager-path conversion — the user's launch_fn + # must remain a pure @cute.jit function. + if isinstance(first_spec, CuTeDSLSpec): + _jit_launch = _launch + + def _launch(*args: Any, **kwargs: Any) -> Any: + from cutlass.cute.runtime import from_dlpack as _from_dlpack + + converted = tuple( + _from_dlpack(a.contiguous()) if isinstance(a, torch.Tensor) else a + for a in args + ) + return _jit_launch(*converted, **kwargs) + + def _eager_body(*args: torch.Tensor) -> Any: + meta_outs = _meta(*args) + if not isinstance(meta_outs, (tuple, list)): + meta_outs = (meta_outs,) + outs = [ + torch.empty(o.shape, dtype=o.dtype, device=args[0].device) + for o in meta_outs + ] + _launch(*args, *outs, **first_config) + # Count dynamically from meta_impl result — no _num_outputs needed here. + return outs[0] if len(outs) == 1 else list(outs) + + # Reuse the same __signature__ trick as _build_desc_fn: attach a custom + # inspect.Signature so torch.library's schema inference sees real named + # Tensor parameters without exec()-generated source code. + # Multi-output ops use List[torch.Tensor] (→ schema "Tensor[]") so that + # torch.library registers the correct return type. + sig_params = [ + inspect.Parameter(p, inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=torch.Tensor) + for p in param_names + ] + ret_annotation = torch.Tensor if _num_outputs == 1 else List[torch.Tensor] + sig = inspect.Signature(sig_params, return_annotation=ret_annotation) + _eager_body.__signature__ = sig + torch.library.custom_op(op_name, mutates_args=())(_eager_body) + + # Fake impl: forward to meta_impl for shape/dtype inference during tracing. + # For multi-output ops, return a list to match the "Tensor[]" schema. + # Count dynamically from result — real-ranked tensors are available here. + def _fake_body(*args: torch.Tensor) -> Any: + result = _meta(*args) + if isinstance(result, tuple): + return list(result) + return result + + _fake_body.__signature__ = sig + torch.library.register_fake(op_name)(_fake_body) + + + + +def custom_plugin( + kernel: Union[KernelSpec, List[KernelSpec]], + meta_impl: Callable[..., Any], + **kwargs: Any, +) -> CustomPluginSpec: + """Create a :class:`CustomPluginSpec` for one or more kernel specs. + + This is the primary entry point for the ``tta.custom_plugin`` API. The + returned descriptor is used as the ``impl=`` argument of ``@tta.export_as``. + + Args: + kernel: Single kernel spec or a non-empty list of specs. Multiple specs + provide alternative tactics; TRT's autotune benchmarks all of them + at engine-build time and selects the fastest. + meta_impl: Required. PyTorch meta function used by QDP for shape/dtype + inference. Receives meta tensors (one per plugin input) and + must return a single ``torch.Tensor`` or a tuple/list of + ``torch.Tensor`` s. Each tensor's ``.shape`` and ``.dtype`` + define the output ``TensorDesc`` s. For the descriptor we use + only the first returned tensor; its shape and dtype must match + the plugin's actual output so TRT gets correct types and shapes. + **kwargs: Plugin-level keyword arguments, split by type at creation time: + + * ``torch.Tensor`` values → **weights**: frozen tensors bound + to this plugin. At TRT lowering each weight is added to the + network as a ``trt.add_constant`` layer and appended to the + dynamic activation inputs before calling the plugin. The + launch_fn must accept ``(*activations, *weights, + *outputs, ...)``. The count of weights is included in the + ``op_name`` fingerprint so the same kernel spec can be + registered with different input arities without stale- + registration bugs. + + * All other values → **attrs**: scalar compile-time constants + (e.g. ``addend=1.0``, ``scale=2``). These are baked into + the kernel PTX at AOT time and are NOT forwarded as TRT + plugin fields. + + Returns: + A :class:`CustomPluginSpec` with an auto-computed ``op_name``. + + Raises: + ValueError: If ``kernel`` is an empty list, or if ``meta_impl`` is ``None``. + TypeError: If any element of ``kernel`` is not a valid ``KernelSpec``, + or if ``meta_impl`` is not callable. + """ + specs: List[KernelSpec] = kernel if isinstance(kernel, list) else [kernel] + if not specs: + raise ValueError("custom_plugin: kernel list cannot be empty") + for s in specs: + if not isinstance(s, (TritonSpec, CuTileSpec, CuTeDSLSpec)): + raise TypeError( + f"custom_plugin: kernel must be TritonSpec, CuTileSpec, or " + f"CuTeDSLSpec, got {type(s).__name__!r}" + ) + if meta_impl is None: + raise ValueError( + "custom_plugin: meta_impl is required and cannot be None; " + "provide a PyTorch meta function for shape/dtype inference" + ) + if not callable(meta_impl): + raise TypeError( + f"custom_plugin: meta_impl must be callable, " + f"got {type(meta_impl).__name__!r}" + ) + + # Split kwargs by value type: + # torch.Tensor → weights (TRT constant layers injected at lowering) + # everything else → attrs (scalar compile-time constants) + weights: Dict[str, torch.Tensor] = {} + attrs: Dict[str, Any] = {} + for k, v in kwargs.items(): + if isinstance(v, torch.Tensor): + weights[k] = v + else: + attrs[k] = v + + impl_id = derive_impl_id(specs, attrs=attrs, num_weights=len(weights)) + op_name = make_qdp_symbol(impl_id) + return CustomPluginSpec( + op_name=op_name, specs=specs, meta_impl=meta_impl, + attrs=attrs, weights=weights, + ) + + +# --------------------------------------------------------------------------- +# Shared parameter-list helpers +# --------------------------------------------------------------------------- + + +def _build_input_params(num_inputs: int, annotation: Any) -> List[inspect.Parameter]: + """Build a positional ``inspect.Parameter`` list for TRT descriptor functions. + + Args: + num_inputs: Number of input parameters to generate (named ``inp0`` … ``inpN``). + annotation: Type annotation attached to each parameter (typically + ``trtp.TensorDesc``). + + Returns: + List of ``inspect.Parameter`` objects suitable for use with + ``inspect.Signature``. + """ + return [ + inspect.Parameter(f"inp{i}", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=annotation) + for i in range(num_inputs) + ] + + +# --------------------------------------------------------------------------- +# Descriptor function builder (shape / dtype via @trtp.register) +# --------------------------------------------------------------------------- + + +def _build_desc_fn( + descriptor: CustomPluginSpec, + num_inputs: int, + num_outputs: int = 1, +) -> Callable[..., Any]: + """Build the ``@trtp.register`` meta handler for a :class:`CustomPluginSpec`. + + Uses ``meta_impl`` (if provided) to infer output ``TensorDesc`` s; falls back + to mirroring ``inp0`` (same shape/dtype) when ``meta_impl`` is ``None``. + + We use a ``*args`` closure and attach a custom ``inspect.Signature`` so that + TRT's ``@trtp.register`` validation (``issubclass(param.annotation, TensorDesc)``) + sees real ``TensorDesc`` class objects — not strings that would result from + exec()-ing code in a module with ``from __future__ import annotations``. + + For ``num_outputs > 1``, the function returns a tuple of ``TensorDesc`` s (one + per output), and ``meta_impl`` must return a tuple/list of that many tensors. + + Note: attrs are intentionally excluded from the descriptor signature. + LIMITATION (workaround for NumPy 1.25+ incompatibility): TRT stores plugin + field values as numpy arrays and calls ``attr_type_annot(f.data)`` to + convert them back. In NumPy 1.25+ this raises + ``"only 0-dimensional arrays can be converted to Python scalars"`` for 1-D + arrays. Until TRT's plugin API is updated to handle NumPy 1.25+, we avoid + the problem entirely by not registering any attrs. Plugins that need to + pass scalar hyperparameters (e.g. epsilon, axis) currently have no way to + do so via the attrs mechanism and must bake constants into the kernel or + pass them as additional tensor inputs. + + Args: + descriptor: The :class:`CustomPluginSpec` being registered. + num_inputs: Number of input ``TensorDesc`` positional parameters. + num_outputs: Number of output ``TensorDesc`` s to produce (default 1). + + Returns: + A callable with the correct ``inspect.Signature`` for ``@trtp.register``. + """ + tensor_desc_cls = trtp.TensorDesc + + # meta_impl expects only the dynamic activation inputs, not the weight + # inputs (which are TRT constant layers appended after activations). + num_dynamic = num_inputs - len(descriptor.weights) + from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import ( + _build_symbolic_desc_fn, + ) + _desc = _build_symbolic_desc_fn(descriptor.meta_impl, num_dynamic, num_outputs) + + sig_params = _build_input_params(num_inputs, tensor_desc_cls) + if num_outputs == 1: + return_annotation = tensor_desc_cls + else: + # Build Tuple[TensorDesc, TensorDesc, ...] (n elements). + # Subscripting typing.Tuple with a plain Python tuple of type args + # unpacks them as separate type parameters (identical to writing + # typing.Tuple[X, X] for n=2). Do NOT use a list — that would + # produce Tuple[List[X]] instead of Tuple[X, X]. + return_annotation = typing.Tuple[tuple([tensor_desc_cls] * num_outputs)] + _desc.__signature__ = inspect.Signature( + sig_params, return_annotation=return_annotation + ) + return _desc + + +# --------------------------------------------------------------------------- +# Autotune function builder (format / tactic combinations via @trtp.autotune) +# --------------------------------------------------------------------------- + + +def _format_token_for_tactic( + entry: TacticEntry, + specs: List[KernelSpec], +) -> str: + """Return the ``AutoTuneCombination`` format string for a single tactic entry. + + Looks up the first ``input_formats`` entry on the kernel spec selected by + ``entry.spec_idx``. Falls back to ``"LINEAR"`` if ``input_formats`` is + absent or empty, or if the format is not in the known token map. + + Args: + entry: The tactic table entry identifying the spec and config indices. + specs: The full list of kernel specs from the descriptor. + + Returns: + A format token string such as ``"LINEAR"`` or ``"CHW32"``. + """ + spec = specs[entry.spec_idx] + input_fmts = getattr(spec, "input_formats", None) + if input_fmts: + return format_token(input_fmts[0]) + return "LINEAR" + + +def _build_autotune_fn( + descriptor: CustomPluginSpec, + num_inputs: int, + num_outputs: int, + tactic_table: List[TacticEntry], +) -> Optional[Callable[..., Any]]: + """Build a ``@trtp.autotune`` function that registers dtype/format/tactic combinations. + + Returns ``None`` if the QDP autotune API (``trtp.autotune`` / + ``trtp.AutoTuneCombination``) is not available in this TRT build. + + TRT calls this callback during engine build to enumerate valid + (dtype, format, tactic) combinations. It then benchmarks all valid + combinations and passes the winning 1-based tactic ID to ``aot_impl``. + + ``AutoTuneCombination`` uses the string constructor:: + + AutoTuneCombination(dtype_str, format_str, tactic_ids) + + where: + + * ``dtype_str`` — comma-separated per-I/O dtype options, + e.g. ``"FP32|FP16, FP32|FP16"`` + * ``format_str`` — memory format, e.g. ``"LINEAR"`` + * ``tactic_ids`` — list of 1-based integer tactic IDs (0 is reserved by TRT) + + The autotune function signature must match:: + + (inp0: TensorDesc, ..., inpN: TensorDesc, + outputs: Tuple[TensorDesc]) -> List[AutoTuneCombination] + + Args: + descriptor: The :class:`CustomPluginSpec` being registered. + num_inputs: Number of input ``TensorDesc`` positional parameters. + num_outputs: Number of output tensors (used to iterate ``outputs``). + tactic_table: Pre-built list of :class:`TacticEntry` objects. + + Returns: + A callable with the correct ``inspect.Signature`` for ``@trtp.autotune``, + or ``None`` if the autotune API is unavailable. + """ + if not ( + _TRT_AVAILABLE + and hasattr(trtp, "autotune") + and hasattr(trtp, "AutoTuneCombination") + ): + return None + + n_tactics = len(tactic_table) + # Tactic IDs are 1-based: TRT reserves 0 as the "no-autotune" default. + tactic_ids = list(range(1, n_tactics + 1)) + + # Pre-compute format string per tactic so the closure doesn't need to + # re-evaluate on every call to _autotune_fn. + _tactic_formats = [ + _format_token_for_tactic(entry, descriptor.specs) + for entry in tactic_table + ] + + tensor_desc_cls = trtp.TensorDesc + + def _autotune_fn(*args: Any) -> List[Any]: + inp_descs = list(args[:num_inputs]) + out_descs_raw = args[num_inputs] + out_descs = list(out_descs_raw) if hasattr(out_descs_raw, "__iter__") else [out_descs_raw] + parts = [ + dtype_token(td) + for td in inp_descs + out_descs + ] + dtype_str = ", ".join(parts) + return [ + trtp.AutoTuneCombination(dtype_str, fmt, [tid]) + for tid, fmt in zip(tactic_ids, _tactic_formats) + ] + + sig_params = (_build_input_params(num_inputs, tensor_desc_cls) + + [inspect.Parameter("outputs", inspect.Parameter.POSITIONAL_OR_KEYWORD)]) + _autotune_fn.__signature__ = inspect.Signature(sig_params) + return _autotune_fn + + +# --------------------------------------------------------------------------- +# AOT impl builder (dispatches to backend) +# --------------------------------------------------------------------------- + +# LIMITATION (TensorRT 10.14 bug — Blackwell only): Mixed Triton+CuTile or +# Triton+CuTeDSL plugins that produce SymIntExprs of different lengths crash +# with CUDA error 700 (illegal memory access) at execute_async_v3 on Blackwell +# (Myelin QUICKAOT path). TRT picks a single SymIntExprs length for the whole +# plugin and passes that many extra int32s to every tactic's kernel; if the +# selected kernel has fewer .param declarations, the extras corrupt unrelated +# memory. Workaround: do not mix Triton and CuTeDSL tactics in the same +# CustomPluginSpec on Blackwell. Fix expected in a future TRT release. +def _build_aot_fn( + descriptor: CustomPluginSpec, + num_inputs: int, + tactic_table: List[TacticEntry], +) -> Callable[..., Any]: + """Build a ``@trtp.aot_impl`` function that dispatches to the right backend. + + TRT calls this once with the winning tactic index (chosen by autotune). + The first ``num_inputs`` positional args are ``TensorDesc`` inputs; then + ``outputs`` (sequence of ``TensorDesc``); then ``tactic`` (int-like). + + The built function is cached in the thread-local ``_aot_fn_cache`` keyed by + ``(op_name, num_inputs)`` to avoid redundant closure construction within a + single thread. + + .. note:: **TensorRT 10.14 bug on Blackwell (Myelin QUICKAOT)** + + When tactics return ``SymIntExprs`` of different lengths, TRT/Myelin + crashes with CUDA error 700 (illegal memory access) at + ``execute_async_v3``. Mixed Triton+CuTeDSL plugins trigger this because: + + - Triton: returns ``SymIntExprs(N)`` with N scalar kernel args (e.g. + ``n_elements``) + - CuTeDSL: returns ``SymIntExprs(0)`` — no scalar args, shape is in CuTe + descriptors + + On Blackwell, TRT picks one ``SymIntExprs`` length for the entire plugin + and passes that many extra ``int32`` s to every tactic's kernel at launch. + If the selected kernel's PTX has fewer ``.param`` declarations than + expected, the extras land in unrelated memory → crash. On pre-Blackwell + (non-Myelin), the mismatch is silently tolerated. + + No workaround is applied here; the annotation layer routes mixed-backend + tests to pre-Blackwell GPUs (see ``tests/py/annotation/conftest.py``). + Filed as a TensorRT bug; repro: + ``tests/py/annotation/repro_blackwell_extra_len_mismatch.py`` + + Args: + descriptor: The :class:`CustomPluginSpec` being registered. + num_inputs: Number of input ``TensorDesc`` positional arguments. + tactic_table: Pre-built list of :class:`TacticEntry` objects. + + Returns: + A callable with the correct ``inspect.Signature`` for ``@trtp.aot_impl``. + + Raises: + :class:`QDPRuntimeError`: If the tactic index is out of range or the + kernel spec type is not supported. + """ + cache = _get_aot_fn_cache() + cache_key = (descriptor.op_name, num_inputs) + if cache_key in cache: + return cache[cache_key] + + op_name = descriptor.op_name + tensor_desc_cls = trtp.TensorDesc + + def _aot_fn(*args: Any) -> Any: + inp_descs = list(args[:num_inputs]) + if len(args) <= num_inputs: + raise IndexError( + f"AOT impl for {op_name!r} called with {len(args)} args but " + f"expected at least {num_inputs + 1} " + f"(inputs + outputs). TRT may not have passed the outputs " + f"argument." + ) + out_descs = list(args[num_inputs]) # outputs is a sequence + tactic_id = int(args[num_inputs + 1]) if len(args) > num_inputs + 1 else 1 + tactic_idx = tactic_id - 1 + + if tactic_idx < 0 or tactic_idx >= len(tactic_table): + raise QDPRuntimeError( + op=op_name, + stage="aot_impl", + backend="custom_plugin", + msg=( + f"tactic ID {tactic_id} (index {tactic_idx}) out of range " + f"[0, {len(tactic_table)}) for op {op_name!r}" + ), + ) + entry = tactic_table[tactic_idx] + spec = descriptor.specs[entry.spec_idx] + configs = spec.configs if spec.configs else [{}] + cfg = configs[entry.config_idx] + launch_fn = spec.launch_fn + + sym_inputs = [ + SymbolicTensor(td=td, role=TensorRole.INPUT, index=i) + for i, td in enumerate(inp_descs) + ] + sym_outputs = [ + SymbolicTensor(td=td, role=TensorRole.OUTPUT, index=j) + for j, td in enumerate(out_descs) + ] + host_args = sym_inputs + sym_outputs + + plugin_attrs = descriptor.attrs + + if isinstance(spec, TritonSpec): + from ._aot._triton import aot_impl_triton + return aot_impl_triton( + qdp_symbol=op_name, spec=spec, cfg=cfg, launch_fn=launch_fn, + host_args=host_args, inp_descs=inp_descs, out_descs=out_descs, + attrs=plugin_attrs, + ) + elif isinstance(spec, CuTileSpec): + from ._aot._cutile import aot_impl_cutile + return aot_impl_cutile( + qdp_symbol=op_name, spec=spec, cfg=cfg, launch_fn=launch_fn, + host_args=host_args, inp_descs=inp_descs, out_descs=out_descs, + attrs=plugin_attrs, + ) + elif isinstance(spec, CuTeDSLSpec): + from ._aot._cutedsl import aot_impl_cutedsl + return aot_impl_cutedsl( + qdp_symbol=op_name, spec=spec, cfg=cfg, launch_fn=launch_fn, + host_args=host_args, inp_descs=inp_descs, out_descs=out_descs, + attrs=plugin_attrs, + ) + else: + raise QDPRuntimeError( + op=op_name, stage="aot_impl", backend="custom_plugin", + msg=f"unsupported kernel spec type {type(spec).__name__!r} for op {op_name!r}", + ) + + sig_params = (_build_input_params(num_inputs, tensor_desc_cls) + + [inspect.Parameter("outputs", inspect.Parameter.POSITIONAL_OR_KEYWORD), + inspect.Parameter("tactic", inspect.Parameter.POSITIONAL_OR_KEYWORD)]) + _aot_fn.__signature__ = inspect.Signature( + sig_params, + return_annotation=typing.Tuple[ + typing.Union[str, bytes], + typing.Union[str, bytes], + trtp.KernelLaunchParams, + trtp.SymIntExprs, + ], + ) + + cache[cache_key] = _aot_fn + return _aot_fn + + +# --------------------------------------------------------------------------- +# QDP registration +# --------------------------------------------------------------------------- + + +def register_custom_plugin( + descriptor: CustomPluginSpec, + num_inputs: int, + num_outputs: int = 1, + qdp_name: Optional[str] = None, +) -> None: + """Register a :class:`CustomPluginSpec` with TRT's QDP plugin registry. + + Registers three QDP callbacks under ``qdp_name`` (if provided) or + ``descriptor.op_name``: + + * ``@trtp.register`` — shape/dtype descriptor (uses ``meta_impl`` or identity) + * ``@trtp.autotune`` — format/tactic combinations per I/O position + * ``@trtp.aot_impl`` — AOT kernel dispatch (Triton/CuTile/CuTeDSL backend) + + **Idempotency**: repeated calls for the same ``op_name`` are no-ops. The + function uses a double-checked locking pattern against + ``_qdp_registered_ops`` (a process-global set) to guard against concurrent + registration from multiple threads. See the module-level comment on + ``_qdp_registered_ops`` for the full threading contract. + + If TRT's own registry reports that an op is "already registered" (e.g. + because a prior in-process call registered it before ``_qdp_registered_ops`` + was updated), the function catches the error, logs a debug message, and + marks the op as registered so future calls are no-ops. + + Args: + descriptor: The :class:`CustomPluginSpec` to register. + num_inputs: Number of TRT input tensors for this op. + num_outputs: Number of TRT output tensors (default 1). + qdp_name: Optional override for the QDP registration name. When + provided the plugin is registered under this name instead + of ``descriptor.op_name``. Use this when wiring a + :class:`CustomPluginSpec` to a ``torch.library`` op whose + name differs from the TTA fingerprint name (e.g. when + called from ``trt_plugins.custom_op``). + + Raises: + Exception: Any exception raised by TRT's ``trtp.register`` that is + **not** an "already registered" error is re-raised verbatim. + Exception: Any exception raised by TRT's ``trtp.aot_impl`` that is + **not** an "already registered" error is re-raised verbatim. + """ + op_name = qdp_name if qdp_name is not None else descriptor.op_name + + # Fast path: check without the lock first (common case after first registration). + # Safe because _qdp_registered_ops only grows — see module-level threading contract. + if op_name in _qdp_registered_ops: + return + + with _qdp_registration_lock: + # Re-check inside the lock to handle the race between the fast-path + # check above and acquiring the lock (TOCTOU). + if op_name in _qdp_registered_ops: + return + + from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import ( + register_plugin_with_aot, + ) + + tactic_table = build_tactic_table(descriptor.specs) + + # Build the three QDP callbacks, then delegate the raw trtp calls to + # register_plugin_with_aot so that all plugin registration (JIT and AOT) + # converges on the same path in the existing plugin system. + desc_fn = _build_desc_fn(descriptor, num_inputs, num_outputs) + autotune_fn = _build_autotune_fn(descriptor, num_inputs, num_outputs, tactic_table) + aot_fn = _build_aot_fn(descriptor, num_inputs, tactic_table) + + # TRT's internal QDP registry is process-global; if TRT already knows the + # op (from a prior call in this process), catch "already has a definition" + # and mark it registered without re-registering. + try: + register_plugin_with_aot(op_name, desc_fn, autotune_fn, aot_fn) + except Exception as exc: # noqa: BLE001 + if "already" in str(exc).lower(): + logger.debug( + "QDP op %r already registered in TRT registry, skipping: %s", + op_name, exc, + ) + _qdp_registered_ops.add(op_name) + return + raise + + _qdp_registered_ops.add(op_name) + logger.debug( + "Registered QDP plugin %r (num_inputs=%d, num_outputs=%d, tactics=%d)", + op_name, + num_inputs, + num_outputs, + len(tactic_table), + ) + + +# --------------------------------------------------------------------------- +# TRT network lowering +# --------------------------------------------------------------------------- + + +def lower_custom_plugin_descriptor( + ctx: Any, + descriptor: CustomPluginSpec, + trt_inputs: List[Any], + name: str, +) -> Union[Any, Sequence[Any]]: + """Lower a :class:`CustomPluginSpec` to a TRT ``IPluginV3`` layer. + + This is the final step in the :class:`CustomPluginSpec` lifecycle (see + class docstring). It: + + 1. Injects weight tensors as ``trt.add_constant`` layers, appending them to + ``trt_inputs`` so the plugin sees ``(*activations, *weights)`` as inputs. + 2. Calls :func:`register_custom_plugin` (idempotent) to ensure the QDP + callbacks are registered before the network refers to the op. + 3. Resolves ``trtp.op..`` and calls it with + ``trt_inputs`` to obtain the plugin handle. + 4. Adds the plugin layer via ``ctx.net.add_plugin(..., aot=True)`` and + attaches TTA layer metadata for debugging. + 5. Returns a single ``trt.ITensor`` (for single-output plugins) or a tuple + of ``trt.ITensor`` s. + + Args: + ctx: Torch-TRT ``ConversionContext`` providing ``ctx.net``. + descriptor: :class:`CustomPluginSpec` for this op. + trt_inputs: List of ``trt.ITensor`` dynamic activation inputs. + name: Layer name used for TRT network debugging. + + Returns: + A single ``trt.ITensor`` if the plugin has one output, or a ``tuple`` + of ``trt.ITensor`` s for multi-output plugins. + + Raises: + :class:`QDPRuntimeError`: Wraps any non-QDP exception raised during + lowering (registration, plugin construction, or layer addition). + """ + op_name = descriptor.op_name + + try: + # Re-infer num_outputs from real trt_inputs ranks (most accurate path). + # trt_inputs at this point are the activation inputs only (weights are + # appended below), so their ranks exactly match what meta_impl expects. + from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import ( + _probe_num_outputs_from_callable, + ) + _hint = len(trt_inputs[0].shape) if trt_inputs else None + num_outputs = _probe_num_outputs_from_callable( + descriptor.meta_impl, len(trt_inputs), preferred_rank=_hint + ) + + # Weight binding: tensor-valued kwargs declared in custom_plugin() are injected + # here as TRT constant layers, appended after the dynamic activation inputs. + # The annotated function only receives the activations (no weight args), so the + # eager body is unchanged. The launch_fn contract is: + # (*activations, *weights_in_declaration_order, *outputs, ...) + if descriptor.weights: + import numpy as np + weight_trt_tensors = [] + for wname, wtensor in descriptor.weights.items(): + np_arr = wtensor.detach().cpu().contiguous().numpy() + trt_weights = trt.Weights(np_arr) + const_layer = ctx.net.add_constant(tuple(np_arr.shape), trt_weights) + const_layer.name = f"{name}_weight_{wname}" + weight_trt_tensors.append(const_layer.get_output(0)) + trt_inputs = list(trt_inputs) + weight_trt_tensors + + num_inputs = len(trt_inputs) + register_custom_plugin(descriptor, num_inputs, num_outputs) + + # Parse namespace and plugin name from op_name (format: "ns::op"). + if "::" in op_name: + namespace, plugin_name = op_name.split("::", 1) + else: + namespace = "tta_custom" + plugin_name = op_name + + ns_module = getattr(trtp.op, namespace) + plugin_fn = getattr(ns_module, plugin_name) + + # Attrs are baked into the kernel PTX at AOT time; do not pass as TRT + # plugin fields (TRT stores them as numpy arrays and the round-trip + # float(np.array([v])) fails in NumPy 1.25+). + plugin_layer = ctx.net.add_plugin(plugin_fn(*trt_inputs), aot=True) + plugin_layer.name = name + if plugin_layer.num_outputs == 1: + return plugin_layer.get_output(0) + return tuple( + plugin_layer.get_output(i) for i in range(plugin_layer.num_outputs) + ) + + except Exception as exc: + if isinstance(exc, QDPRuntimeError): + raise + raise QDPRuntimeError( + op=op_name, + stage="compile", + backend="custom_plugin", + msg=f"lowering failed for op {op_name!r} (layer {name!r}): {exc}", + ) from exc diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_qdp_utils.py b/py/torch_tensorrt/annotation/_custom_plugin/_qdp_utils.py new file mode 100644 index 0000000000..057ba18483 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_qdp_utils.py @@ -0,0 +1,999 @@ +"""Shared utilities for QDP integration. + +Provides sandboxing, fingerprinting, dtype/format conversion, meta tensor +helpers, and launch-arg analysis used by all backend AOT implementations. +""" +from __future__ import annotations + +import hashlib +import importlib +import inspect +import types +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +import torch + +try: + import tensorrt as trt + import tensorrt.plugin as trtp + + _TRT_AVAILABLE = True +except ImportError: + _TRT_AVAILABLE = False + trt = None # type: ignore[assignment] + trtp = None # type: ignore[assignment] + + +# ---- SymInt patching ---- + +# LIMITATION (experimental shim): Monkey-patches a private TRT binding class +# (tensorrt_bindings.plugin._tensor.SymInt32) to add a missing _is_dummy +# property required by TRT's get_launch_params() on Blackwell sm_120 builds. +# This patches an internal, undocumented attribute of TRT's Python bindings. +# It will break silently if TRT renames or removes _tensor.SymInt32, or if a +# future TRT release adds _is_dummy natively (in which case the guard +# ``isinstance(..., property)`` will skip the patch automatically). +# Remove once TRT exposes _is_dummy on SymInt32 in its public plugin API. +# +# TRT's get_launch_params() reads _is_dummy on SymInt32 to determine whether +# a grid dimension is symbolic. SymInt32 doesn't expose this attribute by +# default, causing AttributeError on some TRT builds (e.g. Blackwell sm_120). +# We add it as a property with a no-op setter so that: +# (a) TRT can read _is_dummy without AttributeError, +# (b) ShapeExpr.__init__ (which sets self._is_dummy = False) doesn't crash. +try: + import tensorrt_bindings.plugin._tensor as _trtp_tensor # type: ignore[import] + _SymInt32 = _trtp_tensor.SymInt32 + if not isinstance(getattr(_SymInt32, "_is_dummy", None), property): + _SymInt32._is_dummy = property( + lambda self: getattr(self, "_int_expr", None) is None, + lambda self, v: None, # no-op setter: lets ShapeExpr.__init__ assign _is_dummy = False + ) + del _SymInt32, _trtp_tensor +except (ImportError, AttributeError): + # Broad catch is intentional: this is an optional compatibility shim. + # If tensorrt_bindings is unavailable, or the internal _tensor module has + # changed its layout, we skip the patch silently. The missing _is_dummy + # attribute will only manifest at runtime on affected TRT builds, at which + # point the QDP call will raise a clear AttributeError from TRT itself. + pass + + +from ._symbolic import SymbolicTensor, TensorRole + + +# ---- Error type ---- + + +class QDPRuntimeError(RuntimeError): + """Structured error for QDP plugin runtime failures. + + Distinct from TTAPluginError (in _plugin_lowering.py), which is used for + FX-graph-level lowering failures. QDPRuntimeError covers TRT-side errors + in AOT impl, kernel dispatch, and plugin compilation. + + Args: + op: QDP op name or annotated function name. + stage: Processing stage (e.g. "aot_impl", "compile", "sandbox"). + backend: Backend name ("triton", "cutile", "cutedsl", "custom_plugin"). + msg: Human-readable error message. + """ + + def __init__(self, op: str, stage: str, backend: str, msg: str) -> None: + super().__init__(f"{op}: [{stage}] [{backend}] {msg}") + self.op = op + self.stage = stage + self.backend = backend + self.msg = msg + + +# Backward-compatible alias: other modules in this package import TTAPluginError +# from _qdp_utils; keep the name available until those files are updated. +TTAPluginError = QDPRuntimeError + + +# ---- AOT metadata ---- + + +@dataclass +class AOTMetadata: + """Unified AOT compilation result, shared by all backends. + + Attributes: + binary: PTX or CUBIN bytes ready for TRT to JIT-compile. + kernel_name: Entry point name matching the .entry directive in PTX. + launch_params: SimpleNamespace from _launch_params_from_trt (grid, block, + shared_mem, param_binding_indices, sym_int_exprs). + backend: One of "triton", "cutile", or "cutedsl". + """ + + binary: bytes + kernel_name: str + launch_params: Any # types.SimpleNamespace from _launch_params_from_trt + backend: str # "triton", "cutile", or "cutedsl" + + +# ---- Tactic table ---- + + +@dataclass +class TacticEntry: + """A single tactic entry referencing a kernel spec and config by index. + + Attributes: + spec_idx: Index into the kernel-spec list passed to build_tactic_table. + config_idx: Index into that spec's configs list (0 when configs is empty). + """ + + spec_idx: int + config_idx: int + + +def build_tactic_table(specs: Sequence[Any]) -> List[TacticEntry]: + """Build a flat tactic table from a list of kernel specs. + + Each (spec_idx, config_idx) pair becomes one tactic entry. + """ + table: List[TacticEntry] = [] + for s_idx, spec in enumerate(specs): + configs = spec.configs if spec.configs else [{}] + for c_idx, _ in enumerate(configs): + table.append(TacticEntry(spec_idx=s_idx, config_idx=c_idx)) + return table + + +# ---- Sandboxing ---- + + +def make_sandboxed_host( + launch_fn: Callable[..., Any], + overrides: Dict[str, Any], +) -> Callable[..., Any]: + """Create a copy of launch_fn with cloned globals plus symbol overrides. + + This lets us swap real @triton.jit kernels (or cuTILE programs) for + recorder proxies without mutating the original module. + """ + mod = inspect.getmodule(launch_fn) + if mod is None: + raise RuntimeError(f"cannot find module for launch_fn {launch_fn}") + + new_globals = dict(mod.__dict__) + new_globals.update(overrides) + + return types.FunctionType( + launch_fn.__code__, + new_globals, + name=launch_fn.__name__, + argdefs=launch_fn.__defaults__, + closure=launch_fn.__closure__, + ) + + +# ---- Fingerprinting ---- + + +def fingerprint_fn(fn: Callable[..., Any]) -> bytes: + """Return a stable byte fingerprint for a Python function. + + The fingerprint combines the function's module path, qualified name, and + a SHA-1 hash of its raw bytecode. This makes the fingerprint stable + across interpreter restarts (assuming the source has not changed) while + still changing whenever the function body is edited. + """ + mod = getattr(fn, "__module__", "") + qn = getattr(fn, "__qualname__", getattr(fn, "__name__", "")) + code = getattr(fn, "__code__", None) + code_bytes = getattr(code, "co_code", b"") if code is not None else b"" + blob = { + "module": mod, + "qualname": qn, + "code_sha1": hashlib.sha1(code_bytes).hexdigest(), + } + return str(blob).encode("utf-8") + + +def derive_impl_id( + specs: Sequence[Any], + attrs: Optional[Dict[str, Any]] = None, + num_weights: int = 0, +) -> str: + """Derive a deterministic ID string from a list of kernel specs, plugin attrs, and weight count. + + ``num_weights`` is included in the hash so that the same kernel spec used + with different numbers of declared weight inputs produces a distinct op_name + and QDP registration. Without this, a plugin registered with num_inputs=1 + (activation only) would be incorrectly reused when called with num_inputs=2 + (activation + weight), causing silent stale-registration bugs. + """ + h = hashlib.sha1() + for spec in specs: + h.update(fingerprint_fn(spec.launch_fn)) + configs = spec.configs if spec.configs else [{}] + h.update(str(configs).encode("utf-8")) + in_fmts = getattr(spec, "input_formats", None) + out_fmts = getattr(spec, "output_formats", None) + h.update(str([int(f) for f in in_fmts] if in_fmts else []).encode("utf-8")) + h.update(str([int(f) for f in out_fmts] if out_fmts else []).encode("utf-8")) + if attrs: + h.update(str(sorted(attrs.items())).encode("utf-8")) + if num_weights: + h.update(f"num_weights:{num_weights}".encode("utf-8")) + return h.hexdigest() + + +def make_qdp_symbol(impl_id: str) -> str: + """Derive a unique QDP namespace::name from an impl_id.""" + return f"tta_custom::host_kernel_{impl_id[:8]}" + + +# ---- Backend-object detection heuristics ---- + + +def is_triton_kernel(obj: Any) -> bool: + """Heuristic: Triton kernels implement __getitem__ for [grid] indexing. + + Check both obj.__module__ (may be the user module) and + type(obj).__module__ (always "triton.runtime.jit" for JITFunction) + so that kernels defined in user modules are detected correctly. + """ + if not hasattr(obj, "__getitem__"): + return False + return ( + "triton" in getattr(obj, "__module__", "") + or "triton" in getattr(type(obj), "__module__", "") + ) + + +def is_cutile_program(obj: Any) -> bool: + """Heuristic: cuTILE programs are callables from cuda.tile.* modules.""" + modname = getattr(obj, "__module__", "") + return callable(obj) and ("cuda.tile" in modname or "cutile" in modname) + + +def is_cute_kernel(obj: Any) -> bool: + """Heuristic: @cute.kernel objects from CUTLASS CuTe DSL. + + @cute.kernel objects have type(obj).__module__ == 'builtins', so we check + the _dsl_cls attribute which points to the CuTe DSL kernel class. + """ + if not callable(obj) or isinstance(obj, type): + return False + dsl_cls = getattr(obj, "_dsl_cls", None) + if dsl_cls is None: + return False + cls_mod = getattr(dsl_cls, "__module__", "") or "" + return "cutlass" in cls_mod or "cute" in cls_mod + + +# ---- SymInt helpers ---- + + +def _sym_dim(d: Any) -> Any: + """Convert a shape element (int or SymInt32) to SymInt32 for symbolic arithmetic.""" + if _TRT_AVAILABLE and isinstance(d, trtp.SymInt32): + return d + return trtp.SymInt32(int(d)) + + +def _as_symint32(v: Any) -> Any: + """Ensure v is a trtp.SymInt32 (not SymIntExpr). + + SymInt32 arithmetic returns SymIntExpr objects, but TRT's KernelLaunchParams.grid_* + and extra_args slots require SymInt32. SymIntExpr wraps IDimensionExpr in ._expr; + SymInt32 accepts IDimensionExpr directly. + """ + if _TRT_AVAILABLE and isinstance(v, trtp.SymInt32): + return v + expr = getattr(v, "_expr", None) + if expr is not None: + return trtp.SymInt32(expr) + try: + return trtp.SymInt32(int(v)) + except (TypeError, ValueError): + return trtp.SymInt32(1) + + +def _assign_recorded_grid(launch: Any, recorded_grid: Any) -> None: + """Assign ``grid_x``/``grid_y``/``grid_z`` on *launch* from a recorded grid value. + + Centralises the repeated pattern used by all three AOT backends (Triton, + cuTILE, CuTe DSL) that capture a grid tuple during sandbox execution and + must populate a ``trtp.KernelLaunchParams`` object. + + Args: + launch: A ``trtp.KernelLaunchParams`` instance to modify in-place. + recorded_grid: The grid value captured during sandbox execution. May be: + + * ``None`` — all three dimensions are set to ``SymInt32(1)``. + * A non-tuple scalar (int / SymInt32) — ``grid_x`` is set to that + value; ``grid_y`` and ``grid_z`` default to 1. This covers the + Triton path where the launch_fn passes a plain integer grid. + * A tuple of 1–3 elements (int or SymInt32/SymIntExpr) — each + element is converted with :func:`_as_symint32`; missing dimensions + default to 1. This covers the cuTILE and CuTe DSL paths. + """ + if recorded_grid is None: + launch.grid_x = trtp.SymInt32(1) + launch.grid_y = trtp.SymInt32(1) + launch.grid_z = trtp.SymInt32(1) + return + if not isinstance(recorded_grid, tuple): + launch.grid_x = _as_symint32(recorded_grid) + launch.grid_y = trtp.SymInt32(1) + launch.grid_z = trtp.SymInt32(1) + return + launch.grid_x = _as_symint32(recorded_grid[0]) if len(recorded_grid) >= 1 else trtp.SymInt32(1) + launch.grid_y = _as_symint32(recorded_grid[1]) if len(recorded_grid) >= 2 else trtp.SymInt32(1) + launch.grid_z = _as_symint32(recorded_grid[2]) if len(recorded_grid) >= 3 else trtp.SymInt32(1) + + +def _safe_dim(d: Any, default: int = 1) -> int: + """Extract a concrete int from a shape element safely. + + For TRT SymInt32 elements (dynamic shapes), calling int() directly does NOT raise + but returns a garbage pointer-like value (~470 TB), causing OOM. Check + _int_expr.is_constant() first; return *default* for dynamic dims. + + The default is 1 (minimum valid tensor dimension) rather than a larger value + so that dummy tensors constructed for kernel compilation use the most compact + shape. CuTeDSL kernels bake the tensor layout into their type; using 1 for + dynamic dims gives a [1, static_dim, ...] dummy whose row-major offset formula + (offset = idx) is valid for any larger batch size at runtime. + """ + if isinstance(d, int): + return d + ie = getattr(d, "_int_expr", None) + if ie is not None: + try: + if ie.is_constant(): + return int(ie.get_constant_value()) + except (AttributeError, RuntimeError): + # is_constant() or get_constant_value() may not be available on all + # TRT builds; fall through to return the safe default. + pass + return default + try: + return int(d) + except (TypeError, ValueError): + return default + + +# ---- Dtype / format conversion ---- + + +def td_dtype_to_torch(td_dtype: Any) -> torch.dtype: + """Convert a TensorRT DataType to a torch.dtype. + + Args: + td_dtype: A ``trt.DataType`` enum value. + + Raises: + RuntimeError: If *td_dtype* has no torch equivalent known to this function. + """ + if td_dtype == trt.float32: + return torch.float32 + if td_dtype == trt.float16: + return torch.float16 + if td_dtype == trt.bfloat16: + return torch.bfloat16 + if td_dtype == trt.int32: + return torch.int32 + raise RuntimeError( + f"unsupported TRT dtype {td_dtype!r}; supported: float32, float16, bfloat16, int32" + ) + + +def torch_dtype_to_trt(dtype: torch.dtype) -> Any: + """Convert a torch.dtype to a TensorRT DataType. + + Args: + dtype: A ``torch.dtype`` value. + + Raises: + RuntimeError: If *dtype* has no TRT equivalent known to this function. + """ + if dtype == torch.float32: + return trt.float32 + if dtype == torch.float16: + return trt.float16 + if dtype == torch.bfloat16: + return trt.bfloat16 + if dtype == torch.int32: + return trt.int32 + raise RuntimeError( + f"unsupported torch dtype {dtype!r}; supported: float32, float16, bfloat16, int32" + ) + + +_MAX_DIM = 2**31 - 1 +_MIN_VALID_DIM = 1 + + +def dtype_token(td: Any) -> str: + """Return the AutoTuneCombination string token for a TensorDesc dtype. + + Used by ``_build_autotune_fn`` to construct the dtype string passed to + ``trtp.AutoTuneCombination``. Falls back to ``"FP32"`` for unrecognised + dtypes so that autotune registration does not fail silently. + """ + if td.dtype == trt.float16: + return "FP16" + if td.dtype == trt.bfloat16: + return "BF16" + if td.dtype == trt.float32: + return "FP32" + if td.dtype == trt.int32: + return "INT32" + if td.dtype == trt.int8: + return "INT8" + return "FP32" + + +def format_token(tf: Any) -> str: + """Return a short string token for a trt.TensorFormat.""" + if tf == trt.TensorFormat.LINEAR: + return "LINEAR" + if tf == trt.TensorFormat.CHW2: + return "CHW2" + if tf == trt.TensorFormat.HWC8: + return "HWC8" + if tf == trt.TensorFormat.CHW4: + return "CHW4" + if tf == trt.TensorFormat.CHW16: + return "CHW16" + if tf == trt.TensorFormat.CHW32: + return "CHW32" + if tf == trt.TensorFormat.DHWC8: + return "DHWC8" + if tf == trt.TensorFormat.CDHW32: + return "CDHW32" + if tf == trt.TensorFormat.HWC: + return "HWC" + if tf == trt.TensorFormat.DLA_LINEAR: + return "DLA_LINEAR" + if tf == trt.TensorFormat.DLA_HWC4: + return "DLA_HWC4" + if tf == trt.TensorFormat.HWC16: + return "HWC16" + if tf == trt.TensorFormat.DHWC: + return "DHWC" + return str(tf) + + +# ---- Shape expression utilities ---- + + +def _collect_shape_var_bindings(shape_expr: Any, bindings: Dict[int, int]) -> None: + """Recursively find free/fake/dynamic shape vars and assign them the minimum valid value. + + Walks *shape_expr* recursively (handling nested tensors with a `.shape_expr` + attribute and plain list/tuple containers) and populates *bindings* with a + mapping from ``id(var)`` → 1 for every element that is not a plain ``int`` + and is either: + - marked as fake (``is_fake == True``), or + - a non-constant symbolic expression (``is_constant == False``), or + - not directly convertible to int via ``int()``. + + The value 1 is the minimum positive integer accepted as a tensor dimension. + Both ``is_fake`` fakes (from TRT's shape-inference placeholder pass) and + true dynamic ``ShapeExpr`` dims (from dynamic-shape engines) are bound so + that ``_shape_expr_to_ints`` can produce a concrete fallback shape for + ``meta_impl`` even in dynamic-shape contexts. + + Mutates *bindings* in place. + """ + for d in shape_expr: + if hasattr(d, "shape_expr"): + _collect_shape_var_bindings(d.shape_expr, bindings) + elif isinstance(d, (list, tuple)): + _collect_shape_var_bindings(d, bindings) + elif not isinstance(d, int): + is_symbolic = ( + getattr(d, "is_fake", False) + or (hasattr(d, "is_constant") and not d.is_constant) + ) + if not is_symbolic: + # Last resort: try converting to int; if it fails, treat as symbolic. + try: + int(d) + except (TypeError, ValueError): + is_symbolic = True + if is_symbolic: + vid = id(d) + if vid not in bindings: + bindings[vid] = _MIN_VALID_DIM + + +def _shape_elem_to_int(d: Any, bindings: Dict[int, int]) -> int: + """Convert a single shape dimension to a concrete int. + + Resolution order: + 1. Check *bindings* (fake/symbolic vars already assigned a substitute value). + 2. Plain ``int`` — returned directly. + 3. ``int(d)`` — works for SymInt-like objects that support __int__. + 4. ``d.max()`` — upper bound from a symbolic range. + 5. ``d.value`` — direct value attribute (callable or plain). + 6. ``d.__index__()`` — integer protocol. + 7. ``d.constant_value`` — static constant (skipped if ``is_fake`` is True). + 8. ``d.eval(bindings)`` — explicit evaluation with the binding map. + + Raises: + RuntimeError: If no resolution path succeeds, including the dimension + type, so callers can identify which shape caused the failure. + """ + vid = id(d) + if vid in bindings: + v = bindings[vid] + elif isinstance(d, int): + v = d + else: + v = None + try: + v = int(d) + except (TypeError, ValueError): + pass + if v is None and hasattr(d, "max"): + try: + v = int(d.max()) + except (TypeError, ValueError): + pass + if v is None and hasattr(d, "value"): + val = getattr(d, "value") + try: + v = int(val()) if callable(val) else int(val) + except (TypeError, ValueError): + pass + if v is None and hasattr(d, "__index__"): + try: + v = d.__index__() + except (TypeError, ValueError): + pass + if v is None and hasattr(d, "constant_value") and not getattr(d, "is_fake", True): + try: + cv = d.constant_value + raw = cv() if callable(cv) else cv + v = int(raw) + except (TypeError, ValueError, AttributeError, RuntimeError): + pass + if v is None and hasattr(d, "eval"): + try: + ev = d.eval(bindings) + v = int(ev) + except (TypeError, ValueError, AttributeError, RuntimeError): + pass + if v is None: + raise RuntimeError( + f"Cannot convert shape_expr element (type {type(d).__name__!r}) to int" + ) + if v < 0 or v > _MAX_DIM: + raise RuntimeError( + f"shape_expr dimension {v} out of range [0, {_MAX_DIM}]" + f" (element type {type(d).__name__!r})" + ) + return v + + +def _shape_elem_concrete_without_fake(d: Any) -> bool: + """Return True if this dimension can be resolved to int without a fake-variable substitution. + + "Concrete without fake" means the dimension has a knowable integer value + that does not rely on an artificial placeholder assigned to a free symbolic + variable (i.e. a variable whose ``is_fake`` attribute is True). Such fake + variables arise from TRT's dynamic-shape mechanism: during the first-call + meta_impl pass, shape variables that are not yet bound to any concrete size + appear as fake SymInt32 objects. + + Concreteness is tested in the following order: + 1. Plain ``int`` — always concrete. + 2. ``int(d)`` succeeds — the object can convert itself to an integer. + 3. ``d.is_fake`` is True — explicitly marked as fake, so **not** concrete. + 4. ``d.constant_value`` resolves without error — a static symbolic constant. + 5. ``d.max()`` resolves without error — a bounded symbolic dim's upper bound. + + If none of those paths succeeds the dimension is considered not concrete. + """ + if isinstance(d, int): + return True + try: + int(d) + return True + except (TypeError, ValueError): + pass + if hasattr(d, "is_fake") and getattr(d, "is_fake"): + return False + if hasattr(d, "constant_value"): + try: + cv = d.constant_value + raw = cv() if callable(cv) else cv + int(raw) + return True + except (TypeError, ValueError, AttributeError, RuntimeError): + pass + if hasattr(d, "max"): + try: + int(d.max()) + return True + except (TypeError, ValueError): + pass + return False + + +def _shape_expr_is_concrete(shape_expr: Any) -> bool: + """Return True if every element can be resolved to int without fake/symbolic substitution. + + Recursively descends into nested tensors (via `.shape_expr`) and + list/tuple containers. + """ + for d in shape_expr: + if hasattr(d, "shape_expr"): + if not _shape_expr_is_concrete(d.shape_expr): + return False + elif isinstance(d, (list, tuple)): + if not _shape_expr_is_concrete(d): + return False + elif not _shape_elem_concrete_without_fake(d): + return False + return True + + +def _shape_expr_to_ints(shape_expr: Any) -> List[int]: + """Convert a TRT shape_expr to a list of concrete ints. + + Free/fake shape variables are assigned the minimum valid value (1) via + ``_collect_shape_var_bindings`` before evaluation so that the result is + always a valid concrete shape even in dynamic-shape contexts. + """ + bindings: Dict[int, int] = {} + _collect_shape_var_bindings(shape_expr, bindings) + + def eval_rec(expr: Any) -> List[int]: + result: List[int] = [] + for d in expr: + if hasattr(d, "shape_expr"): + result.extend(eval_rec(d.shape_expr)) + elif isinstance(d, (list, tuple)): + result.extend(eval_rec(d)) + else: + result.append(_shape_elem_to_int(d, bindings)) + return result + + return eval_rec(shape_expr) + + +def _shape_expr_to_meta_shape(shape_expr: Any) -> Tuple[int, ...]: + """Convert a TRT shape_expr to a shape tuple of ints for meta tensors. + + Fake/symbolic dims get unique placeholder ints (1, 2, 3, ...). The counter + is local to each call so values are always 1, 2, 3, ... regardless of prior + calls — making the output fully deterministic across repeated runs of the + same program. + + This differs from ``_shape_expr_to_ints`` in that distinct fake/symbolic + dims receive distinct placeholder values, which preserves shape-broadcasting + semantics when creating meta tensors for first-call shape inference. + """ + bindings: Dict[int, int] = {} + _collect_shape_var_bindings(shape_expr, bindings) + placeholder_cache: Dict[int, int] = {} + _counter = [1] # mutable cell; local to this call + + def eval_rec(expr: Any) -> List[int]: + result: List[int] = [] + for d in expr: + if hasattr(d, "shape_expr"): + result.extend(eval_rec(d.shape_expr)) + elif isinstance(d, (list, tuple)): + result.extend(eval_rec(d)) + elif _shape_elem_concrete_without_fake(d): + result.append(_shape_elem_to_int(d, bindings)) + else: + key = id(d) + if key not in placeholder_cache: + placeholder_cache[key] = _counter[0] + _counter[0] += 1 + result.append(placeholder_cache[key]) + return result + + return tuple(eval_rec(shape_expr)) + + +# ---- TensorDesc helpers ---- + + +# ---- Artifact dump helper ---- + + +def dump_code_artifact( + env_key: str, + filename: str, + content: Union[str, bytes], + default_dir: str = "/tmp/tta_dump", +) -> None: + """Write content to $env_key/, silently ignoring all errors. + + Justified as shared: Triton and CuTile both have near-identical + try: makedirs; write to env-dir; except: pass blocks. CuTeDSL uses + cute.compile(options="--dump-dir=...") so it does not need this helper. + + Args: + env_key: Environment variable naming the target directory. + filename: File name (not path) to write inside that directory. + content: str or bytes content to write. + default_dir: Fallback directory if the env var is unset or empty. + """ + import os as _os + + dump_dir = _os.environ.get(env_key) or default_dir + if not dump_dir: + return + try: + _os.makedirs(dump_dir, exist_ok=True) + path = _os.path.join(dump_dir, filename) + mode = "wb" if isinstance(content, bytes) else "w" + with open(path, mode) as f: + f.write(content) + except (OSError, IOError): + # Broad OS/IO catch is intentional: this is a best-effort debug dump. + # Failures (e.g. read-only filesystem, missing env var dir, permission + # denied) must never propagate to the caller and abort compilation. + pass + + +# ---- Generic sandbox runner ---- + + +def run_kernel_sandbox( + launch_fn: Any, + host_args: List[Any], + is_kernel_fn: Callable[[Any], bool], + recorder_factory: Callable[[Any], Any], + raw_fn: Any = None, + host_kwargs: Optional[Dict[str, Any]] = None, + extra_overrides: Optional[Dict[str, Any]] = None, + strict: bool = False, + op: str = "", + backend: str = "unknown", +) -> Tuple[Optional[Any], Dict[str, Any]]: + """Run launch_fn in a sandbox, replacing kernel objects with recorder proxies. + + All three backends share the same module-discovery + proxy-injection + + sandbox-run pattern. Each backend provides: + - is_kernel_fn: predicate to detect backend kernel objects in module globals. + - recorder_factory: callable(kernel_obj) → recorder proxy. + - raw_fn: (CuTeDSL only) the unwrapped Python function to sandbox instead of + launch_fn (which may be a @cute.jit wrapper without __code__). + - host_kwargs: extra keyword args passed to the sandboxed function call + (Triton passes cfg kwargs; CuTile passes cfg kwargs; CuTeDSL passes none). + - extra_overrides: additional name→value overrides injected into the sandbox + module globals alongside recorder proxies (CuTile uses this for its + patched `ct` module). + - strict: if True, raise QDPRuntimeError when module cannot be found or no + kernel objects exist; if False (default), return (None, {}) instead. + + Returns: + (used_recorder_or_None, all_recorders_dict) + used_recorder_or_None is the first recorder whose grid/args was set. + all_recorders_dict is keyed by module-global name. + + Note on silent-exception policy: + For CuTeDSL the @cute.jit wrapper may not expose __code__, so raw_fn + could be None. When strict=False and module discovery fails, or when + the sandbox run raises, we return (None, {}) — the caller falls back + to a (1,1,1) grid. Triton and CuTile use strict=True so that + missing-module and no-kernel errors are surfaced clearly. + """ + fn_to_sandbox = raw_fn if raw_fn is not None else launch_fn + host_mod_name = getattr(fn_to_sandbox, "__module__", None) + module_obj = None + if host_mod_name: + try: + module_obj = importlib.import_module(host_mod_name) + except ImportError: + pass + + if module_obj is None: + if strict: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=f"cannot import module for launch_fn {launch_fn}", + ) + return None, {} + + recorders: Dict[str, Any] = { + name: recorder_factory(obj) + for name, obj in vars(module_obj).items() + if is_kernel_fn(obj) + } + + if not recorders: + if strict: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=f"no kernel objects matching predicate found in module {host_mod_name}", + ) + return None, {} + + overrides: Dict[str, Any] = dict(recorders) + if extra_overrides: + overrides.update(extra_overrides) + + try: + sandboxed = make_sandboxed_host(fn_to_sandbox, overrides) + sandboxed(*host_args, **(host_kwargs or {})) + except Exception: + # Broad catch is intentional: the sandbox runs user-supplied launch + # functions with fake/symbolic tensor proxies. Any exception — from + # wrong kernel argument types to missing attributes on proxy objects — + # is a sandbox-execution failure, not a programming error in this + # module. When strict=True the original exception propagates so the + # caller can surface it; otherwise we return (None, {}) and let the + # caller fall back to a default (1,1,1) grid. + if strict: + raise + return None, {} + + used = [ + r for r in recorders.values() + if getattr(r, "grid", None) is not None or getattr(r, "args", None) is not None + ] + used_recorder = used[0] if used else None + return used_recorder, recorders + + +# ---- Launch params helper ---- + + +def _launch_params_from_trt( + launch: Any, + extra_args: Any, + num_inputs: int = 1, + num_outputs: int = 1, +) -> types.SimpleNamespace: + """Build a SimpleNamespace from a trtp.KernelLaunchParams and SymIntExprs. + + Justified as shared: all three backends produce a KernelLaunchParams object + with identical grid_x/y/z and block_x/y/z attributes, and the conversion + to SimpleNamespace is identical in all three AOT files (only the default + block_x fallback differs, but block_x is always set explicitly at the call + site before this function is called, so the default is never exercised). + + Args: + launch: A ``trtp.KernelLaunchParams`` instance with grid_x/y/z and + block_x/y/z attributes (missing attributes default to 1). + extra_args: SymIntExprs list to forward as ``sym_int_exprs``. + num_inputs: Number of input tensor bindings; used to build + ``param_binding_indices`` as ``[0..num_inputs+num_outputs)``. + num_outputs: Number of output tensor bindings. + + Returns: + A ``types.SimpleNamespace`` with fields: grid, block, shared_mem, + param_binding_indices, sym_int_exprs. + """ + grid = ( + getattr(launch, "grid_x", 1), + getattr(launch, "grid_y", 1), + getattr(launch, "grid_z", 1), + ) + block = ( + getattr(launch, "block_x", 1), + getattr(launch, "block_y", 1), + getattr(launch, "block_z", 1), + ) + return types.SimpleNamespace( + grid=grid, + block=block, + shared_mem=getattr(launch, "shared_mem", 0), + param_binding_indices=list(range(num_inputs)) + + list(range(num_inputs, num_inputs + num_outputs)), + sym_int_exprs=extra_args, + ) + + +# ---- Launch-arg analysis ---- + + +def analyze_launch_args( + *, + args: Sequence[Any], + num_inputs: int, + num_outputs: int, + op: str, + backend: str, +) -> Tuple[List[int], List[Any]]: + """Split recorded kernel call arguments into pointer bindings and scalar SymInts. + + Enforces the contract: all tensor pointer arguments must appear before + all scalar arguments in the backend kernel call. + + Args: + args: Positional arguments recorded from the kernel call, each + either a ``SymbolicTensor`` (pointer binding) or a scalar + (``trtp.SymInt32``, ``int``, or SymIntExpr with ``._expr``). + num_inputs: Number of input tensor bindings declared for this op. + num_outputs: Number of output tensor bindings declared for this op. + op: Op name included in error messages for diagnostics. + backend: Backend name included in error messages for diagnostics. + + Returns: + (param_binding_indices, scalar_symints) + - param_binding_indices[k] = b means kernel pointer parameter k + maps to exported tensor binding b in [inputs..., outputs...]. + - scalar_symints is a list of trtp.SymInt32 for scalar kernel args. + + Raises: + QDPRuntimeError: If a tensor argument follows a scalar argument, if an + input/output index is out of range, if a tensor has an unknown role, + or if an argument type is not supported. + """ + + def compute_binding_index(st: SymbolicTensor) -> int: + if st.role is TensorRole.INPUT: + if st.index < 0 or st.index >= num_inputs: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=f"invalid input index {st.index} for {num_inputs} inputs", + ) + return st.index + if st.role is TensorRole.OUTPUT: + if st.index < 0 or st.index >= num_outputs: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=f"invalid output index {st.index} for {num_outputs} outputs", + ) + return num_inputs + st.index + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=f"unexpected tensor role {st.role}", + ) + + param_binding_indices: List[int] = [] + scalar_symints: List[Any] = [] + seen_scalar = False + + for a in args: + if isinstance(a, SymbolicTensor): + if seen_scalar: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=( + "backend kernel arguments must be ordered as " + "[all tensor pointers..., then all scalars...]; " + "found a tensor argument after scalar arguments" + ), + ) + param_binding_indices.append(compute_binding_index(a)) + elif _TRT_AVAILABLE and isinstance(a, trtp.SymInt32): + scalar_symints.append(a) + seen_scalar = True + elif isinstance(a, int): + scalar_symints.append(a) + seen_scalar = True + elif _TRT_AVAILABLE and hasattr(a, "_expr"): + scalar_symints.append(a) + seen_scalar = True + else: + raise QDPRuntimeError( + op=op, + stage="aot_impl", + backend=backend, + msg=( + f"unsupported launch argument type {type(a)!r}; " + "only SymbolicTensor, SymInt32, and int scalars are allowed" + ), + ) + + return param_binding_indices, scalar_symints diff --git a/py/torch_tensorrt/annotation/_custom_plugin/_symbolic.py b/py/torch_tensorrt/annotation/_custom_plugin/_symbolic.py new file mode 100644 index 0000000000..07b174e1e0 --- /dev/null +++ b/py/torch_tensorrt/annotation/_custom_plugin/_symbolic.py @@ -0,0 +1,266 @@ +"""Symbolic tensor proxies used during AOT kernel compilation (shape expression bindings). + +During the AOT compilation pipeline each backend (Triton, cuTILE, CuTe DSL) needs +to run the user's *launch function* without a real GPU so that it can: + + 1. Intercept the kernel call and record which tensors / scalars are passed. + 2. Capture the symbolic launch grid (grid_x/y/z) as trtp.SymInt32 expressions + so that TRT can evaluate them at engine-run time with actual input shapes. + +``SymbolicTensor`` is the proxy object injected in place of real ``torch.Tensor`` +arguments. It wraps a QDP ``TensorDesc`` and exposes: + +* ``shape`` / ``shape_expr`` — per-dimension *SymInt32 expressions* that bind to + the input dimension at runtime (e.g. ``shape[0]`` evaluates to the batch size + that TRT resolves during engine execution). +* ``stride`` — row-major symbolic strides, or format-specific strides for packed + channel layouts (HWC, DHWC, etc.). +* ``numel()`` — symbolic product of all dimensions. + +The ``TensorRole`` enum distinguishes input tensors from output tensors so that +``analyze_launch_args`` in ``_qdp_utils.py`` can reconstruct the correct +``param_binding_indices`` that TRT uses to identify which runtime buffer to pass +for each kernel pointer parameter. +""" +from __future__ import annotations + +import logging +from dataclasses import dataclass +from enum import Enum, auto +from typing import Any, Tuple + +logger = logging.getLogger(__name__) + +try: + import tensorrt as trt + import tensorrt.plugin as trtp + + _TRT_AVAILABLE = True +except ImportError: + _TRT_AVAILABLE = False + trt = None # type: ignore[assignment] + trtp = None # type: ignore[assignment] + + +class TensorRole(Enum): + """Identifies whether a SymbolicTensor represents an input or output tensor. + + Used by ``analyze_launch_args`` to map each kernel pointer parameter back to + the correct TRT binding index (inputs occupy indices 0..num_inputs-1, outputs + occupy indices num_inputs..num_inputs+num_outputs-1). + """ + + INPUT = auto() + OUTPUT = auto() + + +def _to_sym_int(x: Any) -> Any: + """Convert a shape/stride element to SymInt32; accept int or SymInt32-like.""" + if _TRT_AVAILABLE and isinstance(x, trtp.SymInt32): + return x + if isinstance(x, int): + return trtp.SymInt32(x) if _TRT_AVAILABLE else x + if _TRT_AVAILABLE: + return trtp.SymInt32(x) + return x + + +def _sym_to_int_if_const(x: Any) -> Any: + """Return Python int if x is concrete (Python int or constant SymInt32), else return x unchanged. + + In TRT's aot_impl context, shape_expr elements may be SymInt32 objects backed by + a trt.IDimensionExpr constant (e.g. for a static-shape tensor). We can extract + the concrete value via _int_expr.is_constant() / .get_constant_value(). + This lets downstream code use isinstance(v, int) checks correctly so that grid + computation in user launch_fn works even with SymbolicTensor arguments. + """ + if isinstance(x, int): + return x + # Try via TRT IDimensionExpr (available when _exprBuilder was set during build). + # Broad catch is necessary: IDimensionExpr attribute access raises different + # exception types across TRT versions (AttributeError, RuntimeError, TypeError). + try: + ie = getattr(x, "_int_expr", None) + if ie is not None and ie.is_constant(): + return int(ie.get_constant_value()) + except (AttributeError, RuntimeError, TypeError, ValueError) as _e: + logger.debug("_sym_to_int_if_const: failed to extract constant from %r: %s", x, _e) + return x + + +class _ShapeDim: + """Wrapper for a symbolic shape dimension supporting arithmetic in launch fns. + + math.prod(iterable) starts with 1 (int) and multiplies left-to-right. + The first step is ``1 * element``, which Python evaluates as: + 1. int.__mul__(element) → NotImplemented + 2. element.__rmul__(1) ← this method + + Standard trtp.SymInt32 has no __rmul__, so math.prod fails for dynamic + dims. _ShapeDim wraps a SymInt32 (or int) and delegates all arithmetic to + self._v so that expressions like math.prod(x.shape) and ceiling division + (M + BX - 1) // BX work transparently in @cute.jit launch functions. + + _expr is exposed so that SymIntExpr._op and _as_symint32 can extract the + underlying IDimensionExpr without knowing the concrete wrapper type. + """ + + def __init__(self, v: Any) -> None: + self._v = v # int or trtp.SymInt32 + # Cache the underlying IDimensionExpr so SymIntExpr._op and _as_symint32 + # can extract it without knowing the concrete wrapper type. + self._expr = getattr(v, "_expr", getattr(v, "_int_expr", None)) + + def __add__(self, other: Any) -> Any: + v = other._v if isinstance(other, _ShapeDim) else other + return self._v + v + + def __radd__(self, other: Any) -> Any: + return self._v + other + + def __sub__(self, other: Any) -> Any: + v = other._v if isinstance(other, _ShapeDim) else other + return self._v - v + + def __rsub__(self, other: Any) -> Any: + return other - self._v + + def __mul__(self, other: Any) -> Any: + v = other._v if isinstance(other, _ShapeDim) else other + return self._v * v + + def __rmul__(self, other: Any) -> Any: + """Called as: other * self (e.g., int(1) * _ShapeDim from math.prod).""" + return self._v * other # commutative: SymInt32.__mul__(int) handles it + + def __floordiv__(self, other: Any) -> Any: + v = other._v if isinstance(other, _ShapeDim) else other + return self._v // v + + def __rfloordiv__(self, other: Any) -> Any: + return other // self._v + + def __int__(self) -> int: + return int(self._v) + + def __repr__(self) -> str: + return f"_ShapeDim({self._v!r})" + + +def _strides_from_td(td: Any) -> Tuple[Any, ...]: + """Return strides for *td*, sourced from TRT runtime where possible. + + Two sources, probed in priority order: + + 1. ``td.strides`` — physical strides on a ``trtp.Tensor`` object (the + runtime tensor type passed to JIT ``enqueue``). When a + ``SymbolicTensor`` wraps a ``trtp.Tensor``, this gives the actual + strides TRT has assigned to the buffer, correctly handling any layout + (e.g. row-padded LINEAR from Myelin matmul fusion). + + 2. Row-major fallback from ``shape_expr`` — correct for contiguous tensors. + Used when wrapping a ``TensorDesc`` in the AOT path (which carries only + logical shape, not physical strides). + + Analytical per-format stride reconstruction is intentionally absent: it + replicates TRT's internal layout logic and is fragile. Correct physical + strides must come from TRT itself. + """ + strides = getattr(td, "strides", None) + if strides is not None: + return tuple(_to_sym_int(s) for s in strides) + + # Fallback: logical row-major strides from shape_expr. + shape_expr = td.shape_expr + prod = trtp.SymInt32(1) if _TRT_AVAILABLE else 1 + strides_list: list = [] + for i in range(len(shape_expr) - 1, -1, -1): + strides_list.insert(0, prod) + prod = prod * _to_sym_int(shape_expr[i]) + return tuple(_to_sym_int(s) for s in strides_list) + + +@dataclass +class SymbolicTensor: + """Symbolic view over a QDP TensorDesc with role metadata. + + Attributes: + td: TensorDesc from QDP. + role: TensorRole.INPUT or TensorRole.OUTPUT. + index: role-local index (0-based, within inputs or outputs). + """ + + td: Any # trtp.TensorDesc + role: TensorRole + index: int + + def __post_init__(self) -> None: + # Pre-compute shape dims once: Python int for static, _ShapeDim for dynamic. + # _ShapeDim adds __rmul__ so math.prod(x.shape) works in @cute.jit kernels. + # Guard against mock/test TDs that lack shape_expr. + shape_expr = getattr(self.td, "shape_expr", None) + _shape = [] + if shape_expr is not None: + for d in shape_expr: + concrete = _sym_to_int_if_const(d) + _shape.append(concrete if isinstance(concrete, int) else _ShapeDim(concrete)) + self._shape: Tuple[Any, ...] = tuple(_shape) + + self._stride: Tuple[Any, ...] = _strides_from_td(self.td) if shape_expr is not None else () + + # Pre-compute numel: Python int for fully-static shapes, SymInt32 for dynamic. + if shape_expr is None: + self._numel: Any = 0 + else: + concrete_dims = [_sym_to_int_if_const(d) for d in shape_expr] + if all(isinstance(v, int) for v in concrete_dims): + result = 1 + for v in concrete_dims: + result *= v + self._numel = result + else: + n = trtp.SymInt32(1) if _TRT_AVAILABLE else 1 + for d in shape_expr: + n = n * _to_sym_int(d) + self._numel = n + + @property + def shape(self) -> Tuple[Any, ...]: + return self._shape + + @property + def shape_expr(self) -> Tuple[Any, ...]: + """Symbolic shape dimensions (same as .shape); use for grid or extra_args.""" + return self.shape + + def size(self, dim: int | None = None): + """PyTorch-style alias for shape / shape[dim].""" + if dim is None: + return self.shape + return self.shape[dim] + + def shape_dim(self, dim: int) -> Any: + """Return the raw SymInt32 for dimension *dim* from the underlying TensorDesc.""" + return _to_sym_int(self.td.shape_expr[dim]) + + def stride(self, dim: int | None = None): + """PyTorch-style stride API: stride() or stride(dim).""" + if dim is None: + return self._stride + return self._stride[dim] + + def numel(self) -> Any: + """Total element count: Python int for static shapes, SymInt32 for dynamic.""" + return self._numel + + @property + def is_cuda(self) -> bool: + return True + + +def cdiv(a: Any, b: int) -> Any: + """Ceiling division on a SymInt32 by a Python int divisor. + + Only the pattern SymInt32 // int is required by the contract. + """ + return (a + (b - 1)) // b diff --git a/py/torch_tensorrt/annotation/_recorders.py b/py/torch_tensorrt/annotation/_recorders.py new file mode 100644 index 0000000000..540761f07f --- /dev/null +++ b/py/torch_tensorrt/annotation/_recorders.py @@ -0,0 +1,207 @@ +"""Backend launch recorders for Triton, cuTILE, and CuTe DSL. + +These recorder classes are used exclusively inside the *autotune sandbox* — +the dry-run environment in which the TTA autotune pass executes a user's +``launch_fn`` without actually dispatching GPU work. The sandbox replaces +real kernel objects with recorder proxies so that launch parameters (grid, +block, args) can be captured and later used to build TRT QDP plugin +descriptors, without touching the GPU or requiring CUDA to be initialised. + +Each recorder mirrors the call protocol of its target backend: + +Backend | Call protocol | Recorder class +----------|--------------------------------------------|-------------------- +Triton | ``kernel[grid](*args, **kwargs)`` | TritonLaunchRecorder +cuTILE | ``prog(*args, **kwargs)`` | CuTileLaunchRecorder +CuTe DSL | ``kernel(*args)(...).launch(grid, block)`` | CuTeDSLKernelRecorder + +After the sandbox ``launch_fn`` returns, the autotune pass inspects the +populated fields on the recorder instance to retrieve the captured +parameters. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Dict, Optional, Tuple + + +@dataclass +class TritonLaunchRecorder: + """Proxy that mimics a Triton kernel, recording launches instead of running. + + Used in the autotune sandbox to capture the parameters that a user's + ``launch_fn`` passes to a Triton kernel without executing any GPU work. + + The Triton call protocol is ``kernel[grid](*args, **kwargs)``. This + recorder implements ``__getitem__`` to capture the grid and returns a + closure that captures the positional and keyword arguments when called. + + Attributes: + real_kernel: The original Triton kernel object being proxied. Not + invoked by the recorder; held for reference only (e.g. to read + kernel metadata outside the sandbox). + grid: Grid tuple captured from ``kernel[grid]``. ``None`` until + a launch is recorded. + args: Positional arguments captured from the launcher call. ``None`` + until a launch is recorded. + kwargs: Keyword arguments captured from the launcher call. ``None`` + until a launch is recorded. + """ + + real_kernel: Any + grid: Optional[Any] = None + args: Optional[Tuple[Any, ...]] = None + kwargs: Optional[Dict[str, Any]] = None + + def __getitem__(self, grid: Any) -> Any: + """Capture ``kernel[grid]`` and return a launcher closure. + + Args: + grid: The grid value (e.g. a tuple or a lambda) passed inside + ``kernel[grid]``. + + Returns: + A callable that, when called with ``(*args, **kwargs)``, stores + those values on this recorder. + """ + + def _launcher(*args: Any, **kwargs: Any) -> None: + self.grid = grid + self.args = args + self.kwargs = kwargs + + return _launcher + + +@dataclass +class CuTileLaunchRecorder: + """Proxy that mimics a cuTILE program object, recording calls instead of running. + + Used in the autotune sandbox to capture the arguments passed to a cuTILE + compiled program without executing any GPU work. + + The cuTILE call protocol is ``prog(*args, **kwargs)`` where *prog* is an + already-compiled ``cuda.tile`` program object. This recorder implements + ``__call__`` to capture positional and keyword arguments. + + Note: cuTILE programs do not expose a separate ``[grid]`` subscript; the + grid is typically embedded in the program configuration or passed as a + keyword argument. The ``grid`` field is reserved for future use when + cuTILE exposes per-launch grid control. + + Attributes: + real_prog: The original cuTILE program object being proxied. Not + invoked by the recorder; held for reference only. + args: Positional arguments captured from the call. ``None`` until + a launch is recorded. + kwargs: Keyword arguments captured from the call. ``None`` until + a launch is recorded. + grid: Reserved for future cuTILE grid capture. Always ``None`` + in the current implementation. + """ + + real_prog: Any + args: Optional[Tuple[Any, ...]] = None + kwargs: Optional[Dict[str, Any]] = None + grid: Optional[Tuple[int, ...]] = None + + def __call__(self, *args: Any, **kwargs: Any) -> None: + """Capture ``prog(*args, **kwargs)`` without running the program. + + Args: + *args: Positional arguments forwarded to the cuTILE program. + **kwargs: Keyword arguments forwarded to the cuTILE program. + """ + self.args = args + self.kwargs = kwargs + + +class _CuTeDSLLaunchProxy: + """Intermediate object returned by ``CuTeDSLKernelRecorder.__call__``. + + Represents the result of calling a ``@cute.kernel`` object (the + compiled artifact) and exposes a ``.launch(grid, block)`` method that + the autotune sandbox intercepts to record grid and block dimensions. + + This class is an implementation detail of ``CuTeDSLKernelRecorder`` and + is not intended to be instantiated directly. + """ + + def __init__(self, recorder: "CuTeDSLKernelRecorder") -> None: + self._recorder = recorder + + def launch( + self, + grid: Any = (1, 1, 1), + block: Any = (1, 1, 1), + **kwargs: Any, + ) -> None: + """Capture the grid and block dimensions from a ``.launch()`` call. + + Args: + grid: Grid dimensions tuple (x, y, z). Non-tuple values are + coerced to a tuple. + block: Block dimensions tuple (x, y, z). Non-tuple values are + coerced to a tuple. + **kwargs: Additional keyword arguments are accepted but ignored; + they are not recorded because they are not needed for TRT + QDP descriptor construction. + """ + grid_t = tuple(grid) if not isinstance(grid, tuple) else grid + block_t = tuple(block) if not isinstance(block, tuple) else block + if len(grid_t) != 3: + raise ValueError( + f"_CuTeDSLLaunchProxy.launch: grid must have exactly 3 elements (x, y, z), " + f"got {len(grid_t)}: {grid_t!r}" + ) + if len(block_t) != 3: + raise ValueError( + f"_CuTeDSLLaunchProxy.launch: block must have exactly 3 elements (x, y, z), " + f"got {len(block_t)}: {block_t!r}" + ) + self._recorder.grid = grid_t + self._recorder.block = block_t + + +@dataclass +class CuTeDSLKernelRecorder: + """Proxy for ``@cute.kernel`` objects that records grid/block on ``.launch()``. + + Used in the autotune sandbox to capture the launch configuration + (grid and block dimensions) of a CuTe DSL kernel without executing GPU + work. + + The CuTe DSL call protocol for ``@cute.kernel`` decorated functions is:: + + result = kernel(*args, **kwargs) # compile / specialize + result.launch(grid=..., block=...) # dispatch + + This recorder implements ``__call__`` to return a ``_CuTeDSLLaunchProxy`` + that records the subsequent ``.launch()`` call, storing ``grid`` and + ``block`` on this instance. + + Attributes: + real_kernel: The original ``@cute.kernel`` object being proxied. + Not invoked by the recorder; held for reference only. + grid: Grid dimensions captured from ``.launch(grid=...)``. ``None`` + until a launch is recorded. + block: Block dimensions captured from ``.launch(block=...)``. + ``None`` until a launch is recorded. + """ + + real_kernel: Any + grid: Optional[Tuple[Any, ...]] = None + block: Optional[Tuple[Any, ...]] = None + + def __call__(self, *args: Any, **kwargs: Any) -> _CuTeDSLLaunchProxy: + """Return a launch proxy that will record grid/block on ``.launch()``. + + Args: + *args: Positional arguments passed to the kernel (ignored; + not needed for TRT descriptor construction). + **kwargs: Keyword arguments passed to the kernel (ignored). + + Returns: + A ``_CuTeDSLLaunchProxy`` bound to this recorder. + """ + return _CuTeDSLLaunchProxy(self) diff --git a/py/torch_tensorrt/annotation/_specs.py b/py/torch_tensorrt/annotation/_specs.py new file mode 100644 index 0000000000..6939f2e28d --- /dev/null +++ b/py/torch_tensorrt/annotation/_specs.py @@ -0,0 +1,257 @@ +""" +TTA Spec Types and Annotation Metadata +======================================= + +This module defines the kernel spec type hierarchy used by the TTA annotation +layer to describe how a custom QDP plugin should be compiled and registered. + +Spec hierarchy +-------------- + +:: + + CustomPluginSpec — AOT QDP plugin descriptor built by ``tta.custom_plugin()`` + ├── TritonSpec — kernel implemented with Triton + ├── CuTileSpec — kernel implemented with NVIDIA CuTile (cuda-tile) + ├── CuTeDSLSpec — kernel implemented with NVIDIA CuTe DSL + └── TvmFfiSpec — kernel compiled via TVM FFI (planned; blocked on QDP TVM FFI support) + +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence + + +# ── Shared helpers ──────────────────────────────────────────────────────────── + + +def _dict_to_stable_tuple(d: Dict[str, Any]) -> tuple: + """Return a deterministic, hashable tuple representation of a dict by sorting keys.""" + return tuple(sorted(d.items())) + + +def _validate_kernel_spec_fields(spec_name: str, launch_fn: Any, configs: Any) -> None: + """Validate the common fields shared by TritonSpec, CuTileSpec, and CuTeDSLSpec.""" + if not callable(launch_fn): + raise TypeError( + f"{spec_name}: launch_fn must be callable, got {type(launch_fn).__name__!r}" + ) + if configs is not None and not isinstance(configs, list): + raise TypeError( + f"{spec_name}: configs must be a list or None, got {type(configs).__name__!r}" + ) + + +# ── Custom kernel specs (Triton / CuTile / CuTeDSL / TvmFfi) ───────────────── +# TvmFfiSpec (tta.tvmffi) is planned as a fourth backend on par with the three +# below, but is blocked on TVM FFI support landing in QDP so that compiled +# kernels can be handed off to TRT's plugin runtime without a live Python +# interpreter. + + +@dataclass(frozen=True) +class TritonSpec: + """Specification for a custom plugin implemented with a Triton kernel. + + Attributes: + launch_fn: The Triton kernel entry-point. + configs: List of launch-parameter dicts for autotuning candidates. + ``None`` means no explicit configs. + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + kwargs: Additional Triton-specific parameters. + """ + + launch_fn: Callable + configs: Optional[List[Dict[str, Any]]] = None + input_formats: Optional[Sequence[int]] = None + output_formats: Optional[Sequence[int]] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def to_cache_key(self) -> tuple: + configs_tuple = tuple( + tuple(sorted(cfg.items())) for cfg in (self.configs or []) + ) + in_fmts = tuple(int(f) for f in self.input_formats) if self.input_formats else () + out_fmts = tuple(int(f) for f in self.output_formats) if self.output_formats else () + return ( + "triton", + id(self.launch_fn), + configs_tuple, + _dict_to_stable_tuple(self.kwargs), + in_fmts, + out_fmts, + ) + + def __post_init__(self) -> None: + _validate_kernel_spec_fields("TritonSpec", self.launch_fn, self.configs) + + +@dataclass(frozen=True) +class CuTileSpec: + """Specification for a custom plugin implemented with an NVIDIA CuTile kernel. + + Attributes: + launch_fn: The CuTile kernel entry-point. + configs: List of launch-parameter dicts for autotuning candidates. + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + kwargs: Additional CuTile-specific parameters. + """ + + launch_fn: Callable + configs: Optional[List[Dict[str, Any]]] = None + input_formats: Optional[Sequence[int]] = None + output_formats: Optional[Sequence[int]] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def to_cache_key(self) -> tuple: + configs_tuple = tuple( + tuple(sorted(cfg.items())) for cfg in (self.configs or []) + ) + in_fmts = tuple(int(f) for f in self.input_formats) if self.input_formats else () + out_fmts = tuple(int(f) for f in self.output_formats) if self.output_formats else () + return ( + "cutile", + id(self.launch_fn), + configs_tuple, + _dict_to_stable_tuple(self.kwargs), + in_fmts, + out_fmts, + ) + + def __post_init__(self) -> None: + _validate_kernel_spec_fields("CuTileSpec", self.launch_fn, self.configs) + + +@dataclass(frozen=True) +class CuTeDSLSpec: + """Specification for a custom plugin implemented with the NVIDIA CuTe DSL. + + Attributes: + launch_fn: The CuTe DSL kernel entry-point. + configs: List of launch-parameter dicts for autotuning candidates. + arch: Optional target GPU architecture string (e.g. ``"sm_80"``). + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + kwargs: Additional CuTe DSL-specific parameters. + """ + + launch_fn: Callable + configs: Optional[List[Dict[str, Any]]] = None + arch: Optional[str] = None + input_formats: Optional[Sequence[int]] = None + output_formats: Optional[Sequence[int]] = None + kwargs: Dict[str, Any] = field(default_factory=dict) + + def to_cache_key(self) -> tuple: + configs_tuple = tuple( + tuple(sorted(cfg.items())) for cfg in (self.configs or []) + ) + in_fmts = tuple(int(f) for f in self.input_formats) if self.input_formats else () + out_fmts = tuple(int(f) for f in self.output_formats) if self.output_formats else () + return ( + "cutedsl", + id(self.launch_fn), + self.arch, + configs_tuple, + _dict_to_stable_tuple(self.kwargs), + in_fmts, + out_fmts, + ) + + def __post_init__(self) -> None: + _validate_kernel_spec_fields("CuTeDSLSpec", self.launch_fn, self.configs) + + + +# ── Factory functions ───────────────────────────────────────────────────────── + + +def triton( + launch_fn: Callable, + configs: Optional[List[Dict[str, Any]]] = None, + input_formats: Optional[Sequence[int]] = None, + output_formats: Optional[Sequence[int]] = None, + **kwargs: Any, +) -> TritonSpec: + """Create a :class:`TritonSpec` for a Triton kernel custom plugin. + + Args: + launch_fn: Triton kernel function. + configs: List of autotuning config dicts. Pass ``None`` for no configs. + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + + Returns: + A :class:`TritonSpec` instance. + """ + return TritonSpec( + launch_fn=launch_fn, + configs=configs, + input_formats=input_formats, + output_formats=output_formats, + kwargs=kwargs, + ) + + +def cutile( + launch_fn: Callable, + configs: Optional[List[Dict[str, Any]]] = None, + input_formats: Optional[Sequence[int]] = None, + output_formats: Optional[Sequence[int]] = None, + **kwargs: Any, +) -> CuTileSpec: + """Create a :class:`CuTileSpec` for a CuTile kernel custom plugin. + + Args: + launch_fn: CuTile kernel function. + configs: List of autotuning config dicts. Pass ``None`` for no configs. + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + + Returns: + A :class:`CuTileSpec` instance. + """ + return CuTileSpec( + launch_fn=launch_fn, + configs=configs, + input_formats=input_formats, + output_formats=output_formats, + kwargs=kwargs, + ) + + +def cutedsl( + launch_fn: Callable, + configs: Optional[List[Dict[str, Any]]] = None, + arch: Optional[str] = None, + input_formats: Optional[Sequence[int]] = None, + output_formats: Optional[Sequence[int]] = None, + **kwargs: Any, +) -> CuTeDSLSpec: + """Create a :class:`CuTeDSLSpec` for a CuTe DSL kernel custom plugin. + + Args: + launch_fn: CuTe DSL kernel function. + configs: List of autotuning config dicts. Pass ``None`` for no configs. + arch: Target GPU architecture string (e.g. ``"sm_80"``). + input_formats: Optional tensor layout descriptors for input tensors. + output_formats: Optional tensor layout descriptors for output tensors. + + Returns: + A :class:`CuTeDSLSpec` instance. + """ + return CuTeDSLSpec( + launch_fn=launch_fn, + configs=configs, + arch=arch, + input_formats=input_formats, + output_formats=output_formats, + kwargs=kwargs, + ) + + + diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py index c936308fd5..2e2e3b1128 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_custom_op.py @@ -1,39 +1,144 @@ -from typing import Callable, Optional +from __future__ import annotations +import uuid +from typing import TYPE_CHECKING, Any, Callable, Optional + +import torch from torch.fx.node import Node from torch_tensorrt.dynamo._settings import CompilationSettings from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority -from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin +from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import ( + _probe_num_outputs, + generate_plugin, +) from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import ( generate_plugin_converter, ) +if TYPE_CHECKING: + from torch_tensorrt.annotation._custom_plugin._descriptor import CustomPluginSpec + def custom_op( op_name: str, + impl: Optional["CustomPluginSpec"] = None, capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None, priority: ConverterPriority = ConverterPriority.STANDARD, supports_dynamic_shapes: bool = False, requires_output_allocator: bool = False, ) -> None: """ - Generate the Plugin and corresponding Plugin Converter using external kernels and TensorRT Quick Deployable Plugin APIs. + Generate the Plugin and corresponding Plugin Converter using external kernels + and TensorRT Quick Deployable Plugin APIs. Args: - plugin_name: the plugin name that is used to generate the plugin automatically. - There should be existing kernels and pytorch custom operation for this plugin name. - capability_validator: A lambda that can take a ``torch.fx.Node`` and determine if the - converter can properly handle this Node. If the validator returns ``False``, the subgraph - partitioner will make sure this Node is run in PyTorch in the compiled graph. - priority: Allows developers to override existing converters in the converter registry - supports_dynamic_shapes: if dynamic shape is supported - requires_output_allocator: if the converter creates operators which require an Output Allocator to run (e.g. data dependent operators) + op_name: the plugin name in ``"namespace::name"`` form. A matching + ``torch.library.custom_op`` must already exist (or be auto-created + via ``impl``). + impl: optional ``tta.CustomPluginSpec`` from ``tta.custom_plugin(...)``. + When provided: + + 1. The torch op is auto-registered (``impl.auto_register_torch_op``). + 2. TRT QDP descriptors are registered via ``register_custom_plugin`` + which handles both activation and weight tensor inputs, then adds + ``@trtp.autotune`` and ``@trtp.aot_impl`` for AOT kernel dispatch. + 3. A Dynamo converter is registered via ``generate_plugin_converter`` + (the same path as the JIT plugin) when there are no weights. + For weighted plugins, a minimal custom converter is registered + that injects weight tensors as ``trt.add_constant`` layers before + delegating to ``impl.lower_to_trt``. + + When ``None`` (default) the existing JIT path (``generate_plugin`` + + ``generate_plugin_converter``) is used unchanged. + capability_validator: optional node capability predicate. + priority: converter registry priority. + supports_dynamic_shapes: whether the converter supports dynamic shapes. + requires_output_allocator: whether the converter requires an output + allocator (e.g. data-dependent operators). """ - generate_plugin(op_name) - generate_plugin_converter( - op_name, - capability_validator, - priority, - supports_dynamic_shapes, - requires_output_allocator, - ) + if impl is not None: + from torch_tensorrt.annotation._custom_plugin._descriptor import ( + register_custom_plugin, + ) + + impl.auto_register_torch_op(op_name) + + namespace, op_local_name = op_name.split("::") + torch_op = getattr(getattr(torch.ops, namespace), op_local_name) + schema = torch_op._schemas[""] + n_tensor_inputs = sum( + 1 for a in schema.arguments + if a.type.isSubtypeOf(torch._C.TensorType.get()) + ) + + # QDP registration: includes weight tensors in the input count so TRT + # can receive them as trt.add_constant outputs. + # Probe torch_op (already registered by auto_register_torch_op above) + # to get the correct output count — same mechanism as generate_plugin. + num_outputs = _probe_num_outputs(torch_op, schema) + register_custom_plugin( + impl, + num_inputs=n_tensor_inputs + len(impl.weights), + num_outputs=num_outputs, + qdp_name=op_name, + ) + + if not impl.weights: + # No weights: reuse generate_plugin_converter, the same path as the + # JIT plugin. After register_custom_plugin, the op is in QDP_REGISTRY + # with aot_impl registered, so generate_plugin_converter detects that + # and sets aot=True automatically — no separate converter code needed. + generate_plugin_converter( + op_name, + capability_validator, + priority, + supports_dynamic_shapes, + requires_output_allocator, + ) + else: + # Weighted plugins need a custom converter: generate_plugin_converter + # cannot inject weight tensors as trt.add_constant before the plugin + # call because it reads input count from plugin.input_tensor_names, + # which includes weights, but the FX node only carries the activation + # args. lower_to_trt handles the weight injection via + # lower_custom_plugin_descriptor. + from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( + dynamo_tensorrt_converter, + ) + from torch_tensorrt.dynamo.conversion.converter_utils import get_trt_tensor + + torch_overload = getattr(torch_op, "default") + _impl = impl + _n_act = n_tensor_inputs + _qdp_name = op_name + + def _impl_converter( + ctx: Any, + target: Any, + args: Any, + kwargs: Any, + name: str, + ) -> Any: + unique_id = uuid.uuid4() + itensor_args = [ + get_trt_tensor(ctx, t, f"inp{i}_{unique_id}") + for i, t in enumerate(args[:_n_act]) + ] + return _impl.lower_to_trt(ctx, itensor_args, name, qdp_name=_qdp_name) + + dynamo_tensorrt_converter( + torch_overload, + capability_validator=capability_validator, + priority=priority, + supports_dynamic_shapes=supports_dynamic_shapes, + requires_output_allocator=requires_output_allocator, + )(_impl_converter) + else: + generate_plugin(op_name) + generate_plugin_converter( + op_name, + capability_validator, + priority, + supports_dynamic_shapes, + requires_output_allocator, + ) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py index de24302ae1..fd70966d02 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py @@ -1,7 +1,7 @@ import itertools import logging from types import FunctionType -from typing import Any, Callable, Tuple +from typing import Any, Callable, Optional, Tuple import torch from sympy import lambdify @@ -27,6 +27,170 @@ def mksym( ) +def _probe_num_outputs_from_callable( + meta_impl: Callable[[Any], Any], + n_args: int, + preferred_rank: Optional[int] = None, +) -> int: + """Probe a meta callable (not a registered torch_op) to count its outputs. + + Used by :meth:`CustomPluginSpec.auto_register_torch_op` before the + ``torch.library`` op exists, so we cannot rely on a registered schema. + Tries ranks 1–4 with size-2 FakeTensors; returns 1 if all probes fail. + + Args: + meta_impl: The shape-inference callable (typically + ``CustomPluginSpec.meta_impl``). + n_args: Number of tensor arguments to pass. + preferred_rank: If provided, this rank is tried first before the + default sweep (1–4). Pass the actual input rank when + it is known (e.g. from live ``trt.ITensor`` inputs in + the lowering path) for a more accurate probe. + """ + ranks = [preferred_rank] if preferred_rank is not None else [] + ranks += [r for r in range(1, 5) if r != preferred_rank] + for rank in ranks: + try: + with FakeTensorMode(): + dummy = torch.randn([2] * rank) + result = meta_impl(*([dummy] * max(n_args, 1))) + return len(result) if isinstance(result, (tuple, list)) else 1 + except Exception: + continue + return 1 + + +def _probe_num_outputs(torch_op: Callable[[Any], Any], schema: Any) -> int: + """Probe ``torch_op`` with rank-2 FakeTensors to count its outputs. + + Non-tensor scalar arguments (int, float, bool, str) receive neutral + defaults (0, 0.0, False, "") which are sufficient for shape inference. + Tries ranks 1–4 with size-2 extents; returns 1 if all probes fail. + """ + # torch._C type singletons are not hashable, so use a list of (type, default) pairs. + _scalar_defaults = [ + (torch._C.IntType.get(), 0), + (torch._C.FloatType.get(), 0.0), + (torch._C.BoolType.get(), False), + (torch._C.StringType.get(), ""), + ] + + for rank in range(1, 5): + try: + probe_args = [] + with FakeTensorMode(): + for arg in schema.arguments: + if arg.type.isSubtypeOf(torch._C.TensorType.get()): + probe_args.append(torch.randn([2] * rank)) + else: + for scalar_type, default in _scalar_defaults: + if arg.type.isSubtypeOf(scalar_type): + probe_args.append(default) + break + result = torch_op(*probe_args) + if isinstance(result, torch.Tensor): + return 1 + return len(result) + except Exception: + continue + return 1 + + +def _compute_out_descs_symbolic( + tensor_args: Any, + callable_impl: Callable[[Any], Any], +) -> list: + """Core ShapeEnv + lambdify computation shared by the JIT and TTA AOT paths. + + Builds symbolic FakeTensors (one per TensorDesc in ``tensor_args``), runs + ``callable_impl`` to obtain output shapes, lambdifies those shapes into + sympy-backed formulas, and returns a list of ``TensorDesc`` s with the + formulas assigned to ``shape_expr``. + + Args: + tensor_args: Sequence of input ``TensorDesc`` objects (TRT native type). + callable_impl: Shape-inference callable. Receives one ``FakeTensor`` + per element of ``tensor_args``; must return a + ``torch.Tensor`` or a sequence of ``torch.Tensor`` s. + + Returns: + List of output ``TensorDesc`` objects (one per output tensor). + """ + shape_env = ShapeEnv() + syms_args = [] + for tensor_arg in tensor_args: + sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} + syms_arg = [ + mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) + for k, v in sample.items() + ] + syms_args.append(syms_arg) + + with FakeTensorMode(shape_env=shape_env): + fake_args = [torch.randn(syms_arg) for syms_arg in syms_args] + raw_output = callable_impl(*fake_args) + + outputs = ( + [raw_output] if isinstance(raw_output, torch.Tensor) else list(raw_output) + ) + + input_node_expr = list( + itertools.chain.from_iterable( + [sym.node.expr for sym in syms_arg] for syms_arg in syms_args + ) + ) + + out_descs = [] + for output in outputs: + shape_calc_fns = [ + lambdify(tuple(input_node_expr), output.shape[i].node.expr, "math") + for i in range(output.ndim) + ] + out_desc = tensor_args[0].like() + input_shape_expr = list( + itertools.chain.from_iterable(td.shape_expr for td in tensor_args) + ) + for i in range(output.ndim): + if output.shape[i].node.expr is None: + raise ValueError(f"output.shape[{i}].node.expr cannot be None") + out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] + out_descs.append(out_desc) + + return out_descs + + +def _build_symbolic_desc_fn( + callable_impl: Callable[[Any], Any], + num_inputs: int, + num_outputs: int = 1, +) -> Callable[..., Any]: + """Build a TRT descriptor callable backed by FakeTensorMode + ShapeEnv + lambdify. + + Used by the TTA AOT path (``_build_desc_fn`` in + ``annotation/_custom_plugin/_descriptor.py``) where ``callable_impl`` is + the user-supplied ``meta_impl``. The JIT path (``_generic_plugin_desc`` + inside ``_generate_plugin``) uses the same ``_compute_out_descs_symbolic`` + core but wraps it in a closure that filters mixed tensor/scalar ``*args``. + + Args: + callable_impl: Shape-inference callable. Receives one ``FakeTensor`` + per input TensorDesc; must return a ``torch.Tensor`` or + a sequence of ``torch.Tensor`` s. + num_inputs: Number of leading ``TensorDesc`` positional arguments. + num_outputs: Expected output count. Returns a single ``TensorDesc`` + when 1, a ``tuple`` otherwise. + """ + _fn = callable_impl + _n_in = num_inputs + _n_out = num_outputs + + def _desc(*args: Any) -> Any: + out_descs = _compute_out_descs_symbolic(args[:_n_in], _fn) + return out_descs[0] if _n_out == 1 else tuple(out_descs) + + return _desc + + def _generate_plugin(plugin_name: str) -> None: try: import tensorrt.plugin as trtp @@ -40,12 +204,14 @@ def _generate_plugin(plugin_name: str) -> None: # retrieve the corresponding torch operation using the passed in string torch_op = getattr(getattr(torch.ops, namespace), name) + schema = torch_op._schemas[""] + num_outputs = _probe_num_outputs(torch_op, schema) + # helper function that generates the required signature based on the torch operation def generate_signature( torch_op: Callable[[Any], Any], + num_outputs: int, ) -> Tuple[str, str, str, dict[str, Any], dict[str, Any]]: - schema = torch_op._schemas[""] - arg_list = [] register_func_annotation = {} @@ -92,7 +258,9 @@ def generate_signature( plugin_impl_input = ", ".join(plugin_impl_arg_list) plugin_impl_signature = f"def add_plugin_impl({plugin_impl_input}):" - register_func_annotation["return"] = Tuple[trtp.TensorDesc] + # Return annotation must encode the exact number of outputs so TRT + # allocates the right number of output ports on the plugin node. + register_func_annotation["return"] = Tuple[tuple([trtp.TensorDesc] * num_outputs)] impl_func_annotation["outputs"] = Tuple[trtp.Tensor] impl_func_annotation["stream"] = int @@ -112,54 +280,11 @@ def generate_signature( plugin_impl_signature, register_func_annotation, impl_func_annotation, - ) = generate_signature(torch_op) + ) = generate_signature(torch_op, num_outputs) - def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]: - shape_env = ShapeEnv() - syms_args = [] + def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc, ...]: tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)] - - for tensor_arg in tensor_args: - sample = {f"{i}": 5 for i in range(tensor_arg.ndim)} - syms_arg = [ - mksym(shape_env, v, LocalSource(k), DimDynamic.DYNAMIC) - for k, v in sample.items() - ] - syms_args.append(syms_arg) - - with FakeTensorMode(shape_env=shape_env) as fake_mode: - fake_args = [] - for syms_arg in syms_args: - fake_arg = torch.randn(syms_arg) - fake_args.append(fake_arg) - - output = torch_op(*fake_args, **kwargs) - - # We assume that number of dimensions are the same in torch op - shape_calc_fns = [None] * output.ndim - - for i in range(output.ndim): - input_node_expr = list( - itertools.chain.from_iterable( - [sym.node.expr for sym in syms_arg] for syms_arg in syms_args - ) - ) - - shape_calc_fns[i] = lambdify( - tuple(input_node_expr), output.shape[i].node.expr, "math" - ) - - out_desc = tensor_args[0].like() - for i in range(out_desc.ndim): - input_shape_expr = list( - itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args) - ) - - if output.shape[i].node.expr is None: - raise ValueError(f"output.shape[{i}].node.expr cannot be None") - out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc] - - return (out_desc,) + return tuple(_compute_out_descs_symbolic(tensor_args, torch_op)) codegen_plugin = f""" {plugin_signature} @@ -233,3 +358,46 @@ def generate_plugin(plugin_name: str) -> None: There should be existing kernels and pytorch custom operation for this plugin name. """ _generate_plugin(plugin_name) + + +def register_plugin_with_aot( + plugin_name: str, + desc_fn: Any, + autotune_fn: Optional[Any] = None, + aot_fn: Optional[Any] = None, +) -> None: + """Register a QDP plugin's descriptor, autotune, and AOT-impl callbacks with TRT. + + Centralises the three ``trtp`` registration calls so that both the + automatic JIT path (``generate_plugin``) and the TTA AOT path + (``register_custom_plugin``) converge on the same registration surface. + + Callers are responsible for idempotency and error handling — this function + makes the raw ``trtp`` calls and raises whatever TRT raises. + + Args: + plugin_name: TRT QDP op name in ``"namespace::name"`` form. + desc_fn: Callable registered via ``@trtp.register``. Must carry a + correct ``inspect.Signature`` with ``TensorDesc``-annotated + parameters. + autotune_fn: Optional callable registered via ``@trtp.autotune``. + Pass ``None`` to skip autotune registration (TRT will use a + single default tactic). + aot_fn: Optional callable registered via ``@trtp.aot_impl``. + Pass ``None`` to skip AOT registration (TRT will use the + JIT ``@trtp.impl`` path if one was registered separately). + """ + try: + import tensorrt.plugin as trtp + except ImportError as exc: + raise RuntimeError( + "TensorRT with plugin support is required for AOT plugin registration." + ) from exc + + trtp.register(plugin_name)(desc_fn) + + if autotune_fn is not None: + trtp.autotune(plugin_name)(autotune_fn) + + if aot_fn is not None: + trtp.aot_impl(plugin_name)(aot_fn) diff --git a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py index 3fc3509f3f..b29f12a7c9 100644 --- a/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py +++ b/py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin_converter.py @@ -102,7 +102,9 @@ def custom_kernel_converter( f"Adding generated plugin for {namespace}::{name} to tensorrt network" ) layer.name = f"[{target}]-[{name}]" - return layer.get_output(0) + if layer.num_outputs == 1: + return layer.get_output(0) + return tuple(layer.get_output(i) for i in range(layer.num_outputs)) custom_kernel_converter = dynamo_tensorrt_converter( torch_overload, diff --git a/tests/py/annotation/BUILD b/tests/py/annotation/BUILD new file mode 100644 index 0000000000..22fbb5cfbc --- /dev/null +++ b/tests/py/annotation/BUILD @@ -0,0 +1,10 @@ +load("@rules_python//python:defs.bzl", "py_test") + +package(default_visibility = ["//visibility:public"]) + +py_test( + name = "test_stage1_basic", + srcs = [ + "test_stage1_basic.py", + ], +) diff --git a/tests/py/annotation/__init__.py b/tests/py/annotation/__init__.py new file mode 100644 index 0000000000..185e5234ae --- /dev/null +++ b/tests/py/annotation/__init__.py @@ -0,0 +1 @@ +# Torch-TensorRT Annotation (TTA) tests diff --git a/tests/py/annotation/conftest.py b/tests/py/annotation/conftest.py new file mode 100644 index 0000000000..309cae9cb2 --- /dev/null +++ b/tests/py/annotation/conftest.py @@ -0,0 +1,210 @@ +""" +Pytest configuration for annotation tests. + +GPU routing +----------- +Tests marked ``@pytest.mark.requires_pre_bw`` need a pre-Blackwell GPU +(compute capability < 10). On Blackwell/Myelin, TRT's multi-tactic plugin +handling produces incorrect results (same root cause as the SymIntExprs +length-mismatch bug filed against TRT 10.14). + +When the full annotation suite is run without ``CUDA_VISIBLE_DEVICES``: + + 1. Main process → Blackwell GPU (if available). ``requires_pre_bw`` + tests are *skipped* in this pass. + 2. ``pytest_sessionfinish`` reruns all ``requires_pre_bw`` test files + in a fresh subprocess with ``CUDA_VISIBLE_DEVICES`` pointing at + the first pre-Blackwell GPU found via nvidia-smi. + +CI / direct invocation +----------------------- +To run only pre-Blackwell tests:: + + CUDA_VISIBLE_DEVICES= pytest -m requires_pre_bw tests/py/annotation/ + +To run only Blackwell tests:: + + CUDA_VISIBLE_DEVICES= pytest -m "not requires_pre_bw" tests/py/annotation/ +""" + +import os +import subprocess +import sys + +import pytest + +_repo_root = os.path.abspath( + os.path.join(os.path.dirname(__file__), os.pardir, os.pardir, os.pardir) +) +_py_dir = os.path.join(_repo_root, "py") +if os.path.isdir(_py_dir) and _py_dir not in sys.path: + sys.path.insert(0, _py_dir) + +def _ensure_cublas_on_ld_path(): + for p in sys.path: + if "site-packages" not in str(p): + continue + for root, _, files in os.walk(p): + for f in files: + if f.startswith("libcublas") and ".so" in f: + lp = os.environ.get("LD_LIBRARY_PATH", "") + if root not in lp: + os.environ["LD_LIBRARY_PATH"] = root + (":" + lp if lp else "") + return + for base in ( + "/root/.pyenv/versions/*/lib/python*/site-packages/nvidia/cu*/lib", + "/usr/local/lib/python*/site-packages/nvidia/cu*/lib", + ): + import glob + for path in glob.glob(base): + if os.path.isdir(path) and ( + os.path.isfile(os.path.join(path, "libcublas.so.13")) + or os.path.isfile(os.path.join(path, "libcublas.so.12")) + or os.path.isfile(os.path.join(path, "libcublas.so")) + ): + lp = os.environ.get("LD_LIBRARY_PATH", "") + if path not in lp: + os.environ["LD_LIBRARY_PATH"] = path + (":" + lp if lp else "") + return + + +def _gpu_map_nvsmi(): + """Return (blackwell_uuids, pre_blackwell_uuids) via nvidia-smi.""" + bw_uuids, pre_uuids = [], [] + try: + out = subprocess.check_output( + [ + "nvidia-smi", + "--query-gpu=compute_cap,uuid", + "--format=csv,noheader,nounits", + ], + text=True, + stderr=subprocess.DEVNULL, + ) + for line in out.strip().splitlines(): + parts = line.split(",") + if len(parts) < 2: + continue + cc_major = int(parts[0].strip().split(".")[0]) + uuid = parts[1].strip() + if cc_major >= 10: + bw_uuids.append(uuid) + else: + pre_uuids.append(uuid) + except Exception: + pass + return bw_uuids, pre_uuids + + +_IS_PRE_BW_SUBPROCESS = os.environ.get("_TTA_PRE_BW_SUBPROCESS") == "1" + +_BW_GPUS, _PRE_BW_GPUS = _gpu_map_nvsmi() + +_ensure_cublas_on_ld_path() + +collect_ignore: list = [] +if _IS_PRE_BW_SUBPROCESS: + # Pin to a pre-Blackwell GPU if the caller did not already set one. + if "CUDA_VISIBLE_DEVICES" not in os.environ and _PRE_BW_GPUS: + os.environ["CUDA_VISIBLE_DEVICES"] = _PRE_BW_GPUS[0] +elif _BW_GPUS: + # Main run on Blackwell: pin to Blackwell GPUs. + os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(_BW_GPUS) + + +# --------------------------------------------------------------------------- +# Mark-based GPU routing +# --------------------------------------------------------------------------- + +def pytest_runtest_setup(item): + """Skip requires_pre_bw tests when running on Blackwell (main pass). + + They will be re-executed on a pre-Blackwell GPU in pytest_sessionfinish. + """ + if _IS_PRE_BW_SUBPROCESS: + return + if not _BW_GPUS: + return # no Blackwell GPU — run everything here + if item.get_closest_marker("requires_pre_bw"): + pytest.skip( + "requires_pre_bw: skipped on Blackwell; " + "runs automatically in a pre-Blackwell subprocess after the main pass. " + "To run directly: CUDA_VISIBLE_DEVICES= pytest -m requires_pre_bw" + ) + + +@pytest.fixture(autouse=True) +def clear_aot_caches(): + """No-op fixture retained for API compatibility.""" + yield + + +# --------------------------------------------------------------------------- +# Per-worker timing cache — speeds up repeated TRT builds within one session +# --------------------------------------------------------------------------- + +# Persistent per-worker timing cache directory. Using ~/.cache (not /tmp) +# so the cache survives reboots: the first run pays the full tactic-selection +# cost; subsequent runs reuse previously measured tactics and are significantly +# faster (especially for INT8/FP8 builds). +_TC_DIR: str = os.path.join(os.path.expanduser("~"), ".cache", "tta_test_timing_cache") +os.makedirs(_TC_DIR, exist_ok=True) + + +@pytest.fixture(autouse=True, scope="session") +def _session_timing_cache(): + """Inject a per-worker TRT timing cache into every ``torch_tensorrt.compile`` call. + + With ``editable_timing_cache=True``, TRT writes each newly selected tactic + to the cache file so subsequent builds of the same op+shape+dtype skip + tactic selection entirely. + + Each worker process uses its own cache file (keyed by ``PYTEST_XDIST_WORKER`` + id) to avoid concurrent-write corruption under ``-n 12``. The cache dir is + **not** deleted at session end so the entries accumulate across runs. + """ + import torch_tensorrt.dynamo + + worker_id = os.environ.get("PYTEST_XDIST_WORKER", "main") + cache_path = os.path.join(_TC_DIR, f"tc_{worker_id}.bin") + + _orig_compile = torch_tensorrt.dynamo.compile + + def _patched_compile(*args, **kwargs): + kwargs.setdefault("timing_cache_path", cache_path) + kwargs.setdefault("editable_timing_cache", True) + return _orig_compile(*args, **kwargs) + + torch_tensorrt.dynamo.compile = _patched_compile + yield + torch_tensorrt.dynamo.compile = _orig_compile + + +def pytest_sessionfinish(session, exitstatus): + if _IS_PRE_BW_SUBPROCESS or not _PRE_BW_GPUS: + return + + # --- requires_pre_bw tests --- + # Collect the unique test-file paths for all requires_pre_bw items. + pre_bw_files: list = [] + seen: set = set() + for item in session.items: + if item.get_closest_marker("requires_pre_bw"): + path = str(item.fspath) + if path not in seen: + seen.add(path) + pre_bw_files.append(path) + + if pre_bw_files: + env = os.environ.copy() + env["CUDA_VISIBLE_DEVICES"] = _PRE_BW_GPUS[0] + env["_TTA_PRE_BW_SUBPROCESS"] = "1" + n_workers = str(session.config.getoption("numprocesses", default=4)) + cmd = ( + [sys.executable, "-m", "pytest", "-v", "--tb=short", "-n", n_workers, + "-m", "requires_pre_bw"] + + pre_bw_files + ) + result = subprocess.run(cmd, env=env, cwd=str(session.config.rootpath)) + if result.returncode != 0 and exitstatus == 0: + session.exitstatus = result.returncode diff --git a/tests/py/annotation/integration/__init__.py b/tests/py/annotation/integration/__init__.py new file mode 100644 index 0000000000..20a7c55161 --- /dev/null +++ b/tests/py/annotation/integration/__init__.py @@ -0,0 +1 @@ +# TTA integration tests diff --git a/tests/py/annotation/integration/test_custom_plugin_trt_plugins_e2e.py b/tests/py/annotation/integration/test_custom_plugin_trt_plugins_e2e.py new file mode 100644 index 0000000000..e6f2df9bb7 --- /dev/null +++ b/tests/py/annotation/integration/test_custom_plugin_trt_plugins_e2e.py @@ -0,0 +1,1797 @@ +"""End-to-end tests for trt_plugins.custom_op(impl=tta.custom_plugin(...)). + +Each test compiles the model with torch_tensorrt and verifies: + 1. The custom op lowers into exactly one TRT engine. + 2. That engine contains a PluginV3 layer for the custom op. + 3. The engine output matches a pure-PyTorch reference. + +Backends and operations +----------------------- + Backend | Unary | Binary + ----------|------------------|---------------------------------- + Triton | silu (x·σ(x)) | swiglu (silu(gate)·up) + CuTile | relu (max(x,0)) | reglu (relu(gate)·up) + CuTeDSL | silu (x·σ(x)) | hadamard (x·y) + +Triton and CuTile also register multi-config variants (BLOCK_SIZE ∈ +{64, 128, 256}) so TRT's tactic-selection autotuner is exercised. + +Test semantics (shared across all backends via _BackendE2ETests mixin) +---------------------------------------------------------------------- + test_unary_activation : standalone activation on [seq, hidden] + test_binary_gating : standalone gating on [seq, hidden] pairs + test_llm_hidden_unary : activation at LLM hidden dim [batch=8, hidden=512] + test_llm_hidden_binary : gating at LLM hidden dim [batch=8, hidden=512] + test_dynamic_batch : dynamic batch dimension, hidden=256 + test_gated_ffn_llm : LLM-ratio FFN (hidden=256, inter=512, 2× expansion) + test_gated_ffn_block : shared-input FFN (xfail — TRT mergeMatmulLayers bug) + test_gated_ffn_block_contiguous : separate-input FFN workaround + test_chained_silu_gate : two chained PluginV3 ops in one engine + +Multi-config semantics (Triton + CuTile only, via _MultiConfigTests mixin) +-------------------------------------------------------------------------- + test_multi_config_unary : unary op with 3 tile-size configs, TRT picks best + test_multi_config_binary : binary op with 3 tile-size configs, TRT picks best + +Cross-backend semantics (TestCrossBackendE2E) +--------------------------------------------- + Tests that mix plugins from different backends in a single TRT engine, + verifying that multiple QDP PluginV3 layers coexist and compile correctly. + All shapes are LLM-domain: [batch=8, hidden=512]. +""" + +import math +import unittest + +import tensorrt as trt +import torch +import torch.nn as nn +import torch_tensorrt +import torch_tensorrt.annotation as tta +import torch_tensorrt.dynamo.conversion.plugins as trt_plugins +import triton +import triton.language as tl +import cuda.tile as ct +import cutlass.cute as cute +from torch_tensorrt.dynamo.runtime._PythonTorchTensorRTModule import PythonTorchTensorRTModule +from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import TorchTensorRTModule + + +# --------------------------------------------------------------------------- +# Test infrastructure +# --------------------------------------------------------------------------- + +def _get_trt_engines(model): + """Return all TRT engine submodules in a compiled model.""" + return [ + m for m in model.modules() + if isinstance(m, (PythonTorchTensorRTModule, TorchTensorRTModule)) + ] + + +def _engine_has_pluginv3_layer(engine, op_name): + """Return True iff the engine contains a PluginV3 layer for *op_name*. + + *op_name* is 'namespace::name' (e.g. 'torchtrt_e2e_triton::swiglu'). + + Two conditions: + 1. The op's creator in the TRT registry is _TemplatePluginCreator (QDP / + PluginV3), not the legacy IPluginCreator (V2). + 2. A layer whose name contains '/' appears in the engine. + Two naming conventions are supported: + - Direct: the layer name starts with '/' (e.g. when set via + lower_to_trt which uses ``plugin_layer.name = name``). + - Wrapped: the layer name contains '-[/' (e.g. when set via + generate_plugin_converter which uses + ``layer.name = f"[{target}]-[{name}]"``). + TRT's Myelin optimizer may also append suffixes such as _myl0_0, so + substring matching is used rather than exact string equality. + """ + namespace, name = op_name.split("::", 1) + + reg = trt.get_plugin_registry() + creator = reg.get_creator(name, "1", namespace) + if creator is None or type(creator).__name__ == "IPluginCreator": + return False + + insp = engine.create_engine_inspector() + layer_names = [ + insp.get_layer_information(i, trt.LayerInformationFormat.ONELINE).strip() + for i in range(engine.num_layers) + ] + prefix = f"/{name}" + return any( + ln == prefix or ln.startswith(prefix + "_") or f"-[{prefix}" in ln + for ln in layer_names + ) + + +def _engine_all_layer_names(engine) -> list: + """Return all layer name strings from the TRT engine inspector.""" + insp = engine.create_engine_inspector() + return [ + insp.get_layer_information(i, trt.LayerInformationFormat.ONELINE).strip() + for i in range(engine.num_layers) + ] + + +def _trt_compile(model, inputs): + torch._dynamo.reset() + return torch_tensorrt.compile(model, inputs=inputs, min_block_size=1) + + +def _check(test, model, inputs, ref_fn, op_name, rtol=1e-3, atol=1e-3): + """Compile model to TRT and assert output matches *ref_fn(*inputs)*. + + Also asserts: + - exactly one TRT engine segment is produced + - that engine contains a PluginV3 layer for *op_name* + + *rtol* / *atol* are forwarded to assert_close. Tests involving chained + linear layers may pass looser tolerances due to FP32 accumulation across + multiple matrix multiplications. + """ + ref = ref_fn(*inputs) + compiled = _trt_compile(model, inputs) + engines = _get_trt_engines(compiled) + test.assertEqual(len(engines), 1, + "Expected the custom op to lower into exactly one TRT engine") + test.assertTrue( + _engine_has_pluginv3_layer(engines[0].engine, op_name), + f"Expected a PluginV3 layer for '{op_name}' in the TRT engine", + ) + trt_out = compiled(*inputs) + torch.testing.assert_close(trt_out, ref, rtol=rtol, atol=atol) + + +def _check_multi_op(test, model, inputs, ref_fn, op_names, rtol=1e-3, atol=1e-3): + """Like _check but verifies multiple PluginV3 layers coexist in one engine.""" + ref = ref_fn(*inputs) + compiled = _trt_compile(model, inputs) + engines = _get_trt_engines(compiled) + test.assertEqual(len(engines), 1, + "Expected all custom ops to lower into exactly one TRT engine") + for op_name in op_names: + test.assertTrue( + _engine_has_pluginv3_layer(engines[0].engine, op_name), + f"Expected a PluginV3 layer for '{op_name}' in the TRT engine", + ) + trt_out = compiled(*inputs) + torch.testing.assert_close(trt_out, ref, rtol=rtol, atol=atol) + + +# --------------------------------------------------------------------------- +# Reusable nn.Module wrappers +# --------------------------------------------------------------------------- + +class _UnaryOp(nn.Module): + def __init__(self, op): + super().__init__() + self._op = op + + def forward(self, x): + return self._op(x) + + +class _BinaryOp(nn.Module): + def __init__(self, op): + super().__init__() + self._op = op + + def forward(self, x, y): + return self._op(x, y) + + +# --------------------------------------------------------------------------- +# Pure-PyTorch reference implementations +# --------------------------------------------------------------------------- + +def _ref_silu(x): + """SiLU / Swish: x · σ(x). Used in LLaMA, Mistral, Phi, etc.""" + return x * torch.sigmoid(x) + + +def _ref_swiglu(gate, up): + """SwiGLU gate: silu(gate) · up. Used in LLaMA / Mistral FFN blocks.""" + return gate * torch.sigmoid(gate) * up + + +def _ref_relu(x): + return torch.relu(x) + + +def _ref_reglu(gate, up): + """ReGLU gate: relu(gate) · up. A simpler gated activation.""" + return torch.relu(gate) * up + + +def _ref_hadamard(x, y): + """Element-wise product. Used in attention masking, LoRA updates, etc.""" + return x * y + + +# --------------------------------------------------------------------------- +# Triton kernels +# +# Implements SiLU (x·σ(x)) and SwiGLU (silu(gate)·up). +# These are the activation and gating ops used verbatim in the LLaMA-series +# feed-forward blocks. tl.sigmoid is a first-class Triton intrinsic. +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_silu_kernel(x_ptr, out_ptr, n_cols, + x_stride0, x_stride1, out_stride0, out_stride1, + BLOCK_SIZE: tl.constexpr): + """Stride-aware SiLU: 2D grid, fully general per-dim strides. + + grid = (n_rows, cdiv(n_cols, BLOCK_SIZE)). + program_id(0) = row; program_id(1) = column-block. + Pointer offset: row * stride0 + col * stride1. + Both stride dimensions are passed so the kernel handles any tensor layout + (contiguous, row-padded LINEAR, or any stride(dim) value). + """ + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_stride0 + col_offsets * x_stride1, mask=mask) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, + x * tl.sigmoid(x), mask=mask) + + +def _triton_launch_silu(x, out, BLOCK_SIZE=128): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_silu_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, out, n_cols, + x.stride(0), x.stride(1), out.stride(0), out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +@triton.jit +def _triton_swiglu_kernel(gate_ptr, up_ptr, out_ptr, n_cols, + gate_stride0, gate_stride1, + up_stride0, up_stride1, + out_stride0, out_stride1, + BLOCK_SIZE: tl.constexpr): + """Stride-aware SwiGLU: 2D grid, fully general per-dim strides. + + grid = (n_rows, cdiv(n_cols, BLOCK_SIZE)). + program_id(0) = row; program_id(1) = column-block. + Pointer offset: row * stride0 + col * stride1 for each tensor. + Both stride dimensions are passed so the kernel handles any tensor layout + (contiguous, row-padded LINEAR, or any stride(dim) value). + """ + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + gate = tl.load(gate_ptr + row * gate_stride0 + col_offsets * gate_stride1, mask=mask) + up = tl.load(up_ptr + row * up_stride0 + col_offsets * up_stride1, mask=mask) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, + gate * tl.sigmoid(gate) * up, mask=mask) + + +def _triton_launch_swiglu(gate, up, out, BLOCK_SIZE=128): + n_rows, n_cols = gate.shape[0], gate.shape[1] + _triton_swiglu_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + gate, up, out, n_cols, + gate.stride(0), gate.stride(1), + up.stride(0), up.stride(1), + out.stride(0), out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_triton::swiglu", + impl=tta.custom_plugin( + tta.triton(_triton_launch_swiglu, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda gate, up: gate.new_empty(gate.shape), + ), + supports_dynamic_shapes=True, +) + +# Multi-config: TRT benchmarks BLOCK_SIZE ∈ {64, 128, 256} and picks the +# fastest tactic for the target GPU automatically at engine-build time. +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_mc", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu, + configs=[{"BLOCK_SIZE": 64}, {"BLOCK_SIZE": 128}, {"BLOCK_SIZE": 256}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_triton::swiglu_mc", + impl=tta.custom_plugin( + tta.triton(_triton_launch_swiglu, + configs=[{"BLOCK_SIZE": 64}, {"BLOCK_SIZE": 128}, {"BLOCK_SIZE": 256}]), + meta_impl=lambda gate, up: gate.new_empty(gate.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# CuTile kernels +# +# Implements ReLU (max(x, 0)) and ReGLU (relu(gate)·up). +# ReGLU is a simpler gated variant of the SwiGLU pattern; it appears in +# ablation studies of transformer architectures and in quantized deployments +# where ReLU is preferred for sparsity. +# --------------------------------------------------------------------------- + +@ct.kernel +def _ct_relu_kernel(x, out, tile_size: ct.Constant[int]): + pid = ct.bid(0) + x_tile = ct.load(x, index=(pid,), shape=(tile_size,)) + ct.store(out, index=(pid,), tile=ct.maximum(x_tile, 0.0)) + + +def _ct_launch_relu(x, out, BLOCK=128): + n = x.numel() + stream = torch.cuda.current_stream().cuda_stream + # CuTile kernels use a 1D tile index. In the AOT sandbox path tensors are + # SymbolicTensor; skip reshape/contiguous there — the sandbox only needs + # the grid tuple and never executes the kernel body. + if isinstance(x, torch.Tensor): + x = x.contiguous().reshape(-1) + out = out.reshape(-1) + ct.launch(stream, (ct.cdiv(n, BLOCK), 1, 1), _ct_relu_kernel, (x, out, BLOCK)) + + +@ct.kernel +def _ct_reglu_kernel(gate, up, out, tile_size: ct.Constant[int]): + pid = ct.bid(0) + g_tile = ct.load(gate, index=(pid,), shape=(tile_size,)) + u_tile = ct.load(up, index=(pid,), shape=(tile_size,)) + ct.store(out, index=(pid,), tile=ct.maximum(g_tile, 0.0) * u_tile) + + +def _ct_launch_reglu(gate, up, out, BLOCK=128): + n = gate.numel() + stream = torch.cuda.current_stream().cuda_stream + if isinstance(gate, torch.Tensor): + gate = gate.contiguous().reshape(-1) + up = up.contiguous().reshape(-1) + out = out.reshape(-1) + ct.launch(stream, (ct.cdiv(n, BLOCK), 1, 1), _ct_reglu_kernel, (gate, up, out, BLOCK)) + + +trt_plugins.custom_op( + "torchtrt_e2e_cutile::relu", + impl=tta.custom_plugin( + tta.cutile(_ct_launch_relu, configs=[{"BLOCK": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_cutile::reglu", + impl=tta.custom_plugin( + tta.cutile(_ct_launch_reglu, configs=[{"BLOCK": 128}]), + meta_impl=lambda gate, up: gate.new_empty(gate.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_cutile::relu_mc", + impl=tta.custom_plugin( + tta.cutile(_ct_launch_relu, + configs=[{"BLOCK": 64}, {"BLOCK": 128}, {"BLOCK": 256}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_cutile::reglu_mc", + impl=tta.custom_plugin( + tta.cutile(_ct_launch_reglu, + configs=[{"BLOCK": 64}, {"BLOCK": 128}, {"BLOCK": 256}]), + meta_impl=lambda gate, up: gate.new_empty(gate.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# CuTeDSL kernels +# +# Implements SiLU (x·σ(x)) and Hadamard product (x·y). +# Sigmoid is computed as 1 / (1 + exp(-x)) via cute.arch.exp, which +# lowers to the CUDA expf intrinsic in device code. +# --------------------------------------------------------------------------- + +@cute.kernel +def _cute_silu_kernel(x: cute.Tensor, out: cute.Tensor): + idx = cute.arch.block_idx()[0] + xi = x[idx] + one = x.element_type(1.0) + sigmoid_xi = one / (one + cute.arch.exp(x.element_type(-1.0) * xi)) + out[idx] = xi * sigmoid_xi + + +@cute.jit +def _cute_launch_silu(x: cute.Tensor, out: cute.Tensor): + _cute_silu_kernel(x, out).launch(grid=(math.prod(x.shape), 1, 1), block=(1, 1, 1)) + + +@cute.kernel +def _cute_hadamard_kernel(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor): + idx = cute.arch.block_idx()[0] + out[idx] = x[idx] * y[idx] + + +@cute.jit +def _cute_launch_hadamard(x: cute.Tensor, y: cute.Tensor, out: cute.Tensor): + _cute_hadamard_kernel(x, y, out).launch(grid=(math.prod(x.shape), 1, 1), block=(1, 1, 1)) + + +trt_plugins.custom_op( + "torchtrt_e2e_cutedsl::silu", + impl=tta.custom_plugin( + tta.cutedsl(_cute_launch_silu), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + +trt_plugins.custom_op( + "torchtrt_e2e_cutedsl::hadamard", + impl=tta.custom_plugin( + tta.cutedsl(_cute_launch_hadamard), + meta_impl=lambda x, y: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Plugin attrs: Triton SiLU with a compile-time SCALE factor passed via attrs +# +# SCALE is declared as tl.constexpr so Triton bakes it into the PTX. +# Passing SCALE=2.0 directly as a kwarg to tta.custom_plugin() stores it in +# CustomPluginSpec.attrs; the aot_impl path merges attrs into the sandbox +# kwargs so the kernel is compiled with the correct constant. +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_scaled_silu_kernel(x_ptr, out_ptr, n_cols, + x_stride0, x_stride1, out_stride0, out_stride1, + SCALE: tl.constexpr, BLOCK_SIZE: tl.constexpr): + """SiLU scaled by a compile-time constant: out = SCALE * x * σ(x).""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_stride0 + col_offsets * x_stride1, mask=mask) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, + x * tl.sigmoid(x) * SCALE, mask=mask) + + +def _triton_launch_scaled_silu(x, out, BLOCK_SIZE=128, SCALE=1.0): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_scaled_silu_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, out, n_cols, + x.stride(0), x.stride(1), out.stride(0), out.stride(1), + SCALE=SCALE, BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_scaled", + impl=tta.custom_plugin( + tta.triton(_triton_launch_scaled_silu, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + SCALE=2.0, # direct kwarg: baked into PTX as tl.constexpr + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Multiple outputs: Triton kernel that returns two tensors from one op. +# +# The op computes both SiLU and ReLU activations in a single fused kernel, +# returning them as a (silu_out, relu_out) pair. This exercises the +# multi-output QDP path: +# meta_impl returning tuple → _infer_num_outputs=2 +# → @trtp.register with Tuple[TensorDesc, TensorDesc] return type +# → TRT PluginV3 layer with num_outputs=2 +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_dual_kernel(x_ptr, silu_ptr, relu_ptr, n_cols, + x_s0, x_s1, silu_s0, silu_s1, relu_s0, relu_s1, + BLOCK_SIZE: tl.constexpr): + """Fused SiLU + ReLU: silu_out = x·σ(x), relu_out = max(x, 0).""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_s0 + col_offsets * x_s1, mask=mask) + tl.store(silu_ptr + row * silu_s0 + col_offsets * silu_s1, + x * tl.sigmoid(x), mask=mask) + tl.store(relu_ptr + row * relu_s0 + col_offsets * relu_s1, + tl.maximum(x, 0.0), mask=mask) + + +def _triton_launch_dual(x, silu_out, relu_out, BLOCK_SIZE=128): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_dual_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, silu_out, relu_out, n_cols, + x.stride(0), x.stride(1), + silu_out.stride(0), silu_out.stride(1), + relu_out.stride(0), relu_out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::dual_activation", + impl=tta.custom_plugin( + tta.triton(_triton_launch_dual, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: (x.new_empty(x.shape), x.new_empty(x.shape)), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# supports_dynamic_shapes=False: static-only engine +# +# This op is identical to silu but registered with supports_dynamic_shapes=False. +# --------------------------------------------------------------------------- + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_static", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=False, +) + + +# --------------------------------------------------------------------------- +# Tensor format: explicit input_formats / output_formats on a Triton spec +# --------------------------------------------------------------------------- + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_linear_fmt", + impl=tta.custom_plugin( + tta.triton( + _triton_launch_silu, + configs=[{"BLOCK_SIZE": 128}], + input_formats=[trt.TensorFormat.LINEAR], + output_formats=[trt.TensorFormat.LINEAR], + ), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Non-LINEAR format: CHW32 Triton plugin +# +# A flat-indexed (layout-agnostic) elementwise ReLU registered with +# input_formats=[trt.TensorFormat.CHW32]. TRT will negotiate CHW32 format +# and insert reformatting nodes (LINEAR→CHW32 before, CHW32→LINEAR after) +# around the plugin layer. +# +# The kernel uses flat pointer + offset arithmetic so it is correct for any +# contiguous memory layout, including CHW32. A [4, 32, 8, 8] input uses +# 32 channels to satisfy CHW32's channel-alignment requirement. +# --------------------------------------------------------------------------- + + +@triton.jit +def _triton_flat_relu_kernel(x_ptr, out_ptr, n: tl.constexpr, BLOCK: tl.constexpr): + """Elementwise ReLU via flat pointer offset — layout-agnostic.""" + pid = tl.program_id(0) + offs = pid * BLOCK + tl.arange(0, BLOCK) + mask = offs < n + x = tl.load(x_ptr + offs, mask=mask) + tl.store(out_ptr + offs, tl.maximum(x, 0.0), mask=mask) + + +def _triton_launch_flat_relu(x, out, BLOCK=128): + n = x.numel() + _triton_flat_relu_kernel[(math.ceil(n / BLOCK),)](x, out, n=n, BLOCK=BLOCK) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::relu_chw32", + impl=tta.custom_plugin( + tta.triton( + _triton_launch_flat_relu, + configs=[{"BLOCK": 128}], + input_formats=[trt.TensorFormat.CHW32], + output_formats=[trt.TensorFormat.CHW32], + ), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=False, +) + + +# --------------------------------------------------------------------------- +# Float16 dtype: Triton SiLU with explicit fp32 promotion for AOT compatibility +# +# Triton's AOT compiler can fail for tl.sigmoid with *fp16 pointer types in +# some builds. This kernel promotes fp16→fp32 for sigmoid, then casts back. +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_silu_f16_kernel(x_ptr, out_ptr, n_cols, + x_stride0, x_stride1, out_stride0, out_stride1, + BLOCK_SIZE: tl.constexpr): + """FP16 SiLU: load fp16, promote to fp32 for sigmoid, store fp16.""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_stride0 + col_offsets * x_stride1, mask=mask) + xf = x.to(tl.float32) + result = (xf * tl.sigmoid(xf)).to(tl.float16) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, result, mask=mask) + + +def _triton_launch_silu_f16(x, out, BLOCK_SIZE=128): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_silu_f16_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, out, n_cols, + x.stride(0), x.stride(1), out.stride(0), out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_f16", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu_f16, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Weights: Triton SiLU with a constant column-scale weight tensor. +# +# W is a [hidden_dim] weight baked into the TRT engine as a trt.add_constant +# layer. The plugin receives (x, W) as inputs; out[i,j] = W[j] * silu(x[i,j]). +# This exercises the full custom_plugin(W=tensor) → weights injection path. +# +# _W_COLUMN_SCALE is a module-level constant so its values are known and the +# test reference can be computed independently without calling the torch op. +# --------------------------------------------------------------------------- + +_LLM_B, _LLM_H = 8, 512 # LLM-domain shape used for cross-backend and weight-injection tests + +_W_COLUMN_SCALE = torch.full((_LLM_H,), 3.0) # CPU tensor; lowered via trt.add_constant + + +@triton.jit +def _triton_weighted_silu_kernel(x_ptr, w_ptr, out_ptr, n_cols, + x_s0, x_s1, w_s0, out_s0, out_s1, + BLOCK_SIZE: tl.constexpr): + """Elementwise W[j] * silu(x[i,j]) for a fixed column-scale weight W.""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_s0 + col_offsets * x_s1, mask=mask) + w = tl.load(w_ptr + col_offsets * w_s0, mask=mask) + tl.store(out_ptr + row * out_s0 + col_offsets * out_s1, + x * tl.sigmoid(x) * w, mask=mask) + + +def _triton_launch_weighted_silu(x, W, out, BLOCK_SIZE=128): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_weighted_silu_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, W, out, n_cols, + x.stride(0), x.stride(1), W.stride(0), out.stride(0), out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::weighted_silu", + impl=tta.custom_plugin( + tta.triton(_triton_launch_weighted_silu, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + W=_W_COLUMN_SCALE, # tensor kwarg → injected as trt.add_constant + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Flat (rank-agnostic) Triton SiLU: handles any-rank input via numel(). +# +# The 2D silu kernel assumes rank-2 input. This flat variant uses a single +# 1D grid over all elements, making it compatible with 3D inputs [B, S, H]. +# The launch function reshapes real tensors to 1D before calling the kernel; +# for SymbolicTensors (sandbox path) reshape is skipped since the kernel +# recording only needs the grid shape and pointer types. +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_silu_flat_kernel(x_ptr, out_ptr, n_total, BLOCK_SIZE: tl.constexpr): + """Flat SiLU over n_total elements; grid = (cdiv(n_total, BLOCK_SIZE),).""" + pid = tl.program_id(0) + offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = offsets < n_total + x = tl.load(x_ptr + offsets, mask=mask) + tl.store(out_ptr + offsets, x * tl.sigmoid(x), mask=mask) + + +def _triton_launch_silu_flat(x, out, BLOCK_SIZE=128): + n = x.numel() + if isinstance(x, torch.Tensor): + x = x.contiguous().reshape(-1) + out = out.reshape(-1) + _triton_silu_flat_kernel[(triton.cdiv(n, BLOCK_SIZE),)]( + x, out, n, BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_flat", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu_flat, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# BFloat16 dtype: Triton SiLU with explicit fp32 promotion. +# +# Mirrors the fp16 variant: loads bf16, promotes to fp32 for sigmoid, stores bf16. +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_silu_bf16_kernel(x_ptr, out_ptr, n_cols, + x_stride0, x_stride1, out_stride0, out_stride1, + BLOCK_SIZE: tl.constexpr): + """BF16 SiLU: load bf16, promote to fp32 for sigmoid, store bf16.""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_stride0 + col_offsets * x_stride1, mask=mask) + xf = x.to(tl.float32) + result = (xf * tl.sigmoid(xf)).to(tl.bfloat16) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, result, mask=mask) + + +def _triton_launch_silu_bf16(x, out, BLOCK_SIZE=128): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_silu_bf16_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, out, n_cols, + x.stride(0), x.stride(1), out.stride(0), out.stride(1), + BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_bf16", + impl=tta.custom_plugin( + tta.triton(_triton_launch_silu_bf16, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Multiple attrs: Triton SiLU with two compile-time constants SCALE and BIAS. +# +# Both are declared tl.constexpr and baked into the PTX at AOT time. +# This exercises the full multi-attr path: +# custom_plugin(SCALE=3.0, BIAS=1.0) → attrs={"SCALE":3.0,"BIAS":1.0} +# → sandbox merged kwargs → constexpr values in PTX +# --------------------------------------------------------------------------- + +@triton.jit +def _triton_scale_bias_kernel(x_ptr, out_ptr, n_cols, + x_stride0, x_stride1, out_stride0, out_stride1, + SCALE: tl.constexpr, BIAS: tl.constexpr, + BLOCK_SIZE: tl.constexpr): + """SiLU scaled and shifted: out = SCALE * x * σ(x) + BIAS.""" + row = tl.program_id(0) + col_pid = tl.program_id(1) + col_offsets = col_pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask = col_offsets < n_cols + x = tl.load(x_ptr + row * x_stride0 + col_offsets * x_stride1, mask=mask) + tl.store(out_ptr + row * out_stride0 + col_offsets * out_stride1, + x * tl.sigmoid(x) * SCALE + BIAS, mask=mask) + + +def _triton_launch_scale_bias(x, out, BLOCK_SIZE=128, SCALE=1.0, BIAS=0.0): + n_rows, n_cols = x.shape[0], x.shape[1] + _triton_scale_bias_kernel[(n_rows, triton.cdiv(n_cols, BLOCK_SIZE))]( + x, out, n_cols, + x.stride(0), x.stride(1), out.stride(0), out.stride(1), + SCALE=SCALE, BIAS=BIAS, BLOCK_SIZE=BLOCK_SIZE, + ) + + +trt_plugins.custom_op( + "torchtrt_e2e_triton::silu_scale_bias", + impl=tta.custom_plugin( + tta.triton(_triton_launch_scale_bias, configs=[{"BLOCK_SIZE": 128}]), + meta_impl=lambda x: x.new_empty(x.shape), + SCALE=3.0, BIAS=1.0, + ), + supports_dynamic_shapes=True, +) + + +# --------------------------------------------------------------------------- +# Mixin: shared test semantics across all backends +# --------------------------------------------------------------------------- + +class _BackendE2ETests: + """Four core E2E semantics shared by every backend. + + Concrete subclasses must set: + _NS : op namespace (e.g. "torchtrt_e2e_triton") + _UNARY_OP : unary op name (e.g. "silu") + _BINARY_OP : binary op name (e.g. "swiglu") + _UNARY_REF : staticmethod — pure-PyTorch reference for the unary op + _BINARY_REF : staticmethod — pure-PyTorch reference for the binary op + """ + + _NS: str + _UNARY_OP: str + _BINARY_OP: str + _UNARY_REF: callable + _BINARY_REF: callable + + def _op(self, name): + return getattr(getattr(torch.ops, self._NS), name).default + + def test_unary_activation(self): + """Standalone activation on [seq_len=16, hidden=256] — typical LLM hidden state. + + Exercises the op in isolation: no surrounding aten ops, single custom + plugin layer in the engine. + """ + m = _UnaryOp(self._op(self._UNARY_OP)).eval().cuda() + inputs = [torch.randn(16, 256, device="cuda")] + _check(self, m, inputs, self._UNARY_REF, f"{self._NS}::{self._UNARY_OP}") + + def test_binary_gating(self): + """Standalone gating on [seq_len=16, hidden=256] tensor pairs. + + Exercises the binary custom op in isolation with matched input shapes, + reproducing the element-wise gate × value combination used in attention + and gated linear units before the down-projection. + """ + m = _BinaryOp(self._op(self._BINARY_OP)).eval().cuda() + inputs = [torch.randn(16, 256, device="cuda"), torch.randn(16, 256, device="cuda")] + _check(self, m, inputs, self._BINARY_REF, f"{self._NS}::{self._BINARY_OP}") + + # TRT mergeMatmulLayers delivers non-contiguous sub-region buffers to + # IPluginV3::enqueue without inserting a reformat copy, violating the + # LINEAR stride contract. Expected to fail until TRT fixes this. + @unittest.expectedFailure + def test_gated_ffn_block(self): + """Full gated FFN block with shared input — expected to fail due to TRT bug. + + This is the exact feed-forward structure used in LLaMA, Mistral, Qwen, + and other SwiGLU / ReGLU-based transformers: + + gate = fc_gate(x) # [batch, hidden] → [batch, intermediate] + up = fc_up(x) # [batch, hidden] → [batch, intermediate] + h = gate_fn(gate, up) # custom PluginV3 — fused gate computation + out = fc_down(h) # [batch, intermediate] → [batch, hidden] + + Both fc_gate and fc_up share the same input ITensor x. TRT's + mergeMatmulLayers optimizer fuses them into a single [batch, 2*intermediate] + GEMM and presents each half as a [batch, intermediate] sub-region with + physical row-stride 2*intermediate while still tagging the format as LINEAR. + This violates the LINEAR stride contract: IPluginV3::enqueue receives + non-contiguous buffers but PluginTensorDesc has no physical stride field, + so the plugin infers stride = intermediate (logical) instead of 2*intermediate. + + See test_gated_ffn_block_contiguous for the workaround (separate inputs). + TRT bug filed: mergeMatmulLayers delivers non-contiguous sub-region buffers + to IPluginV3::enqueue without inserting a reformat copy. + """ + hidden_dim, intermediate_dim = 64, 128 + gate_op = self._op(self._BINARY_OP) + _binary_ref = self._BINARY_REF + + class GatedFFN(nn.Module): + def __init__(self): + super().__init__() + self.fc_gate = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_up = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_down = nn.Linear(intermediate_dim, hidden_dim, bias=False) + self._gate = gate_op + + def forward(self, x): + gate = self.fc_gate(x) + up = self.fc_up(x) + return self.fc_down(self._gate(gate, up)) + + m = GatedFFN().eval().cuda() + inputs = [torch.randn(8, hidden_dim, device="cuda")] + + def _ref(x): + with torch.no_grad(): + return m.fc_down(_binary_ref(m.fc_gate(x), m.fc_up(x))) + + _check(self, m, inputs, _ref, f"{self._NS}::{self._BINARY_OP}", rtol=1e-3, atol=1e-3) + + def test_gated_ffn_block_contiguous(self): + """Full gated FFN block with separate inputs — workaround for TRT mergeMatmulLayers bug. + + Same four-op structure as test_gated_ffn_block but fc_gate and fc_up + receive separate input ITensors (x and x_up), which prevents TRT's + mergeMatmulLayers from fusing the two matmuls. Each matmul produces its + own contiguous [batch, intermediate] buffer; the plugin receives + contiguous inputs and the stride-aware kernel reads correct data. + + The model forward takes (x, x_up) where the caller passes the same + tensor for both; in the TRT graph they are distinct network inputs + so canMatmulBeHorizontallyMerged returns false. + """ + hidden_dim, intermediate_dim = 64, 128 + gate_op = self._op(self._BINARY_OP) + _binary_ref = self._BINARY_REF + + class GatedFFNContiguous(nn.Module): + def __init__(self): + super().__init__() + self.fc_gate = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_up = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_down = nn.Linear(intermediate_dim, hidden_dim, bias=False) + self._gate = gate_op + + def forward(self, x, x_up): + gate = self.fc_gate(x) + up = self.fc_up(x_up) + return self.fc_down(self._gate(gate, up)) + + m = GatedFFNContiguous().eval().cuda() + x = torch.randn(8, hidden_dim, device="cuda") + inputs = [x, x] + + def _ref(x, x_up): + with torch.no_grad(): + return m.fc_down(_binary_ref(m.fc_gate(x), m.fc_up(x_up))) + + _check(self, m, inputs, _ref, f"{self._NS}::{self._BINARY_OP}", rtol=1e-3, atol=1e-3) + + def test_llm_hidden_unary(self): + """Standalone activation at LLM hidden dimension [batch=8, hidden=512]. + + Tests the activation at the scale of a small LLM hidden state, exercising + tactic selection at tensor sizes representative of production inference. + """ + m = _UnaryOp(self._op(self._UNARY_OP)).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + _check(self, m, inputs, self._UNARY_REF, f"{self._NS}::{self._UNARY_OP}") + + def test_llm_hidden_binary(self): + """Standalone gating at LLM hidden dimension [batch=8, hidden=512]. + + Tests the binary gating op at the scale of a small LLM hidden state, + covering the element-wise gate × value step in real SwiGLU / ReGLU FFNs. + """ + m = _BinaryOp(self._op(self._BINARY_OP)).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda"), torch.randn(8, 512, device="cuda")] + _check(self, m, inputs, self._BINARY_REF, f"{self._NS}::{self._BINARY_OP}") + + def test_dynamic_batch(self): + """Unary op with dynamic batch dimension, hidden=256. + + Uses torch_tensorrt.Input with min/opt/max batch sizes to exercise TRT's + dynamic shape path. Compilation uses opt_shape=8; inference runs at batch=8. + """ + hidden = 256 + m = _UnaryOp(self._op(self._UNARY_OP)).eval().cuda() + torch._dynamo.reset() + compiled = torch_tensorrt.compile( + m, + inputs=[torch_tensorrt.Input( + min_shape=(1, hidden), opt_shape=(8, hidden), max_shape=(32, hidden), + dtype=torch.float32, + )], + min_block_size=1, + ) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1, + "Expected the custom op to lower into exactly one TRT engine") + self.assertTrue( + _engine_has_pluginv3_layer(engines[0].engine, f"{self._NS}::{self._UNARY_OP}"), + f"Expected a PluginV3 layer for '{self._NS}::{self._UNARY_OP}' in the TRT engine", + ) + x = torch.randn(8, hidden, device="cuda") + trt_out = compiled(x) + torch.testing.assert_close(trt_out, self._UNARY_REF(x), rtol=1e-3, atol=1e-3) + + def test_gated_ffn_llm(self): + """Full gated FFN at LLM expansion ratio: hidden=256, inter=512 (2×). + + Same four-op structure as test_gated_ffn_block_contiguous but at LLM + scale: hidden=256, intermediate=512 reproduces the 2× expansion ratio + used in LLaMA / Mistral FFN blocks. Two distinct inputs prevent + mergeMatmulLayers fusion; both matmuls produce contiguous buffers. + """ + hidden_dim, intermediate_dim = 256, 512 + gate_op = self._op(self._BINARY_OP) + _binary_ref = self._BINARY_REF + + class GatedFFNLLM(nn.Module): + def __init__(self): + super().__init__() + self.fc_gate = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_up = nn.Linear(hidden_dim, intermediate_dim, bias=False) + self.fc_down = nn.Linear(intermediate_dim, hidden_dim, bias=False) + self._gate = gate_op + + def forward(self, x, x_up): + gate = self.fc_gate(x) + up = self.fc_up(x_up) + return self.fc_down(self._gate(gate, up)) + + m = GatedFFNLLM().eval().cuda() + x = torch.randn(8, hidden_dim, device="cuda") + inputs = [x, x] + + def _ref(x, x_up): + with torch.no_grad(): + return m.fc_down(_binary_ref(m.fc_gate(x), m.fc_up(x_up))) + + _check(self, m, inputs, _ref, f"{self._NS}::{self._BINARY_OP}", rtol=1e-3, atol=1e-3) + + def test_chained_silu_gate(self): + """Chained custom ops reproducing the core of SwiGLU / ReGLU gating. + + The two-op chain: + activated = unary_act(x) ← first PluginV3 + out = binary_gate(activated, y) ← second PluginV3 + + models the split where x is the pre-activated gate branch and y is + the up branch that has been computed separately (e.g. by a separate + linear projection not shown here). Both ops must land in one engine. + """ + act_op = self._op(self._UNARY_OP) + gate_op = self._op(self._BINARY_OP) + + class ChainedGate(nn.Module): + def __init__(self): + super().__init__() + self._act = act_op + self._gate = gate_op + + def forward(self, x, y): + return self._gate(self._act(x), y) + + inputs = [torch.randn(16, 256, device="cuda"), torch.randn(16, 256, device="cuda")] + ref_fn = lambda x, y: self._BINARY_REF(self._UNARY_REF(x), y) + _check(self, ChainedGate().eval().cuda(), inputs, ref_fn, + f"{self._NS}::{self._BINARY_OP}") + + +# --------------------------------------------------------------------------- +# Mixin: multi-config tactic selection (Triton + CuTile only) +# --------------------------------------------------------------------------- + +class _MultiConfigTests: + """Two additional tests for backends that register multi-config variants. + + TRT's autotuner benchmarks each tile-size configuration + (e.g. BLOCK_SIZE ∈ {64, 128, 256}) and picks the fastest tactic for the + target GPU at engine-build time. Correctness against the reference is + still verified. + + Multi-config op names follow the convention '_mc'. + Shape [16, 512] gives all three tile sizes non-trivial work. + """ + + def test_multi_config_unary(self): + """Unary op with 3 tile-size configs; TRT picks the best tactic.""" + op_name = f"{self._UNARY_OP}_mc" + m = _UnaryOp(self._op(op_name)).eval().cuda() + inputs = [torch.randn(16, 512, device="cuda")] + _check(self, m, inputs, self._UNARY_REF, f"{self._NS}::{op_name}") + + def test_multi_config_binary(self): + """Binary op with 3 tile-size configs; TRT picks the best tactic.""" + op_name = f"{self._BINARY_OP}_mc" + m = _BinaryOp(self._op(op_name)).eval().cuda() + inputs = [torch.randn(16, 512, device="cuda"), torch.randn(16, 512, device="cuda")] + _check(self, m, inputs, self._BINARY_REF, f"{self._NS}::{op_name}") + + +# --------------------------------------------------------------------------- +# Concrete test classes — one per backend +# --------------------------------------------------------------------------- + +class TestTritonE2E(_MultiConfigTests, _BackendE2ETests, unittest.TestCase): + """Triton backend: SiLU activation + SwiGLU gating (6 tests). + + SiLU and SwiGLU are the canonical activation ops in the LLaMA / Mistral + family. The SwiGLU FFN block tested here is structurally identical to + what ships in those models. + """ + + _NS = "torchtrt_e2e_triton" + _UNARY_OP = "silu" + _BINARY_OP = "swiglu" + _UNARY_REF = staticmethod(_ref_silu) + _BINARY_REF = staticmethod(_ref_swiglu) + + +class TestCuTileE2E(_MultiConfigTests, _BackendE2ETests, unittest.TestCase): + """CuTile backend: ReLU activation + ReGLU gating (6 tests). + + ReGLU (relu(gate) · up) is a simpler gated variant used in ablation + studies and in quantized deployments where ReLU sparsity is desirable. + """ + + _NS = "torchtrt_e2e_cutile" + _UNARY_OP = "relu" + _BINARY_OP = "reglu" + _UNARY_REF = staticmethod(_ref_relu) + _BINARY_REF = staticmethod(_ref_reglu) + + +class TestCuTeDSLE2E(_BackendE2ETests, unittest.TestCase): + """CuTeDSL backend: SiLU activation + Hadamard gating (4 tests). + + SiLU is implemented via cute.arch.exp (lowers to CUDA expf intrinsic). + Hadamard product models the element-wise combination step in attention + masking, LoRA weight updates, and gated linear units. + """ + + _NS = "torchtrt_e2e_cutedsl" + _UNARY_OP = "silu" + _BINARY_OP = "hadamard" + _UNARY_REF = staticmethod(_ref_silu) + _BINARY_REF = staticmethod(_ref_hadamard) + + +# --------------------------------------------------------------------------- +# Cross-backend tests: multiple PluginV3 ops from different backends in one engine +# --------------------------------------------------------------------------- + +# Convenience references to ops registered above. +_triton_silu = lambda: torch.ops.torchtrt_e2e_triton.silu.default +_triton_swiglu = lambda: torch.ops.torchtrt_e2e_triton.swiglu.default +_cutile_relu = lambda: torch.ops.torchtrt_e2e_cutile.relu.default +_cutile_reglu = lambda: torch.ops.torchtrt_e2e_cutile.reglu.default +_cutedsl_silu = lambda: torch.ops.torchtrt_e2e_cutedsl.silu.default +_cutedsl_had = lambda: torch.ops.torchtrt_e2e_cutedsl.hadamard.default + + + +class TestCrossBackendE2E(unittest.TestCase): + """Cross-backend tests: two or three PluginV3 ops from different backends in one engine. + + All shapes are LLM-domain [batch=8, hidden=512]. Each test uses + _check_multi_op to assert that every named PluginV3 layer is present in the + single compiled TRT engine. + """ + + def test_triton_then_cutile_unary(self): + """Triton SiLU → CuTile ReLU: two sequential unary activations.""" + class Model(nn.Module): + def forward(self, x): + return _cutile_relu()(_triton_silu()(x)) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x: _ref_relu(_ref_silu(x)) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_triton::silu", "torchtrt_e2e_cutile::relu"]) + + def test_triton_then_cutedsl_unary(self): + """Triton SiLU → CuTeDSL SiLU: two sequential SiLU activations, different backends.""" + class Model(nn.Module): + def forward(self, x): + return _cutedsl_silu()(_triton_silu()(x)) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x: _ref_silu(_ref_silu(x)) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_triton::silu", "torchtrt_e2e_cutedsl::silu"]) + + def test_cutile_then_cutedsl_unary(self): + """CuTile ReLU → CuTeDSL SiLU: two sequential unary ops, different backends.""" + class Model(nn.Module): + def forward(self, x): + return _cutedsl_silu()(_cutile_relu()(x)) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x: _ref_silu(_ref_relu(x)) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_cutile::relu", "torchtrt_e2e_cutedsl::silu"]) + + def test_triton_unary_then_cutile_binary(self): + """Triton SiLU activation followed by CuTile ReGLU gating. + + Models the pattern: activated_gate = silu(x); out = reglu(activated_gate, up) + where the gate branch passes through a Triton activation before the CuTile + binary gating op. + """ + class Model(nn.Module): + def forward(self, x, y): + return _cutile_reglu()(_triton_silu()(x), y) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x, y: _ref_reglu(_ref_silu(x), y) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_triton::silu", "torchtrt_e2e_cutile::reglu"]) + + def test_cutile_unary_then_triton_binary(self): + """CuTile ReLU activation followed by Triton SwiGLU gating.""" + class Model(nn.Module): + def forward(self, x, y): + return _triton_swiglu()(_cutile_relu()(x), y) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x, y: _ref_swiglu(_ref_relu(x), y) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_cutile::relu", "torchtrt_e2e_triton::swiglu"]) + + def test_cutedsl_unary_then_triton_binary(self): + """CuTeDSL SiLU activation followed by Triton SwiGLU gating.""" + class Model(nn.Module): + def forward(self, x, y): + return _triton_swiglu()(_cutedsl_silu()(x), y) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x, y: _ref_swiglu(_ref_silu(x), y) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_cutedsl::silu", "torchtrt_e2e_triton::swiglu"]) + + def test_triton_swiglu_then_cutedsl_hadamard(self): + """Triton SwiGLU gating followed by CuTeDSL Hadamard masking. + + Models a two-stage gate computation: swiglu merges gate + up, then + hadamard applies a learned mask (e.g. LoRA adapter residual scaling). + """ + class Model(nn.Module): + def forward(self, gate, up, mask): + h = _triton_swiglu()(gate, up) + return _cutedsl_had()(h, mask) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda gate, up, mask: _ref_hadamard(_ref_swiglu(gate, up), mask) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_triton::swiglu", "torchtrt_e2e_cutedsl::hadamard"]) + + def test_cutile_reglu_then_cutedsl_hadamard(self): + """CuTile ReGLU gating followed by CuTeDSL Hadamard masking.""" + class Model(nn.Module): + def forward(self, gate, up, mask): + h = _cutile_reglu()(gate, up) + return _cutedsl_had()(h, mask) + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda"), + torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda gate, up, mask: _ref_hadamard(_ref_reglu(gate, up), mask) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_cutile::reglu", "torchtrt_e2e_cutedsl::hadamard"]) + + def test_all_three_unary_chain(self): + """Triton SiLU → CuTile ReLU → CuTeDSL SiLU: three backends in one engine. + + Verifies that three PluginV3 layers from three different backends all + coexist in a single TRT engine and that the data flows correctly through + each activation in sequence. + """ + class Model(nn.Module): + def forward(self, x): + x = _triton_silu()(x) + x = _cutile_relu()(x) + x = _cutedsl_silu()(x) + return x + + inputs = [torch.randn(_LLM_B, _LLM_H, device="cuda")] + ref_fn = lambda x: _ref_silu(_ref_relu(_ref_silu(x))) + _check_multi_op(self, Model().eval().cuda(), inputs, ref_fn, + ["torchtrt_e2e_triton::silu", + "torchtrt_e2e_cutile::relu", + "torchtrt_e2e_cutedsl::silu"]) + + +# --------------------------------------------------------------------------- +# Plugin attrs E2E +# --------------------------------------------------------------------------- + +class TestAttrsE2E(unittest.TestCase): + """Verify attrs flow from tta.custom_plugin(SCALE=2.0) through TRT to the kernel. + + The op 'torchtrt_e2e_triton::silu_scaled' is compiled with SCALE=2.0 + baked in as a tl.constexpr. Output must equal 2.0 * silu(x). + """ + + def test_triton_attrs_scale_factor(self): + """SCALE attr is baked into PTX as constexpr; output is 2·x·σ(x).""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_scaled.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: 2.0 * x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_scaled") + + +# --------------------------------------------------------------------------- +# Multiple outputs E2E +# --------------------------------------------------------------------------- + +class TestMultiOutputE2E(unittest.TestCase): + """Verify a PluginV3 op with two output tensors compiles and runs correctly. + + The op 'torchtrt_e2e_triton::dual_activation' returns (silu(x), relu(x)). + The test combines them as silu(x) + relu(x) for a single-output comparison. + """ + + def test_triton_dual_output_plugin(self): + """Dual-output PluginV3: engine has one layer, two ITensor outputs.""" + op = torch.ops.torchtrt_e2e_triton.dual_activation.default + + class DualOutputModel(nn.Module): + def forward(self, x): + silu_out, relu_out = op(x) + return silu_out + relu_out + + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + torch.relu(x) + _check(self, DualOutputModel().eval().cuda(), inputs, ref_fn, + "torchtrt_e2e_triton::dual_activation") + + +# --------------------------------------------------------------------------- +# Non-float32 dtype E2E +# --------------------------------------------------------------------------- + +class TestDtypeE2E(unittest.TestCase): + """Verify custom plugins handle non-float32 tensor dtypes end-to-end.""" + + def test_triton_silu_float16(self): + """SiLU op with float16 input/output using a dedicated fp16-safe kernel.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_f16.default).eval().cuda() + inputs = [torch.randn(8, 512, dtype=torch.float16, device="cuda")] + ref_fn = lambda x: (x.float() * torch.sigmoid(x.float())).half() + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_f16", + rtol=1e-3, atol=1e-3) + + def test_cutile_relu_float16(self): + """CuTile ReLU with float16 input/output.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_cutile.relu.default).eval().cuda() + inputs = [torch.randn(8, 512, dtype=torch.float16, device="cuda")] + ref_fn = lambda x: torch.relu(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_cutile::relu", + rtol=1e-3, atol=1e-3) + + +# --------------------------------------------------------------------------- +# supports_dynamic_shapes=False E2E +# --------------------------------------------------------------------------- + +class TestStaticShapeE2E(unittest.TestCase): + """Verify an op registered with supports_dynamic_shapes=False compiles correctly.""" + + def test_static_shape_engine(self): + """Fixed-shape input compiles to a valid TRT engine.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_static.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_static") + + def test_static_shape_different_batch(self): + """Different fixed-shape input also compiles independently.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_static.default).eval().cuda() + inputs = [torch.randn(4, 256, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_static") + + +# --------------------------------------------------------------------------- +# Tensor format E2E +# --------------------------------------------------------------------------- + +class TestTensorFormatE2E(unittest.TestCase): + """Verify explicit input_formats/output_formats flow through to TRT autotune.""" + + def test_explicit_linear_format(self): + """Explicit LINEAR format spec compiles and produces correct output.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_linear_fmt.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_linear_fmt") + + def test_explicit_linear_format_with_llm_shape(self): + """Explicit LINEAR format spec at LLM scale [batch=8, hidden=512].""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_linear_fmt.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_linear_fmt") + + +# --------------------------------------------------------------------------- +# Weights E2E +# --------------------------------------------------------------------------- + +class TestWeightsE2E(unittest.TestCase): + """Verify tensor weights injected via custom_plugin(W=tensor) are baked into + the engine as trt.add_constant layers and reach the kernel correctly. + + Engine check: the output equals W * silu(x) where W is the module-level + _W_COLUMN_SCALE tensor (all-3.0). A plain silu(x) reference would differ + by 3×, so any mismatch in weight injection is immediately caught by the + accuracy assertion. + """ + + def test_triton_column_scale_weight(self): + """Weight tensor W[j] scales column j of silu(x); engine bakes W as constant.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.weighted_silu.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + W_dev = _W_COLUMN_SCALE.cuda() + ref_fn = lambda x: W_dev * x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::weighted_silu") + + +# --------------------------------------------------------------------------- +# BFloat16 dtype E2E +# --------------------------------------------------------------------------- + +class TestBF16DtypeE2E(unittest.TestCase): + """Verify custom plugins compile and produce numerically correct results + for bfloat16 inputs. Two engine-level checks per test: + 1. The TRT engine contains a PluginV3 layer for the op (via _check). + 2. The output tensor dtype is torch.bfloat16 (explicit assertion). + """ + + def test_triton_silu_bfloat16(self): + """BF16 Triton SiLU: engine runs in bf16; output dtype is bfloat16.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_bf16.default).eval().cuda() + inputs = [torch.randn(8, 512, dtype=torch.bfloat16, device="cuda")] + ref_fn = lambda x: (x.float() * torch.sigmoid(x.float())).bfloat16() + compiled = _trt_compile(m, inputs) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_triton::silu_bf16")) + trt_out = compiled(*inputs) + # Engine-level dtype check: output must be bf16, not fp32. + self.assertEqual(trt_out.dtype, torch.bfloat16, + "Expected engine output dtype to be bfloat16") + torch.testing.assert_close(trt_out, ref_fn(*inputs), rtol=2e-2, atol=2e-2) + + def test_cutile_relu_bfloat16(self): + """BF16 CuTile ReLU: engine runs in bf16; output dtype is bfloat16.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_cutile.relu.default).eval().cuda() + inputs = [torch.randn(8, 512, dtype=torch.bfloat16, device="cuda")] + ref_fn = lambda x: torch.relu(x) + compiled = _trt_compile(m, inputs) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_cutile::relu")) + trt_out = compiled(*inputs) + self.assertEqual(trt_out.dtype, torch.bfloat16, + "Expected engine output dtype to be bfloat16") + torch.testing.assert_close(trt_out, ref_fn(*inputs), rtol=2e-2, atol=2e-2) + + +# --------------------------------------------------------------------------- +# 3D input E2E +# --------------------------------------------------------------------------- + +class Test3DInputE2E(unittest.TestCase): + """Verify custom plugins compile and run correctly for rank-3 inputs + [batch, seq_len, hidden]. + + Engine checks: + 1. PluginV3 layer present (via _engine_has_pluginv3_layer). + 2. Output shape equals input shape (rank preserved through engine). + """ + + def test_triton_silu_3d_input(self): + """Triton flat-SiLU on [2, 16, 256]: output shape and values correct.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_flat.default).eval().cuda() + inputs = [torch.randn(2, 16, 256, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + compiled = _trt_compile(m, inputs) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_triton::silu_flat")) + trt_out = compiled(*inputs) + # Engine-level shape check: rank-3 shape must be preserved end-to-end. + self.assertEqual(tuple(trt_out.shape), tuple(inputs[0].shape), + f"Expected output shape {inputs[0].shape}, got {trt_out.shape}") + torch.testing.assert_close(trt_out, ref_fn(*inputs), rtol=1e-3, atol=1e-3) + + def test_cutile_relu_3d_input(self): + """CuTile ReLU on [2, 16, 256]: output shape and values correct.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_cutile.relu.default).eval().cuda() + inputs = [torch.randn(2, 16, 256, device="cuda")] + ref_fn = lambda x: torch.relu(x) + compiled = _trt_compile(m, inputs) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_cutile::relu")) + trt_out = compiled(*inputs) + self.assertEqual(tuple(trt_out.shape), tuple(inputs[0].shape), + f"Expected output shape {inputs[0].shape}, got {trt_out.shape}") + torch.testing.assert_close(trt_out, ref_fn(*inputs), rtol=1e-3, atol=1e-3) + + +# --------------------------------------------------------------------------- +# Multiple attrs E2E +# --------------------------------------------------------------------------- + +class TestMultiAttrsE2E(unittest.TestCase): + """Verify that multiple scalar attrs (SCALE, BIAS) are all baked into PTX + as separate tl.constexpr values. + + Engine check: output equals SCALE * silu(x) + BIAS with both values set + independently. A single-attr kernel with only SCALE would produce + SCALE * silu(x) without the BIAS offset, failing the assertion. + """ + + def test_triton_scale_and_bias_attrs(self): + """SCALE=3.0 and BIAS=1.0 both baked as constexprs; output verified.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu_scale_bias.default).eval().cuda() + inputs = [torch.randn(8, 512, device="cuda")] + ref_fn = lambda x: 3.0 * x * torch.sigmoid(x) + 1.0 + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu_scale_bias") + + +# --------------------------------------------------------------------------- +# Multi-output + dynamic batch E2E +# --------------------------------------------------------------------------- + +class TestMultiOutputDynamicE2E(unittest.TestCase): + """Verify a multi-output plugin compiles with dynamic batch and produces + correct outputs at three different batch sizes. + + Engine checks per batch size: + 1. PluginV3 layer present (verified at compile time via single compile). + 2. Both silu and relu outputs numerically correct (two allclose checks). + """ + + def test_dual_output_dynamic_batch(self): + """dual_activation with dynamic batch: both outputs correct at 1, 8, 32.""" + op = torch.ops.torchtrt_e2e_triton.dual_activation.default + + class DualModel(nn.Module): + def forward(self, x): + silu_out, relu_out = op(x) + return silu_out + relu_out + + torch._dynamo.reset() + compiled = torch_tensorrt.compile( + DualModel().eval().cuda(), + inputs=[torch_tensorrt.Input( + min_shape=(1, 512), opt_shape=(8, 512), max_shape=(32, 512), + dtype=torch.float32, + )], + min_block_size=1, + ) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_triton::dual_activation")) + for batch in (1, 8, 32): + x = torch.randn(batch, 512, device="cuda") + trt_out = compiled(x) + ref = x * torch.sigmoid(x) + torch.relu(x) + torch.testing.assert_close(trt_out, ref, rtol=1e-3, atol=1e-3, + msg=f"Mismatch at batch={batch}") + + +# --------------------------------------------------------------------------- +# Dynamic hidden dimension E2E +# --------------------------------------------------------------------------- + +class TestDynamicHiddenE2E(unittest.TestCase): + """Verify a plugin compiled with a dynamic hidden dimension (dim-1, not dim-0) + produces correct outputs across the full hidden range. + + Engine check: PluginV3 present + outputs correct at hidden = 128, 256, 512. + This is distinct from test_dynamic_batch (which varies dim-0 only). + """ + + def test_triton_dynamic_hidden_dim(self): + """SiLU with dynamic hidden; correct at min=128, opt=256, max=512.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu.default).eval().cuda() + torch._dynamo.reset() + compiled = torch_tensorrt.compile( + m, + inputs=[torch_tensorrt.Input( + min_shape=(8, 128), opt_shape=(8, 256), max_shape=(8, 512), + dtype=torch.float32, + )], + min_block_size=1, + ) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_triton::silu")) + for hidden in (128, 256, 512): + x = torch.randn(8, hidden, device="cuda") + trt_out = compiled(x) + torch.testing.assert_close(trt_out, x * torch.sigmoid(x), + rtol=1e-3, atol=1e-3, + msg=f"Mismatch at hidden={hidden}") + + +# --------------------------------------------------------------------------- +# Production scale E2E +# --------------------------------------------------------------------------- + +class TestLargeShapeE2E(unittest.TestCase): + """Verify the plugin pipeline handles production-scale LLM shapes. + + [32, 4096] matches GPT-2 large / LLaMA-7B hidden_size=4096, batch=32. + Engine check: PluginV3 present + accuracy within tolerance at this scale. + """ + + def test_production_scale(self): + """Triton SiLU at [32, 4096] — production LLM inference scale.""" + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.silu.default).eval().cuda() + inputs = [torch.randn(32, 4096, device="cuda")] + ref_fn = lambda x: x * torch.sigmoid(x) + _check(self, m, inputs, ref_fn, "torchtrt_e2e_triton::silu") + + +# --------------------------------------------------------------------------- +# Cross-backend + dynamic batch E2E +# --------------------------------------------------------------------------- + +class TestCrossBackendDynamicE2E(unittest.TestCase): + """Verify that mixing Triton and CuTile plugins in one engine still works + under dynamic batch shapes. + + Engine checks: + 1. Both PluginV3 layers present in a single engine. + 2. Outputs correct at batch = 1, 8, and 32. + """ + + def test_dynamic_batch_cross_backend(self): + """Triton SiLU → CuTile ReLU in one engine; correct at 3 batch sizes.""" + class ChainModel(nn.Module): + def forward(self, x): + x = torch.ops.torchtrt_e2e_triton.silu.default(x) + return torch.ops.torchtrt_e2e_cutile.relu.default(x) + + torch._dynamo.reset() + compiled = torch_tensorrt.compile( + ChainModel().eval().cuda(), + inputs=[torch_tensorrt.Input( + min_shape=(1, 256), opt_shape=(8, 256), max_shape=(32, 256), + dtype=torch.float32, + )], + min_block_size=1, + ) + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1, "Expected both plugins in one TRT engine") + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_triton::silu")) + self.assertTrue(_engine_has_pluginv3_layer(engines[0].engine, + "torchtrt_e2e_cutile::relu")) + for batch in (1, 8, 32): + x = torch.randn(batch, 256, device="cuda") + trt_out = compiled(x) + ref = torch.relu(x * torch.sigmoid(x)) + torch.testing.assert_close(trt_out, ref, rtol=1e-3, atol=1e-3, + msg=f"Mismatch at batch={batch}") + + +# --------------------------------------------------------------------------- +# Non-LINEAR tensor format E2E +# --------------------------------------------------------------------------- + + +class TestNonLinearFormatsE2E(unittest.TestCase): + """Verify that non-LINEAR ``input_formats`` / ``output_formats`` on a + TritonSpec are correctly propagated through the QDP autotune registration + and result in TRT inserting format-conversion (reformat) layers around the + PluginV3 layer in the compiled engine. + + The observable signal for format negotiation is the presence of + ``"Reformatting CopyNode"`` layer names in the engine inspector output. + When a plugin declares CHW32, TRT inserts: + - a LINEAR → CHW32 reformat before the plugin + - a CHW32 → LINEAR reformat after the plugin + so the engine has 3 layers total instead of 1. + + Input shape [4, 32, 8, 8] is used because CHW32 requires the channel + dimension (dim 1) to be a multiple of 32. + + The flat-indexed ReLU kernel (_triton_launch_flat_relu) uses contiguous + pointer arithmetic so it is correct for any memory layout, including CHW32. + """ + + _OP = "torchtrt_e2e_triton::relu_chw32" + _SHAPE = (4, 32, 8, 8) + + def _compile(self): + m = _UnaryOp(torch.ops.torchtrt_e2e_triton.relu_chw32.default).eval().cuda() + torch._dynamo.reset() + return torch_tensorrt.compile( + m, + inputs=[torch.zeros(self._SHAPE, device="cuda")], + min_block_size=1, + ) + + def test_pluginv3_present(self): + """CHW32 plugin compiles into a TRT engine with a PluginV3 layer.""" + compiled = self._compile() + engines = _get_trt_engines(compiled) + self.assertEqual(len(engines), 1) + self.assertTrue( + _engine_has_pluginv3_layer(engines[0].engine, self._OP), + f"Expected PluginV3 layer for {self._OP!r} in engine", + ) + + def test_reformat_nodes_present(self): + """TRT inserts reformatting nodes around the CHW32 plugin. + + This confirms the CHW32 format token from ``input_formats`` was + correctly propagated into the ``AutoTuneCombination`` and TRT + negotiated the non-LINEAR format for the plugin layer. + + When TRT selects CHW32, it inserts format-conversion layers before and + after the plugin. These may appear as ``"Reformatting CopyNode"`` + entries in the raw inspector output, or as Myelin-fused ``__mye*`` / + ``__myl*`` nodes after graph optimization. Either way the engine has + more than one layer, unlike a LINEAR plugin which needs no reformats. + """ + compiled = self._compile() + engines = _get_trt_engines(compiled) + layer_names = _engine_all_layer_names(engines[0].engine) + # Find layers that are NOT the plugin itself — these are the reformat + # (or Myelin-fused reformat) nodes inserted by TRT for CHW32. + non_plugin = [ + ln for ln in layer_names + if "relu_chw32" not in ln + ] + self.assertGreater( + len(non_plugin), 0, + f"Expected extra reformat layers for CHW32 format negotiation, " + f"but found none. Layer names: {layer_names}", + ) + + def test_output_correct(self): + """Output values are correct despite CHW32 format repack.""" + compiled = self._compile() + x = torch.randn(self._SHAPE, device="cuda") + trt_out = compiled(x) + ref = torch.relu(x) + torch.testing.assert_close(trt_out, ref, rtol=1e-5, atol=1e-5) + diff --git a/tests/py/annotation/unit/__init__.py b/tests/py/annotation/unit/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/py/annotation/unit/test_specs.py b/tests/py/annotation/unit/test_specs.py new file mode 100644 index 0000000000..f84e7640a2 --- /dev/null +++ b/tests/py/annotation/unit/test_specs.py @@ -0,0 +1,83 @@ +"""Unit tests for CustomPluginSpec op_name invariants. + +These two properties have no equivalent in the integration suite: e2e tests +only register one plugin at a time and never inspect op_name directly. A +broken op_name causes TRT to silently fail to find or alias plugins, which +is extremely hard to diagnose from e2e output alone. +""" + +import unittest + +import torch_tensorrt.annotation as tta + + +class TestKernelSpecValidation(unittest.TestCase): + """Negative tests for TritonSpec / CuTileSpec / CuTeDSLSpec construction.""" + + def test_rejects_non_callable_launch_fn(self): + with self.assertRaises(TypeError): + tta.TritonSpec(launch_fn="not_callable") + + def test_rejects_non_list_configs(self): + def kernel(x, out): pass + with self.assertRaises(TypeError): + tta.TritonSpec(launch_fn=kernel, configs={"BLOCK": 128}) + + +class TestCustomPluginSpecValidation(unittest.TestCase): + """Negative tests for user-facing API errors not exercised by integration tests.""" + + def test_rejects_missing_meta_impl(self): + def kernel(x, out): pass + with self.assertRaises((TypeError, ValueError)): + tta.custom_plugin(tta.triton(kernel)) + + def test_rejects_none_meta_impl(self): + def kernel(x, out): pass + with self.assertRaises(ValueError) as ctx: + tta.custom_plugin(tta.triton(kernel), meta_impl=None) + self.assertIn("meta_impl", str(ctx.exception)) + + def test_rejects_non_callable_meta_impl(self): + def kernel(x, out): pass + with self.assertRaises(TypeError) as ctx: + tta.custom_plugin(tta.triton(kernel), meta_impl="not_callable") + self.assertIn("callable", str(ctx.exception)) + + def test_rejects_invalid_spec_type_lists_all_backends(self): + with self.assertRaises(TypeError) as ctx: + tta.custom_plugin("not_a_spec", meta_impl=lambda x: x.new_empty(x.shape)) + msg = str(ctx.exception) + self.assertIn("TritonSpec", msg) + self.assertIn("CuTileSpec", msg) + self.assertIn("CuTeDSLSpec", msg) + + def test_rejects_invalid_element_in_list(self): + def kernel(x, out): pass + with self.assertRaises(TypeError): + tta.custom_plugin([tta.triton(kernel), "invalid"], meta_impl=lambda x: x.new_empty(x.shape)) + + def test_rejects_empty_spec_list(self): + with self.assertRaises(ValueError) as ctx: + tta.custom_plugin([], meta_impl=lambda x: x.new_empty(x.shape)) + self.assertIn("empty", str(ctx.exception)) + + +class TestCustomPluginSpecOpName(unittest.TestCase): + + def test_op_name_uses_tta_custom_namespace(self): + def kernel(x, out): pass + descriptor = tta.custom_plugin(tta.triton(kernel), meta_impl=lambda x: x.new_empty(x.shape)) + ns, name = descriptor.op_name.split("::", 1) + self.assertEqual(ns, "tta_custom") + self.assertTrue(len(name) > 0) + + def test_op_name_differs_across_kernel_functions(self): + """Different kernels must get different op_names — a collision causes TRT + to silently execute the wrong plugin.""" + def kernel_a(x, out): pass + def kernel_b(x, out): pass + meta = lambda x: x.new_empty(x.shape) + op_a = tta.custom_plugin(tta.triton(kernel_a), meta_impl=meta).op_name + op_b = tta.custom_plugin(tta.triton(kernel_b), meta_impl=meta).op_name + self.assertNotEqual(op_a, op_b)