Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 30 additions & 3 deletions src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -412,9 +412,18 @@ def train(self, **kwargs):
"""Train an RF-DETR model via the PyTorch Lightning stack.

All keyword arguments are forwarded to :meth:`get_train_config` to build
a :class:`~rfdetr.config.TrainConfig`. Several legacy kwargs are absorbed
so existing call-sites do not break:

a :class:`~rfdetr.config.TrainConfig`. Several kwargs are absorbed and
handled specially so that existing call-sites do not break:

* ``resolution`` — overrides the model's input resolution for this
training run. The value is applied directly to
:attr:`model_config.resolution` (and
:attr:`model_config.positional_encoding_size` is updated accordingly)
before the train config is built. The resolution must be divisible
by ``patch_size × num_windows`` for the model variant; a
:class:`ValueError` is raised otherwise. This is the correct way to
change the training resolution instead of setting it at construction
time.
* ``device`` — normalized via :class:`torch.device` and mapped to PyTorch
Lightning trainer arguments. ``"cpu"`` becomes ``accelerator="cpu"``;
``"cuda"`` and ``"cuda:N"`` become ``accelerator="gpu"`` and optionally
Expand All @@ -435,6 +444,8 @@ def train(self, **kwargs):
Raises:
ImportError: If training dependencies are not installed. Install with
``pip install "rfdetr[train,loggers]"``.
ValueError: If ``resolution`` is not divisible by
``patch_size × num_windows`` for the model variant.

"""
# Both imports are grouped in a single try block because they both live in
Expand Down Expand Up @@ -486,6 +497,22 @@ def train(self, **kwargs):
stacklevel=2,
)

# Apply resolution override to model_config before building the train config.
# resolution is a ModelConfig field, not a TrainConfig field, so we pop it
# here to avoid it being silently ignored by TrainConfig.
_resolution = kwargs.pop("resolution", None)
if _resolution is not None:
block_size = self.model_config.patch_size * self.model_config.num_windows
if _resolution % block_size != 0:
raise ValueError(
f"resolution={_resolution} is not divisible by "
f"patch_size ({self.model_config.patch_size}) × num_windows "
f"({self.model_config.num_windows}) = {block_size}. "
f"Choose a resolution that is a multiple of {block_size}."
)
self.model_config.resolution = _resolution
self.model_config.positional_encoding_size = _resolution // self.model_config.patch_size

config = self.get_train_config(**kwargs)
if config.batch_size == "auto":
auto_batch = resolve_auto_batch_config(
Expand Down
46 changes: 46 additions & 0 deletions tests/training/test_detr_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,6 +527,52 @@ def test_do_benchmark_true_emits_deprecation_warning(self, tmp_path, patch_lit):
depr = [x for x in w if issubclass(x.category, DeprecationWarning)]
assert any("do_benchmark" in str(d.message) or "rfdetr benchmark" in str(d.message) for d in depr)

def test_resolution_kwarg_updates_model_config_resolution(self, tmp_path, patch_lit):
"""resolution kwarg is applied to model_config.resolution before training."""
mock_self = _make_rfdetr_self(tmp_path)
# RFDETRBaseConfig has patch_size=14, num_windows=4 → block_size=56.
# 560 is divisible by 56.
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=560)
assert mock_self.model_config.resolution == 560

def test_resolution_kwarg_updates_positional_encoding_size(self, tmp_path, patch_lit):
"""resolution kwarg updates positional_encoding_size to resolution // patch_size."""
mock_self = _make_rfdetr_self(tmp_path)
# RFDETRBaseConfig: patch_size=14, num_windows=4 → block_size=56; 560 // 56 = 10 (valid).
# positional_encoding_size should become 560 // 14 = 40.
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=560)
assert mock_self.model_config.positional_encoding_size == 560 // mock_self.model_config.patch_size

def test_resolution_kwarg_does_not_reach_get_train_config(self, tmp_path, patch_lit):
"""resolution kwarg is popped before get_train_config is called."""
mock_self = _make_rfdetr_self(tmp_path)
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self, resolution=560)
assert "resolution" not in mock_self.get_train_config.call_args.kwargs

def test_resolution_indivisible_raises_value_error(self, tmp_path, patch_lit):
"""resolution not divisible by patch_size × num_windows raises ValueError."""
mock_self = _make_rfdetr_self(tmp_path)
# RFDETRBaseConfig: patch_size=14, num_windows=4 → block_size=56.
# 570 % 56 != 0, so this should raise.
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt, pytest.raises(ValueError, match="resolution=570"):
RFDETR.train(mock_self, resolution=570)

def test_resolution_none_leaves_model_config_unchanged(self, tmp_path, patch_lit):
"""Omitting resolution leaves model_config.resolution unchanged."""
mock_self = _make_rfdetr_self(tmp_path)
original_resolution = mock_self.model_config.resolution
p_mod, p_dm, p_bt, *_ = patch_lit
with p_mod, p_dm, p_bt:
RFDETR.train(mock_self)
assert mock_self.model_config.resolution == original_resolution

def test_returns_none(self, tmp_path, patch_lit):
"""RFDETR.train() returns None."""
mock_self = _make_rfdetr_self(tmp_path)
Expand Down