Skip to content

feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL)#4147

Open
BowenFu wants to merge 18 commits intopytorch:mainfrom
BowenFu:feat/tta-custom-plugin
Open

feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL)#4147
BowenFu wants to merge 18 commits intopytorch:mainfrom
BowenFu:feat/tta-custom-plugin

Conversation

@BowenFu
Copy link
Copy Markdown

@BowenFu BowenFu commented Mar 30, 2026

Motivation

Torch-TRT compiles PyTorch models to TensorRT engines, but today there is no first-class path for users who want to replace a subgraph with their own Triton, CuTile, or CuTeDSL kernel inside the compiled engine. The typical workaround — writing a C++ TRT plugin and registering it manually — requires leaving the Python ecosystem, managing separate build systems, and wiring up the plugin registry by hand. This is a significant barrier for researchers and ML engineers who already have high-performance Python kernels.

TensorRT 10.x introduced Quick Deployable Plugins (QDP), which support AOT-compiled Python kernels (@trtp.aot_impl) that are embedded directly into the TRT engine with no Python required at runtime. This PR adds the descriptor and registration layer that lets users express a custom QDP plugin as a plain Python object and pass it to Torch-TRT — with no changes to any core compiler files.

What's in this PR

Public API (import torch_tensorrt.annotation as tta):

Factory Returns Description
tta.triton(launch_fn, configs) TritonSpec Triton kernel descriptor
tta.cutile(launch_fn, arch, configs) CuTileSpec CuTile kernel descriptor (Blackwell sm_100+)
tta.cutedsl(launch_fn, configs) CuTeDSLSpec CuTeDSL kernel descriptor
tta.custom_plugin(impl) CustomPluginSpec AOT QDP plugin descriptor wrapping one or more kernel specs

QDP registration (_custom_plugin/):

  • _descriptor.pyCustomPluginSpec dataclass + custom_plugin() factory; computes a deterministic op name from the kernel function identity + config hash; register_custom_plugin() registers @trtp.register / @trtp.autotune / @trtp.aot_impl with TRT's process-global QDP registry using double-checked locking for xdist safety.
  • _lowering.py — lowers a CustomPluginSpec to a TRT plugin layer via ctx.net.add_plugin(trtp.op.<ns>.<name>(*inputs), aot=True); injects weight tensors as add_constant layers.
  • _qdp_utils.py — deterministic op-name derivation, tactic table building, meta-tensor helpers for symbolic shape inference.
  • _symbolic.pySymbolicTensor abstraction for QDP shape/dtype descriptor registration.

AOT backends (_custom_plugin/_aot/):

  • _triton.py — Triton → PTX via triton.compile; per-config tactic entries.
  • _cutile.py — CuTile → cubin via tileiras; sm_100+ only.
  • _cutedsl.py — CuTeDSL → PTX/cubin via nvidia-cutlass-dsl.

Supporting modules:

  • _specs.pyTritonSpec, CuTileSpec, CuTeDSLSpec, KernelImplSpec frozen dataclasses; triton() / cutile() / cutedsl() factories; normalize_impl_to_spec().
  • _layer_metadata.pyset_tta_layer_metadata() helper for stamping TRT layer metadata; encode/decode round-trip for custom plugin attribution.
  • _recorders.py — launch-parameter recording for Triton/CuTile/CuTeDSL AOT backends.
  • _validation.py — spec and descriptor validation utilities.
  • _errors.pyTTADiagnosticError structured error type.

Tests (tests/py/annotation/unit/, CPU-only, 46 tests):

  • test_specs.py — kernel spec construction, validation, cache-key stability.
  • test_specs_custom_plugin.pyCustomPluginSpec and custom_plugin() factory.
  • test_layer_metadata.py — metadata encode/decode round-trip.

