Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,20 @@ def train(self, **kwargs):
"""Train an RF-DETR model via the PyTorch Lightning stack.

All keyword arguments are forwarded to :meth:`get_train_config` to build
a :class:`~rfdetr.config.TrainConfig`. Several legacy kwargs are absorbed
so existing call-sites do not break:

a :class:`~rfdetr.config.TrainConfig`. Several kwargs are absorbed and
handled specially so that existing call-sites do not break:

* ``resolution`` — updates the model's input resolution by mutating
:attr:`model_config.resolution` in place before the train config is
built. This change persists on :attr:`model_config` after
:meth:`train` returns. The value must be a positive integer divisible
by ``patch_size * num_windows`` for the model variant; a
:class:`ValueError` is raised otherwise.
:attr:`model_config.positional_encoding_size` is also updated when
the config derives it formulaically (``PE == resolution //
patch_size``); configs with a pretrained-specific PE value (e.g.
``RFDETRBase`` uses DINOv2's PE=37 at 560 px) are left unchanged to
preserve checkpoint compatibility.
* ``device`` — normalized via :class:`torch.device` and mapped to PyTorch
Lightning trainer arguments. ``"cpu"`` becomes ``accelerator="cpu"``;
``"cuda"`` and ``"cuda:N"`` become ``accelerator="gpu"`` and optionally
Expand All @@ -457,6 +468,8 @@ def train(self, **kwargs):
Raises:
ImportError: If training dependencies are not installed. Install with
``pip install "rfdetr[train,loggers]"``.
ValueError: If ``resolution`` is not a positive integer or is not
divisible by ``patch_size * num_windows`` for the model variant.

"""
# Both imports are grouped in a single try block because they both live in
Expand Down Expand Up @@ -508,6 +521,54 @@ def train(self, **kwargs):
stacklevel=2,
)

# Apply resolution override to model_config before building the train config.
# resolution is a ModelConfig field, not a TrainConfig field, so we pop it
# here to avoid it being silently ignored by TrainConfig.
_resolution = kwargs.pop("resolution", None)
if _resolution is not None:
if isinstance(_resolution, bool):
raise ValueError("resolution must be a positive integer")
try:
_resolution = operator.index(_resolution)
except TypeError as error:
raise ValueError("resolution must be a positive integer") from error
if _resolution <= 0:
raise ValueError("resolution must be a positive integer")
block_size = self.model_config.patch_size * self.model_config.num_windows
if _resolution % block_size != 0:
raise ValueError(
f"resolution={_resolution} is not divisible by "
f"patch_size ({self.model_config.patch_size}) * num_windows "
f"({self.model_config.num_windows}) = {block_size}. "
f"Choose a resolution that is a multiple of {block_size}."
)
# Smart PE update: only recompute positional_encoding_size when the
# current config derives it formulaically (PE == resolution // patch_size).
# Configs with a pretrained-specific PE (e.g. RFDETRBase uses DINOv2's
# PE=37 at 518 px, training at 560 px) must not have PE silently changed
# — doing so causes shape mismatches when loading pretrained checkpoints.
_current_pe = self.model_config.positional_encoding_size
_derived_pe = self.model_config.resolution // self.model_config.patch_size
if _current_pe == _derived_pe:
# Formula-derived: update PE proportionally to the new resolution.
new_pe = _resolution // self.model_config.patch_size
self.model_config.positional_encoding_size = new_pe
else:
# Pretrained-specific PE; leave it unchanged.
new_pe = _current_pe
self.model_config.resolution = _resolution

# Keep the cached inference/export context in sync with model_config so
# predict()/export()/deployment all see the same resolution metadata.
if hasattr(self, "model") and self.model is not None:
if hasattr(self.model, "resolution"):
self.model.resolution = _resolution
model_args = getattr(self.model, "args", None)
if model_args is not None:
if hasattr(model_args, "resolution"):
model_args.resolution = _resolution
if hasattr(model_args, "positional_encoding_size"):
model_args.positional_encoding_size = new_pe
config = self.get_train_config(**kwargs)
if config.batch_size == "auto":
# Auto-batch probing runs forward/backward on the actual model, which
Expand Down
84 changes: 83 additions & 1 deletion tests/training/test_detr_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import pytest
import torch

