Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,7 @@ def export(
backbone_only=backbone_only,
verbose=verbose,
opset_version=opset_version,
variant_name=getattr(self, "size", None),
)

logger.info(f"Successfully exported ONNX model to: {output_file}")
Expand Down
9 changes: 8 additions & 1 deletion src/rfdetr/export/_onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def export_onnx(
backbone_only: bool = False,
verbose: bool = True,
opset_version: int = 17,
variant_name: str | None = None,
) -> str:
"""Export a model to ONNX.

Expand All @@ -74,11 +75,17 @@ def export_onnx(
backbone_only: Whether to export backbone-only graph naming.
verbose: Whether ONNX exporter should emit verbose logs.
opset_version: ONNX opset version.
variant_name: Model variant identifier (e.g. ``"rfdetr-medium"``).
When provided, the exported file is named ``{variant_name}.onnx``
instead of the generic ``inference_model.onnx``.
Comment on lines +79 to +80
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variant_name docstring says the export will be named {variant_name}.onnx, but when backbone_only=True the actual filename becomes {variant_name}-backbone.onnx. Update the docstring to reflect the backbone-only naming behavior so callers aren’t surprised.

Suggested change
When provided, the exported file is named ``{variant_name}.onnx``
instead of the generic ``inference_model.onnx``.
When provided, the exported file is named ``{variant_name}.onnx`` or
``{variant_name}-backbone.onnx`` (when ``backbone_only=True``) instead
of the generic ``inference_model.onnx`` or ``backbone_model.onnx``.

Copilot uses AI. Check for mistakes.

Returns:
Path to the exported ONNX model.
"""
export_name = "backbone_model" if backbone_only else "inference_model"
if variant_name:
export_name = f"{variant_name}-backbone" if backbone_only else variant_name
else:
export_name = "backbone_model" if backbone_only else "inference_model"
output_file = os.path.join(output_dir, f"{export_name}.onnx")

Comment on lines +85 to 90
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

variant_name is interpolated directly into the filename. If it contains path separators (e.g. "foo/bar"), an absolute path (e.g. "/tmp/x"), or a drive prefix on Windows, os.path.join(output_dir, ...) can write outside output_dir or fail due to missing intermediate directories. Consider sanitizing to a safe stem (e.g. Path(variant_name).name/.stem, stripping a trailing .onnx) and rejecting values that aren’t simple filenames.

Copilot uses AI. Check for mistakes.
# Prepare model for export
Expand Down
1 change: 1 addition & 0 deletions src/rfdetr/export/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def main(args):
backbone_only=args.backbone_only,
verbose=args.verbose,
opset_version=args.opset_version,
variant_name=getattr(args, "variant_name", None),
)

if args.simplify:
Expand Down
176 changes: 176 additions & 0 deletions tests/models/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def __call__(self, *_args, **_kwargs):
resolution=14,
),
model_config=types.SimpleNamespace(segmentation_head=False),
size=None,
)

export_called: dict[str, bool] = {"value": False}
Expand Down Expand Up @@ -208,6 +209,7 @@ def __call__(self, *_args, **_kwargs):
model = types.SimpleNamespace(
model=types.SimpleNamespace(model=_DummyCoreModel(), device="cpu", resolution=14),
model_config=types.SimpleNamespace(segmentation_head=segmentation_head),
size=None,
)

captured: dict = {}
Expand Down Expand Up @@ -610,6 +612,7 @@ def __call__(self, *_a, **_kw):
patch_size=patch_size,
num_windows=num_windows,
),
size=None,
)

def _fake_make_infer_image(*_a, **_kw):
Expand Down Expand Up @@ -742,3 +745,176 @@ def test_make_infer_image_produces_correct_rectangular_shape() -> None:
h, w, b = 112, 224, 2
tensor = make_infer_image(infer_dir=None, shape=(h, w), batch_size=b, device="cpu")
assert tensor.shape == (b, 3, h, w), f"Expected shape ({b}, 3, {h}, {w}), got {tensor.shape}"


# ---------------------------------------------------------------------------
# ONNX export variant naming (#issue)
Copy link

Copilot AI Apr 2, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The section header comment says # ONNX export variant naming (#issue) which looks like a placeholder. Replace it with the actual issue reference (e.g. #905) or remove the parenthetical to avoid stale/incorrect references in the test file.

Suggested change
# ONNX export variant naming (#issue)
# ONNX export variant naming

Copilot uses AI. Check for mistakes.
# ---------------------------------------------------------------------------


class TestExportOnnxVariantNaming:
"""Verify that export_onnx uses variant_name in the output filename."""