Design notes

  • Descriptor-only, zero core impact. CustomPluginSpec is a plain frozen dataclass. No hooks into _compiler.py, _TRTInterpreter.py, or any other existing file. The integration point (passing a descriptor to a converter) is left for a follow-up PR.
  • Deterministic op naming. The QDP op_name is derived from a hash of the kernel function identity, config set, and weight count. The same descriptor created in two different processes produces the same name, making engine caching safe.
  • Process-global registration with xdist safety. TRT's QDP registry is process-global. A double-checked locking pattern over a process-level set prevents duplicate registration when pytest-xdist workers share a process.
  • Multi-tactic autotuning. Multiple configs dicts produce multiple QDP tactics; TRT's autotuner benchmarks all of them at engine-build time.

trt_plugins.custom_op integration — a torch_tensorrt.dynamo.conversion.plugins API that wires a CustomPluginSpec directly to a registered torch.library custom op, so the TRT lowering path is set up with no manual converter code:

import torch_tensorrt.annotation as tta
from torch_tensorrt.dynamo.conversion import plugins as trt_plugins

def my_op_meta(x, y):
    return x.new_empty(x.shape)

def launch_triton(x, y, out, *, BLOCK: int):
    triton_kernel[...](x, y, out, BLOCK=BLOCK)

def launch_cutile(x, y, out, *, TILE: int):
    cutile_kernel.run(x, y, out, TILE=TILE)

triton_spec = tta.triton(launch_fn=launch_triton, configs=[{"BLOCK": 128}, {"BLOCK": 256}])
cutile_spec = tta.cutile(launch_fn=launch_cutile, configs=[{"TILE": 64}, {"TILE": 128}])

MY_IMPL = tta.custom_plugin(
    kernel=[triton_spec, cutile_spec],
    meta_impl=my_op_meta,
)

trt_plugins.custom_op(
    "torchtrt_ex::my_custom_op",
    impl=MY_IMPL,
)

TRT's autotuner benchmarks all tactics across both backends and selects the fastest for the target GPU.

Future work

This PR establishes the descriptor and registration layer. The follow-up work:

  1. Converter integration — wire CustomPluginSpec into the _TRTInterpreter converter dispatch so that annotated subgraphs are lowered to the registered QDP op during torch_tensorrt.compile.
  2. Region annotation API — a tta.lower_as(impl=..., name=...) context manager that tags subgraph regions during torch.export for targeted lowering to custom plugins. The intended end-to-end usage looks like:
import torch
import torch_tensorrt
import torch_tensorrt.annotation as tta

@triton.jit
def _fused_add_relu(x_ptr, y_ptr, out_ptr, n, BLOCK: tl.constexpr):
    i = tl.program_id(0) * BLOCK + tl.arange(0, BLOCK)
    mask = i < n
    tl.store(out_ptr + i, tl.maximum(0, tl.load(x_ptr+i, mask=mask) + tl.load(y_ptr+i, mask=mask)), mask=mask)

def launch_fused_add_relu(x, y, out, stream, BLOCK=256):
    _fused_add_relu[(triton.cdiv(x.numel(), BLOCK),)](x, y, out, x.numel(), BLOCK=BLOCK)

fused_add_relu = tta.custom_plugin(tta.triton(launch_fused_add_relu, configs=[{"BLOCK": 128}, {"BLOCK": 256}]))

class ResidualBlock(nn.Module):
    def forward(self, x, y):
        with tta.lower_as(impl=fused_add_relu):
            return torch.relu(x + y)

model = ResidualBlock().cuda().eval()
x = torch.randn(4, 1024, device="cuda")
trt_model = torch_tensorrt.compile(model, inputs=[x, x])
  1. End-to-end tests — GPU tests exercising the full compile → run path for Triton, CuTile, and CuTeDSL backends on representative kernels (fused add+ReLU, RMSNorm, attention).

Test plan

docker exec torch_tensorrt_dev \
  python -m pytest tests/py/annotation/unit/ -n 4 --tb=short -v
# 46 passed

@meta-cla
Copy link
Copy Markdown

meta-cla bot commented Mar 30, 2026