from rfdetr.config import RFDETRBaseConfig, TrainConfig
from rfdetr.config import RFDETRBaseConfig, RFDETRSmallConfig, TrainConfig
from rfdetr.detr import RFDETR, RFDETRLarge
from rfdetr.detr import logger as detr_logger
from rfdetr.training.auto_batch import AutoBatchResult
Expand Down Expand Up @@ -527,6 +527,88 @@ def test_do_benchmark_true_emits_deprecation_warning(self, tmp_path, patch_lit):
depr = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("do_benchmark" in str(d.message) or "rfdetr benchmark" in str(d.message) for d in depr)

def test_resolution_kwarg_updates_model_config_resolution(self, tmp_path, patch_lit):
"""resolution kwarg is applied to model_config.resolution before training."""
mock_self = _make_rfdetr_self(tmp_path)
block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows
valid_resolution = block_size * 11 # guaranteed divisible and different from default
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=valid_resolution)
assert mock_self.model_config.resolution == valid_resolution

def test_resolution_kwarg_does_not_implicitly_update_positional_encoding_size(self, tmp_path, patch_lit):
"""Pretrained-specific PE (RFDETRBase DINOv2=37) is preserved when resolution is overridden."""
mock_self = _make_rfdetr_self(tmp_path)
# RFDETRBaseConfig: PE=37 (DINOv2 native 518//14), resolution=560, patch_size=14.
# PE != resolution // patch_size, so the smart PE guard leaves PE unchanged.
original_pe = mock_self.model_config.positional_encoding_size
block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows
valid_override_resolution = block_size * 11 # different from default 560
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=valid_override_resolution)
assert mock_self.model_config.positional_encoding_size == original_pe

def test_resolution_kwarg_updates_positional_encoding_size_for_formula_derived_config(self, tmp_path, patch_lit):
"""For configs where PE == resolution // patch_size, resolution override updates PE."""
# RFDETRSmallConfig: patch_size=16, num_windows=2, resolution=512, PE=32=512//16.
mock_self = _make_rfdetr_self(tmp_path)
mock_self.model_config = RFDETRSmallConfig(pretrain_weights=None, num_classes=3, device="cpu")
block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows
new_resolution = block_size * 21 # 672 for Small — valid and different from default 512
expected_pe = new_resolution // mock_self.model_config.patch_size
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=new_resolution)
assert mock_self.model_config.positional_encoding_size == expected_pe

def test_resolution_kwarg_does_not_reach_get_train_config(self, tmp_path, patch_lit):
"""resolution kwarg is popped before get_train_config is called."""
mock_self = _make_rfdetr_self(tmp_path)
block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=block_size * 10)
assert "resolution" not in mock_self.get_train_config.call_args.kwargs

def test_resolution_indivisible_raises_value_error(self, tmp_path, patch_lit):
"""resolution not divisible by patch_size * num_windows raises ValueError."""
mock_self = _make_rfdetr_self(tmp_path)
block_size = mock_self.model_config.patch_size * mock_self.model_config.num_windows
indivisible = block_size * 10 + 1 # guaranteed not divisible by block_size
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt, pytest.raises(ValueError, match=f"resolution={indivisible}"):
RFDETR.train(mock_self, resolution=indivisible)

def test_resolution_none_leaves_model_config_unchanged(self, tmp_path, patch_lit):
"""Omitting resolution leaves model_config.resolution unchanged."""
mock_self = _make_rfdetr_self(tmp_path)
original_resolution = mock_self.model_config.resolution
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self)
assert mock_self.model_config.resolution == original_resolution

@pytest.mark.parametrize(
"bad_resolution",
[
pytest.param(0, id="zero"),
pytest.param(-56, id="negative"),
pytest.param(True, id="bool_true"),
pytest.param(False, id="bool_false"),
pytest.param(1.5, id="non_integer_float"),
pytest.param(560.0, id="whole_number_float"),
pytest.param("560", id="string"),
],
)
def test_resolution_invalid_type_or_value_raises_value_error(self, tmp_path, patch_lit, bad_resolution):
"""Non-positive, bool, or non-integer resolution raises ValueError before divisibility check."""
mock_self = _make_rfdetr_self(tmp_path)
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt, pytest.raises(ValueError, match="resolution must be a positive integer"):
RFDETR.train(mock_self, resolution=bad_resolution)

def test_returns_none(self, tmp_path, patch_lit):
"""RFDETR.train() returns None."""
mock_self = _make_rfdetr_self(tmp_path)
Expand Down
Loading