diff --git a/src/rfdetr/detr.py b/src/rfdetr/detr.py index 931b1be7c..42ffac6b0 100644 --- a/src/rfdetr/detr.py +++ b/src/rfdetr/detr.py @@ -280,6 +280,7 @@ def export( force: bool = False, shape: tuple = None, batch_size: int = 1, + dynamic_batch: bool = False, **kwargs, ) -> None: """Export the trained model to ONNX format. @@ -297,6 +298,8 @@ def export( force: Deprecated and ignored. shape: ``(height, width)`` tuple; defaults to square at model resolution. batch_size: Static batch size to bake into the ONNX graph. + dynamic_batch: If True, export with a dynamic batch dimension + so the ONNX model accepts variable batch sizes at runtime. **kwargs: Additional keyword arguments forwarded to export_onnx. """ logger.info("Exporting model to ONNX format") @@ -330,7 +333,10 @@ def export( else: output_names = ["dets", "labels"] - dynamic_axes = None + if dynamic_batch: + dynamic_axes = {name: {0: "batch"} for name in input_names + output_names} + else: + dynamic_axes = None model.eval() with torch.no_grad(): if backbone_only: diff --git a/src/rfdetr/export/main.py b/src/rfdetr/export/main.py index ccd5dc101..8a3c275bf 100644 --- a/src/rfdetr/export/main.py +++ b/src/rfdetr/export/main.py @@ -118,7 +118,10 @@ def main(args): output_names = ["dets", "labels", "masks"] else: output_names = ["dets", "labels"] - dynamic_axes = None + if getattr(args, "dynamic_batch", False): + dynamic_axes = {name: {0: "batch"} for name in input_names + output_names} + else: + dynamic_axes = None # Run model inference in pytorch mode model.eval().to("cuda") input_tensors = input_tensors.to("cuda") diff --git a/src/rfdetr/models/backbone/projector.py b/src/rfdetr/models/backbone/projector.py index f04cf2207..06dbc64e5 100644 --- a/src/rfdetr/models/backbone/projector.py +++ b/src/rfdetr/models/backbone/projector.py @@ -43,7 +43,7 @@ def forward(self, x): TODO: this is a hack to avoid overflow when using fp16 """ x = x.permute(0, 2, 3, 1) - x = F.layer_norm(x, (x.size(3),), self.weight, self.bias, self.eps) + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) x = x.permute(0, 3, 1, 2) return x diff --git a/src/rfdetr/models/transformer.py b/src/rfdetr/models/transformer.py index a05ff6323..aab495e2e 100644 --- a/src/rfdetr/models/transformer.py +++ b/src/rfdetr/models/transformer.py @@ -90,8 +90,14 @@ def gen_encoder_output_proposals(memory, memory_padding_mask, spatial_shapes, un valid_H = torch.sum(~mask_flatten_[:, :, 0, 0], 1) valid_W = torch.sum(~mask_flatten_[:, 0, :, 0], 1) else: - valid_H = torch.tensor([H_ for _ in range(N_)], device=memory.device) - valid_W = torch.tensor([W_ for _ in range(N_)], device=memory.device) + if isinstance(H_, torch.Tensor): + valid_H = H_.expand(N_).to(dtype=torch.long, device=memory.device) + else: + valid_H = torch.full((N_,), H_, dtype=torch.long, device=memory.device) + if isinstance(W_, torch.Tensor): + valid_W = W_.expand(N_).to(dtype=torch.long, device=memory.device) + else: + valid_W = torch.full((N_,), W_, dtype=torch.long, device=memory.device) grid_y, grid_x = torch.meshgrid( torch.linspace(0, H_ - 1, H_, dtype=torch.float32, device=memory.device), @@ -224,12 +230,18 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): src_flatten = [] mask_flatten = [] if masks is not None else None lvl_pos_embed_flatten = [] - spatial_shapes = [] + # Build spatial_shapes as a tensor directly so that the ONNX tracer + # can track h/w symbolically instead of baking them in as constants. + spatial_shapes = torch.empty((len(srcs), 2), device=srcs[0].device, dtype=torch.long) + # Keep Python int pairs for gen_encoder_output_proposals — its loop uses h/w + # as slice indices and linspace steps, which require Python ints, not tensors. + spatial_shapes_hw: list[tuple[int, int]] = [] valid_ratios = [] if masks is not None else None for lvl, (src, pos_embed) in enumerate(zip(srcs, pos_embeds)): bs, c, h, w = src.shape - spatial_shape = (h, w) - spatial_shapes.append(spatial_shape) + spatial_shapes[lvl, 0] = h + spatial_shapes[lvl, 1] = w + spatial_shapes_hw.append((h, w)) src = src.flatten(2).transpose(1, 2) # bs, hw, c pos_embed = pos_embed.flatten(2).transpose(1, 2) # bs, hw, c @@ -243,12 +255,6 @@ def forward(self, srcs, masks, pos_embeds, refpoint_embed, query_feat): mask_flatten = torch.cat(mask_flatten, 1) # bs, \sum{hxw} valid_ratios = torch.stack([self.get_valid_ratio(m) for m in masks], 1) lvl_pos_embed_flatten = torch.cat(lvl_pos_embed_flatten, 1) # bs, \sum{hxw}, c - # Keep a plain Python list of (H, W) int pairs for gen_encoder_output_proposals. - # The tensor form is only needed by ms_deform_attn and level_start_index. - # Passing Python ints avoids .item() calls inside the function, which would - # cause torch.compile graph breaks on loop-accumulated slice indices. - spatial_shapes_hw = list(spatial_shapes) - spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=memory.device) level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1])) if self.two_stage: diff --git a/tests/models/test_export.py b/tests/models/test_export.py index df297ae02..caeae7b37 100644 --- a/tests/models/test_export.py +++ b/tests/models/test_export.py @@ -168,6 +168,76 @@ def _fake_export_onnx(*_args, **_kwargs): return model, export_called +@pytest.mark.parametrize( + "dynamic_batch, segmentation_head", + [ + pytest.param(True, False, id="detection_dynamic"), + pytest.param(True, True, id="segmentation_dynamic"), + pytest.param(False, False, id="detection_static"), + ], +) +def test_rfdetr_export_dynamic_batch_forwards_dynamic_axes( + monkeypatch: pytest.MonkeyPatch, + tmp_path: Path, + dynamic_batch: bool, + segmentation_head: bool, +) -> None: + """`RFDETR.export(..., dynamic_batch=True)` must pass a non-None `dynamic_axes` dict + to `export_onnx`; `dynamic_batch=False` must pass `None`. + """ + + class _DummyCoreModel: + def to(self, *_args, **_kwargs): + return self + + def eval(self): + return self + + def cpu(self): + return self + + def __call__(self, *_args, **_kwargs): + if segmentation_head: + return { + "pred_boxes": torch.zeros(1, 1, 4), + "pred_logits": torch.zeros(1, 1, 2), + "pred_masks": torch.zeros(1, 1, 2, 2), + } + return {"pred_boxes": torch.zeros(1, 1, 4), "pred_logits": torch.zeros(1, 1, 2)} + + model = types.SimpleNamespace( + model=types.SimpleNamespace(model=_DummyCoreModel(), device="cpu", resolution=14), + model_config=types.SimpleNamespace(segmentation_head=segmentation_head), + ) + + captured: dict = {} + + def _fake_make_infer_image(*_args, **_kwargs): + return torch.zeros(1, 3, 14, 14) + + def _fake_export_onnx(*_args, dynamic_axes=None, **_kw): + captured["dynamic_axes"] = dynamic_axes + return str(tmp_path / "inference_model.onnx") + + monkeypatch.setattr("rfdetr.export.main.make_infer_image", _fake_make_infer_image) + monkeypatch.setattr("rfdetr.export.main.export_onnx", _fake_export_onnx) + monkeypatch.setattr("rfdetr.detr.deepcopy", lambda x: x) + + _detr_module.RFDETR.export(model, output_dir=str(tmp_path), dynamic_batch=dynamic_batch, shape=(14, 14)) + + dynamic_axes = captured.get("dynamic_axes") + if not dynamic_batch: + assert dynamic_axes is None, f"expected None for static export, got {dynamic_axes!r}" + return + + assert isinstance(dynamic_axes, dict), f"expected dict, got {dynamic_axes!r}" + for name, axes in dynamic_axes.items(): + assert axes == {0: "batch"}, f"axis spec for {name!r} should be {{0: 'batch'}}, got {axes!r}" + + expected_names = {"input", "dets", "labels", "masks"} if segmentation_head else {"input", "dets", "labels"} + assert set(dynamic_axes.keys()) == expected_names, f"expected keys {expected_names}, got {set(dynamic_axes.keys())}" + + def test_export_simplify_flag_is_ignored_with_deprecation_warning(_detr_export_scaffold: tuple, tmp_path: Path) -> None: """`simplify=True` should not run ONNX simplification and should emit a deprecation warning.""" model, export_called = _detr_export_scaffold @@ -243,6 +313,7 @@ def _make_args( opset_version: int = 17, simplify: bool = False, tensorrt: bool = False, + dynamic_batch: bool = False, ) -> types.SimpleNamespace: return types.SimpleNamespace( device="cpu", @@ -259,6 +330,7 @@ def _make_args( opset_version=opset_version, simplify=simplify, tensorrt=tensorrt, + dynamic_batch=dynamic_batch, ) @staticmethod @@ -313,6 +385,7 @@ def fake_export_onnx(output_dir, model, input_names, input_tensors, output_names export_onnx_captured["output_dir"] = output_dir export_onnx_captured["model"] = model export_onnx_captured["output_names"] = output_names + export_onnx_captured["dynamic_axes"] = dynamic_axes export_onnx_captured["kwargs"] = kwargs return str(args.output_dir) + "/inference_model.onnx" @@ -454,3 +527,51 @@ def fake_export_onnx(*_args, **_kwargs): mock_logger.warning.assert_called_once() assert "simplify" in mock_logger.warning.call_args[0][0].lower() assert export_onnx_called["value"] is True, "export_onnx should still be called with simplify=True" + + @pytest.mark.parametrize( + "dynamic_batch, segmentation_head, backbone_only", + [ + pytest.param(True, False, False, id="detection_dynamic"), + pytest.param(True, True, False, id="segmentation_dynamic"), + pytest.param(True, False, True, id="backbone_only_dynamic"), + pytest.param(False, False, False, id="detection_static"), + ], + ) + def test_dynamic_batch_forwards_dynamic_axes( + self, + output_dir: str, + dynamic_batch: bool, + segmentation_head: bool, + backbone_only: bool, + ) -> None: + """CLI --dynamic_batch=True must pass {name: {0: 'batch'}} for every I/O name. + + When dynamic_batch=False, dynamic_axes must be None (static export). + """ + args = self._make_args( + output_dir=output_dir, + dynamic_batch=dynamic_batch, + segmentation_head=segmentation_head, + backbone_only=backbone_only, + ) + _, captured = self._run(args) + + dynamic_axes = captured.get("dynamic_axes") + if not dynamic_batch: + assert dynamic_axes is None, f"expected None for static export, got {dynamic_axes!r}" + return + + assert isinstance(dynamic_axes, dict), f"expected dict, got {dynamic_axes!r}" + for name, axes in dynamic_axes.items(): + assert axes == {0: "batch"}, f"axis spec for {name!r} should be {{0: 'batch'}}, got {axes!r}" + + # Every input/output name must have an entry + if backbone_only: + expected_names = {"input", "features"} + elif segmentation_head: + expected_names = {"input", "dets", "labels", "masks"} + else: + expected_names = {"input", "dets", "labels"} + assert set(dynamic_axes.keys()) == expected_names, ( + f"expected keys {expected_names}, got {set(dynamic_axes.keys())}" + ) diff --git a/tests/models/test_transformer.py b/tests/models/test_transformer.py index a17d9a04b..13236fcee 100644 --- a/tests/models/test_transformer.py +++ b/tests/models/test_transformer.py @@ -36,3 +36,25 @@ def _meshgrid_with_indexing_assertion(*args, **kwargs): assert call_count == 1 assert output_memory.shape == memory.shape assert output_proposals.shape == (1, 4, 4) + + +def test_gen_encoder_output_proposals_accepts_int_tuple_spatial_shapes() -> None: + """Regression: spatial_shapes as list[tuple[int, int]] with masks=None must not crash. + + Transformer.forward() passes Python int pairs (from bs, c, h, w = src.shape) to + gen_encoder_output_proposals. The export path (masks=None) triggers the else branch + which previously called H_.expand(N_) — failing with AttributeError on a Python int. + """ + batch, h, w, d = 2, 3, 4, 8 + memory = torch.randn(batch, h * w, d) + spatial_shapes = [(h, w)] # Python int pairs, as produced by Transformer.forward() + + output_memory, output_proposals = gen_encoder_output_proposals( + memory=memory, + memory_padding_mask=None, + spatial_shapes=spatial_shapes, + unsigmoid=True, + ) + + assert output_memory.shape == memory.shape + assert output_proposals.shape == (batch, h * w, 4)