Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions docs/learn/export.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ The `export()` method accepts several parameters to customize the export process
| `force` | `False` | Deprecated and ignored. |
| `shape` | `None` | Input shape as tuple `(height, width)`. Must be divisible by 14. If not provided, uses the model's default resolution. |
| `batch_size` | `1` | Batch size for the exported model. |
| `tensorrt` | `False` | When `True`, convert the ONNX model to a TensorRT `.engine` file. Requires TensorRT (`trtexec`) to be installed. |

## Advanced Export Examples

Expand Down Expand Up @@ -118,6 +119,20 @@ If you want lower latency on NVIDIA GPUs, you can convert the exported ONNX mode
- Install TensorRT (`trtexec` must be available in your `PATH`)
- Export an ONNX model first (for example: `output/inference_model.onnx`)

### Export Directly to TensorRT

Pass `tensorrt=True` to `export()` to export ONNX and convert to a TensorRT engine in one step:

```python
from rfdetr import RFDETRMedium

model = RFDETRMedium(pretrain_weights="<path/to/checkpoint.pth>")

model.export(tensorrt=True)
```

This exports `output/inference_model.onnx` first and then produces `output/inference_model.engine`.

### Python API Conversion

```python
Expand All @@ -131,10 +146,78 @@ args = Namespace(
dry_run=False,
)

trtexec("output/inference_model.onnx", args)
engine_path = trtexec("output/inference_model.onnx", args)
```

`trtexec` returns the path to the generated `.engine` file. If `profile=True`, it also writes an Nsight Systems report (`.nsys-rep`).

## Run Inference with `inference-models`

