Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,35 @@
import torch
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator

DEVICE = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"

def _detect_device() -> str:
"""Detect the best available device **without** initialising the CUDA runtime.

``torch.cuda.is_available()`` creates a CUDA driver context that makes
``_is_in_bad_fork()`` return ``True`` in child processes. This breaks
fork-based DDP strategies (e.g. ``ddp_notebook``) in notebook environments.

We defer to :func:`torch.accelerator.current_accelerator` (PyTorch ≥ 2.4)
when available — it queries the driver through NVML without creating a
primary context. On older builds we fall back to ``torch.cuda.is_available()``.
"""
if hasattr(torch.accelerator, "current_accelerator"):
try:
accel = torch.accelerator.current_accelerator()
if accel is not None:
return str(accel)
return "cpu"
except RuntimeError:
return "cpu"
# Fallback for PyTorch < 2.4 — this DOES create a CUDA driver context.
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"


DEVICE: str = _detect_device()


class BaseConfig(BaseModel):
Expand Down
30 changes: 30 additions & 0 deletions src/rfdetr/detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,28 @@ def _resolve_patch_size(patch_size: int | None, model_config: object, caller: st
return patch_size


def _ensure_model_on_device(model_ctx) -> None:
"""Move model weights to the target device recorded in *model_ctx*.

``_build_model_context`` intentionally keeps the ``nn.Module`` on CPU so
that ``RFDETR.__init__`` does not initialise CUDA (which would prevent DDP
strategies from forking in notebook environments). This helper performs
the deferred ``.to(device)`` on first use.

It is safe to call on duck-typed stand-ins (e.g. ``SimpleNamespace``); the
function silently returns when the expected attributes are missing.
"""
target = getattr(model_ctx, "device", None)
inner = getattr(model_ctx, "model", None)
if target is None or inner is None or not hasattr(inner, "parameters"):
return
if isinstance(target, str):
target = torch.device(target)
first_param = next(inner.parameters(), None)
if first_param is not None and first_param.device != target:
model_ctx.model = inner.to(target)


class RFDETR:
"""The base RF-DETR class implements the core methods for training RF-DETR models,
running inference on the models, optimising models, and uploading trained
Expand Down Expand Up @@ -488,6 +510,10 @@ def train(self, **kwargs):

config = self.get_train_config(**kwargs)
if config.batch_size == "auto":
# Auto-batch probing runs forward/backward on the actual model, which
# must be on the target device (typically CUDA). Lazy placement keeps
# the model on CPU until first use — move it now.
_ensure_model_on_device(self.model)
auto_batch = resolve_auto_batch_config(
model_context=self.model,
model_config=self.model_config,
Expand Down Expand Up @@ -585,6 +611,7 @@ def optimize_for_inference(
# Clear any previously optimized state before starting a new optimization run.
self.remove_optimized_model()

_ensure_model_on_device(self.model)
device = self.model.device
cuda_ctx = torch.cuda.device(device) if device.type == "cuda" else contextlib.nullcontext()

Expand Down Expand Up @@ -700,6 +727,7 @@ def export(
)
raise

_ensure_model_on_device(self.model)
device = self.model.device
model = deepcopy(self.model.model.to("cpu"))
model.to(device)
Expand Down Expand Up @@ -1003,6 +1031,8 @@ def predict(
"""
import supervision as sv

_ensure_model_on_device(self.model)

patch_size = _resolve_patch_size(patch_size, self.model_config, "predict")
num_windows = getattr(self.model_config, "num_windows", 1)
if isinstance(num_windows, bool) or not isinstance(num_windows, int) or num_windows <= 0:
Expand Down
6 changes: 5 additions & 1 deletion src/rfdetr/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ def _build_model_context(model_config: ModelConfig) -> ModelContext:
apply_lora(nn_model)

device = torch.device(args.device)
nn_model = nn_model.to(device)
# Keep the model on CPU here; predict() / export() / optimize_for_inference()
# will lazily move it to the target device on first use. Eagerly calling
# .to("cuda") would initialise the CUDA runtime during __init__(), which
# prevents DDP strategies (ddp_notebook, ddp_spawn) from forking/spawning
# child processes in notebook environments.
postprocess = PostProcess(num_select=args.num_select)

return ModelContext(
Expand Down
21 changes: 13 additions & 8 deletions src/rfdetr/training/module_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,16 +41,21 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None
self._dataset_val: Optional[torch.utils.data.Dataset] = None
self._dataset_test: Optional[torch.utils.data.Dataset] = None

num_workers = self.train_config.num_workers
self._num_workers: int = self.train_config.num_workers

# Use the fork-safe DEVICE constant instead of torch.cuda.is_available(),
# which creates a CUDA driver context that breaks fork-based DDP.
from rfdetr.config import DEVICE

self._pin_memory: bool = (
torch.cuda.is_available() if self.train_config.pin_memory is None else bool(self.train_config.pin_memory)
(DEVICE == "cuda") if self.train_config.pin_memory is None else bool(self.train_config.pin_memory)
)
self._persistent_workers: bool = (
num_workers > 0
self._num_workers > 0
if self.train_config.persistent_workers is None
else bool(self.train_config.persistent_workers)
)
if num_workers > 0:
if self._num_workers > 0:
self._prefetch_factor = (
self.train_config.prefetch_factor if self.train_config.prefetch_factor is not None else 2
)
Expand Down Expand Up @@ -104,7 +109,7 @@ def train_dataloader(self) -> DataLoader:
dataset = self._dataset_train
batch_size = self.train_config.batch_size
effective_batch_size = batch_size * self.train_config.grad_accum_steps
num_workers = self.train_config.num_workers
num_workers = self._num_workers

if len(dataset) < effective_batch_size * _MIN_TRAIN_BATCHES:
logger.info(
Expand Down Expand Up @@ -152,7 +157,7 @@ def val_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_val),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand All @@ -170,7 +175,7 @@ def test_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_test),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand All @@ -188,7 +193,7 @@ def predict_dataloader(self) -> DataLoader:
sampler=torch.utils.data.SequentialSampler(self._dataset_val),
drop_last=False,
collate_fn=collate_fn,
num_workers=self.train_config.num_workers,
num_workers=self._num_workers,
pin_memory=self._pin_memory,
persistent_workers=self._persistent_workers,
prefetch_factor=self._prefetch_factor,
Expand Down
6 changes: 5 additions & 1 deletion src/rfdetr/training/module_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,11 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None

# torch.compile is opt-in: set model_config.compile=True to enable.
# Only enabled on CUDA; MPS and CPU do not benefit from compilation.
compile_enabled = model_config.compile and torch.cuda.is_available() and not train_config.multi_scale
# Use the fork-safe DEVICE constant instead of torch.cuda.is_available(),
# which creates a CUDA driver context that breaks fork-based DDP.
from rfdetr.config import DEVICE

compile_enabled = model_config.compile and DEVICE == "cuda" and not train_config.multi_scale
if model_config.compile and train_config.multi_scale:
logger.info("Disabling torch.compile because multi_scale=True introduces dynamic input shapes.")
if compile_enabled:
Expand Down
68 changes: 63 additions & 5 deletions src/rfdetr/training/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,10 @@
from pytorch_lightning.callbacks import ModelCheckpoint, RichProgressBar, TQDMProgressBar
from pytorch_lightning.callbacks.progress.rich_progress import RichProgressBarTheme
from pytorch_lightning.loggers import CSVLogger, MLFlowLogger, TensorBoardLogger, WandbLogger
from pytorch_lightning.strategies import DDPStrategy as _DDPStrategy
from pytorch_lightning.strategies.launchers.multiprocessing import (
_MultiProcessingLauncher,
)

from rfdetr.config import ModelConfig, TrainConfig
from rfdetr.training.callbacks import (
Expand All @@ -28,6 +32,41 @@
_logger = get_logger()


# ---------------------------------------------------------------------------
# Notebook-safe spawn-based DDP
# ---------------------------------------------------------------------------
# ``ddp_notebook`` maps to fork-based DDP which is fundamentally unsafe:
# PyTorch's OpenMP thread pool (created during model construction) cannot
# survive fork() — the worker threads become zombie handles, causing
# "Invalid thread pool!" SIGABRT when the autograd engine initialises in
# the forked child.
#
# PTL considers ``start_method="spawn"`` incompatible with interactive
# environments and raises ``MisconfigurationException`` if used in Jupyter.
# However, PTL's own ``_wrapping_function`` is the entry-point for spawned
# children — no ``if __name__ == "__main__"`` guard is required — so spawn
# is perfectly safe here.
#
# Classes MUST live at module level (not inside a function) so that Python's
# pickle can serialise them for the spawned child processes.


class _InteractiveSpawnLauncher(_MultiProcessingLauncher):
"""Spawn launcher that reports itself as interactive-compatible."""

@property
def is_interactive_compatible(self) -> bool: # type: ignore[override]
return True


class _NotebookSpawnDDPStrategy(_DDPStrategy):
"""Spawn-based DDP strategy that works inside Jupyter / Kaggle notebooks."""

def _configure_launcher(self) -> None:
assert self.cluster_environment is not None
self._launcher = _InteractiveSpawnLauncher(self, start_method=self._start_method)


def build_trainer(
train_config: TrainConfig,
model_config: ModelConfig,
Expand Down Expand Up @@ -69,12 +108,12 @@ def build_trainer(
def _resolve_precision() -> str:
if not model_config.amp:
return "32-true"
# Ampere+ GPUs support bf16-mixed which is scaler-free —
# no GradScaler.scale/unscale/update overhead per optimizer step.
# BF16 is safe for fine-tuning (pretrained weights loaded by default).
# Training from random init with very small LR may underflow; callers
# can override via trainer_kwargs(precision="16-mixed") if needed.
if torch.cuda.is_available():
# Ampere+ GPUs support bf16-mixed which is scaler-free —
# no GradScaler.scale/unscale/update overhead per optimizer step.
# BF16 is safe for fine-tuning (pretrained weights loaded by default).
# Training from random init with very small LR may underflow; callers
# can override via trainer_kwargs(precision="16-mixed") if needed.
if torch.cuda.is_bf16_supported():
return "bf16-mixed"
return "16-mixed"
Expand All @@ -84,6 +123,25 @@ def _resolve_precision() -> str:

# --- Strategy + EMA sharding guard ---
strategy = tc.strategy

# ``ddp_notebook`` maps to fork-based DDP which is fundamentally unsafe:
# PyTorch's OpenMP thread pool (created during model construction) cannot
# survive fork() — the worker threads become zombie handles, causing
# "Invalid thread pool!" SIGABRT when the autograd engine initialises in
# the forked child. ``ddp_spawn`` is safe but PTL blocks it in notebooks.
#
# Both are replaced with a spawn-based strategy whose launcher is marked
# interactive-compatible. PTL's ``_wrapping_function`` is the entry-point
# for spawned children, so no ``if __name__ == "__main__"`` guard is needed.
if strategy in ("ddp_notebook", "ddp_spawn"):
strategy = _NotebookSpawnDDPStrategy(
start_method="spawn",
find_unused_parameters=True,
)
_logger.info(
"%s → spawn-based DDP to avoid OpenMP thread pool corruption after fork.",
tc.strategy,
)
sharded = any(s in str(strategy).lower() for s in ("fsdp", "deepspeed"))
enable_ema = bool(tc.use_ema) and not sharded
if tc.use_ema and sharded:
Expand Down
71 changes: 58 additions & 13 deletions tests/training/test_build_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,13 +235,7 @@ def test_amp_true_cpu_gives_32_true(self, tmp_path):
assert trainer.precision == "32-true"

def test_amp_true_cuda_no_bf16_gives_16_mixed(self, tmp_path):
"""amp=True with CUDA but no bf16 support must produce '16-mixed'.

We mock the Trainer constructor to capture the precision kwarg rather
than inspecting trainer.precision after construction: PTL may re-detect
hardware bf16 support during __init__ and normalise the precision string
on machines that happen to have a bf16-capable GPU.
"""
"""amp=True with CUDA but no bf16 support must produce '16-mixed'."""
import unittest.mock as mock

captured: dict = {}
Expand All @@ -259,12 +253,7 @@ def _fake_trainer(**kwargs):
assert captured["precision"] == "16-mixed"

def test_amp_true_cuda_bf16_supported_gives_bf16_mixed(self, tmp_path):
"""amp=True with CUDA + bf16 hardware produces 'bf16-mixed' (scaler-free).

On Ampere+ GPUs (bf16 supported) we select bf16-mixed to eliminate
GradScaler overhead. Fine-tuning from pretrained weights is safe with
BF16; callers training from scratch can override via trainer_kwargs.
"""
"""amp=True with CUDA + bf16 hardware produces 'bf16-mixed'."""
import unittest.mock as mock

captured: dict = {}
Expand All @@ -281,6 +270,62 @@ def _fake_trainer(**kwargs):
build_trainer(_tc(tmp_path, use_ema=False), _mc(amp=True))
assert captured["precision"] == "bf16-mixed"

def test_amp_true_ddp_notebook_probes_bf16_normally(self, tmp_path):
"""ddp_notebook uses standard precision probing (spawn makes CUDA init safe).

With spawn-based DDP, child processes start fresh — CUDA init in the
parent does not propagate. So ``is_bf16_supported()`` is safe to call
and pre-Ampere GPUs correctly get ``16-mixed`` instead of the slower
bf16 emulation path.
"""
import unittest.mock as mock

captured: dict = {}

def _fake_trainer(**kwargs):
captured.update(kwargs)
return mock.MagicMock()

# Simulate pre-Ampere GPU: CUDA available but bf16 NOT supported.
with (
mock.patch("torch.cuda.is_available", return_value=True),
mock.patch("torch.cuda.is_bf16_supported", return_value=False),
mock.patch("rfdetr.training.trainer.Trainer", side_effect=_fake_trainer),
):
build_trainer(
_tc(tmp_path, use_ema=False, strategy="ddp_notebook"),
_mc(amp=True),
)
assert captured["precision"] == "16-mixed"

@pytest.mark.parametrize("strategy_name", ["ddp_notebook", "ddp_spawn"])
def test_ddp_notebook_and_spawn_use_interactive_spawn(self, tmp_path, strategy_name):
"""ddp_notebook and ddp_spawn must be replaced with interactive spawn DDPStrategy.

Fork-based DDP inherits the parent's OpenMP thread pool which is
invalid after fork, causing SIGABRT in the autograd engine.
ddp_spawn is blocked by PTL in notebooks without the override.
"""
import unittest.mock as mock

from pytorch_lightning.strategies import DDPStrategy

captured: dict = {}

def _fake_trainer(**kwargs):
captured.update(kwargs)
return mock.MagicMock()

with mock.patch("rfdetr.training.trainer.Trainer", side_effect=_fake_trainer):
build_trainer(
_tc(tmp_path, use_ema=False, strategy=strategy_name),
_mc(amp=True),
)
strategy_obj = captured["strategy"]
assert isinstance(strategy_obj, DDPStrategy)
assert strategy_obj._start_method == "spawn"
assert strategy_obj._ddp_kwargs.get("find_unused_parameters") is True


class TestBuildTrainerEMAShardingGuard:
"""EMA must be disabled and a UserWarning emitted for sharded strategies.
Expand Down
Loading
Loading