Hi @BowenFu!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at cla@meta.com. Thanks!

@github-actions github-actions bot added component: tests Issues re: Tests component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels Mar 30, 2026
@BowenFu BowenFu changed the title feat(annotation): Torch-TensorRT annotation layer — export_as, lower_as, and custom_plugin (QDP/Triton/CuTile/CuTeDSL) draft: feat(annotation): Torch-TensorRT annotation layer — export_as, lower_as, and custom_plugin (QDP/Triton/CuTile/CuTeDSL) Mar 30, 2026
@BowenFu BowenFu marked this pull request as draft March 30, 2026 14:40
@narendasan narendasan requested a review from bowang007 March 30, 2026 17:12
Copy link
Copy Markdown
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@BowenFu can we split this out in to a PR stack? lets put tta.custom_plugin at the bottom. What I want to focus on is lets say a user already has implemented a custom operator in PyTorch backed by one of these kernels. We want to enable the AOT QDP launch of that kernel without a bunch of boilerplate: Basically this example https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/auto_generate_plugins.html with AOT QDP. There is already facilities for the converter generation and plugin registration from PyTorch Meta Kernel. Once we have that we have a solid base to look at region labeling / manual fusion and other advance usecases.

# ---------------------------------------------------------------------------


def lower_custom_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we merge this stuff with already existing plugin autogeneration in torch_tensorrt.dynamo.conversion.plugin? Like we already automate converter generation key'ed on operator name

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just cleaned up all tta.lower_as/tta.export_as related codes from this PR.

return lower_custom_plugin_descriptor(ctx, descriptor, trt_inputs, name)


def register_custom_plugin_qdp(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same here, these facilities already exist, they should be extended not duplicated

@narendasan
Copy link
Copy Markdown
Collaborator

Also @BowenFu please follow the instructions I sent for how to get added to the CLA

@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch from cc46c6c to 29abd38 Compare March 31, 2026 02:29
@BowenFu
Copy link
Copy Markdown
Author

BowenFu commented Mar 31, 2026

@BowenFu can we split this out in to a PR stack? lets put tta.custom_plugin at the bottom. What I want to focus on is lets say a user already has implemented a custom operator in PyTorch backed by one of these kernels. We want to enable the AOT QDP launch of that kernel without a bunch of boilerplate: Basically this example https://docs.pytorch.org/TensorRT/tutorials/_rendered_examples/dynamo/auto_generate_plugins.html with AOT QDP. There is already facilities for the converter generation and plugin registration from PyTorch Meta Kernel. Once we have that we have a solid base to look at region labeling / manual fusion and other advance usecases.

Sure. Will include only tta.custom_plugin in this PR.

@BowenFu BowenFu changed the title draft: feat(annotation): Torch-TensorRT annotation layer — export_as, lower_as, and custom_plugin (QDP/Triton/CuTile/CuTeDSL) draft: feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL) Mar 31, 2026
@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch from 5d036ae to cf485d2 Compare March 31, 2026 12:04
@BowenFu BowenFu marked this pull request as ready for review March 31, 2026 13:04
@BowenFu BowenFu changed the title draft: feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL) feat(annotation): Torch-TensorRT annotation layer — custom_plugin (QDP/Triton/CuTile/CuTeDSL) Mar 31, 2026
@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch 3 times, most recently from 493cf05 to e7b6d3b Compare April 2, 2026 08:53
@BowenFu BowenFu requested a review from narendasan April 2, 2026 09:09
@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch from d1a1767 to 1ac78b7 Compare April 2, 2026 09:35
Copy link
Copy Markdown
Collaborator

@narendasan narendasan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marked a bunch of immediate stuff that stood out. the TL;DR is theres a ton of re-implementation here and some stuff (particularly the triton stuff) seems hacky.

My general recommendation is focus on adding aot_impl and perhaps autotune (@bowang007 did you look at autotune at all?) to the existing plugin system which should handle the rest, rather than essentially making a whole second version. If there are limitations you are running into with what is there then I believe that is where the technical discussion should be centered. Like perhaps the locking system (cc: @bowang007)

