diff --git a/src/rfdetr/config.py b/src/rfdetr/config.py index c45114bf..acf4a5ee 100644 --- a/src/rfdetr/config.py +++ b/src/rfdetr/config.py @@ -12,7 +12,37 @@ import torch from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" + +def _detect_device() -> str: + """Detect the best available device **without** initialising the CUDA runtime. + + ``torch.cuda.is_available()`` creates a CUDA driver context that makes + ``_is_in_bad_fork()`` return ``True`` in child processes. This breaks + fork-based DDP strategies (e.g. ``ddp_notebook``) in notebook environments. + + We defer to :func:`torch.accelerator.current_accelerator` (PyTorch ≥ 2.4) + when available — it queries the driver through NVML without creating a + primary context. On older builds we fall back to ``torch.cuda.is_available()``. + """ + accelerator = getattr(torch, "accelerator", None) + current_accelerator = getattr(accelerator, "current_accelerator", None) + if current_accelerator is not None: + try: + accel = current_accelerator() + if accel is not None: + return str(accel) + return "cpu" + except RuntimeError: + return "cpu" + # Fallback for PyTorch < 2.4 — this DOES create a CUDA driver context. + if torch.cuda.is_available(): + return "cuda" + if torch.backends.mps.is_available(): + return "mps" + return "cpu" + + +DEVICE: str = _detect_device() class BaseConfig(BaseModel): diff --git a/src/rfdetr/detr.py b/src/rfdetr/detr.py index 19cbabf8..949489e8 100644 --- a/src/rfdetr/detr.py +++ b/src/rfdetr/detr.py @@ -182,6 +182,28 @@ def _resolve_patch_size(patch_size: int | None, model_config: object, caller: st return patch_size +def _ensure_model_on_device(model_ctx: Any) -> None: + """Move model weights to the target device recorded in *model_ctx*. + + ``_build_model_context`` intentionally keeps the ``nn.Module`` on CPU so + that ``RFDETR.__init__`` does not initialise CUDA (which would prevent DDP + strategies from forking in notebook environments). This helper performs + the deferred ``.to(device)`` on first use. + + It is safe to call on duck-typed stand-ins (e.g. ``SimpleNamespace``); the + function silently returns when the expected attributes are missing. + """ + target = getattr(model_ctx, "device", None) + inner = getattr(model_ctx, "model", None) + if target is None or inner is None or not hasattr(inner, "parameters"): + return + if isinstance(target, str): + target = torch.device(target) + first_param = next(inner.parameters(), None) + if first_param is not None and first_param.device != target: + model_ctx.model = inner.to(target) + + class RFDETR: """The base RF-DETR class implements the core methods for training RF-DETR models, running inference on the models, optimising models, and uploading trained @@ -488,6 +510,10 @@ def train(self, **kwargs): config = self.get_train_config(**kwargs) if config.batch_size == "auto": + # Auto-batch probing runs forward/backward on the actual model, which + # must be on the target device (typically CUDA). Lazy placement keeps + # the model on CPU until first use — move it now. + _ensure_model_on_device(self.model) auto_batch = resolve_auto_batch_config( model_context=self.model, model_config=self.model_config, @@ -585,6 +611,7 @@ def optimize_for_inference( # Clear any previously optimized state before starting a new optimization run. self.remove_optimized_model() + _ensure_model_on_device(self.model) device = self.model.device cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else contextlib.nullcontext() @@ -1003,6 +1030,8 @@ def predict( """ import supervision as sv + _ensure_model_on_device(self.model) + patch_size = _resolve_patch_size(patch_size, self.model_config, "predict") num_windows = getattr(self.model_config, "num_windows", 1) if isinstance(num_windows, bool) or not isinstance(num_windows, int) or num_windows <= 0: diff --git a/src/rfdetr/inference.py b/src/rfdetr/inference.py index eead82be..bfad5c24 100644 --- a/src/rfdetr/inference.py +++ b/src/rfdetr/inference.py @@ -71,13 +71,19 @@ def _build_model_context(model_config: ModelConfig) -> ModelContext: """Build a ModelContext from ModelConfig without using legacy main.py:Model. Replicates ``Model.__init__`` logic: builds the nn.Module, optionally loads - pretrain weights and applies LoRA, then moves the model to the target device. + pretrain weights and applies LoRA. The model is intentionally kept on CPU; + :func:`_ensure_model_on_device` in ``detr.py`` performs the deferred + ``.to(device)`` on the first ``predict()`` / ``export()`` / + ``optimize_for_inference()`` call. Keeping construction CPU-only prevents + CUDA initialisation during ``__init__``, which would block DDP strategies + (``ddp_notebook``, ``ddp_spawn``) from spawning child processes in notebook + environments. Args: model_config: Architecture configuration. Returns: - Fully initialised ModelContext ready for inference or training. + ModelContext with the model on CPU, ready for lazy device placement. """ from rfdetr._namespace import _namespace_from_configs @@ -99,7 +105,11 @@ def _build_model_context(model_config: ModelConfig) -> ModelContext: apply_lora(nn_model) device = torch.device(args.device) - nn_model = nn_model.to(device) + # Keep the model on CPU here; predict() / export() / optimize_for_inference() + # will lazily move it to the target device on first use. Eagerly calling + # .to("cuda") would initialise the CUDA runtime during __init__(), which + # prevents DDP strategies (ddp_notebook, ddp_spawn) from forking/spawning + # child processes in notebook environments. postprocess = PostProcess(num_select=args.num_select) return ModelContext( diff --git a/src/rfdetr/training/module_data.py b/src/rfdetr/training/module_data.py index e081a4aa..f84dfe19 100644 --- a/src/rfdetr/training/module_data.py +++ b/src/rfdetr/training/module_data.py @@ -41,16 +41,25 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None self._dataset_val: Optional[torch.utils.data.Dataset] = None self._dataset_test: Optional[torch.utils.data.Dataset] = None - num_workers = self.train_config.num_workers + self._num_workers: int = self.train_config.num_workers + + # Use the fork-safe DEVICE constant instead of torch.cuda.is_available(), + # which creates a CUDA driver context that breaks fork-based DDP. + from rfdetr.config import DEVICE + + accelerator = str(self.train_config.accelerator).lower() + uses_cuda_accelerator = accelerator in {"auto", "gpu", "cuda"} self._pin_memory: bool = ( - torch.cuda.is_available() if self.train_config.pin_memory is None else bool(self.train_config.pin_memory) + (DEVICE == "cuda" and uses_cuda_accelerator) + if self.train_config.pin_memory is None + else bool(self.train_config.pin_memory) ) self._persistent_workers: bool = ( - num_workers > 0 + self._num_workers > 0 if self.train_config.persistent_workers is None else bool(self.train_config.persistent_workers) ) - if num_workers > 0: + if self._num_workers > 0: self._prefetch_factor = ( self.train_config.prefetch_factor if self.train_config.prefetch_factor is not None else 2 ) @@ -104,7 +113,7 @@ def train_dataloader(self) -> DataLoader: dataset = self._dataset_train batch_size = self.train_config.batch_size effective_batch_size = batch_size * self.train_config.grad_accum_steps - num_workers = self.train_config.num_workers + num_workers = self._num_workers if len(dataset) < effective_batch_size * _MIN_TRAIN_BATCHES: logger.info( @@ -152,7 +161,7 @@ def val_dataloader(self) -> DataLoader: sampler=torch.utils.data.SequentialSampler(self._dataset_val), drop_last=False, collate_fn=collate_fn, - num_workers=self.train_config.num_workers, + num_workers=self._num_workers, pin_memory=self._pin_memory, persistent_workers=self._persistent_workers, prefetch_factor=self._prefetch_factor, @@ -170,7 +179,7 @@ def test_dataloader(self) -> DataLoader: sampler=torch.utils.data.SequentialSampler(self._dataset_test), drop_last=False, collate_fn=collate_fn, - num_workers=self.train_config.num_workers, + num_workers=self._num_workers, pin_memory=self._pin_memory, persistent_workers=self._persistent_workers, prefetch_factor=self._prefetch_factor, @@ -188,7 +197,7 @@ def predict_dataloader(self) -> DataLoader: sampler=torch.utils.data.SequentialSampler(self._dataset_val), drop_last=False, collate_fn=collate_fn, - num_workers=self.train_config.num_workers, + num_workers=self._num_workers, pin_memory=self._pin_memory, persistent_workers=self._persistent_workers, prefetch_factor=self._prefetch_factor, diff --git a/src/rfdetr/training/module_model.py b/src/rfdetr/training/module_model.py index 49b934b9..62420ed8 100644 --- a/src/rfdetr/training/module_model.py +++ b/src/rfdetr/training/module_model.py @@ -67,7 +67,15 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None # torch.compile is opt-in: set model_config.compile=True to enable. # Only enabled on CUDA; MPS and CPU do not benefit from compilation. - compile_enabled = model_config.compile and torch.cuda.is_available() and not train_config.multi_scale + # Use the fork-safe DEVICE constant instead of torch.cuda.is_available(), + # which creates a CUDA driver context that breaks fork-based DDP. + from rfdetr.config import DEVICE + + accelerator = str(train_config.accelerator).lower() + uses_cuda_accelerator = accelerator in {"auto", "gpu", "cuda"} + compile_enabled = ( + model_config.compile and DEVICE == "cuda" and uses_cuda_accelerator and not train_config.multi_scale + ) if model_config.compile and train_config.multi_scale: logger.info("Disabling torch.compile because multi_scale=True introduces dynamic input shapes.") if compile_enabled: diff --git a/src/rfdetr/training/trainer.py b/src/rfdetr/training/trainer.py index b7cc9d26..6c4ad459 100644 --- a/src/rfdetr/training/trainer.py +++ b/src/rfdetr/training/trainer.py @@ -14,6 +14,15 @@ from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme from pytorch_lightning.loggers import CSVLogger, MLFlowLogger, TensorBoardLogger, WandbLogger +from pytorch_lightning.strategies import DDPStrategy as _DDPStrategy + +# _MultiProcessingLauncher is a private PTL API (leading underscore) that may change +# in minor PTL releases within the >=2.6,<3 range. No public equivalent exists in +# PTL 2.x. Monitor PTL changelogs when bumping the lower bound. +try: + from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher +except ImportError: # pragma: no cover - exercised in unit tests via monkeypatch + _MultiProcessingLauncher = None # type: ignore[assignment] from rfdetr.config import ModelConfig, TrainConfig from rfdetr.training.callbacks import ( @@ -28,6 +37,57 @@ _logger = get_logger() +# --------------------------------------------------------------------------- +# Notebook-safe spawn-based DDP +# --------------------------------------------------------------------------- +# ``ddp_notebook`` maps to fork-based DDP which is fundamentally unsafe: +# PyTorch's OpenMP thread pool (created during model construction) cannot +# survive fork() — the worker threads become zombie handles, causing +# "Invalid thread pool!" SIGABRT when the autograd engine initialises in +# the forked child. +# +# PTL considers ``start_method="spawn"`` incompatible with interactive +# environments and raises ``MisconfigurationException`` if used in Jupyter. +# However, PTL's own ``_wrapping_function`` is the entry-point for spawned +# children — no ``if __name__ == "__main__"`` guard is required — so spawn +# is perfectly safe here. +# +# Classes MUST live at module level (not inside a function) so that Python's +# pickle can serialise them for the spawned child processes. + + +if _MultiProcessingLauncher is not None: + + class _InteractiveSpawnLauncher(_MultiProcessingLauncher): + """Spawn launcher that reports itself as interactive-compatible.""" + + @property + def is_interactive_compatible(self) -> bool: # type: ignore[override] + return True + +else: + _InteractiveSpawnLauncher = None + + +class _NotebookSpawnDDPStrategy(_DDPStrategy): + """Spawn-based DDP strategy that works inside Jupyter / Kaggle notebooks.""" + + def _configure_launcher(self) -> None: + if self.cluster_environment is None: + raise RuntimeError( + "_NotebookSpawnDDPStrategy requires a cluster environment; " + "ensure the strategy is initialised through PTL's Trainer." + ) + if _InteractiveSpawnLauncher is None: + raise RuntimeError( + "Notebook spawn strategy requires " + "pytorch_lightning.strategies.launchers.multiprocessing._MultiProcessingLauncher. " + "Your installed PyTorch Lightning version changed this private API; " + "pin/upgrade PTL to a compatible version in the supported >=2.6,<3 range." + ) + self._launcher = _InteractiveSpawnLauncher(self, start_method=self._start_method) + + def build_trainer( train_config: TrainConfig, model_config: ModelConfig, @@ -69,12 +129,21 @@ def build_trainer( def _resolve_precision() -> str: if not model_config.amp: return "32-true" + # Ampere+ GPUs support bf16-mixed which is scaler-free — + # no GradScaler.scale/unscale/update overhead per optimizer step. + # BF16 is safe for fine-tuning (pretrained weights loaded by default). + # Training from random init with very small LR may underflow; callers + # can override via trainer_kwargs(precision="16-mixed") if needed. + # + # Note: torch.cuda.is_available() and torch.cuda.is_bf16_supported() both + # create a CUDA driver context in the parent process. This is intentional + # and safe for the multi-process launch modes we rely on here because we + # avoid fork-based launching in notebook contexts (see + # _NotebookSpawnDDPStrategy above), and spawn/subprocess-based launchers + # start child processes with a fresh CUDA state regardless of what the + # parent has initialised. If a fork-based path is ever added, this + # precision check must be moved into the child process. if torch.cuda.is_available(): - # Ampere+ GPUs support bf16-mixed which is scaler-free — - # no GradScaler.scale/unscale/update overhead per optimizer step. - # BF16 is safe for fine-tuning (pretrained weights loaded by default). - # Training from random init with very small LR may underflow; callers - # can override via trainer_kwargs(precision="16-mixed") if needed. if torch.cuda.is_bf16_supported(): return "bf16-mixed" return "16-mixed" @@ -84,6 +153,15 @@ def _resolve_precision() -> str: # --- Strategy + EMA sharding guard --- strategy = tc.strategy + + # Transparently replace fork-based DDP with spawn-based DDP — see the + # module-level comment block above _InteractiveSpawnLauncher for rationale. + if strategy in ("ddp_notebook", "ddp_spawn"): + strategy = _NotebookSpawnDDPStrategy(start_method="spawn", find_unused_parameters=True) + _logger.info( + "%s → spawn-based DDP to avoid OpenMP thread pool corruption after fork.", + tc.strategy, + ) sharded = any(s in str(strategy).lower() for s in ("fsdp", "deepspeed")) enable_ema = bool(tc.use_ema) and not sharded if tc.use_ema and sharded: diff --git a/tests/models/test_config.py b/tests/models/test_config.py index 15608c32..8484db8d 100644 --- a/tests/models/test_config.py +++ b/tests/models/test_config.py @@ -4,6 +4,8 @@ # Licensed under the Apache License, Version 2.0 [see LICENSE for details] # ------------------------------------------------------------------------ +from unittest.mock import MagicMock, patch + import pytest import torch from pydantic import ValidationError @@ -19,6 +21,7 @@ RFDETRSegXLargeConfig, SegmentationTrainConfig, TrainConfig, + _detect_device, ) @@ -368,3 +371,29 @@ def test_default_cls_loss_coef_no_warning(self, recwarn) -> None: RFDETRBaseConfig(pretrain_weights=None, device="cpu") depr_warnings = [w for w in recwarn.list if issubclass(w.category, DeprecationWarning)] assert not depr_warnings, f"Unexpected DeprecationWarning: {depr_warnings}" + + +class TestDetectDevice: + """Tests for _detect_device() covering PyTorch accelerator detection paths.""" + + @patch("rfdetr.config.torch") + def test_falls_back_to_cuda_when_accelerator_module_absent(self, mock_torch: MagicMock) -> None: + """Returns 'cuda' via legacy fallback when torch.accelerator lacks current_accelerator (PyTorch < 2.4).""" + mock_torch.accelerator = MagicMock(spec=[]) # no current_accelerator → hasattr returns False → fallback + mock_torch.cuda.is_available.return_value = True + mock_torch.backends.mps.is_available.return_value = False + assert _detect_device() == "cuda" + + @patch("rfdetr.config.torch") + def test_returns_cpu_when_current_accelerator_raises(self, mock_torch: MagicMock) -> None: + """Returns 'cpu' directly from the except handler when current_accelerator() raises RuntimeError.""" + mock_torch.accelerator.current_accelerator.side_effect = RuntimeError("no device") + assert _detect_device() == "cpu" + + @patch("rfdetr.config.torch") + def test_returns_cpu_when_no_gpu_available(self, mock_torch: MagicMock) -> None: + """Returns 'cpu' when accelerator is absent and neither CUDA nor MPS is available.""" + mock_torch.accelerator = MagicMock(spec=[]) # no current_accelerator → fallback branch + mock_torch.cuda.is_available.return_value = False + mock_torch.backends.mps.is_available.return_value = False + assert _detect_device() == "cpu" diff --git a/tests/models/test_optimize_for_inference.py b/tests/models/test_optimize_for_inference.py index 5a3dd73b..f2d1090b 100644 --- a/tests/models/test_optimize_for_inference.py +++ b/tests/models/test_optimize_for_inference.py @@ -114,11 +114,20 @@ def test_string_dtype_variants_are_accepted(self, dtype_str: str) -> None: class TestOptimizeForInferenceCudaDeviceContext: """Verify that optimize_for_inference wraps operations in the correct device context.""" - def test_cuda_device_context_manager_is_used_for_cuda_device(self) -> None: + @patch("rfdetr.detr._ensure_model_on_device") + @patch("rfdetr.detr.deepcopy") + @patch("torch.cuda.device") + def test_cuda_device_context_manager_is_used_for_cuda_device( + self, + mock_cuda_device, + mock_deepcopy, + _mock_ensure_model_on_device, + ) -> None: """torch.cuda.device() context should be entered when model is on CUDA.""" rfdetr = _FakeRFDETR() # Simulate a CUDA device without actually requiring CUDA hardware rfdetr.model.device = torch.device("cuda", 0) + mock_deepcopy.return_value = rfdetr.model.model entered_devices: list[torch.device] = [] @@ -132,11 +141,8 @@ def __enter__(self): def __exit__(self, *args): pass - with ( - patch("torch.cuda.device", side_effect=_CapturingDeviceCtx), - patch("rfdetr.detr.deepcopy", return_value=rfdetr.model.model), - ): - rfdetr.optimize_for_inference(compile=False, dtype=torch.float32) + mock_cuda_device.side_effect = _CapturingDeviceCtx + rfdetr.optimize_for_inference(compile=False, dtype=torch.float32) assert len(entered_devices) == 1 assert entered_devices[0] == torch.device("cuda", 0) @@ -155,11 +161,20 @@ def test_nullcontext_used_for_cpu_device(self) -> None: mock_cuda_device.assert_not_called() - def test_cuda_device_context_uses_model_device(self) -> None: + @patch("rfdetr.detr._ensure_model_on_device") + @patch("rfdetr.detr.deepcopy") + @patch("torch.cuda.device") + def test_cuda_device_context_uses_model_device( + self, + mock_cuda_device, + mock_deepcopy, + _mock_ensure_model_on_device, + ) -> None: """The device passed to torch.cuda.device() should match self.model.device.""" rfdetr = _FakeRFDETR() expected_device = torch.device("cuda", 2) rfdetr.model.device = expected_device + mock_deepcopy.return_value = rfdetr.model.model captured: dict[str, torch.device] = {} @@ -173,11 +188,8 @@ def __enter__(self): def __exit__(self, *args): pass - with ( - patch("torch.cuda.device", side_effect=_CapturingCtx), - patch("rfdetr.detr.deepcopy", return_value=rfdetr.model.model), - ): - rfdetr.optimize_for_inference(compile=False) + mock_cuda_device.side_effect = _CapturingCtx + rfdetr.optimize_for_inference(compile=False) assert captured.get("device") == expected_device diff --git a/tests/training/test_auto_batch.py b/tests/training/test_auto_batch.py index a1e4306b..51b991c8 100644 --- a/tests/training/test_auto_batch.py +++ b/tests/training/test_auto_batch.py @@ -10,6 +10,7 @@ import pytest import torch +from rfdetr.detr import RFDETR from rfdetr.training import auto_batch from rfdetr.training.auto_batch import AutoBatchResult @@ -139,6 +140,58 @@ def test_resolve_auto_batch_config_returns_expected_values(): assert result.device_name == "Fake GPU" +@patch("rfdetr.detr.is_main_process", return_value=False) +@patch("rfdetr.training.auto_batch.resolve_auto_batch_config") +@patch("rfdetr.training.build_trainer") +@patch("rfdetr.training.RFDETRDataModule") +@patch("rfdetr.training.RFDETRModelModule") +@patch("rfdetr.detr._ensure_model_on_device") +def test_train_auto_batch_ensures_model_on_device_before_resolve( + mock_ensure: MagicMock, + _mock_module: MagicMock, + _mock_data_module: MagicMock, + _mock_build_trainer: MagicMock, + mock_resolve: MagicMock, + _mock_is_main: MagicMock, +) -> None: + """_ensure_model_on_device must be called before resolve_auto_batch_config when batch_size='auto'.""" + auto_result = SimpleNamespace(safe_micro_batch=4, recommended_grad_accum_steps=1, effective_batch_size=4) + call_order: list[str] = [] + + def _ensure_side_effect(model: object) -> None: + call_order.append("ensure") + + def _resolve_side_effect(**_kwargs: object) -> object: + call_order.append("resolve") + return auto_result + + mock_ensure.side_effect = _ensure_side_effect + mock_resolve.side_effect = _resolve_side_effect + + train_config = SimpleNamespace( + batch_size="auto", + grad_accum_steps=99, + dataset_dir=None, + resume=None, + class_names=None, + ) + mock_self = MagicMock() + mock_self.model_config = SimpleNamespace(model_name=None) + mock_self.get_train_config.return_value = train_config + + RFDETR.train(mock_self) + + assert train_config.batch_size == 4 + assert train_config.grad_accum_steps == 1 + mock_ensure.assert_called_once_with(mock_self.model) + mock_resolve.assert_called_once_with( + model_context=mock_self.model, + model_config=mock_self.model_config, + train_config=train_config, + ) + assert call_order == ["ensure", "resolve"] + + @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required for segmentation probe") def test_probe_step_with_real_segmentation_criterion(tmp_path): """Run one probe step with real segmentation model and criterion so loss_masks and t['masks'] are exercised.""" diff --git a/tests/training/test_build_trainer.py b/tests/training/test_build_trainer.py index 223f9df9..a090d041 100644 --- a/tests/training/test_build_trainer.py +++ b/tests/training/test_build_trainer.py @@ -7,6 +7,7 @@ """Tests for build_trainer() — PTL Ch3/T5 (callbacks) and Ch4/T1 (precision, loggers, trainer kwargs).""" import warnings +from unittest.mock import MagicMock, patch import pytest from pytorch_lightning.callbacks import ModelCheckpoint @@ -235,13 +236,7 @@ def test_amp_true_cpu_gives_32_true(self, tmp_path): assert trainer.precision == "32-true" def test_amp_true_cuda_no_bf16_gives_16_mixed(self, tmp_path): - """amp=True with CUDA but no bf16 support must produce '16-mixed'. - - We mock the Trainer constructor to capture the precision kwarg rather - than inspecting trainer.precision after construction: PTL may re-detect - hardware bf16 support during __init__ and normalise the precision string - on machines that happen to have a bf16-capable GPU. - """ + """amp=True with CUDA but no bf16 support must produce '16-mixed'.""" import unittest.mock as mock captured: dict = {} @@ -259,12 +254,7 @@ def _fake_trainer(**kwargs): assert captured["precision"] == "16-mixed" def test_amp_true_cuda_bf16_supported_gives_bf16_mixed(self, tmp_path): - """amp=True with CUDA + bf16 hardware produces 'bf16-mixed' (scaler-free). - - On Ampere+ GPUs (bf16 supported) we select bf16-mixed to eliminate - GradScaler overhead. Fine-tuning from pretrained weights is safe with - BF16; callers training from scratch can override via trainer_kwargs. - """ + """amp=True with CUDA + bf16 hardware produces 'bf16-mixed'.""" import unittest.mock as mock captured: dict = {} @@ -281,6 +271,80 @@ def _fake_trainer(**kwargs): build_trainer(_tc(tmp_path, use_ema=False), _mc(amp=True)) assert captured["precision"] == "bf16-mixed" + @patch("torch.cuda.is_available", return_value=True) + @patch("torch.cuda.is_bf16_supported", return_value=False) + @patch("rfdetr.training.trainer.Trainer") + def test_amp_true_ddp_notebook_probes_bf16_normally( + self, mock_trainer: MagicMock, _mock_bf16: MagicMock, _mock_cuda: MagicMock, tmp_path + ): + """ddp_notebook uses standard precision probing (spawn makes CUDA init safe). + + With spawn-based DDP, child processes start fresh — CUDA init in the + parent does not propagate. So ``is_bf16_supported()`` is safe to call + and pre-Ampere GPUs correctly get ``16-mixed`` instead of the slower + bf16 emulation path. Simulates pre-Ampere GPU: CUDA available, bf16 NOT supported. + """ + captured: dict = {} + + def _fake_trainer(**kwargs): + captured.update(kwargs) + return MagicMock() + + mock_trainer.side_effect = _fake_trainer + build_trainer( + _tc(tmp_path, use_ema=False, strategy="ddp_notebook"), + _mc(amp=True), + ) + assert captured["precision"] == "16-mixed" + + @pytest.mark.parametrize("strategy_name", ["ddp_notebook", "ddp_spawn"]) + def test_ddp_notebook_and_spawn_use_interactive_spawn(self, tmp_path, strategy_name): + """ddp_notebook and ddp_spawn must be replaced with interactive spawn DDPStrategy. + + Fork-based DDP inherits the parent's OpenMP thread pool which is + invalid after fork, causing SIGABRT in the autograd engine. + ddp_spawn is blocked by PTL in notebooks without the override. + """ + import unittest.mock as mock + + from pytorch_lightning.strategies import DDPStrategy + + captured: dict = {} + + def _fake_trainer(**kwargs): + captured.update(kwargs) + return mock.MagicMock() + + with mock.patch("rfdetr.training.trainer.Trainer", side_effect=_fake_trainer): + build_trainer( + _tc(tmp_path, use_ema=False, strategy=strategy_name), + _mc(amp=True), + ) + strategy_obj = captured["strategy"] + assert isinstance(strategy_obj, DDPStrategy) + assert strategy_obj._start_method == "spawn" + assert strategy_obj._ddp_kwargs.get("find_unused_parameters") is True + + @patch("rfdetr.training.trainer._InteractiveSpawnLauncher", None) + def test_ddp_notebook_raises_clear_error_when_private_launcher_is_missing(self, tmp_path): + """Missing private PTL launcher should raise a targeted compatibility error.""" + captured: dict = {} + + def _fake_trainer(**kwargs): + captured.update(kwargs) + return MagicMock() + + with patch("rfdetr.training.trainer.Trainer", side_effect=_fake_trainer): + build_trainer( + _tc(tmp_path, use_ema=False, strategy="ddp_notebook"), + _mc(amp=True), + ) + + strategy = captured["strategy"] + strategy.cluster_environment = object() + with pytest.raises(RuntimeError, match="private API"): + strategy._configure_launcher() + class TestBuildTrainerEMAShardingGuard: """EMA must be disabled and a UserWarning emitted for sharded strategies. diff --git a/tests/training/test_module_data.py b/tests/training/test_module_data.py index 4b4198b4..d2a95463 100644 --- a/tests/training/test_module_data.py +++ b/tests/training/test_module_data.py @@ -178,12 +178,35 @@ def test_pin_memory_override_is_respected(self, build_datamodule, base_train_con dm = build_datamodule(train_config=tc) assert dm._pin_memory is False + @patch("rfdetr.config.DEVICE", "cuda") + def test_pin_memory_defaults_to_false_when_accelerator_is_cpu(self, build_datamodule, base_train_config): + """Default pin_memory stays off when training is explicitly CPU-only.""" + tc = base_train_config(pin_memory=None, accelerator="cpu") + dm = build_datamodule(train_config=tc) + assert dm._pin_memory is False + def test_persistent_workers_override_is_respected(self, build_datamodule, base_train_config): """persistent_workers can be explicitly overridden from TrainConfig.""" tc = base_train_config(num_workers=2, persistent_workers=False) dm = build_datamodule(train_config=tc) assert dm._persistent_workers is False + def test_ddp_notebook_preserves_num_workers(self, build_datamodule, base_train_config): + """ddp_notebook keeps num_workers as configured (spawn-based DDP + children initialise CUDA fresh; DataLoader fork workers are CPU-only + and never touch CUDA, so nested forks are safe).""" + tc = base_train_config(num_workers=4, strategy="ddp_notebook") + dm = build_datamodule(train_config=tc) + assert dm._num_workers == 4 + assert dm._prefetch_factor == 2 + + def test_other_strategy_preserves_num_workers(self, build_datamodule, base_train_config): + """Non-ddp_notebook strategies also keep num_workers as configured.""" + tc = base_train_config(num_workers=4, strategy="ddp") + dm = build_datamodule(train_config=tc) + assert dm._num_workers == 4 + assert dm._prefetch_factor == 2 # default prefetch_factor for num_workers>0 + class TestSetup: """setup(stage) builds the correct dataset(s) for each PTL stage.""" diff --git a/tests/training/test_module_model.py b/tests/training/test_module_model.py index 26c8b119..6a07db70 100644 --- a/tests/training/test_module_model.py +++ b/tests/training/test_module_model.py @@ -185,12 +185,21 @@ def test_compile_runs_when_enabled_and_static_shapes(self, tmp_path): mc = _base_model_config(compile=True) tc = _base_train_config(tmp_path, multi_scale=False) with ( - patch("torch.cuda.is_available", return_value=True), + patch("rfdetr.config.DEVICE", "cuda"), patch("rfdetr.training.module_model.torch.compile", side_effect=lambda m, **_: m) as mock_compile, ): _build_module(model_config=mc, train_config=tc, tmp_path=tmp_path) mock_compile.assert_called_once() + @patch("rfdetr.training.module_model.torch.compile") + @patch("rfdetr.config.DEVICE", "cuda") + def test_compile_disabled_when_train_accelerator_is_cpu(self, _mock_compile: MagicMock, tmp_path): + """compile stays disabled when training is explicitly forced to CPU.""" + mc = _base_model_config(compile=True) + tc = _base_train_config(tmp_path, multi_scale=False, accelerator="cpu") + _build_module(model_config=mc, train_config=tc, tmp_path=tmp_path) + _mock_compile.assert_not_called() + class TestLoadPretrainWeights: """Tests for _load_pretrain_weights() — covers checkpoint validation, detection-head