def test_variant_name_in_filename(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""When variant_name is provided, the ONNX file is named after the variant."""
captured: dict = {}

def _fake_onnx_export(*args, **kwargs) -> None:
captured["output_file"] = args[2] # 3rd positional arg is output_file

monkeypatch.setattr(_cli_export_module.torch.onnx, "export", _fake_onnx_export)

_cli_export_module.export_onnx(
output_dir=str(tmp_path),
model=torch.nn.Identity(),
input_names=["input"],
input_tensors=torch.randn(1, 3, 8, 8),
output_names=["dets"],
dynamic_axes=None,
verbose=False,
variant_name="rfdetr-medium",
)

assert captured["output_file"].endswith("rfdetr-medium.onnx")

def test_variant_name_with_backbone(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""backbone_only + variant_name produces '{variant}-backbone.onnx'."""
captured: dict = {}

def _fake_onnx_export(*args, **kwargs) -> None:
captured["output_file"] = args[2]

monkeypatch.setattr(_cli_export_module.torch.onnx, "export", _fake_onnx_export)

_cli_export_module.export_onnx(
output_dir=str(tmp_path),
model=torch.nn.Identity(),
input_names=["input"],
input_tensors=torch.randn(1, 3, 8, 8),
output_names=["features"],
dynamic_axes=None,
backbone_only=True,
verbose=False,
variant_name="rfdetr-nano",
)

assert captured["output_file"].endswith("rfdetr-nano-backbone.onnx")

def test_default_name_without_variant(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Without variant_name, falls back to 'inference_model.onnx'."""
captured: dict = {}

def _fake_onnx_export(*args, **kwargs) -> None:
captured["output_file"] = args[2]

monkeypatch.setattr(_cli_export_module.torch.onnx, "export", _fake_onnx_export)

_cli_export_module.export_onnx(
output_dir=str(tmp_path),
model=torch.nn.Identity(),
input_names=["input"],
input_tensors=torch.randn(1, 3, 8, 8),
output_names=["dets"],
dynamic_axes=None,
verbose=False,
)

assert captured["output_file"].endswith("inference_model.onnx")

def test_default_backbone_name_without_variant(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Without variant_name + backbone_only, falls back to 'backbone_model.onnx'."""
captured: dict = {}

def _fake_onnx_export(*args, **kwargs) -> None:
captured["output_file"] = args[2]

monkeypatch.setattr(_cli_export_module.torch.onnx, "export", _fake_onnx_export)

_cli_export_module.export_onnx(
output_dir=str(tmp_path),
model=torch.nn.Identity(),
input_names=["input"],
input_tensors=torch.randn(1, 3, 8, 8),
output_names=["features"],
dynamic_axes=None,
backbone_only=True,
verbose=False,
)

assert captured["output_file"].endswith("backbone_model.onnx")

def test_rfdetr_export_passes_variant_name(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""RFDETR.export() passes self.size as variant_name to export_onnx."""
captured: dict = {}

class _DummyCoreModel:
def to(self, *_args, **_kwargs):
return self

def eval(self):
return self

def cpu(self):
return self

def __call__(self, *_args, **_kwargs):
return {"pred_boxes": torch.zeros(1, 1, 4), "pred_logits": torch.zeros(1, 1, 2)}

model = types.SimpleNamespace(
model=types.SimpleNamespace(model=_DummyCoreModel(), device="cpu", resolution=14),
model_config=types.SimpleNamespace(segmentation_head=False),
size="rfdetr-medium",
)

def _fake_make_infer_image(*_args, **_kwargs):
return torch.zeros(1, 3, 14, 14)

def _fake_export_onnx(*_args, variant_name=None, **_kw):
captured["variant_name"] = variant_name
return str(tmp_path / "rfdetr-medium.onnx")

monkeypatch.setattr("rfdetr.export.main.make_infer_image", _fake_make_infer_image)
monkeypatch.setattr("rfdetr.export.main.export_onnx", _fake_export_onnx)
monkeypatch.setattr("rfdetr.detr.deepcopy", lambda x: x)

_detr_module.RFDETR.export(model, output_dir=str(tmp_path), shape=(14, 14))

assert captured["variant_name"] == "rfdetr-medium"

def test_rfdetr_export_passes_none_when_size_not_set(self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> None:
"""Base RFDETR (size=None) passes None as variant_name."""
captured: dict = {}

class _DummyCoreModel:
def to(self, *_args, **_kwargs):
return self

def eval(self):
return self

def cpu(self):
return self

def __call__(self, *_args, **_kwargs):
return {"pred_boxes": torch.zeros(1, 1, 4), "pred_logits": torch.zeros(1, 1, 2)}

model = types.SimpleNamespace(
model=types.SimpleNamespace(model=_DummyCoreModel(), device="cpu", resolution=14),
model_config=types.SimpleNamespace(segmentation_head=False),
size=None,
)

def _fake_make_infer_image(*_args, **_kwargs):
return torch.zeros(1, 3, 14, 14)

def _fake_export_onnx(*_args, variant_name=None, **_kw):
captured["variant_name"] = variant_name
return str(tmp_path / "inference_model.onnx")

monkeypatch.setattr("rfdetr.export.main.make_infer_image", _fake_make_infer_image)
monkeypatch.setattr("rfdetr.export.main.export_onnx", _fake_export_onnx)
monkeypatch.setattr("rfdetr.detr.deepcopy", lambda x: x)

_detr_module.RFDETR.export(model, output_dir=str(tmp_path), shape=(14, 14))

assert captured["variant_name"] is None
Loading