Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion core/runtime/BUILD
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
load("@rules_cc//cc:defs.bzl", "cc_library")
load("@rules_pkg//:pkg.bzl", "pkg_tar")
load("@rules_pkg//pkg:mappings.bzl", "pkg_files")

package(default_visibility = ["//visibility:public"])

config_setting(
Expand Down Expand Up @@ -58,6 +59,13 @@ config_setting(
],
)

config_setting(
name = "python_core",
values = {
"define": "target_lang=python",
},
)

cc_library(
name = "runtime",
srcs = [
Expand Down Expand Up @@ -96,6 +104,9 @@ cc_library(
":use_torch_whl": ["@torch_whl//:libtorch"],
":windows": ["@libtorch_win//:libtorch"],
"//conditions:default": ["@libtorch"],
}) + select({
":python_core": ["@libtorch//:pybind11"],
"//conditions:default": [],
}),
alwayslink = True,
)
Expand All @@ -121,6 +132,6 @@ pkg_tar(
pkg_files(
name = "include_pkg_files",
srcs = [":include_files"],
visibility = ["//visibility:public"],
prefix = "include/torch_tensorrt/core/runtime/",
visibility = ["//visibility:public"],
)
28 changes: 28 additions & 0 deletions core/runtime/register_jit_hooks.cpp
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
#include <codecvt>

#include <pybind11/pybind11.h>
#include "core/runtime/Platform.h"
#include "core/runtime/runtime.h"
#include "core/util/macros.h"

namespace py = pybind11;

namespace torch::jit {
struct OpaqueObject : public CustomClassHolder {
OpaqueObject(py::object payload) : payload_(std::move(payload)) {}
py::object payload_;
};
} // namespace torch::jit

namespace torch_tensorrt {
namespace core {
namespace runtime {
Expand Down Expand Up @@ -122,6 +132,8 @@ static auto TORCHTRT_UNUSED TRTEngineTSRegistrtion =

TORCH_LIBRARY(tensorrt, m) {
m.def("execute_engine(Tensor[] input_tensors, __torch__.torch.classes.tensorrt.Engine engine) -> Tensor[]");
m.def("execute_engine.opaque(Tensor[] input_tensors, __torch__.torch.classes.aten.OpaqueObject engine) -> Tensor[]");
m.def("_wrap_engine(__torch__.torch.classes.tensorrt.Engine engine) -> __torch__.torch.classes.aten.OpaqueObject");
m.def("SERIALIZED_ENGINE_BINDING_DELIM", []() -> std::string { return std::string(1, TRTEngine::BINDING_DELIM); });
m.def("SERIALIZED_RT_DEVICE_DELIM", []() -> std::string { return DEVICE_INFO_DELIM; });
m.def("ABI_VERSION", []() -> std::string { return ABI_VERSION; });
Expand Down Expand Up @@ -174,6 +186,22 @@ TORCH_LIBRARY(tensorrt, m) {

TORCH_LIBRARY_IMPL(tensorrt, CompositeExplicitAutograd, m) {
m.impl("execute_engine", execute_engine);
m.impl(
"execute_engine.opaque",
[](std::vector<at::Tensor> inputs, c10::intrusive_ptr<torch::jit::OpaqueObject> opaque_engine) {
py::gil_scoped_acquire gil;
auto capsule = py::cast<py::capsule>(opaque_engine->payload_);
auto* engine_ptr = static_cast<c10::intrusive_ptr<TRTEngine>*>(capsule.get_pointer());
return execute_engine(std::move(inputs), *engine_ptr);
});
m.impl("_wrap_engine", [](c10::intrusive_ptr<TRTEngine> engine) -> c10::intrusive_ptr<torch::jit::OpaqueObject> {
py::gil_scoped_acquire gil;
auto* holder = new c10::intrusive_ptr<TRTEngine>(std::move(engine));
py::capsule capsule(holder, "TRTEngine", [](PyObject* o) {
delete static_cast<c10::intrusive_ptr<TRTEngine>*>(PyCapsule_GetPointer(o, "TRTEngine"));
});
return c10::make_intrusive<torch::jit::OpaqueObject>(py::object(std::move(capsule)));
});
}

} // namespace
Expand Down
5 changes: 2 additions & 3 deletions docsrc/contributors/complex_number_support.rst
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,8 @@ runtime modules handle the conversion:
* ``prepare_inputs`` (``dynamo/utils.py``) — builds the ``Input`` spec with the
``view_as_real`` shape/dtype but retains the original complex tensor in
``inp.torch_tensor`` for tracing.
* ``_PythonTorchTensorRTModule.forward`` — applies ``torch.view_as_real(i).contiguous()``
for each complex input before feeding it to the engine.
* ``_TorchTensorRTModule.forward`` — same ``view_as_real`` conversion.
* ``TorchTensorRTModule.forward`` — applies ``torch.view_as_real(i).contiguous()``
for each complex input before feeding tensors to ``execute_engine`` / ``execute_engine_python``.

Key Implementation Invariants
-------------------------------
Expand Down
4 changes: 2 additions & 2 deletions docsrc/contributors/cuda_graphs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ Subsequent inference launches the instantiated graph instead of calling
Graph Storage
^^^^^^^^^^^^^

Each runtime module (both C++ ``TorchTensorRTModule`` and Python
``PythonTorchTensorRTModule``) stores a ``cudaGraphExec_t`` instance. When
``TorchTensorRTModule`` (C++ or Python execution path) may record a CUDA graph for
engine execution when CUDA graphs are enabled at runtime. When
``use_cuda_graph=True`` is set at compile time the runtime records one graph
per engine for the first input shape encountered.

Expand Down
12 changes: 7 additions & 5 deletions docsrc/debugging/troubleshooting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,10 @@ Runtime Errors
the engine. Upgrade TRT or rebuild with ``version_compatible=True``.
* The GPU compute capability is lower than on the build machine. Rebuild with
``hardware_compatible=True`` (requires Ampere or newer).
* The ``.ep`` file was generated with ``use_python_runtime=True`` which is not
serializable. Rebuild with the default C++ runtime.
* The ``.ep`` export path does not support your compiled module layout (e.g. mixed
Python-runtime subgraphs in a specific exporter version). Try the default C++ path
at compile time or use ``torch_tensorrt`` module save/load APIs that preserve
``TorchTensorRTModule`` state.

**Shape mismatch at runtime / "Invalid input shape"**

Expand All @@ -153,9 +155,9 @@ Runtime Errors
The model contains data-dependent-shape ops (``nonzero``, ``unique``,
``masked_select``, etc.) which require TRT's output allocator.

* Use ``PythonTorchTensorRTModule`` (``use_python_runtime=True``) — it
activates the dynamic output allocator automatically via
``requires_output_allocator=True``.
* Use :func:`~torch_tensorrt.runtime.set_runtime_backend` with ``"python"`` or use a module with
``requires_output_allocator=True`` so the runtime can use TRT's output allocator
on the Python execution path when needed.
* See :ref:`cuda_graphs` for ``DynamicOutputAllocator`` details.

----
Expand Down
16 changes: 16 additions & 0 deletions docsrc/py_api/runtime.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,29 @@ Functions

.. autofunction:: enable_output_allocator

Runtime backend selection
-------------------------

.. autofunction:: torch_tensorrt.runtime.get_runtime_backend

.. autofunction:: torch_tensorrt.runtime.set_runtime_backend

Classes
---------

.. autoclass:: TorchTensorRTModule
:members:
:special-members: __init__
:show-inheritance:

Single runtime module for TensorRT engines. Dispatches to the C++ or Python execution
implementation based on :func:`~torch_tensorrt.runtime.get_runtime_backend` /
:func:`~torch_tensorrt.runtime.set_runtime_backend`. See :ref:`python_runtime`.

.. autoclass:: PythonTorchTensorRTModule
:members:
:special-members: __init__
:show-inheritance:

Subclass of ``TorchTensorRTModule`` that **pins** the Python engine path. Prefer
``TorchTensorRTModule`` plus compile flags unless you need this guarantee. See :ref:`python_runtime`.
4 changes: 2 additions & 2 deletions docsrc/tutorials/runtime_opt/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@ Runtime Optimization
=====================

Optimize inference throughput and latency: CUDA Graphs for kernel-replay,
pre-allocated output buffers, and the Python runtime module.
pre-allocated output buffers, and choosing the Python vs C++ TRT execution path.

.. toctree::
:maxdepth: 1

cuda_graphs
Example: Torch Export with Cudagraphs <../_rendered_examples/dynamo/torch_export_cudagraphs>
Example: Pre-allocated output buffer <../_rendered_examples/dynamo/pre_allocated_output_example>
python_runtime
Python vs C++ runtime <python_runtime>
152 changes: 77 additions & 75 deletions docsrc/tutorials/runtime_opt/python_runtime.rst
Original file line number Diff line number Diff line change
@@ -1,96 +1,103 @@
.. _python_runtime:

Python Runtime
==============
Python vs C++ runtime
=====================

Torch-TensorRT provides two runtime backends for executing compiled TRT engines
inside a PyTorch graph:
Torch-TensorRT uses a single module type, :class:`~torch_tensorrt.runtime.TorchTensorRTModule`,
to run TensorRT engines inside PyTorch. The **execution path** (which code actually drives
``execute_async``) is selected at runtime:

* **C++ runtime** (default) — ``TorchTensorRTModule`` backed by a C++ TorchBind class.
Fully serializable, supports CUDAGraphs, multi-device safe.
* **Python runtime** — ``PythonTorchTensorRTModule`` backed entirely by the TRT Python
API. Simpler to instrument for debugging but **not serializable** to
``ExportedProgram``.
* **C++ path (default)** — ``torch.classes.tensorrt.Engine`` and ``torch.ops.tensorrt.execute_engine``.
Preferred for production when the Torch-TensorRT C++ extension is available: TorchScript-friendly,
and integrates with the full C++ runtime stack.
* **Python path** — When the C++ runtime is absent, use the internal ``TRTEngine`` plus
``torch.ops.tensorrt.execute_engine`` (registered from Python when the C++ runtime is absent). Useful when the C++ extension is absent, or when
you want easier Python-level debugging and instrumentation.

:class:`~torch_tensorrt.runtime.PythonTorchTensorRTModule` is a **thin subclass** of
``TorchTensorRTModule`` that **pins** the Python path (same constructor and behavior, but always
resolves to the Python engine). Prefer ``TorchTensorRTModule`` plus the global backend APIs below
when you do not need that pin.

----

When to Use the Python Runtime
--------------------------------
When to use the Python path
---------------------------

Use ``use_python_runtime=True`` when:
Use :func:`~torch_tensorrt.runtime.set_runtime_backend` (typically as a context manager) when:

* You need to run on a machine where the C++ Torch-TensorRT library is not installed
(e.g., a minimal CI container with only the Python wheel).
* You want to attach Python-level callbacks to the engine execution (via
:ref:`observer`) for debugging or profiling without building the C++ extension.
* You are debugging a conversion issue and want to step through TRT execution in Python.
* The C++ Torch-TensorRT library is not installed (e.g. a minimal environment with only the Python pieces).
* You want Python-level hooks (e.g. :ref:`observer`) without relying on the C++ extension.
* You are debugging conversion or execution and want to break inside the Python TRT wrapper.

Use the default C++ runtime in all other cases, especially:
Prefer the C++ path when:

* When saving a compiled module to disk (``torch_tensorrt.save()``).
* When using CUDAGraphs for low-latency inference.
* In production deployments.
* You rely on the default Torch-TensorRT deployment story and maximum parity with TorchScript export.
* You use whole-graph CUDAGraph wrappers that assume the C++ runtime (see :ref:`cuda_graphs`).

----

Enabling the Python Runtime
-----------------------------
Enabling the Python path
------------------------

**Process-wide default (context manager)**

.. code-block:: python

import torch_tensorrt
import torch_tensorrt as tt

trt_gm = torch_tensorrt.dynamo.compile(
exported_program,
arg_inputs=inputs,
use_python_runtime=True,
)
with tt.runtime.set_runtime_backend("python"):
trt_gm = tt.dynamo.compile(exported_program, inputs)

Or via ``torch.compile``:
**``torch.compile``** (same context manager around compile / first run)

.. code-block:: python

trt_model = torch.compile(
model,
backend="tensorrt",
options={"use_python_runtime": True},
)
import torch_tensorrt as tt

----
with tt.runtime.set_runtime_backend("python"):
trt_model = torch.compile(model, backend="tensorrt", options={})

Limitations
-----------
The context manager does **not** replace :class:`~torch_tensorrt.runtime.PythonTorchTensorRTModule`,
which always requests the Python path via a class-level pin.

* **Not serializable**: ``PythonTorchTensorRTModule`` cannot be saved via
``torch_tensorrt.save()`` as an ``ExportedProgram`` or loaded back. The module is
Python-only in-process.
----

.. code-block:: python
Serialization
---------------

# This will raise an error with use_python_runtime=True:
torch_tensorrt.save(trt_gm, "model.ep", arg_inputs=inputs)
Module state records which backend was used (``runtime_backend`` in packed metadata). After load,
``TorchTensorRTModule`` reconstructs either the C++ engine or the Python engine wrapper
as appropriate. Some **export** workflows (e.g. certain ``ExportedProgram`` save paths) may still
assume a C++-only graph; validate your deployment path if you mix Python execution with AOT export.

* **No C++ deployment**: The compiled module cannot be exported to AOTInductor or used
in a C++ application without re-compiling with the C++ runtime.
----

* **CUDAGraphs**: Whole-graph CUDAGraphs work with the Python runtime, but the
per-submodule CUDAGraph recording in ``CudaGraphsTorchTensorRTModule`` is
only available with the C++ runtime.
Limitations
-----------

* **C++ deployment**: A module that executed on the Python path still needs TensorRT and the
Torch-TensorRT Python pieces available in-process unless you recompile targeting the C++ path.
* **CUDAGraphs**: Whole-graph CUDAGraph wrappers may assume the C++ runtime for some configurations;
see :ref:`cuda_graphs`.
* **Explicit allocator engines**: Engines with data-dependent outputs may set
``requires_output_allocator=True``; the unified module supports the output-allocator execution
mode on the Python path. See :ref:`cuda_graphs` for interaction with CUDA graphs.

----

``PythonTorchTensorRTModule`` Direct Instantiation
----------------------------------------------------
``PythonTorchTensorRTModule`` direct instantiation
--------------------------------------------------

You can instantiate ``PythonTorchTensorRTModule`` directly from raw engine bytes,
for example when integrating a TRT engine built outside of Torch-TensorRT:
You can instantiate :class:`~torch_tensorrt.runtime.PythonTorchTensorRTModule` from raw engine bytes
when you need a **guaranteed** Python execution path (e.g. integrating an engine built outside
Torch-TensorRT):

.. code-block:: python

from torch_tensorrt.dynamo.runtime import PythonTorchTensorRTModule
from torch_tensorrt.dynamo._settings import CompilationSettings

# Load raw engine bytes (e.g., from trtexec output or torch_tensorrt.dynamo.convert_*)
with open("model.engine", "rb") as f:
engine_bytes = f.read()

Expand All @@ -104,37 +111,32 @@ for example when integrating a TRT engine built outside of Torch-TensorRT:

output = module(torch.randn(1, 3, 224, 224).cuda())

**Constructor arguments:**
**Constructor arguments** (same as ``TorchTensorRTModule``):

``serialized_engine`` (``bytes``)
The raw serialized TRT engine bytes.

``input_binding_names`` (``List[str]``)
TRT input binding names in the order they are passed to ``forward()``.
Raw serialized TRT engine.

``output_binding_names`` (``List[str]``)
TRT output binding names in the order they should be returned.
``input_binding_names`` / ``output_binding_names`` (``List[str]``)
Binding names in ``forward`` order.

``name`` (``str``, optional)
Human-readable name for the module (used in logging).
Name for logging and serialization.

``settings`` (``CompilationSettings``, optional)
The compilation settings used to build the engine. Used to determine device
placement and other runtime behaviors.
``settings`` (:class:`~torch_tensorrt.dynamo._settings.CompilationSettings`, optional)
Device and runtime options (must match how the engine was built).

``weight_name_map`` (``dict``, optional)
Mapping of TRT weight names to PyTorch state dict names. Required for refit
support via :func:`~torch_tensorrt.dynamo.refit_module_weights`.
For refit workflows; see :func:`~torch_tensorrt.dynamo.refit_module_weights`.

``requires_output_allocator`` (``bool``, default ``False``)
Set to ``True`` if the engine contains data-dependent-shape ops (``nonzero``,
``unique``, etc.) that require TRT's output allocator.
``requires_output_allocator`` (``bool``)
Set ``True`` for data-dependent-shape ops that need TRT's output allocator.

----

Runtime Selection Logic
------------------------
Runtime selection summary
-------------------------

When ``use_python_runtime`` is ``None`` (auto-select), Torch-TensorRT tries to import
the C++ TorchBind class. If the C++ extension is not available it silently falls back to
the Python runtime. Pass ``True`` or ``False`` to force a specific runtime.
* :func:`~torch_tensorrt.runtime.get_runtime_backend` / :func:`~torch_tensorrt.runtime.set_runtime_backend`
— process default for newly created ``TorchTensorRTModule`` instances (unless a subclass pins a backend).
Use ``set_runtime_backend`` as a context manager to scope C++ vs Python for compile and forward.
* If the C++ extension is **not** built, only the Python path is available.
Loading
Loading