Skip to content
Merged
Show file tree
Hide file tree
Changes from 32 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
2f814fd
fix: defer CUDA init to enable DDP training in notebooks
mfazrinizar Apr 5, 2026
9814557
fix: skip CUDA bf16 probe for ddp_notebook strategy
mfazrinizar Apr 6, 2026
9104d43
fix: eliminate all CUDA driver context leaks before DDP fork
mfazrinizar Apr 6, 2026
34ab21b
fix: use overridden num_workers in all dataloaders for ddp_notebook
mfazrinizar Apr 6, 2026
6ce6ad0
fix: possible thread-state corruption from fork()
mfazrinizar Apr 6, 2026
ed99190
revert: remove torch.set_num_threads that crashes forked DDP children
mfazrinizar Apr 6, 2026
a31dcb2
fix: use spawn-based DDP for ddp_notebook to avoid OpenMP SIGABRT
mfazrinizar Apr 6, 2026
728c1e5
fix: adding logger for ddp_notebook strategy
mfazrinizar Apr 6, 2026
a464cf2
fix: use spawn-based DDP for ddp_notebook to avoid OpenMP SIGABRT
mfazrinizar Apr 6, 2026
08af3c5
fix: remove unnecessary num_workers=0 override for ddp_notebook
mfazrinizar Apr 6, 2026
bcdfd0a
fix: use standard precision probing for DDP and guard auto-batch
mfazrinizar Apr 6, 2026
c4c88f2
Merge branch 'develop' into fix/ddp-notebook-cuda-init
mfazrinizar Apr 6, 2026
e7a84d0
fix(pre-commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Apr 6, 2026
e67ac24
style: fix ruff E402 imports and codespell in DDP tests
mfazrinizar Apr 6, 2026
1927ca5
fix: handle None from torch.accelerator on CPU-only environments
mfazrinizar Apr 7, 2026
d798aed
fix: guard torch.accelerator access before current_accelerator check
Borda Apr 8, 2026
ea8eddf
fix: replace assert with RuntimeError in _NotebookSpawnDDPStrategy
Borda Apr 8, 2026
ef80e40
Apply suggestions from code review
Borda Apr 8, 2026
28582bf
fix(pre-commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Apr 8, 2026
680b308
Merge branch 'fix/ddp-notebook-cuda-init' of https://github.com/mfazr…
Borda Apr 8, 2026
10b35fc
docs: note private PTL launcher API risk in trainer.py
Borda Apr 8, 2026
9711ab9
docs: update _build_model_context docstring for lazy device placement
Borda Apr 8, 2026
8602d95
fix: add Any type annotation to _ensure_model_on_device parameter
Borda Apr 8, 2026
3359d21
docs: explain why CUDA calls in _resolve_precision are safe with spaw…
Borda Apr 8, 2026
a16ba34
docs: consolidate duplicated OMP fork explanation in trainer.py
Borda Apr 8, 2026
d309dbc
test: add coverage for _ensure_model_on_device auto-batch path + _det…
Borda Apr 8, 2026
16023ab
lint: fix import ordering in test_config.py (I001)
Borda Apr 8, 2026
31a6fb7
Merge branch 'fix/ddp-notebook-cuda-init' of https://github.com/mfazr…
Borda Apr 8, 2026
c74e181
Apply suggestions from code review
Borda Apr 8, 2026
42fd15c
fix(pre-commit): 🎨 auto format pre-commit hooks
pre-commit-ci[bot] Apr 8, 2026
908dfef
refactor(tests): convert mocking to @patch decorator style
Borda Apr 8, 2026
903e962
refactor(tests): convert test_amp_true_ddp_notebook_probes_bf16_norma…
Borda Apr 8, 2026
73a073d
Apply suggestions from code review
Borda Apr 8, 2026
3307018
fix: address PR #928 unresolved reviews and CPU CI failure
Borda Apr 8, 2026
b4e82e4
Merge branch 'develop' into fix/ddp-notebook-cuda-init
Borda Apr 8, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,37 @@
import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

def _detect_device() -> str:
"""Detect the best available device **without** initialising the CUDA runtime.

``torch.cuda.is_available()`` creates a CUDA driver context that makes
``_is_in_bad_fork()`` return ``True`` in child processes. This breaks
fork-based DDP strategies (e.g. ``ddp_notebook``) in notebook environments.

We defer to :func:`torch.accelerator.current_accelerator` (PyTorch ≥ 2.4)
when available — it queries the driver through NVML without creating a
primary context. On older builds we fall back to ``torch.cuda.is_available()``.
"""
accelerator = getattr(torch, "accelerator", None)
current_accelerator = getattr(accelerator, "current_accelerator", None)
if current_accelerator is not None:
try:
accel = current_accelerator()
if accel is not None:
return str(accel)
return "cpu"
except RuntimeError:
return "cpu"
# Fallback for PyTorch < 2.4 — this DOES create a CUDA driver context.
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"


DEVICE: str = _detect_device()


class BaseConfig(BaseModel):
Expand Down
30 changes: 30 additions & 0 deletions src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,28 @@ def _resolve_patch_size(patch_size: int | None, model_config: object, caller: st
return patch_size


def _ensure_model_on_device(model_ctx: Any) -> None:
"""Move model weights to the target device recorded in *model_ctx*.

``_build_model_context`` intentionally keeps the ``nn.Module`` on CPU so
that ``RFDETR.__init__`` does not initialise CUDA (which would prevent DDP
strategies from forking in notebook environments). This helper performs
the deferred ``.to(device)`` on first use.

It is safe to call on duck-typed stand-ins (e.g. ``SimpleNamespace``); the
function silently returns when the expected attributes are missing.
"""
target = getattr(model_ctx, "device", None)
inner = getattr(model_ctx, "model", None)
if target is None or inner is None or not hasattr(inner, "parameters"):
return
if isinstance(target, str):
target = torch.device(target)
first_param = next(inner.parameters(), None)
if first_param is not None and first_param.device != target:
model_ctx.model = inner.to(target)


class RFDETR:
"""The base RF-DETR class implements the core methods for training RF-DETR models,
running inference on the models, optimising models, and uploading trained
Expand Down Expand Up @@ -488,6 +510,10 @@ def train(self, **kwargs):

config = self.get_train_config(**kwargs)
if config.batch_size == "auto":
# Auto-batch probing runs forward/backward on the actual model, which
# must be on the target device (typically CUDA). Lazy placement keeps
# the model on CPU until first use — move it now.
_ensure_model_on_device(self.model)
auto_batch = resolve_auto_batch_config(
model_context=self.model,
model_config=self.model_config,
Expand Down Expand Up @@ -585,6 +611,7 @@ def optimize_for_inference(
# Clear any previously optimized state before starting a new optimization run.
self.remove_optimized_model()

_ensure_model_on_device(self.model)
device = self.model.device
cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else contextlib.nullcontext()

Expand Down Expand Up @@ -700,6 +727,7 @@ def export(
)
raise

_ensure_model_on_device(self.model)
device = self.model.device
model = deepcopy(self.model.model.to("cpu"))
model.to(device)
Expand Down Expand Up @@ -1003,6 +1031,8 @@ def predict(
"""
import supervision as sv

_ensure_model_on_device(self.model)

patch_size = _resolve_patch_size(patch_size, self.model_config, "predict")
num_windows = getattr(self.model_config, "num_windows", 1)
if isinstance(num_windows, bool) or not isinstance(num_windows, int) or num_windows <= 0:
Expand Down
16 changes: 13 additions & 3 deletions src/rfdetr/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,19 @@ def _build_model_context(model_config: ModelConfig) -> ModelContext:
"""Build a ModelContext from ModelConfig without using legacy main.py:Model.

Replicates ``Model.__init__`` logic: builds the nn.Module, optionally loads
pretrain weights and applies LoRA, then moves the model to the target device.
pretrain weights and applies LoRA. The model is intentionally kept on CPU;
:func:`_ensure_model_on_device` in ``detr.py`` performs the deferred
``.to(device)`` on the first ``predict()`` / ``export()`` /
``optimize_for_inference()`` call. Keeping construction CPU-only prevents
CUDA initialisation during ``__init__``, which would block DDP strategies
(``ddp_notebook``, ``ddp_spawn``) from spawning child processes in notebook
environments.

Args:
model_config: Architecture configuration.

Returns:
Fully initialised ModelContext ready for inference or training.
ModelContext with the model on CPU, ready for lazy device placement.
"""
from rfdetr._namespace import _namespace_from_configs

Expand All @@ -99,7 +105,11 @@ def _build_model_context(model_config: ModelConfig) -> ModelContext:
apply_lora(nn_model)

device = torch.device(args.device)
nn_model = nn_model.to(device)
# Keep the model on CPU here; predict() / export() / optimize_for_inference()
# will lazily move it to the target device on first use. Eagerly calling
# .to("cuda") would initialise the CUDA runtime during __init__(), which
# prevents DDP strategies (ddp_notebook, ddp_spawn) from forking/spawning
# child processes in notebook environments.
postprocess = PostProcess(num_select=args.num_select)

return ModelContext(
Expand Down
21 changes: 13 additions & 8 deletions src/rfdetr/training/module_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,21 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None
self._dataset_val: Optional[torch.utils.data.Dataset] = None
self._dataset_test: Optional[torch.utils.data.Dataset] = None

num_workers = self.train_config.num_workers
self._num_workers: int = self.train_config.num_workers

# Use the fork-safe DEVICE constant instead of torch.cuda.is_available(),
# which creates a CUDA driver context that breaks fork-based DDP.
from rfdetr.config import DEVICE

self._pin_memory: bool = (
torch.cuda.is_available() if self.train_config.pin_memory is None else bool(self.train_config.pin_memory)
(DEVICE == "cuda") if self.train_config.pin_memory is None else bool(self.train_config.pin_memory)
)
self._persistent_workers: bool = (
num_workers > 0
self._num_workers > 0
if self.train_config.persistent_workers is None
else bool(self.train_config.persistent_workers)
)
if num_workers > 0:
if self._num_workers > 0:
self._prefetch_factor = (
self.train_config.prefetch_factor if self.train_config.prefetch_factor is not None else 2
)
Expand Down Expand Up @@ -104,7 +109,7 @@ def train_dataloader(self) -> DataLoader:
dataset = self._dataset_train
batch_size = self.train_config.batch_size
effective_batch_size = batch_size * self.train_config.grad_accum_steps
num_workers = self.train_config.num_workers
num_workers = self._num_workers

if len(dataset) < effective_batch_size * _MIN_TRAIN_BATCHES:
logger.info(
Expand Down Expand Up @@ -152,7 +157,7 @@ def val_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_val),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand All @@ -170,7 +175,7 @@ def test_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_test),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand All @@ -188,7 +193,7 @@ def predict_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_val),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand Down
6 changes: 5 additions & 1 deletion src/rfdetr/training/module_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None

# torch.compile is opt-in: set model_config.compile=True to enable.
# Only enabled on CUDA; MPS and CPU do not benefit from compilation.
compile_enabled = model_config.compile and torch.cuda.is_available() and not train_config.multi_scale
# Use the fork-safe DEVICE constant instead of torch.cuda.is_available(),
# which creates a CUDA driver context that breaks fork-based DDP.
from rfdetr.config import DEVICE

compile_enabled = model_config.compile and DEVICE == "cuda" and not train_config.multi_scale
if model_config.compile and train_config.multi_scale:
logger.info("Disabling torch.compile because multi_scale=True introduces dynamic input shapes.")
if compile_enabled:
Expand Down
71 changes: 66 additions & 5 deletions src/rfdetr/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.loggers import CSVLogger, MLFlowLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.strategies import DDPStrategy as _DDPStrategy

# _MultiProcessingLauncher is a private PTL API (leading underscore) that may change
# in minor PTL releases within the >=2.6,<3 range. No public equivalent exists in
# PTL 2.x. Monitor PTL changelogs when bumping the lower bound.
from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher

from rfdetr.config import ModelConfig, TrainConfig
from rfdetr.training.callbacks import (
Expand All @@ -28,6 +34,45 @@
_logger = get_logger()


# ---------------------------------------------------------------------------
# Notebook-safe spawn-based DDP
# ---------------------------------------------------------------------------
# ``ddp_notebook`` maps to fork-based DDP which is fundamentally unsafe:
# PyTorch's OpenMP thread pool (created during model construction) cannot
# survive fork() — the worker threads become zombie handles, causing
# "Invalid thread pool!" SIGABRT when the autograd engine initialises in
# the forked child.
#
# PTL considers ``start_method="spawn"`` incompatible with interactive
# environments and raises ``MisconfigurationException`` if used in Jupyter.
# However, PTL's own ``_wrapping_function`` is the entry-point for spawned
# children — no ``if __name__ == "__main__"`` guard is required — so spawn
# is perfectly safe here.
#
# Classes MUST live at module level (not inside a function) so that Python's
# pickle can serialise them for the spawned child processes.


class _InteractiveSpawnLauncher(_MultiProcessingLauncher):
"""Spawn launcher that reports itself as interactive-compatible."""

@property
def is_interactive_compatible(self) -> bool: # type: ignore[override]
return True


class _NotebookSpawnDDPStrategy(_DDPStrategy):
"""Spawn-based DDP strategy that works inside Jupyter / Kaggle notebooks."""

def _configure_launcher(self) -> None:
if self.cluster_environment is None:
raise RuntimeError(
"_NotebookSpawnDDPStrategy requires a cluster environment; "
"ensure the strategy is initialised through PTL's Trainer."
)
self._launcher = _InteractiveSpawnLauncher(self, start_method=self._start_method)


def build_trainer(
train_config: TrainConfig,
model_config: ModelConfig,
Expand Down Expand Up @@ -69,12 +114,19 @@ def build_trainer(
def _resolve_precision() -> str:
if not model_config.amp:
return "32-true"
# Ampere+ GPUs support bf16-mixed which is scaler-free —
# no GradScaler.scale/unscale/update overhead per optimizer step.
# BF16 is safe for fine-tuning (pretrained weights loaded by default).
# Training from random init with very small LR may underflow; callers
# can override via trainer_kwargs(precision="16-mixed") if needed.
#
# Note: torch.cuda.is_available() and torch.cuda.is_bf16_supported() both
# create a CUDA driver context in the parent process. This is intentional
# and safe: all DDP paths use spawn-based strategies (see _NotebookSpawnDDPStrategy
# above) so spawned children start with a fresh CUDA state regardless of
# what the parent has initialised. If a fork-based path is ever added,
# this precision check must be moved into the child process.
if torch.cuda.is_available():
# Ampere+ GPUs support bf16-mixed which is scaler-free —
# no GradScaler.scale/unscale/update overhead per optimizer step.
# BF16 is safe for fine-tuning (pretrained weights loaded by default).
# Training from random init with very small LR may underflow; callers
# can override via trainer_kwargs(precision="16-mixed") if needed.
if torch.cuda.is_bf16_supported():
return "bf16-mixed"
return "16-mixed"
Expand All @@ -84,6 +136,15 @@ def _resolve_precision() -> str:

# --- Strategy + EMA sharding guard ---
strategy = tc.strategy

# Transparently replace fork-based DDP with spawn-based DDP — see the
# module-level comment block above _InteractiveSpawnLauncher for rationale.
if strategy in ("ddp_notebook", "ddp_spawn"):
strategy = _NotebookSpawnDDPStrategy(start_method="spawn", find_unused_parameters=True)
_logger.info(
"%s → spawn-based DDP to avoid OpenMP thread pool corruption after fork.",
tc.strategy,
)
sharded = any(s in str(strategy).lower() for s in ("fsdp", "deepspeed"))
enable_ema = bool(tc.use_ema) and not sharded
if tc.use_ema and sharded:
Expand Down
29 changes: 29 additions & 0 deletions tests/models/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
# ------------------------------------------------------------------------

from unittest.mock import MagicMock, patch

import pytest
import torch
from pydantic import ValidationError
Expand All @@ -19,6 +21,7 @@
RFDETRSegXLargeConfig,
SegmentationTrainConfig,
TrainConfig,
_detect_device,
)


Expand Down Expand Up @@ -368,3 +371,29 @@ def test_default_cls_loss_coef_no_warning(self, recwarn) -> None:
RFDETRBaseConfig(pretrain_weights=None, device="cpu")
depr_warnings = [w for w in recwarn.list if issubclass(w.category, DeprecationWarning)]
assert not depr_warnings, f"Unexpected DeprecationWarning: {depr_warnings}"


class TestDetectDevice:
"""Tests for _detect_device() covering PyTorch accelerator detection paths."""

@patch("rfdetr.config.torch")
def test_falls_back_to_cuda_when_accelerator_module_absent(self, mock_torch: MagicMock) -> None:
"""Returns 'cuda' via legacy fallback when torch.accelerator lacks current_accelerator (PyTorch < 2.4)."""
mock_torch.accelerator = MagicMock(spec=[]) # no current_accelerator → hasattr returns False → fallback
mock_torch.cuda.is_available.return_value = True
mock_torch.backends.mps.is_available.return_value = False
assert _detect_device() == "cuda"

@patch("rfdetr.config.torch")
def test_returns_cpu_when_current_accelerator_raises(self, mock_torch: MagicMock) -> None:
"""Returns 'cpu' directly from the except handler when current_accelerator() raises RuntimeError."""
mock_torch.accelerator.current_accelerator.side_effect = RuntimeError("no device")
assert _detect_device() == "cpu"

@patch("rfdetr.config.torch")
def test_returns_cpu_when_no_gpu_available(self, mock_torch: MagicMock) -> None:
"""Returns 'cpu' when accelerator is absent and neither CUDA nor MPS is available."""
mock_torch.accelerator = MagicMock(spec=[]) # no current_accelerator → fallback branch
mock_torch.cuda.is_available.return_value = False
mock_torch.backends.mps.is_available.return_value = False
assert _detect_device() == "cpu"
Loading
Loading