Skip to content
8 changes: 7 additions & 1 deletion src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ def export(
force: bool = False,
shape: tuple = None,
batch_size: int = 1,
dynamic_batch: bool = False,
**kwargs,
) -> None:
"""Export the trained model to ONNX format.
Expand All @@ -297,6 +298,8 @@ def export(
force: Deprecated and ignored.
shape: ``(height, width)`` tuple; defaults to square at model resolution.
batch_size: Static batch size to bake into the ONNX graph.
dynamic_batch: If True, export with a dynamic batch dimension
so the ONNX model accepts variable batch sizes at runtime.
**kwargs: Additional keyword arguments forwarded to export_onnx.
"""
logger.info("Exporting model to ONNX format")
Expand Down Expand Up @@ -330,7 +333,10 @@ def export(
else:
output_names = ["dets", "labels"]

dynamic_axes = None
if dynamic_batch:
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
else:
dynamic_axes = None
model.eval()
with torch.no_grad():
if backbone_only:
Expand Down
5 changes: 4 additions & 1 deletion src/rfdetr/export/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,10 @@ def main(args):
output_names = ["dets", "labels", "masks"]
else:
output_names = ["dets", "labels"]
dynamic_axes = None
if getattr(args, "dynamic_batch", False):
dynamic_axes = {name: {0: "batch"} for name in input_names + output_names}
else:
dynamic_axes = None
# Run model inference in pytorch mode
model.eval().to("cuda")
input_tensors = input_tensors.to("cuda")
Expand Down
2 changes: 1 addition & 1 deletion src/rfdetr/models/backbone/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def forward(self, x):
TODO: this is a hack to avoid overflow when using fp16
"""
x = x.permute(0, 2, 3, 1)
x = F.layer_norm(x, (x.size(3),), self.weight, self.bias, self.eps)
x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
x = x.permute(0, 3, 1, 2)
return x

Expand Down
28 changes: 17 additions & 11 deletions src/rfdetr/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,14 @@ def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, un
valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
else:
valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device)
valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device)
if isinstance(H_, torch.Tensor):
valid_H = H_.expand(N_).to(dtype=torch.long, device=memory.device)
else:
valid_H = torch.full((N_,), H_, dtype=torch.long, device=memory.device)
if isinstance(W_, torch.Tensor):
valid_W = W_.expand(N_).to(dtype=torch.long, device=memory.device)
else:
valid_W = torch.full((N_,), W_, dtype=torch.long, device=memory.device)

grid_y, grid_x = torch.meshgrid(
torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device),
Expand Down Expand Up @@ -224,12 +230,18 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
src_flatten = []
mask_flatten = [] if masks is not None else None
lvl_pos_embed_flatten = []
spatial_shapes = []
# Build spatial_shapes as a tensor directly so that the ONNX tracer
# can track h/w symbolically instead of baking them in as constants.
spatial_shapes = torch.empty((len(srcs), 2), device=srcs[0].device, dtype=torch.long)
# Keep Python int pairs for gen_encoder_output_proposals — its loop uses h/w
# as slice indices and linspace steps, which require Python ints, not tensors.
spatial_shapes_hw: list[tuple[int, int]] = []
valid_ratios = [] if masks is not None else None
for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)):
bs, c, h, w = src.shape
spatial_shape = (h, w)
spatial_shapes.append(spatial_shape)
spatial_shapes[lvl, 0] = h
spatial_shapes[lvl, 1] = w
spatial_shapes_hw.append((h, w))

src = src.flatten(2).transpose(1, 2) # bs, hw, c
pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c
Expand All @@ -243,12 +255,6 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat):
mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw}
valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1)
lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c
# Keep a plain Python list of (H, W) int pairs for gen_encoder_output_proposals.
# The tensor form is only needed by ms_deform_attn and level_start_index.
# Passing Python ints avoids .item() calls inside the function, which would
# cause torch.compile graph breaks on loop-accumulated slice indices.
spatial_shapes_hw = list(spatial_shapes)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

if self.two_stage:
Expand Down
121 changes: 121 additions & 0 deletions tests/models/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,76 @@ def _fake_export_onnx(*_args, **_kwargs):
return model, export_called


@pytest.mark.parametrize(
"dynamic_batch, segmentation_head",
[
pytest.param(True, False, id="detection_dynamic"),
pytest.param(True, True, id="segmentation_dynamic"),
pytest.param(False, False, id="detection_static"),
],
)
def test_rfdetr_export_dynamic_batch_forwards_dynamic_axes(
monkeypatch: pytest.MonkeyPatch,
tmp_path: Path,
dynamic_batch: bool,
segmentation_head: bool,
) -> None:
"""`RFDETR.export(..., dynamic_batch=True)` must pass a non-None `dynamic_axes` dict
to `export_onnx`; `dynamic_batch=False` must pass `None`.
"""

class _DummyCoreModel:
def to(self, *_args, **_kwargs):
return self

def eval(self):
return self

def cpu(self):
return self

def __call__(self, *_args, **_kwargs):
if segmentation_head:
return {
"pred_boxes": torch.zeros(1, 1, 4),
"pred_logits": torch.zeros(1, 1, 2),
"pred_masks": torch.zeros(1, 1, 2, 2),
}
return {"pred_boxes": torch.zeros(1, 1, 4), "pred_logits": torch.zeros(1, 1, 2)}

model = types.SimpleNamespace(
model=types.SimpleNamespace(model=_DummyCoreModel(), device="cpu", resolution=14),
model_config=types.SimpleNamespace(segmentation_head=segmentation_head),
)

captured: dict = {}

def _fake_make_infer_image(*_args, **_kwargs):
return torch.zeros(1, 3, 14, 14)

def _fake_export_onnx(*_args, dynamic_axes=None, **_kw):
captured["dynamic_axes"] = dynamic_axes
return str(tmp_path / "inference_model.onnx")

monkeypatch.setattr("rfdetr.export.main.make_infer_image", _fake_make_infer_image)
monkeypatch.setattr("rfdetr.export.main.export_onnx", _fake_export_onnx)
monkeypatch.setattr("rfdetr.detr.deepcopy", lambda x: x)

_detr_module.RFDETR.export(model, output_dir=str(tmp_path), dynamic_batch=dynamic_batch, shape=(14, 14))

dynamic_axes = captured.get("dynamic_axes")
if not dynamic_batch:
assert dynamic_axes is None, f"expected None for static export, got {dynamic_axes!r}"
return

assert isinstance(dynamic_axes, dict), f"expected dict, got {dynamic_axes!r}"
for name, axes in dynamic_axes.items():
assert axes == {0: "batch"}, f"axis spec for {name!r} should be {{0: 'batch'}}, got {axes!r}"

expected_names = {"input", "dets", "labels", "masks"} if segmentation_head else {"input", "dets", "labels"}
assert set(dynamic_axes.keys()) == expected_names, f"expected keys {expected_names}, got {set(dynamic_axes.keys())}"


def test_export_simplify_flag_is_ignored_with_deprecation_warning(_detr_export_scaffold: tuple, tmp_path: Path) -> None:
"""`simplify=True` should not run ONNX simplification and should emit a deprecation warning."""
model, export_called = _detr_export_scaffold
Expand Down Expand Up @@ -243,6 +313,7 @@ def _make_args(
opset_version: int = 17,
simplify: bool = False,
tensorrt: bool = False,
dynamic_batch: bool = False,
) -> types.SimpleNamespace:
return types.SimpleNamespace(
device="cpu",
Expand All @@ -259,6 +330,7 @@ def _make_args(
opset_version=opset_version,
simplify=simplify,
tensorrt=tensorrt,
dynamic_batch=dynamic_batch,
)

@staticmethod
Expand Down Expand Up @@ -313,6 +385,7 @@ def fake_export_onnx(output_dir, model, input_names, input_tensors, output_names
export_onnx_captured["output_dir"] = output_dir
export_onnx_captured["model"] = model
export_onnx_captured["output_names"] = output_names
export_onnx_captured["dynamic_axes"] = dynamic_axes
export_onnx_captured["kwargs"] = kwargs
return str(args.output_dir) + "/inference_model.onnx"

Expand Down Expand Up @@ -454,3 +527,51 @@ def fake_export_onnx(*_args, **_kwargs):
mock_logger.warning.assert_called_once()
assert "simplify" in mock_logger.warning.call_args[0][0].lower()
assert export_onnx_called["value"] is True, "export_onnx should still be called with simplify=True"

@pytest.mark.parametrize(
"dynamic_batch, segmentation_head, backbone_only",
[
pytest.param(True, False, False, id="detection_dynamic"),
pytest.param(True, True, False, id="segmentation_dynamic"),
pytest.param(True, False, True, id="backbone_only_dynamic"),
pytest.param(False, False, False, id="detection_static"),
],
)
def test_dynamic_batch_forwards_dynamic_axes(
self,
output_dir: str,
dynamic_batch: bool,
segmentation_head: bool,
backbone_only: bool,
) -> None:
"""CLI --dynamic_batch=True must pass {name: {0: 'batch'}} for every I/O name.

When dynamic_batch=False, dynamic_axes must be None (static export).
"""
args = self._make_args(
output_dir=output_dir,
dynamic_batch=dynamic_batch,
segmentation_head=segmentation_head,
backbone_only=backbone_only,
)
_, captured = self._run(args)

dynamic_axes = captured.get("dynamic_axes")
if not dynamic_batch:
assert dynamic_axes is None, f"expected None for static export, got {dynamic_axes!r}"
return

assert isinstance(dynamic_axes, dict), f"expected dict, got {dynamic_axes!r}"
for name, axes in dynamic_axes.items():
assert axes == {0: "batch"}, f"axis spec for {name!r} should be {{0: 'batch'}}, got {axes!r}"

# Every input/output name must have an entry
if backbone_only:
expected_names = {"input", "features"}
elif segmentation_head:
expected_names = {"input", "dets", "labels", "masks"}
else:
expected_names = {"input", "dets", "labels"}
assert set(dynamic_axes.keys()) == expected_names, (
f"expected keys {expected_names}, got {set(dynamic_axes.keys())}"
)
22 changes: 22 additions & 0 deletions tests/models/test_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,25 @@ def _meshgrid_with_indexing_assertion(*args, **kwargs):
assert call_count == 1
assert output_memory.shape == memory.shape
assert output_proposals.shape == (1, 4, 4)


def test_gen_encoder_output_proposals_accepts_int_tuple_spatial_shapes() -> None:
"""Regression: spatial_shapes as list[tuple[int, int]] with masks=None must not crash.

Transformer.forward() passes Python int pairs (from bs, c, h, w = src.shape) to
gen_encoder_output_proposals. The export path (masks=None) triggers the else branch
which previously called H_.expand(N_) — failing with AttributeError on a Python int.
"""
batch, h, w, d = 2, 3, 4, 8
memory = torch.randn(batch, h * w, d)
spatial_shapes = [(h, w)] # Python int pairs, as produced by Transformer.forward()

output_memory, output_proposals = gen_encoder_output_proposals(
memory=memory,
memory_padding_mask=None,
spatial_shapes=spatial_shapes,
unsigmoid=True,
)

assert output_memory.shape == memory.shape
assert output_proposals.shape == (batch, h * w, 4)
Loading