[`inference-models`](https://github.com/roboflow/inference/tree/main/inference_models) is the
recommended library for running RF-DETR inference. It supports multiple backends — PyTorch,
ONNX, and TensorRT — with automatic backend selection and a unified API.

### Installation

```bash
# CPU / PyTorch only
pip install inference-models

# With TensorRT support (NVIDIA GPU required)
pip install "inference-models[trt-cu12]" # CUDA 12.x
```

See the [inference-models installation guide](https://inference-models.roboflow.com/getting-started/installation/)
for all installation options including Jetson and CUDA 11.x.

### Load a Pre-trained RF-DETR Model

```python
import cv2
from inference_models import AutoModel

# Automatically selects the best available backend for your environment
model = AutoModel.from_pretrained("rfdetr-base")

image = cv2.imread("image.jpg")
predictions = model(image)

# Convert to supervision Detections
detections = predictions[0].to_supervision()
print(detections)
```

### Load a Local RF-DETR Checkpoint

```python
import cv2
from inference_models import AutoModel

# Load from a local .pth checkpoint (same file used by rfdetr for training)
model = AutoModel.from_pretrained(
"/path/to/checkpoint.pth",
model_type="rfdetr-base", # specify the architecture variant
)

image = cv2.imread("image.jpg")
predictions = model(image)
```

### Force TensorRT Backend

```python
import cv2
from inference_models import AutoModel, BackendType

# Explicitly request TensorRT — requires TRT to be installed
model = AutoModel.from_pretrained("rfdetr-base", backend=BackendType.TRT)

image = cv2.imread("image.jpg")
predictions = model(image)
```

This produces `output/inference_model.engine`. If `profile=True`, it also writes an Nsight Systems report (`.nsys-rep`).
`AutoModel.from_pretrained` accepts `backend="onnx"`, `backend="torch"`, or
`backend="trt"` to override automatic backend selection.

## Using the Exported Model

Expand Down Expand Up @@ -174,5 +257,6 @@ boxes, labels = outputs
After exporting your model, you may want to:

- [Deploy to Roboflow](deploy.md) for cloud-based inference and workflow integration
- Use the ONNX model with TensorRT for optimized GPU inference
- Use [`inference-models`](https://github.com/roboflow/inference/tree/main/inference_models) for
multi-backend inference (PyTorch, ONNX, TensorRT) with automatic backend selection
- Integrate with edge deployment frameworks like ONNX Runtime or OpenVINO
16 changes: 15 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,9 +554,10 @@ def export(
batch_size: int = 1,
dynamic_batch: bool = False,
patch_size: int | None = None,
tensorrt: bool = False,
**kwargs,
) -> None:
"""Export the trained model to ONNX format.
"""Export the trained model to ONNX format, and optionally to TensorRT.

See the `ONNX export documentation <https://rfdetr.roboflow.com/learn/export/>`_
for more information.
Expand All @@ -578,6 +579,9 @@ def export(
``model_config.patch_size`` (typically 14 or 16). When provided
explicitly it must match the instantiated model's patch size.
Shape divisibility is validated against ``patch_size * num_windows``.
tensorrt: When ``True``, convert the exported ONNX model to a TensorRT
``.engine`` file using ``trtexec``. Requires TensorRT to be installed
and ``trtexec`` available in ``PATH``.
**kwargs: Additional keyword arguments forwarded to export_onnx.

"""
Expand Down Expand Up @@ -666,6 +670,16 @@ def export(

logger.info(f"Successfully exported ONNX model to: {output_file}")

if tensorrt:
from argparse import Namespace

from rfdetr.export.tensorrt import trtexec

logger.info("Converting ONNX model to TensorRT engine")
trt_args = Namespace(verbose=verbose, profile=False, dry_run=False)
engine_file = trtexec(output_file, trt_args)
logger.info(f"Successfully exported TensorRT engine to: {engine_file}")

logger.info("ONNX export completed successfully")
self.model.model = self.model.model.to(device)

Expand Down
22 changes: 21 additions & 1 deletion src/rfdetr/export/tensorrt.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@

"""
TensorRT export helpers: trtexec invocation and output parsing.

For TensorRT inference, use the `inference-models` library which provides
multi-backend RF-DETR support (PyTorch, ONNX, TensorRT) with automatic backend
selection:

from inference_models import AutoModel
model = AutoModel.from_pretrained("rfdetr-base")

See https://github.com/roboflow/inference/tree/main/inference_models for details.
"""

import os
Expand All @@ -32,7 +41,17 @@ def run_command_shell(command, dry_run: bool = False) -> subprocess.CompletedPro
raise


def trtexec(onnx_dir: str, args) -> None:
def trtexec(onnx_dir: str, args) -> str:
"""Convert an ONNX model to a TensorRT engine using trtexec.

Args:
onnx_dir: Path to the input ONNX file.
args: Namespace with ``verbose`` (bool), ``profile`` (bool), and
``dry_run`` (bool) attributes.

Returns:
Path to the generated ``.engine`` file.
"""
engine_dir = onnx_dir.replace(".onnx", ".engine")

# Base trtexec command
Expand All @@ -59,6 +78,7 @@ def trtexec(onnx_dir: str, args) -> None:

output = run_command_shell(command, args.dry_run)
parse_trtexec_output(output.stdout)
return engine_dir


def parse_trtexec_output(output_text):
Expand Down
19 changes: 19 additions & 0 deletions tests/export/test_tensorrt_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,22 @@ def _fake_run(command, shell, capture_output, text, check):

assert result.returncode == 0
assert any("CUDA_VISIBLE_DEVICES=" in message for message in logged_messages)


def test_trtexec_returns_engine_path(monkeypatch) -> None:
"""trtexec should return the .engine file path derived from the .onnx path."""
fake_result = subprocess.CompletedProcess("cmd", 0, stdout="", stderr="")

monkeypatch.setattr(
tensorrt_export,
"run_command_shell",
lambda command, dry_run: fake_result,
)
monkeypatch.setattr(tensorrt_export, "parse_trtexec_output", lambda text: {})

from argparse import Namespace

args = Namespace(verbose=False, profile=False, dry_run=False)
result = tensorrt_export.trtexec("output/inference_model.onnx", args)

assert result == "output/inference_model.engine"
74 changes: 74 additions & 0 deletions tests/models/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,80 @@ def test_dynamic_batch_forwards_dynamic_axes(
f"expected keys {expected_names}, got {set(dynamic_axes.keys())}"
)

def test_tensorrt_flag_calls_trtexec(self, output_dir: str) -> None:
"""When tensorrt=True, main() must call trtexec with the ONNX output path."""
trtexec_calls: list[str] = []

def fake_trtexec(onnx_path: str, args) -> str:
trtexec_calls.append(onnx_path)
return onnx_path.replace(".onnx", ".engine")

args = self._make_args(output_dir=output_dir, tensorrt=True)
onnx_output = str(args.output_dir) + "/inference_model.onnx"

mock_model = MagicMock()
mock_model.parameters.return_value = []
mock_model.backbone.parameters.return_value = []
mock_model.backbone.__getitem__.return_value.projector.parameters.return_value = []
mock_model.backbone.__getitem__.return_value.encoder.parameters.return_value = []
mock_model.transformer.parameters.return_value = []
mock_model.to.return_value = mock_model
mock_model.cpu.return_value = mock_model
mock_model.eval.return_value = mock_model
mock_model.return_value = {
"pred_boxes": torch.zeros(1, 300, 4),
"pred_logits": torch.zeros(1, 300, 90),
}
mock_tensor = MagicMock()
mock_tensor.to.return_value = mock_tensor
mock_tensor.cpu.return_value = mock_tensor

with (
patch.object(_cli_export_module, "build_model", return_value=(mock_model, MagicMock(), MagicMock())),
patch.object(_cli_export_module, "make_infer_image", return_value=mock_tensor),
patch.object(_cli_export_module, "export_onnx", return_value=onnx_output),
patch.object(_cli_export_module, "trtexec", side_effect=fake_trtexec),
patch.object(_cli_export_module, "get_rank", return_value=0),
):
_cli_export_module.main(args)

assert len(trtexec_calls) == 1, "trtexec should be called exactly once"
assert trtexec_calls[0] == onnx_output, f"trtexec called with {trtexec_calls[0]!r}, expected {onnx_output!r}"

def test_tensorrt_false_does_not_call_trtexec(self, output_dir: str) -> None:
"""When tensorrt=False (default), main() must not call trtexec."""
trtexec_calls: list[str] = []

def fake_trtexec(onnx_path: str, args) -> str:
trtexec_calls.append(onnx_path)
return onnx_path.replace(".onnx", ".engine")

args = self._make_args(output_dir=output_dir, tensorrt=False)

mock_model = MagicMock()
mock_model.parameters.return_value = []
mock_model.backbone.parameters.return_value = []
mock_model.backbone.__getitem__.return_value.projector.parameters.return_value = []
mock_model.backbone.__getitem__.return_value.encoder.parameters.return_value = []
mock_model.transformer.parameters.return_value = []
mock_model.to.return_value = mock_model
mock_model.cpu.return_value = mock_model
mock_model.eval.return_value = mock_model
mock_tensor = MagicMock()
mock_tensor.to.return_value = mock_tensor
mock_tensor.cpu.return_value = mock_tensor

with (
patch.object(_cli_export_module, "build_model", return_value=(mock_model, MagicMock(), MagicMock())),
patch.object(_cli_export_module, "make_infer_image", return_value=mock_tensor),
patch.object(_cli_export_module, "export_onnx", return_value=str(args.output_dir) + "/model.onnx"),
patch.object(_cli_export_module, "trtexec", side_effect=fake_trtexec),
patch.object(_cli_export_module, "get_rank", return_value=0),
):
_cli_export_module.main(args)

assert len(trtexec_calls) == 0, "trtexec must not be called when tensorrt=False"


class TestExportPatchSize:
"""RFDETR.export() patch_size validation and shape-divisibility tests."""
Expand Down