diff --git a/src/rfdetr/detr.py b/src/rfdetr/detr.py index 949489e8..bbe2b8cd 100644 --- a/src/rfdetr/detr.py +++ b/src/rfdetr/detr.py @@ -434,9 +434,20 @@ def train(self, **kwargs): """Train an RF-DETR model via the PyTorch Lightning stack. All keyword arguments are forwarded to :meth:`get_train_config` to build - a :class:`~rfdetr.config.TrainConfig`. Several legacy kwargs are absorbed - so existing call-sites do not break: - + a :class:`~rfdetr.config.TrainConfig`. Several kwargs are absorbed and + handled specially so that existing call-sites do not break: + + * ``resolution`` — updates the model's input resolution by mutating + :attr:`model_config.resolution` in place before the train config is + built. This change persists on :attr:`model_config` after + :meth:`train` returns. The value must be a positive integer divisible + by ``patch_size * num_windows`` for the model variant; a + :class:`ValueError` is raised otherwise. + :attr:`model_config.positional_encoding_size` is also updated when + the config derives it formulaically (``PE == resolution // + patch_size``); configs with a pretrained-specific PE value (e.g. + ``RFDETRBase`` uses DINOv2's PE=37 at 560 px) are left unchanged to + preserve checkpoint compatibility. * ``device`` — normalized via :class:`torch.device` and mapped to PyTorch Lightning trainer arguments. ``"cpu"`` becomes ``accelerator="cpu"``; ``"cuda"`` and ``"cuda:N"`` become ``accelerator="gpu"`` and optionally @@ -457,6 +468,8 @@ def train(self, **kwargs): Raises: ImportError: If training dependencies are not installed. Install with ``pip install "rfdetr[train,loggers]"``. + ValueError: If ``resolution`` is not a positive integer or is not + divisible by ``patch_size * num_windows`` for the model variant. """ # Both imports are grouped in a single try block because they both live in @@ -508,6 +521,54 @@ def train(self, **kwargs): stacklevel=2, ) + # Apply resolution override to model_config before building the train config. + # resolution is a ModelConfig field, not a TrainConfig field, so we pop it + # here to avoid it being silently ignored by TrainConfig. + _resolution = kwargs.pop("resolution", None) + if _resolution is not None: + if isinstance(_resolution, bool): + raise ValueError("resolution must be a positive integer") + try: + _resolution = operator.index(_resolution) + except TypeError as error: + raise ValueError("resolution must be a positive integer") from error + if _resolution <= 0: + raise ValueError("resolution must be a positive integer") + block_size = self.model_config.patch_size * self.model_config.num_windows + if _resolution % block_size != 0: + raise ValueError( + f"resolution={_resolution} is not divisible by " + f"patch_size ({self.model_config.patch_size}) * num_windows " + f"({self.model_config.num_windows}) = {block_size}. " + f"Choose a resolution that is a multiple of {block_size}." + ) + # Smart PE update: only recompute positional_encoding_size when the + # current config derives it formulaically (PE == resolution // patch_size). + # Configs with a pretrained-specific PE (e.g. RFDETRBase uses DINOv2's + # PE=37 at 518 px, training at 560 px) must not have PE silently changed + # — doing so causes shape mismatches when loading pretrained checkpoints. + _current_pe = self.model_config.positional_encoding_size + _derived_pe = self.model_config.resolution // self.model_config.patch_size + if _current_pe == _derived_pe: + # Formula-derived: update PE proportionally to the new resolution. + new_pe = _resolution // self.model_config.patch_size + self.model_config.positional_encoding_size = new_pe + else: + # Pretrained-specific PE; leave it unchanged. + new_pe = _current_pe + self.model_config.resolution = _resolution + + # Keep the cached inference/export context in sync with model_config so + # predict()/export()/deployment all see the same resolution metadata. + if hasattr(self, "model") and self.model is not None: + if hasattr(self.model, "resolution"): + self.model.resolution = _resolution + model_args = getattr(self.model, "args", None) + if model_args is not None: + if hasattr(model_args, "resolution"): + model_args.resolution = _resolution + if hasattr(model_args, "positional_encoding_size"): + model_args.positional_encoding_size = new_pe config = self.get_train_config(**kwargs) if config.batch_size == "auto": # Auto-batch probing runs forward/backward on the actual model, which diff --git a/tests/training/test_detr_shim.py b/tests/training/test_detr_shim.py index f841e812..eb5d1c41 100644 --- a/tests/training/test_detr_shim.py +++ b/tests/training/test_detr_shim.py @@ -29,7 +29,7 @@ import pytest import torch -from rfdetr.config import RFDETRBaseConfig, TrainConfig +from rfdetr.config import RFDETRBaseConfig, RFDETRSmallConfig, TrainConfig from rfdetr.detr import RFDETR, RFDETRLarge from rfdetr.detr import logger as detr_logger from rfdetr.training.auto_batch import AutoBatchResult @@ -527,6 +527,88 @@ def test_do_benchmark_true_emits_deprecation_warning(self, tmp_path, patch_lit): depr = [x for x in w if issubclass(x.category, DeprecationWarning)] assert any("do_benchmark" in str(d.message) or "rfdetr benchmark" in str(d.message) for d in depr) + def test_resolution_kwarg_updates_model_config_resolution(self, tmp_path, patch_lit): + """resolution kwarg is applied to model_config.resolution before training.""" + mock_self = _make_rfdetr_self(tmp_path) + block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows + valid_resolution = block_size * 11 # guaranteed divisible and different from default + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt: + RFDETR.train(mock_self, resolution=valid_resolution) + assert mock_self.model_config.resolution == valid_resolution + + def test_resolution_kwarg_does_not_implicitly_update_positional_encoding_size(self, tmp_path, patch_lit): + """Pretrained-specific PE (RFDETRBase DINOv2=37) is preserved when resolution is overridden.""" + mock_self = _make_rfdetr_self(tmp_path) + # RFDETRBaseConfig: PE=37 (DINOv2 native 518//14), resolution=560, patch_size=14. + # PE != resolution // patch_size, so the smart PE guard leaves PE unchanged. + original_pe = mock_self.model_config.positional_encoding_size + block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows + valid_override_resolution = block_size * 11 # different from default 560 + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt: + RFDETR.train(mock_self, resolution=valid_override_resolution) + assert mock_self.model_config.positional_encoding_size == original_pe + + def test_resolution_kwarg_updates_positional_encoding_size_for_formula_derived_config(self, tmp_path, patch_lit): + """For configs where PE == resolution // patch_size, resolution override updates PE.""" + # RFDETRSmallConfig: patch_size=16, num_windows=2, resolution=512, PE=32=512//16. + mock_self = _make_rfdetr_self(tmp_path) + mock_self.model_config = RFDETRSmallConfig(pretrain_weights=None, num_classes=3, device="cpu") + block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows + new_resolution = block_size * 21 # 672 for Small — valid and different from default 512 + expected_pe = new_resolution // mock_self.model_config.patch_size + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt: + RFDETR.train(mock_self, resolution=new_resolution) + assert mock_self.model_config.positional_encoding_size == expected_pe + + def test_resolution_kwarg_does_not_reach_get_train_config(self, tmp_path, patch_lit): + """resolution kwarg is popped before get_train_config is called.""" + mock_self = _make_rfdetr_self(tmp_path) + block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt: + RFDETR.train(mock_self, resolution=block_size * 10) + assert "resolution" not in mock_self.get_train_config.call_args.kwargs + + def test_resolution_indivisible_raises_value_error(self, tmp_path, patch_lit): + """resolution not divisible by patch_size * num_windows raises ValueError.""" + mock_self = _make_rfdetr_self(tmp_path) + block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows + indivisible = block_size * 10 + 1 # guaranteed not divisible by block_size + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt, pytest.raises(ValueError, match=f"resolution={indivisible}"): + RFDETR.train(mock_self, resolution=indivisible) + + def test_resolution_none_leaves_model_config_unchanged(self, tmp_path, patch_lit): + """Omitting resolution leaves model_config.resolution unchanged.""" + mock_self = _make_rfdetr_self(tmp_path) + original_resolution = mock_self.model_config.resolution + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt: + RFDETR.train(mock_self) + assert mock_self.model_config.resolution == original_resolution + + @pytest.mark.parametrize( + "bad_resolution", + [ + pytest.param(0, id="zero"), + pytest.param(-56, id="negative"), + pytest.param(True, id="bool_true"), + pytest.param(False, id="bool_false"), + pytest.param(1.5, id="non_integer_float"), + pytest.param(560.0, id="whole_number_float"), + pytest.param("560", id="string"), + ], + ) + def test_resolution_invalid_type_or_value_raises_value_error(self, tmp_path, patch_lit, bad_resolution): + """Non-positive, bool, or non-integer resolution raises ValueError before divisibility check.""" + mock_self = _make_rfdetr_self(tmp_path) + p_mod, p_dm, p_bt, *_ = patch_lit + with p_mod, p_dm, p_bt, pytest.raises(ValueError, match="resolution must be a positive integer"): + RFDETR.train(mock_self, resolution=bad_resolution) + def test_returns_none(self, tmp_path, patch_lit): """RFDETR.train() returns None.""" mock_self = _make_rfdetr_self(tmp_path)