I would also recommend trying to make the systems for defining launch parameters more generically applicable, then we dont need to do as much work to say add support for pallas or nvrtc kernels.

Also I would recommend that all the sort of kernel encapsulation stuff would nicely fit in a namespace called torch_tensorrt.kernels or torch_tensorrt.dynamo.kernels It would be immediately obvious what you would use the namespace for then.

requires_output_allocator,
)
if impl is not None:
impl.register_dynamo_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need a complete second code path for this, lets reuse what we have. For example creating the converter is likely the only user of capability_validator, priority, supports_dynamic_shapes, requires_output_allocator. generate_plugin_converter already creates this converter key'ed on name,

Really I would expect the code to look like:

generate_plugin(op_name) # Generates JIT QDP Plugin 
generate_plugin_converter(op_name, capability_validator, priority, supports_dynamic_shapes, requires_output_allocator) # Generates the converter that inserts the QDP plugin 

if impl: #this should be kernel_impl probably 
    impl.generate_plugin_aot_impl() 

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated as suggested.

return getattr(fn, _ANNOTATION_METADATA_ATTR, None)


# ── Custom kernel specs (Triton / CuTile / CuTeDSL) ──────────────────────────
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have one for NVRTC?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@narendasan I would expect one TVM path later after TensorRT actually supports that. NVRTC is too flexible which we need to add lots of constraints on it or bridge work to make it work with the existing QDP path.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are users of the plugin system who are explicitly using NVRTC which is why I ask

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can support that in the future. For now, they can register to trt plugin registery manually, and use the incoming tta.plugin annotation (on par with tta.custom_plugin) to refer to it.



@dataclass
class AnnotationMetadata:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed in this PR?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pushed the Metadata related changes to a future PR.

# ---------------------------------------------------------------------------


def set_tta_layer_metadata(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?

@@ -0,0 +1,546 @@
"""TTA metadata stored on TensorRT ``ILayer`` objects as a plain string.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed in this PR or one of the higher level ones in the PR stack?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

layer metadata is not necessary for the functionality. Will defer it to a separate PR. It is used to map back TRT engine layers back to torch codes, which would be used by some other high level annotations (which will come later).

default_factory=dict, hash=False, compare=False
)

def register_dynamo_plugin(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is mostly redundant, all we need is kernel PTX -> AOT IMPL

_meta = self.meta_impl
_num_outputs = self.num_outputs

# CuTeDSL @cute.jit functions expect cute.Tensor, not torch.Tensor.
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need this only if users are using the kernel in TensorRT and not in PyTorch as custom op, I dont expect this to be very common. Even in eager execution of manually fused regions in future PRs, I would think the expectation is that there should be some custom op that we generate a fx pass to insert

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Naren, cutedsl kernels only accept cute.Tensor, so we need to convert from torch.Tensor to it even for capturing the launch params.

# causing a downstream shape mismatch at TRT engine build time. Users with
# shape-sensitive or rank-sensitive meta_impl functions must pass num_outputs
# explicitly to custom_plugin() to avoid this.
def _infer_num_outputs(meta_impl: Callable[..., Any]) -> int:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this already for custom ops

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to reuse and enhance the existing version.

return False


def _build_meta_impl_desc_fn(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, see:

with FakeTensorMode(shape_env=shape_env) as fake_mode:

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated to reuse and enhance the existing version.


# Mapping from TRT dtype enum values to the string tokens accepted by
# AutoTuneCombination. Populated lazily only when TRT is available.
_TRT_DTYPE_TOKEN: Dict[Any, str] = (
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just add autotune support to the existing plugin system

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

BowenFu pushed a commit to BowenFu/Torch-TensorRT that referenced this pull request Apr 3, 2026
…verter infrastructure

Address PR review comment (pytorch#4147): register_dynamo_plugin
created a parallel converter registration path duplicating the logic already
in generate_plugin_converter.

Changes:
- custom_op(impl=None): unchanged — calls generate_plugin + generate_plugin_converter
- custom_op(impl=...): replaces impl.register_dynamo_plugin() with an inline
  converter that calls register_custom_plugin (QDP reg with weight support) +
  lower_to_trt (weight injection via trt.add_constant), registered via the
  same dynamo_tensorrt_converter decorator used by generate_plugin_converter
- _generate_plugin_converter: return tuple for multi-output plugins

All 181 tests pass (126 unit + 55 e2e, 3 xfailed).
BowenFu pushed a commit to BowenFu/Torch-TensorRT that referenced this pull request Apr 3, 2026
…llow-up PR

Address PR review comments pytorch#4147:
- Comment 4: Remove AnnotationMetadata / attach_annotation_metadata /
  get_annotation_metadata from _specs.py and __init__.py — unused in this PR
  (designed for @tta.export_as which is deferred).
- Comment 5: Remove _layer_metadata.py and its set_tta_layer_metadata call
  in lower_custom_plugin_descriptor — diagnostic-only (TRT engine inspector),
  non-fatal by design, out of scope for this PR.

Both modules are preserved in the backup branch and will be reintroduced in
a higher-level diagnostics / export_as PR.
tta.normalize_impl_to_spec(123)


if __name__ == "__main__":
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test spec seems a little opaque to me. I have no idea what it tests without diving into the code and understand that the spec includes difference specification for each type of kernel. We could make it more straightforward.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Update the file to make it more straightfoward.

pass

normalized = tta.normalize_impl_to_spec(tta.cutedsl(kernel))
self.assertIsInstance(normalized, tta.CustomPluginSpec)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I understand correctly, this won't include the runtime test right?
How can we make sure that the kernel produces correct output with different config?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unit tests does not test accuracy. Please check e2e tests in integration folder.

@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch from 584f349 to 2b7bd95 Compare April 9, 2026 08:43
…OT integration

Adds the `torch_tensorrt.annotation` (tta) module, which lets users register
custom QDP plugins backed by Triton, CuTile, or CuTe DSL kernels and compile
them AOT into TensorRT engines — no Python interpreter required at inference
time.

Public API surface
------------------
* `tta.triton(launch_fn, configs=...)` → TritonSpec
* `tta.cutile(launch_fn, configs=...)` → CuTileSpec
* `tta.cutedsl(launch_fn, configs=..., arch=...)` → CuTeDSLSpec
* `tta.custom_plugin(spec_or_list, meta_impl=...)` → CustomPluginSpec
* `tta.normalize_impl_to_spec(impl)` — coerce bare specs to CustomPluginSpec

Integration path
----------------
CustomPluginSpec plugs into the existing `trt_plugins.custom_op()` /
`generate_plugin_converter()` infrastructure: it registers the QDP
descriptor, autotune callbacks, and AOT impl, then hands the resulting
converter to the standard Dynamo lowering pipeline.

AOT compilation backends
------------------------
* **Triton** — kernel compiled to PTX via `triton.compile`; PTX header
  patched to match TRT 10.16's expected `.version` before handoff.
* **CuTile** — compiled via `cuda_tile.compile`.
* **CuTe DSL** — compiled via `cutlass.cute.compile` with optional arch
  override.

TvmFfiSpec (fourth backend) is planned but blocked on TVM FFI support
landing in QDP.

Key implementation details
--------------------------
* `FakeTensorMode` used for `meta_impl` shape/dtype inference — avoids
  real tensor allocation during registration.
* `_assign_recorded_grid` shared helper centralises the grid assignment
  step reused by all three AOT backends.
* `_lowering.py` and `_impl.py` removed — both were dead code superseded
  by `_descriptor.py` and the dynamo plugin pipeline.
* 3 `test_gated_ffn_block` tests marked `@expectedFailure`: TRT's
  `mergeMatmulLayers` pass delivers non-contiguous sub-region buffers to
  `IPluginV3::enqueue`, violating the LINEAR stride contract.  The
  `_contiguous` variants (separate inputs, no fusion) pass as workaround.
@BowenFu BowenFu force-pushed the feat/tta-custom-plugin branch from 2b7bd95 to d4083cc Compare April 9, 2026 08:44
…ti-rank probe + real-rank at lowering

_infer_num_outputs now tries three strategies in order:
1. Return-type annotation (-> torch.Tensor, -> tuple[Tensor, Tensor])
2. Multi-rank probing (ranks 1-4, size=2 extents to avoid size-1 traps)
3. Raise TypeError with a clear message pointing to num_outputs= kwarg

_eager_body and _fake_body in auto_register_torch_op no longer depend on
_num_outputs — they count outputs dynamically from meta_impl's return value
at call time, when real-ranked tensors are available.

lower_custom_plugin_descriptor re-infers num_outputs from real trt.ITensor
input ranks (the definitive source of truth) before calling
register_custom_plugin.  Falls back to descriptor.num_outputs only if
that inference itself fails.
Bowen Fu added 16 commits April 9, 2026 09:12
…plugin_desc

_generate_plugin hardcoded single-output: it always built one TensorDesc
and returned (out_desc,) regardless of how many tensors torch_op produced.

Changes:
- _probe_num_outputs(): probes torch_op with rank-1..4 FakeTensors (size=2)
  and neutral scalar defaults to determine output count before registration.
- generate_signature(): accepts num_outputs and sets the @trtp.register
  return annotation to Tuple[TensorDesc * N] so TRT allocates N output ports.
- _generic_plugin_desc(): normalises torch_op output to a list, builds a
  TensorDesc with correct symbolic shape exprs for each output, returns a
  tuple of all of them.
…outputs

num_outputs no longer needs to be stored on CustomPluginSpec or inferred
at custom_plugin() time for the QDP registration path.

Changes:
- _custom_op.py: replace impl.num_outputs (always 1 after removing early
  inference) with _probe_num_outputs(torch_op, schema) — called after
  auto_register_torch_op so torch_op is already registered and probeable.
- _generate_plugin.py: fix _probe_num_outputs to use a list of
  (type, default) pairs instead of a dict — torch._C type singletons are
  not hashable.
- _descriptor.py: remove _infer_num_outputs call from custom_plugin();
  auto_register_torch_op now calls _infer_num_outputs(meta_impl) inline
  (annotation + multi-rank probe) for the torch.library schema only;
  CustomPluginSpec.num_outputs retained as deprecated field (default 1).
…ts_from_callable

Move the inline FakeTensor probe logic from auto_register_torch_op into a
named function _probe_num_outputs_from_callable in _generate_plugin.py so
both the JIT and AOT paths share a single implementation.  Remove the
standalone _count_outputs_from_annotation helper from _descriptor.py —
annotation-based inference is now subsumed by the probe (which handles
unannotated multi-output lambdas correctly without requiring type hints).
…_custom_plugin_descriptor

Replace the inline FakeTensor probe in lower_custom_plugin_descriptor with
a call to _probe_num_outputs_from_callable (enhanced with preferred_rank).
When real trt.ITensor inputs are available their rank is passed as the
preferred probe rank, giving a more accurate result without reinventing the
pattern in the annotation module.
…_symbolic

Extract the ShapeEnv + lambdify core from _generic_plugin_desc into a
module-level _compute_out_descs_symbolic helper, and add _build_symbolic_desc_fn
as a factory wrapper for the TTA AOT path.

- _generic_plugin_desc now delegates to _compute_out_descs_symbolic (same
  behaviour, ~50 lines removed from the closure)
- _build_desc_fn in _descriptor.py calls _build_symbolic_desc_fn instead of
  the old _build_meta_impl_desc_fn, eliminating ~100 lines of duplicated
  FakeTensor / meta-tensor logic from the annotation module
- Remove _build_meta_impl_desc_fn and the four make_meta_tensor_from_td*
  imports from _descriptor.py
…plicates

_descriptor.py maintained its own _TRT_DTYPE_TOKEN and _TRT_FORMAT_TOKEN
dicts that duplicated the dtype_token() and format_token() functions already
in _qdp_utils.py.  Remove the dicts, import the functions, and update both
call sites.  Also extend dtype_token() with int8 support and align its
fallback to "FP32" to match the behaviour of the removed dict.
Remove functions that have no callers anywhere in the codebase:

- _build_attr_params in _descriptor.py: attrs are intentionally excluded
  from the descriptor signature (NumPy 1.25+ workaround); this builder
  was never called.

- collect_allowed_formats_for_io in _qdp_utils.py: the LIMITATION comment
  already documented that _build_autotune_fn ignores its result and always
  passes "LINEAR"; no call sites existed.

- make_meta_tensor_from_td, make_meta_tensor_from_td_symbolic,
  make_td_from_meta, make_td_from_meta_using_template,
  make_td_from_meta_using_template_symbolic, _output_shape_to_shape_expr
  in _qdp_utils.py: all became dead after _build_meta_impl_desc_fn was
  replaced by _build_symbolic_desc_fn from _generate_plugin.py.
…lImplSpec, normalize_impl_to_spec

These three symbols were exported in __init__.py and had dedicated tests
but were never raised or instantiated in production code paths.

- TTADiagnosticError: re-exported but never raised; all errors go through
  QDPRuntimeError. Delete _errors.py and test_errors.py.
- KernelImplSpec: exported but never constructed outside tests;
  normalize_impl_to_spec (its only production reference) was also
  unused outside tests. Remove both and their test coverage.
Adds TestNonLinearFormatsE2E to verify that input_formats/output_formats
fields on TritonSpec are correctly propagated through the QDP autotune
path and result in TRT negotiating a non-LINEAR (CHW32) memory format for
the plugin layer.

Three assertions per test:
- PluginV3 layer present in the compiled engine
- Extra (reformat) layers present beyond the plugin itself, confirming
  TRT inserted LINEAR<->CHW32 conversion nodes around the plugin
- Output values correct despite format repacking (max error = 0)

Also adds:
- _triton_flat_relu_kernel / _triton_launch_flat_relu: flat-indexed
  layout-agnostic elementwise ReLU, registered as relu_chw32 with
  input_formats=[trt.TensorFormat.CHW32]
- _engine_all_layer_names(): helper returning all ONELINE layer names
  from the TRT engine inspector

Also adds .venv/ to .gitignore to prevent accidental commits of
the Python virtual environment.
_generate_plugin.py:
- _compute_out_descs_symbolic: loop over output.ndim not out_desc.ndim
  (out_desc mirrors the first input, so ndim would mismatch for ops that
  change dimensionality)

_aot/_cutedsl.py:
- Replace mkdtemp with TemporaryDirectory to clean up dump files on exit
- Remove unused _as_symint32 import
- Fix return type annotation: first element is str, not bytes

_aot/_triton.py:
- Fix return type annotation: first element is str, not bytes

_aot/_cutile.py:
- Drop unused param_binding_indices (only scalar_symints is consumed)
- Remove dead extra_args = SymIntExprs(0) assignment (overwritten later)

_descriptor.py:
- Remove redundant local `import inspect` (already imported at module level)

__init__.py:
- Fix API example: cutile() has no arch param; remove it from the docstring

_specs.py:
- Tighten factory function signatures: Optional[Any] -> Optional[Sequence[int]]
  for input_formats/output_formats to match the dataclass field types
Delete test_recorders.py (internal recorder details covered implicitly
by integration tests) and trim test_specs.py from 31 tests to 11,
keeping only user-facing validation rules and op_name invariants that
the integration suite does not exercise.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

component: api [Python] Issues re: Python API component: conversion Issues re: Conversion stage component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants