diff --git a/CHANGELOG.md b/CHANGELOG.md index 9c645d16..169f244e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `augmentation_backend` field on `TrainConfig` (`"cpu"` / `"auto"` / `"gpu"`): opt-in GPU-side augmentation via [Kornia](https://kornia.readthedocs.io) applied in `RFDETRDataModule.on_after_batch_transfer` after the batch is resident on the GPU. CPU path is unchanged and remains the default. Install with `pip install 'rfdetr[kornia]'`. Phase 1 supports detection only; segmentation mask support is planned for Phase 2. - `BuilderArgs` — a `@runtime_checkable` `typing.Protocol` documenting the minimum attribute set consumed by `build_model()`, `build_backbone()`, `build_transformer()`, and `build_criterion_and_postprocessors()`. Enables static type-checker support for custom builder integrations. Exported from `rfdetr.models`. (#841) - `build_model_from_config(model_config, train_config=None, defaults=MODEL_DEFAULTS)` — config-native alternative to `build_model(build_namespace(mc, tc))`; accepts Pydantic config objects directly and constructs the internal namespace automatically. Exported from `rfdetr.models`. (#845) - `build_criterion_from_config(model_config, train_config, defaults=MODEL_DEFAULTS)` — config-native alternative to `build_criterion_and_postprocessors(build_namespace(mc, tc))`; returns a `(SetCriterion, PostProcess)` tuple. Exported from `rfdetr.models`. (#845) diff --git a/pyproject.toml b/pyproject.toml index 8cc5ecf6..a04fe12b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,6 +73,9 @@ trt = [ "tensorrt>=8.6.1", "polygraphy", ] +kornia = [ + "kornia>=0.7,<1", # GPU-side augmentation via on_after_batch_transfer +] loggers = [ "tensorboard>=2.13.0", "protobuf>=3.20.0,<4.0.0", # Pins protobuf below 4.x to avoid TensorBoard descriptor crash with protobuf>=4 (see #844) diff --git a/src/rfdetr/config.py b/src/rfdetr/config.py index acf4a5ee..cee24584 100644 --- a/src/rfdetr/config.py +++ b/src/rfdetr/config.py @@ -449,6 +449,7 @@ class TrainConfig(BaseModel): eval_interval: int = 1 log_per_class_metrics: bool = True aug_config: Optional[Dict[str, Any]] = None + augmentation_backend: Literal["cpu", "auto", "gpu"] = "cpu" @model_validator(mode="after") def _warn_deprecated_train_config_fields(self) -> "TrainConfig": diff --git a/src/rfdetr/datasets/aug_config.py b/src/rfdetr/datasets/aug_config.py index d4abcb03..4cc72910 100644 --- a/src/rfdetr/datasets/aug_config.py +++ b/src/rfdetr/datasets/aug_config.py @@ -61,6 +61,27 @@ "YourCustomTransform", # Add here } ``` + +## Kornia GPU Backend + +When ``augmentation_backend="auto"`` or ``"gpu"`` is set in ``TrainConfig``, augmentations +run on the GPU via Kornia instead of Albumentations. + +**Supported transforms** (all presets): + +| Preset key | Kornia equivalent | Notes | +|---|---|---| +| ``HorizontalFlip`` | ``K.RandomHorizontalFlip`` | Direct | +| ``VerticalFlip`` | ``K.RandomVerticalFlip`` | Direct | +| ``Rotate`` | ``K.RandomRotation`` | ``limit`` may be scalar or tuple | +| ``Affine`` | ``K.RandomAffine`` | ``translate_percent`` treated as fraction | +| ``ColorJitter`` | ``K.ColorJiggle`` | Same multiplicative semantics | +| ``RandomBrightnessContrast`` | ``K.ColorJiggle`` | ``brightness_limit`` / ``contrast_limit`` direct | +| ``GaussianBlur`` | ``K.RandomGaussianBlur`` | ``blur_limit`` rounded up to odd; ``sigma=(0.1, 2.0)`` | +| ``GaussNoise`` | ``K.RandomGaussianNoise`` | Upper bound of ``std_range`` used as fixed std | + +**Phase 1 limitation**: Segmentation models (``segmentation_head=True``) skip GPU augmentation; +CPU Albumentations are used instead. Mask support is planned for Phase 2. """ # --------------------------------------------------------------------------- diff --git a/src/rfdetr/datasets/coco.py b/src/rfdetr/datasets/coco.py index 5b747882..7688ae5c 100644 --- a/src/rfdetr/datasets/coco.py +++ b/src/rfdetr/datasets/coco.py @@ -392,6 +392,7 @@ def make_coco_transforms( patch_size: int = 16, num_windows: int = 4, aug_config: Optional[Dict[str, Dict[str, Any]]] = None, + gpu_postprocess: bool = False, ) -> Compose: """Build the standard COCO transform pipeline for a given dataset split. @@ -425,6 +426,10 @@ def make_coco_transforms( :class:`~rfdetr.datasets.transforms.AlbumentationsWrapper`. Falls back to the default :data:`~rfdetr.datasets.aug_config.AUG_CONFIG` when ``None``. + gpu_postprocess: When ``True``, skip Albumentations augmentation wrappers and + ``Normalize`` from the CPU pipeline. The ``RFDETRDataModule`` then applies + both augmentation and normalization on the GPU in + ``on_after_batch_transfer``. Has no effect on val/test splits. Returns: A :class:`torchvision.transforms.v2.Compose` pipeline ready to be passed @@ -450,8 +455,14 @@ def make_coco_transforms( resize_wrappers = AlbumentationsWrapper.from_config( _build_train_resize_config(scales, square=False, max_size=1333) ) - aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config) - return Compose([*resize_wrappers, *aug_wrappers, to_image, to_float, normalize]) + pipeline = [*resize_wrappers] + if not gpu_postprocess: + aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config) + pipeline += [*aug_wrappers] + pipeline += [to_image, to_float] + if not gpu_postprocess: + pipeline += [normalize] + return Compose(pipeline) if image_set in ("val", "test"): resize_wrappers = AlbumentationsWrapper.from_config( @@ -477,6 +488,7 @@ def make_coco_transforms_square_div_64( patch_size: int = 16, num_windows: int = 4, aug_config: Optional[Dict[str, Dict[str, Any]]] = None, + gpu_postprocess: bool = False, ) -> Compose: """ Create COCO transforms with square resizing where the output size is divisible by 64. @@ -506,6 +518,10 @@ def make_coco_transforms_square_div_64( aug_config: Augmentation configuration dictionary compatible with :class:`~rfdetr.datasets.transforms.AlbumentationsWrapper`. If ``None``, the default :data:`~rfdetr.datasets.aug_config.AUG_CONFIG` is used. + gpu_postprocess: When ``True``, skip Albumentations augmentation wrappers and + ``Normalize`` from the CPU pipeline. The ``RFDETRDataModule`` then applies + both augmentation and normalization on the GPU in + ``on_after_batch_transfer``. Has no effect on val/test splits. Returns: A ``Compose`` object containing the composed image transforms appropriate @@ -526,8 +542,14 @@ def make_coco_transforms_square_div_64( if image_set == "train": resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG resize_wrappers = AlbumentationsWrapper.from_config(_build_train_resize_config(scales, square=True)) - aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config) - return Compose([*resize_wrappers, *aug_wrappers, to_image, to_float, normalize]) + pipeline = [*resize_wrappers] + if not gpu_postprocess: + aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config) + pipeline += [*aug_wrappers] + pipeline += [to_image, to_float] + if not gpu_postprocess: + pipeline += [normalize] + return Compose(pipeline) if image_set in ("val", "test", "val_speed"): resize_wrappers = AlbumentationsWrapper.from_config([{"Resize": {"height": resolution, "width": resolution}}]) @@ -554,6 +576,32 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection: square_resize_div_64 = getattr(args, "square_resize_div_64", False) include_masks = getattr(args, "segmentation_head", False) aug_config = getattr(args, "aug_config", None) + augmentation_backend = getattr(args, "augmentation_backend", "cpu") + resolved_augmentation_backend = augmentation_backend + if include_masks and augmentation_backend != "cpu": + logger.warning( + "Segmentation training does not currently support GPU postprocess transforms; " + "forcing augmentation_backend='cpu' to retain CPU transforms and normalization." + ) + resolved_augmentation_backend = "cpu" + if hasattr(args, "augmentation_backend"): + setattr(args, "augmentation_backend", "cpu") + if resolved_augmentation_backend == "auto": + gpu_available = torch.cuda.is_available() + if gpu_available: + try: + import kornia # type: ignore[import-not-found] + except ImportError: + gpu_available = False + if not gpu_available: + logger.warning( + "augmentation_backend='auto' resolved to 'cpu' because CUDA or kornia is unavailable; " + "disabling GPU postprocess transforms and retaining CPU normalization." + ) + resolved_augmentation_backend = "cpu" + if hasattr(args, "augmentation_backend"): + setattr(args, "augmentation_backend", "cpu") + gpu_postprocess = resolved_augmentation_backend != "cpu" and not include_masks if square_resize_div_64: logger.info(f"Building COCO {image_set} dataset with square resize at resolution {resolution}") @@ -569,6 +617,7 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection: patch_size=args.patch_size, num_windows=args.num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, ) @@ -586,6 +635,7 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection: patch_size=args.patch_size, num_windows=args.num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, ) @@ -618,6 +668,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco patch_size = getattr(args, "patch_size", 16) num_windows = getattr(args, "num_windows", 4) aug_config = getattr(args, "aug_config", None) + gpu_postprocess = getattr(args, "augmentation_backend", "cpu") != "cpu" and not include_masks if square_resize_div_64: logger.info(f"Building Roboflow {image_set} dataset with square resize at resolution {resolution}") @@ -633,6 +684,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco patch_size=patch_size, num_windows=num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, remap_category_ids=True, @@ -651,6 +703,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco patch_size=patch_size, num_windows=num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, remap_category_ids=True, diff --git a/src/rfdetr/datasets/kornia_transforms.py b/src/rfdetr/datasets/kornia_transforms.py new file mode 100644 index 00000000..370e9c1b --- /dev/null +++ b/src/rfdetr/datasets/kornia_transforms.py @@ -0,0 +1,373 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Kornia-based GPU augmentation pipeline for RF-DETR detection training. + +This module provides GPU-side augmentation as an alternative to the CPU-based +Albumentations pipeline. All transforms run on the device where the batch +already resides (typically CUDA), avoiding a CPU-GPU round-trip per sample. + +Phase 1 supports detection bounding boxes only; segmentation masks are +deferred to phase 2. + +Usage:: + + from rfdetr.datasets.kornia_transforms import ( + build_kornia_pipeline, + build_normalize, + collate_boxes, + unpack_boxes, + ) + + pipeline = build_kornia_pipeline(aug_config, resolution=560) + normalize = build_normalize() + + # In on_after_batch_transfer: + boxes_padded, valid = collate_boxes(targets, device) + img_aug, boxes_aug = pipeline(img, boxes_padded) + img_aug = normalize(img_aug) + targets = unpack_boxes(boxes_aug, valid, targets, H, W) +""" + +from __future__ import annotations + +from typing import Any, Callable, Dict, List, Tuple + +import torch +from torch import Tensor + +from rfdetr.utilities.logger import get_logger + +logger = get_logger() + +#: ImageNet channel-wise mean (RGB order). +IMAGENET_MEAN = (0.485, 0.456, 0.406) +#: ImageNet channel-wise standard deviation (RGB order). +IMAGENET_STD = (0.229, 0.224, 0.225) + + +def _require_kornia() -> None: + """Verify that Kornia is importable, raising a clear error if not. + + Raises: + ImportError: When ``kornia`` is not installed, with an install hint. + """ + try: + import kornia.augmentation # noqa: F401 + except ImportError as e: + raise ImportError("GPU augmentation requires kornia. Install with: pip install 'rfdetr[kornia]'") from e + + +# --------------------------------------------------------------------------- +# Registry: Albumentations key -> Kornia factory +# --------------------------------------------------------------------------- + + +def _make_horizontal_flip(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomHorizontalFlip`` from aug_config params.""" + import kornia.augmentation as K + + return K.RandomHorizontalFlip(p=params.get("p", 0.5)) + + +def _make_vertical_flip(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomVerticalFlip`` from aug_config params.""" + import kornia.augmentation as K + + return K.RandomVerticalFlip(p=params.get("p", 0.5)) + + +def _make_rotate(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomRotation`` from aug_config params. + + The ``limit`` parameter may be a scalar (symmetric range) or a tuple. + """ + import kornia.augmentation as K + + limit = params.get("limit", 15) + if isinstance(limit, (list, tuple)): + degrees = tuple(limit) + else: + degrees = (-limit, limit) + return K.RandomRotation(degrees=degrees, p=params.get("p", 0.5)) + + +def _make_affine(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomAffine`` from aug_config params.""" + import kornia.augmentation as K + + return K.RandomAffine( + degrees=params.get("rotate", (-15, 15)), + translate=params.get("translate_percent", None), + scale=params.get("scale", None), + shear=params.get("shear", None), + p=params.get("p", 0.5), + ) + + +def _make_color_jitter(params: Dict[str, Any]) -> Any: + """Build a ``K.ColorJiggle`` from aug_config ``ColorJitter`` params. + + Note: Kornia >=0.7 uses ``ColorJiggle``; the ``ColorJitter`` alias was + added in later versions. We use ``ColorJiggle`` for broad compatibility. + """ + import kornia.augmentation as K + + return K.ColorJiggle( + brightness=params.get("brightness", 0.0), + contrast=params.get("contrast", 0.0), + saturation=params.get("saturation", 0.0), + hue=params.get("hue", 0.0), + p=params.get("p", 0.5), + ) + + +def _make_random_brightness_contrast(params: Dict[str, Any]) -> Any: + """Build a ``K.ColorJiggle`` from ``RandomBrightnessContrast`` params.""" + import kornia.augmentation as K + + return K.ColorJiggle( + brightness=params.get("brightness_limit", 0.2), + contrast=params.get("contrast_limit", 0.2), + p=params.get("p", 0.5), + ) + + +def _make_gaussian_blur(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomGaussianBlur`` from aug_config params. + + ``blur_limit`` is rounded up to an odd number for the kernel size. + """ + import kornia.augmentation as K + + blur_limit = params.get("blur_limit", 3) + # Ensure blur_limit is odd and at least 3 (Kornia requires kernel_size >= 3) + if blur_limit % 2 == 0: + blur_limit = blur_limit + 1 + blur_limit = max(3, blur_limit) + return K.RandomGaussianBlur( + kernel_size=(blur_limit, blur_limit), + sigma=(0.1, 2.0), + p=params.get("p", 0.5), + ) + + +def _make_gauss_noise(params: Dict[str, Any]) -> Any: + """Build a ``K.RandomGaussianNoise`` from aug_config params. + + Kornia takes a single ``std`` value; we use the upper bound of + ``std_range`` as an acceptable approximation. + """ + import kornia.augmentation as K + + std_range = params.get("std_range", (0.01, 0.05)) + return K.RandomGaussianNoise( + std=std_range[1], + p=params.get("p", 0.5), + ) + + +_REGISTRY: Dict[str, Callable[[Dict[str, Any]], Any]] = { + "HorizontalFlip": _make_horizontal_flip, + "VerticalFlip": _make_vertical_flip, + "Rotate": _make_rotate, + "Affine": _make_affine, + "ColorJitter": _make_color_jitter, + "RandomBrightnessContrast": _make_random_brightness_contrast, + "GaussianBlur": _make_gaussian_blur, + "GaussNoise": _make_gauss_noise, +} + + +# --------------------------------------------------------------------------- +# Pipeline builders +# --------------------------------------------------------------------------- + + +def build_kornia_pipeline( + aug_config: Dict[str, Dict[str, Any]], + resolution: int, +) -> Any: + """Build a Kornia ``AugmentationSequential`` from an aug_config dict. + + Each key in *aug_config* is looked up in ``_REGISTRY`` and instantiated + with the corresponding parameter dict. Unknown keys raise ``ValueError``. + + Args: + aug_config: Mapping of augmentation names to parameter dicts, identical + to the format accepted by the Albumentations path (e.g. + ``{"HorizontalFlip": {"p": 0.5}}``). + resolution: Target image resolution in pixels (currently reserved for + future resolution-aware augmentations). + + Returns: + A ``kornia.augmentation.AugmentationSequential`` instance configured + with ``data_keys=["input", "bbox_xyxy"]``. + + Raises: + ValueError: If *aug_config* contains an unsupported augmentation key. + """ + _require_kornia() + import kornia.augmentation as K + + transforms: List[Any] = [] + for name, params in aug_config.items(): + factory = _REGISTRY.get(name) + if factory is None: + raise ValueError( + f"Unknown augmentation key {name!r} for Kornia GPU backend. Supported keys: {sorted(_REGISTRY)}." + ) + transforms.append(factory(params)) + + return K.AugmentationSequential( + *transforms, + data_keys=["input", "bbox_xyxy"], + ) + + +def build_normalize( + mean: Tuple[float, ...] = IMAGENET_MEAN, + std: Tuple[float, ...] = IMAGENET_STD, +) -> Any: + """Build a Kornia ``Normalize`` transform for GPU-side normalization. + + Args: + mean: Per-channel mean values. Defaults to ImageNet statistics. + std: Per-channel standard deviation values. Defaults to ImageNet + statistics. + + Returns: + A ``kornia.augmentation.Normalize`` instance. + """ + _require_kornia() + import kornia.augmentation as K + + return K.Normalize( + mean=mean, + std=std, + ) + + +# --------------------------------------------------------------------------- +# Bounding-box utilities +# --------------------------------------------------------------------------- + + +def collate_boxes( + targets: List[Dict[str, Any]], + device: torch.device, +) -> Tuple[Tensor, Tensor]: + """Pack variable-length xyxy boxes into a padded tensor and valid mask. + + Kornia ``AugmentationSequential`` expects boxes as ``[B, N_max, 4]``. + This function zero-pads each image's boxes to the maximum count in the + batch and returns a boolean mask indicating which entries are real. + + Args: + targets: List of target dicts (one per image), each containing a + ``"boxes"`` key with an ``[N_i, 4]`` tensor in xyxy format. + device: Device on which to allocate the output tensors. + + Returns: + Tuple of: + - ``boxes_padded`` — ``[B, N_max, 4]`` float tensor (zero-padded). + - ``valid_mask`` — ``[B, N_max]`` bool tensor (``True`` = real box). + + When ``B == 0`` or all images have zero boxes, both tensors have + ``N_max == 0``. + """ + if len(targets) == 0: + return ( + torch.zeros(0, 0, 4, device=device), + torch.zeros(0, 0, dtype=torch.bool, device=device), + ) + + box_counts = [t["boxes"].shape[0] for t in targets] + n_max = max(box_counts) if box_counts else 0 + B = len(targets) + + if n_max == 0: + return ( + torch.zeros(B, 0, 4, device=device), + torch.zeros(B, 0, dtype=torch.bool, device=device), + ) + + boxes_padded = torch.zeros(B, n_max, 4, device=device) + valid_mask = torch.zeros(B, n_max, dtype=torch.bool, device=device) + + for i, t in enumerate(targets): + n = t["boxes"].shape[0] + if n > 0: + boxes_padded[i, :n] = t["boxes"] + valid_mask[i, :n] = True + + return boxes_padded, valid_mask + + +def unpack_boxes( + boxes_aug: Tensor, + valid: Tensor, + targets: List[Dict[str, Any]], + H: int, + W: int, +) -> List[Dict[str, Any]]: + """Unpack augmented boxes, clamp to image bounds, and remove zero-area boxes. + + After Kornia augmentation the padded ``[B, N_max, 4]`` tensor is unpacked + back into per-image target dicts. Boxes are clamped to ``[0, W] x [0, H]`` + and any that collapse to zero area are removed along with their + corresponding ``labels``, ``area``, and ``iscrowd`` entries. + + Args: + boxes_aug: Augmented boxes tensor ``[B, N_max, 4]`` in xyxy format. + valid: Boolean mask ``[B, N_max]`` from :func:`collate_boxes`. + targets: Original target dicts; each dict is shallow-copied before + modification — the input list itself is not mutated. + H: Image height in pixels (for clamping). + W: Image width in pixels (for clamping). + + Returns: + A new list of target dicts with updated ``boxes``, ``labels``, + ``area``, and ``iscrowd`` entries. + """ + new_targets: List[Dict[str, Any]] = [] + for i, t in enumerate(targets): + t = t.copy() + n_orig = t["boxes"].shape[0] + + if n_orig == 0 or valid.shape[1] == 0: + new_targets.append(t) + continue + + # Extract valid boxes for this image + v = valid[i, :n_orig] + boxes_i = boxes_aug[i, :n_orig] + + # Clamp to image boundaries + boxes_i = boxes_i.clone() + boxes_i[:, 0].clamp_(min=0, max=W) + boxes_i[:, 1].clamp_(min=0, max=H) + boxes_i[:, 2].clamp_(min=0, max=W) + boxes_i[:, 3].clamp_(min=0, max=H) + + # Remove zero-area boxes (after clamping) + widths = boxes_i[:, 2] - boxes_i[:, 0] + heights = boxes_i[:, 3] - boxes_i[:, 1] + keep = v & (widths > 0) & (heights > 0) + + t["boxes"] = boxes_i[keep] + if "labels" in t: + t["labels"] = t["labels"][keep] + if "area" in t: + # Recompute area from clamped boxes + kept_boxes = t["boxes"] + t["area"] = (kept_boxes[:, 2] - kept_boxes[:, 0]) * (kept_boxes[:, 3] - kept_boxes[:, 1]) + if "iscrowd" in t: + t["iscrowd"] = t["iscrowd"][keep] + + new_targets.append(t) + + return new_targets diff --git a/src/rfdetr/datasets/o365.py b/src/rfdetr/datasets/o365.py index e3bfa772..3da15e2d 100644 --- a/src/rfdetr/datasets/o365.py +++ b/src/rfdetr/datasets/o365.py @@ -15,9 +15,12 @@ from PIL import Image from rfdetr.datasets.coco import CocoDetection, make_coco_transforms, make_coco_transforms_square_div_64 +from rfdetr.utilities.logger import get_logger Image.MAX_IMAGE_PIXELS = None +logger = get_logger() + def build_o365_raw(image_set: str, args: Any, resolution: int) -> CocoDetection: root = Path(getattr(args, "dataset_dir", None) or args.coco_path) @@ -28,13 +31,50 @@ def build_o365_raw(image_set: str, args: Any, resolution: int) -> CocoDetection: img_folder, ann_file = PATHS[image_set] square_resize_div_64 = getattr(args, "square_resize_div_64", False) + augmentation_backend = getattr(args, "augmentation_backend", "cpu") + resolved_backend = augmentation_backend + + if augmentation_backend == "auto": + # Resolve 'auto' based on CUDA and kornia availability + has_cuda = False + has_kornia = False + try: + import torch + + has_cuda = bool(torch.cuda.is_available()) + except Exception: + has_cuda = False + + try: + import kornia.augmentation # noqa: F401 + + has_kornia = True + except Exception: + has_kornia = False + + if has_cuda and has_kornia: + resolved_backend = "gpu" + else: + resolved_backend = "cpu" + + if resolved_backend != "cpu": + logger.warning( + "O365 dataset does not support custom aug_config in Phase 1 GPU augmentation; " + "Albumentations augmentation is skipped and normalization runs on GPU. " + "Pass augmentation_backend='cpu' for full CPU augmentation pipeline with O365." + ) + gpu_postprocess = resolved_backend != "cpu" if square_resize_div_64: dataset = CocoDetection( img_folder, ann_file, transforms=make_coco_transforms_square_div_64( - image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + gpu_postprocess=gpu_postprocess, ), ) else: @@ -42,7 +82,11 @@ def build_o365_raw(image_set: str, args: Any, resolution: int) -> CocoDetection: img_folder, ann_file, transforms=make_coco_transforms( - image_set, resolution, multi_scale=args.multi_scale, expanded_scales=args.expanded_scales + image_set, + resolution, + multi_scale=args.multi_scale, + expanded_scales=args.expanded_scales, + gpu_postprocess=gpu_postprocess, ), ) return dataset diff --git a/src/rfdetr/datasets/yolo.py b/src/rfdetr/datasets/yolo.py index 86a74563..a3f3610f 100644 --- a/src/rfdetr/datasets/yolo.py +++ b/src/rfdetr/datasets/yolo.py @@ -708,6 +708,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo patch_size = getattr(args, "patch_size", None) num_windows = getattr(args, "num_windows", None) aug_config = getattr(args, "aug_config", None) + gpu_postprocess = getattr(args, "augmentation_backend", "cpu") != "cpu" and not include_masks if square_resize_div_64: dataset = YoloDetection( @@ -723,6 +724,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo patch_size=patch_size, num_windows=num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, ) @@ -740,6 +742,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo patch_size=patch_size, num_windows=num_windows, aug_config=aug_config, + gpu_postprocess=gpu_postprocess, ), include_masks=include_masks, ) diff --git a/src/rfdetr/training/module_data.py b/src/rfdetr/training/module_data.py index 96cca6af..2a2458da 100644 --- a/src/rfdetr/training/module_data.py +++ b/src/rfdetr/training/module_data.py @@ -16,6 +16,7 @@ from rfdetr._namespace import _namespace_from_configs from rfdetr.config import ModelConfig, TrainConfig from rfdetr.datasets import build_dataset +from rfdetr.datasets.aug_config import AUG_CONFIG from rfdetr.utilities.logger import get_logger from rfdetr.utilities.tensors import collate_fn @@ -92,6 +93,41 @@ def __getitem__(self, idx: int) -> Any: return self._dataset[dataset_idx] +def _resolve_augmentation_backend(backend: str) -> str: + """Resolve ``"auto"`` to ``"cpu"`` or ``"gpu"`` based on runtime availability. + + For ``"cpu"`` and ``"gpu"`` the value is returned unchanged. For + ``"auto"`` the function checks CUDA and kornia availability and returns + ``"gpu"`` only when both are present; otherwise ``"cpu"``. + + Called before dataset construction so that ``gpu_postprocess`` in the + dataset builders always matches what the DataModule will actually do in + ``on_after_batch_transfer``. + + Args: + backend: Value of ``TrainConfig.augmentation_backend``. + + Returns: + Resolved backend string, either ``"cpu"`` or ``"gpu"``. + + Examples: + >>> _resolve_augmentation_backend("cpu") + 'cpu' + >>> _resolve_augmentation_backend("gpu") + 'gpu' + """ + if backend != "auto": + return backend + if not torch.cuda.is_available(): + return "cpu" + try: + import kornia.augmentation # noqa: F401 + + return "gpu" + except ImportError: + return "cpu" + + class RFDETRDataModule(LightningDataModule): """LightningDataModule wrapping RF-DETR dataset construction and data loading. @@ -109,6 +145,15 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None self._dataset_val: Optional[torch.utils.data.Dataset] = None self._dataset_test: Optional[torch.utils.data.Dataset] = None + # GPU augmentation pipeline (Kornia); built lazily in setup("fit"). + self._kornia_pipeline: Optional[Any] = None + self._kornia_normalize: Optional[Any] = None + # Sentinel: True once _setup_kornia_pipeline has run (even on fallback paths + # where _kornia_pipeline stays None), preventing redundant re-runs on repeated + # setup("fit") calls (e.g. during validation loops in some PTL strategies). + self._kornia_setup_done: bool = False + + num_workers = self.train_config.num_workers self._num_workers: int = self.train_config.num_workers # Use the fork-safe DEVICE constant instead of torch.cuda.is_available(), @@ -152,10 +197,24 @@ def setup(self, stage: str) -> None: resolution = self.model_config.resolution ns = _namespace_from_configs(self.model_config, self.train_config) if stage == "fit": + # Resolve 'auto' to an actual backend before building datasets so that + # gpu_postprocess in dataset builders always matches what the DataModule + # will actually do in on_after_batch_transfer. Without this, 'auto' on + # a machine without CUDA/kornia would strip CPU Normalize from datasets + # while _kornia_pipeline stays None, leaving training inputs unnormalized. + resolved = _resolve_augmentation_backend(self.train_config.augmentation_backend) + if resolved != self.train_config.augmentation_backend: + ns.augmentation_backend = resolved if self._dataset_train is None: self._dataset_train = build_dataset("train", ns, resolution) if self._dataset_val is None: self._dataset_val = build_dataset("val", ns, resolution) + # Build Kornia GPU augmentation pipeline (once). + # Use _kornia_setup_done (not _kornia_pipeline is None) so that fallback + # paths — where the pipeline stays None — do not re-run on every setup("fit"). + if not self._kornia_setup_done: + self._setup_kornia_pipeline() + self._kornia_setup_done = True elif stage == "validate": if self._dataset_val is None: self._dataset_val = build_dataset("val", ns, resolution) @@ -282,6 +341,81 @@ def predict_dataloader(self) -> DataLoader: prefetch_factor=self._prefetch_factor, ) + def _setup_kornia_pipeline(self) -> None: + """Resolve augmentation backend and build the Kornia pipeline if applicable. + + Called once during ``setup("fit")``. When ``augmentation_backend`` + is ``"cpu"`` this is a no-op. For ``"auto"`` the method falls back + silently when CUDA or Kornia are unavailable. For ``"gpu"`` missing + requirements raise hard errors. + """ + backend = self.train_config.augmentation_backend + if backend == "cpu": + return + + if backend == "auto": + if not torch.cuda.is_available(): + logger.warning("augmentation_backend='auto': no CUDA, falling back to CPU augmentation") + return + try: + import kornia.augmentation + except ImportError: + logger.warning("augmentation_backend='auto': kornia not installed, using CPU augmentation") + return + elif backend == "gpu": + if not torch.cuda.is_available(): + raise RuntimeError("augmentation_backend='gpu' requires a CUDA device") + try: + import kornia.augmentation # noqa: F401 + except ImportError as e: + raise ImportError("GPU augmentation requires kornia. Install with: pip install 'rfdetr[kornia]'") from e + + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline, build_normalize + + self._kornia_pipeline = build_kornia_pipeline( + self.train_config.aug_config if self.train_config.aug_config is not None else AUG_CONFIG, + self.model_config.resolution, + ) + self._kornia_normalize = build_normalize() + logger.info("Kornia GPU augmentation pipeline built (backend=%s)", backend) + + def on_after_batch_transfer(self, batch: Tuple, dataloader_idx: int) -> Tuple: + """Apply Kornia GPU augmentation after the batch is transferred to device. + + When ``_kornia_pipeline`` is set and the trainer is in training mode, + augmentation and normalization are applied on the GPU. Validation + and test batches pass through unchanged. + + Segmentation models skip GPU augmentation in phase 1 with a warning. + + Args: + batch: Tuple of ``(NestedTensor, list[dict])`` already on device. + dataloader_idx: Index of the current dataloader. + + Returns: + The (possibly augmented) batch. + """ + if self.trainer.training and self._kornia_pipeline is not None: + if self.model_config.segmentation_head: + logger.warning_once("Kornia GPU augmentation skipped for segmentation models (phase 2)") + return batch + + from rfdetr.datasets.kornia_transforms import collate_boxes, unpack_boxes + from rfdetr.utilities.tensors import NestedTensor + + samples, targets = batch + img = samples.tensors # [B, C, H, W] + # Move Kornia modules to the batch device (no-op if already there). + # nn.Module.to() is in-place; no reassignment needed. + self._kornia_pipeline.to(img.device) + self._kornia_normalize.to(img.device) + boxes_padded, valid = collate_boxes(targets, img.device) + img_aug, boxes_aug = self._kornia_pipeline(img, boxes_padded) + img_aug = self._kornia_normalize(img_aug) + targets = unpack_boxes(boxes_aug, valid, targets, *img_aug.shape[-2:]) + batch = (NestedTensor(img_aug, samples.mask), targets) + return batch + # ------------------------------------------------------------------ # Properties # ------------------------------------------------------------------ diff --git a/tests/datasets/test_coco.py b/tests/datasets/test_coco.py index 62df944d..27a5c874 100644 --- a/tests/datasets/test_coco.py +++ b/tests/datasets/test_coco.py @@ -217,3 +217,112 @@ def test_unsorted_category_ids_return_id_sorted_class_order(self, tmp_path: Path _write_coco_json(tmp_path / "train" / "_annotations.coco.json", categories) result = RFDETR._load_classes(str(tmp_path)) assert result == ["car", "truck", "person"] + + +# --------------------------------------------------------------------------- +# TestBuildO365RawGpuBackend — validates that build_o365_raw emits a WARNING +# and passes gpu_postprocess when augmentation_backend != 'cpu'. +# --------------------------------------------------------------------------- + + +class TestBuildO365RawGpuBackend: + """build_o365_raw warns and wires gpu_postprocess for non-cpu backends.""" + + class _FakeArgs: + """Minimal args stub for build_o365_raw.""" + + def __init__(self, augmentation_backend="cpu", square_resize_div_64=False): + self.augmentation_backend = augmentation_backend + self.square_resize_div_64 = square_resize_div_64 + self.multi_scale = False + self.expanded_scales = False + self.dataset_dir = "/nonexistent/o365" + self.coco_path = "/nonexistent/o365" + + def _call_build_o365_raw(self, augmentation_backend, square_resize_div_64=False): + """Call build_o365_raw with mocked CocoDetection and transform builders.""" + from unittest.mock import MagicMock, patch + + from rfdetr.datasets.o365 import build_o365_raw + + args = self._FakeArgs(augmentation_backend=augmentation_backend, square_resize_div_64=square_resize_div_64) + fake_dataset = MagicMock() + + with ( + patch("rfdetr.datasets.o365.CocoDetection", return_value=fake_dataset), + patch("rfdetr.datasets.o365.make_coco_transforms") as mock_transform, + patch("rfdetr.datasets.o365.make_coco_transforms_square_div_64") as mock_sq_transform, + ): + mock_transform.return_value = MagicMock() + mock_sq_transform.return_value = MagicMock() + result = build_o365_raw("train", args, resolution=640) + return result, mock_transform, mock_sq_transform + + def test_cpu_backend_no_warning(self): + """cpu backend does not call logger.warning with O365 content.""" + from unittest.mock import patch + + with patch("rfdetr.datasets.o365.logger") as mock_logger: + self._call_build_o365_raw("cpu") + o365_warns = [c for c in mock_logger.warning.call_args_list if "O365" in str(c)] + assert len(o365_warns) == 0, "cpu backend must not warn about O365 GPU augmentation" + + def test_auto_backend_emits_warning(self): + """auto + CUDA + kornia available: logger.warning about O365 Phase 1 limitation.""" + import sys + from unittest.mock import MagicMock, patch + + with ( + patch("torch.cuda.is_available", return_value=True), + patch.dict(sys.modules, {"kornia": MagicMock(), "kornia.augmentation": MagicMock()}), + patch("rfdetr.datasets.o365.logger") as mock_logger, + ): + self._call_build_o365_raw("auto") + o365_warns = [c for c in mock_logger.warning.call_args_list if "O365" in str(c)] + assert len(o365_warns) >= 1, "auto backend must warn about O365 GPU aug limitation" + + def test_auto_backend_no_cuda_no_warning(self): + """auto + no CUDA: resolves to cpu, no O365 warning emitted.""" + from unittest.mock import patch + + with ( + patch("torch.cuda.is_available", return_value=False), + patch("rfdetr.datasets.o365.logger") as mock_logger, + ): + self._call_build_o365_raw("auto") + o365_warns = [c for c in mock_logger.warning.call_args_list if "O365" in str(c)] + assert len(o365_warns) == 0, "auto + no CUDA must not warn about O365 GPU aug" + + def test_gpu_postprocess_false_for_cpu_backend(self): + """cpu backend passes gpu_postprocess=False (or omits it) to make_coco_transforms.""" + _, mock_transform, _ = self._call_build_o365_raw("cpu") + call_kwargs = mock_transform.call_args.kwargs if mock_transform.call_args else {} + assert call_kwargs.get("gpu_postprocess", False) is False + + def test_gpu_postprocess_true_for_auto_backend(self): + """auto + CUDA + kornia available: gpu_postprocess=True passed to make_coco_transforms.""" + import sys + from unittest.mock import MagicMock, patch + + with ( + patch("torch.cuda.is_available", return_value=True), + patch.dict(sys.modules, {"kornia": MagicMock(), "kornia.augmentation": MagicMock()}), + ): + _, mock_transform, _ = self._call_build_o365_raw("auto") + call_kwargs = mock_transform.call_args.kwargs if mock_transform.call_args else {} + assert call_kwargs.get("gpu_postprocess", False) is True + + def test_gpu_postprocess_false_for_auto_no_cuda(self): + """auto + no CUDA: gpu_postprocess=False so CPU Normalize is retained.""" + from unittest.mock import patch + + with patch("torch.cuda.is_available", return_value=False): + _, mock_transform, _ = self._call_build_o365_raw("auto") + call_kwargs = mock_transform.call_args.kwargs if mock_transform.call_args else {} + assert call_kwargs.get("gpu_postprocess", False) is False, "auto + no CUDA must not strip CPU Normalize" + + def test_square_resize_uses_square_transform(self): + """square_resize_div_64=True delegates to make_coco_transforms_square_div_64.""" + _, mock_transform, mock_sq_transform = self._call_build_o365_raw("cpu", square_resize_div_64=True) + mock_sq_transform.assert_called_once() + mock_transform.assert_not_called() diff --git a/tests/datasets/test_kornia_transforms.py b/tests/datasets/test_kornia_transforms.py new file mode 100644 index 00000000..5e17c3e1 --- /dev/null +++ b/tests/datasets/test_kornia_transforms.py @@ -0,0 +1,469 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Tests for Kornia GPU augmentation pipeline builder and bbox utilities. + +All tests in this module are CPU-compatible — Kornia operates on CPU tensors +identically to GPU tensors, so no ``@pytest.mark.gpu`` is needed. +""" + +import pytest +import torch + +from rfdetr.datasets.aug_config import ( + AUG_AERIAL, + AUG_AGGRESSIVE, + AUG_CONSERVATIVE, + AUG_INDUSTRIAL, +) + +# --------------------------------------------------------------------------- +# TestBuildKorniaPipeline — validates the factory that translates aug_config +# dicts into a Kornia AugmentationSequential pipeline. +# --------------------------------------------------------------------------- + + +class TestBuildKorniaPipeline: + """build_kornia_pipeline returns a valid pipeline for every preset and + rejects unknown transform keys with a clear error.""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + @pytest.mark.parametrize( + "config,config_name", + [ + pytest.param(AUG_CONSERVATIVE, "AUG_CONSERVATIVE", id="conservative"), + pytest.param(AUG_AGGRESSIVE, "AUG_AGGRESSIVE", id="aggressive"), + pytest.param(AUG_AERIAL, "AUG_AERIAL", id="aerial"), + pytest.param(AUG_INDUSTRIAL, "AUG_INDUSTRIAL", id="industrial"), + ], + ) + def test_each_preset_config(self, config, config_name): + """Each named preset builds a pipeline without errors.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline(config, 560) + assert pipeline is not None, f"build_kornia_pipeline({config_name}, 560) must return a non-None pipeline" + + def test_unknown_key_raises_value_error(self): + """An unrecognised transform key raises ValueError immediately.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + with pytest.raises(ValueError, match="FooBarTransform"): + build_kornia_pipeline({"FooBarTransform": {"p": 0.5}}, 560) + + def test_empty_config_returns_pipeline(self): + """An empty config dict returns a valid (no-op) pipeline, not None.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline({}, 560) + assert pipeline is not None, "Empty config must still return a pipeline object" + + def test_known_plus_unknown_raises(self): + """Mixing a valid key with an unknown key still raises ValueError.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + mixed = {"HorizontalFlip": {"p": 0.5}, "BogusTransform": {"p": 0.3}} + with pytest.raises(ValueError, match="BogusTransform"): + build_kornia_pipeline(mixed, 560) + + +# --------------------------------------------------------------------------- +# TestCollateBoxes — validates packing of variable-length per-image boxes +# into a zero-padded [B, N_max, 4] tensor with a boolean validity mask. +# --------------------------------------------------------------------------- + + +class TestCollateBoxes: + """collate_boxes packs variable-length boxes into [B, N_max, 4] with mask.""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + def _make_targets(self, box_counts): + """Build a list of target dicts with the given per-image box counts. + + Each box is a valid xyxy rectangle within a 100x100 image. + """ + targets = [] + for n in box_counts: + boxes = ( + torch.tensor([[10.0, 10.0, 50.0, 50.0]] * n, dtype=torch.float32) + if n > 0 + else torch.zeros(0, 4, dtype=torch.float32) + ) + targets.append({"boxes": boxes}) + return targets + + def test_normal_batch(self): + """Batch of 2 images: output shape is [2, N_max, 4] with valid mask [2, N_max].""" + from rfdetr.datasets.kornia_transforms import collate_boxes + + targets = self._make_targets([2, 3]) + boxes_padded, valid = collate_boxes(targets, torch.device("cpu")) + + assert boxes_padded.shape == (2, 3, 4), f"Expected shape (2, 3, 4), got {boxes_padded.shape}" + assert valid.shape == (2, 3), f"Expected valid shape (2, 3), got {valid.shape}" + assert valid.dtype == torch.bool + + def test_b_zero(self): + """Empty target list produces shape [0, 0, 4] and valid [0, 0].""" + from rfdetr.datasets.kornia_transforms import collate_boxes + + boxes_padded, valid = collate_boxes([], torch.device("cpu")) + + assert boxes_padded.shape == (0, 0, 4), f"Expected (0, 0, 4) for empty batch, got {boxes_padded.shape}" + assert valid.shape == (0, 0), f"Expected valid (0, 0) for empty batch, got {valid.shape}" + + def test_n_zero_per_image(self): + """One image with 0 boxes: shape [1, 0, 4], valid all-False.""" + from rfdetr.datasets.kornia_transforms import collate_boxes + + targets = self._make_targets([0]) + boxes_padded, valid = collate_boxes(targets, torch.device("cpu")) + + assert boxes_padded.shape == (1, 0, 4), f"Expected (1, 0, 4), got {boxes_padded.shape}" + assert valid.shape == (1, 0), f"Expected (1, 0), got {valid.shape}" + + def test_single_image(self): + """B=1 with 3 boxes: output shape is [1, 3, 4].""" + from rfdetr.datasets.kornia_transforms import collate_boxes + + targets = self._make_targets([3]) + boxes_padded, valid = collate_boxes(targets, torch.device("cpu")) + + assert boxes_padded.shape == (1, 3, 4) + assert valid.shape == (1, 3) + + def test_valid_mask_matches_box_count(self): + """The valid mask has True for real boxes and False for padding.""" + from rfdetr.datasets.kornia_transforms import collate_boxes + + targets = self._make_targets([1, 3]) + _, valid = collate_boxes(targets, torch.device("cpu")) + + # Image 0: 1 real box, 2 padding → [True, False, False] + assert valid[0].tolist() == [True, False, False], f"Image 0 valid mask wrong: {valid[0].tolist()}" + # Image 1: 3 real boxes, 0 padding → [True, True, True] + assert valid[1].tolist() == [True, True, True], f"Image 1 valid mask wrong: {valid[1].tolist()}" + + +# --------------------------------------------------------------------------- +# TestUnpackBoxes — validates the inverse: writing augmented boxes back into +# per-image target dicts with clamping, zero-area removal, and label sync. +# --------------------------------------------------------------------------- + + +class TestUnpackBoxes: + """unpack_boxes writes augmented boxes back and removes zero-area entries.""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + def _make_inputs(self, boxes_aug, valid_mask, original_targets, H=100, W=100): + """Return tensors suitable for unpack_boxes.""" + boxes_tensor = torch.tensor(boxes_aug, dtype=torch.float32) + valid_tensor = torch.tensor(valid_mask, dtype=torch.bool) + return boxes_tensor, valid_tensor, original_targets, H, W + + def test_all_boxes_removed_after_aug(self): + """When all augmented boxes are zero-area, output targets have empty boxes.""" + from rfdetr.datasets.kornia_transforms import unpack_boxes + + # B=1, N=2: both boxes are zero-area (x1==x2 or y1==y2) + boxes_aug = [[[10.0, 10.0, 10.0, 10.0], [20.0, 20.0, 20.0, 20.0]]] + valid = [[True, True]] + targets = [ + { + "boxes": torch.tensor([[10.0, 10.0, 50.0, 50.0], [20.0, 20.0, 60.0, 60.0]]), + "labels": torch.tensor([1, 2]), + "area": torch.tensor([1600.0, 1600.0]), + "iscrowd": torch.tensor([0, 0]), + } + ] + boxes_t, valid_t, tgts, H, W = self._make_inputs(boxes_aug, valid, targets) + result = unpack_boxes(boxes_t, valid_t, tgts, H, W) + + assert result[0]["boxes"].shape[0] == 0, ( + f"Expected 0 boxes after zero-area removal, got {result[0]['boxes'].shape[0]}" + ) + assert result[0]["labels"].shape[0] == 0 + + def test_partial_removal(self): + """Some boxes survive, some removed; labels/area/iscrowd synced.""" + from rfdetr.datasets.kornia_transforms import unpack_boxes + + # Box 0: valid, non-zero area; Box 1: zero-area + boxes_aug = [[[10.0, 10.0, 50.0, 50.0], [30.0, 30.0, 30.0, 30.0]]] + valid = [[True, True]] + targets = [ + { + "boxes": torch.tensor([[10.0, 10.0, 50.0, 50.0], [30.0, 30.0, 70.0, 70.0]]), + "labels": torch.tensor([1, 2]), + "area": torch.tensor([1600.0, 1600.0]), + "iscrowd": torch.tensor([0, 1]), + } + ] + boxes_t, valid_t, tgts, H, W = self._make_inputs(boxes_aug, valid, targets) + result = unpack_boxes(boxes_t, valid_t, tgts, H, W) + + assert result[0]["boxes"].shape[0] == 1, f"Expected 1 surviving box, got {result[0]['boxes'].shape[0]}" + assert result[0]["labels"].tolist() == [1] + + def test_labels_area_iscrowd_sync(self): + """When boxes are removed, labels/area/iscrowd entries are also removed.""" + from rfdetr.datasets.kornia_transforms import unpack_boxes + + # Box 0: zero-area (removed), Box 1: valid + boxes_aug = [[[5.0, 5.0, 5.0, 5.0], [10.0, 10.0, 40.0, 40.0]]] + valid = [[True, True]] + targets = [ + { + "boxes": torch.tensor([[5.0, 5.0, 30.0, 30.0], [10.0, 10.0, 40.0, 40.0]]), + "labels": torch.tensor([7, 9]), + "area": torch.tensor([625.0, 900.0]), + "iscrowd": torch.tensor([0, 1]), + } + ] + boxes_t, valid_t, tgts, H, W = self._make_inputs(boxes_aug, valid, targets) + result = unpack_boxes(boxes_t, valid_t, tgts, H, W) + + assert result[0]["labels"].tolist() == [9], ( + f"Expected label [9] after removal of box 0, got {result[0]['labels'].tolist()}" + ) + assert result[0]["area"].shape[0] == 1 + assert result[0]["iscrowd"].tolist() == [1] + + def test_boxes_clamped_to_image_bounds(self): + """Boxes outside [0,W]x[0,H] are clamped to image bounds.""" + from rfdetr.datasets.kornia_transforms import unpack_boxes + + # Box extends beyond 100x100 image + boxes_aug = [[[-10.0, -5.0, 120.0, 110.0]]] + valid = [[True]] + targets = [ + { + "boxes": torch.tensor([[0.0, 0.0, 90.0, 90.0]]), + "labels": torch.tensor([1]), + "area": torch.tensor([8100.0]), + "iscrowd": torch.tensor([0]), + } + ] + H, W = 100, 100 + boxes_t, valid_t, tgts, H, W = self._make_inputs(boxes_aug, valid, targets, H, W) + result = unpack_boxes(boxes_t, valid_t, tgts, H, W) + + result_boxes = result[0]["boxes"] + assert result_boxes.shape[0] == 1, "Clamped box should survive (non-zero area)" + # Verify clamping: x1>=0, y1>=0, x2<=W, y2<=H + assert result_boxes[0, 0].item() >= 0.0, "x1 not clamped to >= 0" + assert result_boxes[0, 1].item() >= 0.0, "y1 not clamped to >= 0" + assert result_boxes[0, 2].item() <= W, f"x2 not clamped to <= {W}" + assert result_boxes[0, 3].item() <= H, f"y2 not clamped to <= {H}" + + +# --------------------------------------------------------------------------- +# TestRotateFactory — validates the Rotate parameter translation from +# Albumentations-style limit (scalar or tuple) to Kornia RandomRotation. +# --------------------------------------------------------------------------- + + +class TestRotateFactory: + """Rotate factory translates limit (scalar or tuple) to K.RandomRotation(degrees=...).""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + def test_limit_as_scalar(self): + """Rotate(limit=45) produces K.RandomRotation(degrees=(-45, 45)).""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + # Build a pipeline with just Rotate(limit=45) + pipeline = build_kornia_pipeline({"Rotate": {"limit": 45, "p": 1.0}}, 560) + assert pipeline is not None + + # Inspect the pipeline's children to find the RandomRotation and check degrees + import kornia.augmentation as K + + rotation_augs = [child for child in pipeline.children() if isinstance(child, K.RandomRotation)] + assert len(rotation_augs) == 1, f"Expected exactly 1 RandomRotation, found {len(rotation_augs)}" + degrees = rotation_augs[0].flags["degrees"] + # degrees should be a tensor representing (-45, 45) + assert float(degrees[0]) == pytest.approx(-45.0, abs=0.1) + assert float(degrees[1]) == pytest.approx(45.0, abs=0.1) + + def test_limit_as_tuple(self): + """Rotate(limit=(90, 90)) produces K.RandomRotation(degrees=(90, 90)).""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline({"Rotate": {"limit": (90, 90), "p": 1.0}}, 560) + assert pipeline is not None + + import kornia.augmentation as K + + rotation_augs = [child for child in pipeline.children() if isinstance(child, K.RandomRotation)] + assert len(rotation_augs) == 1 + degrees = rotation_augs[0].flags["degrees"] + assert float(degrees[0]) == pytest.approx(90.0, abs=0.1) + assert float(degrees[1]) == pytest.approx(90.0, abs=0.1) + + +# --------------------------------------------------------------------------- +# TestGpuPostprocessFlag — validates that make_coco_transforms respects the +# gpu_postprocess flag to omit augmentation and normalization from CPU path. +# --------------------------------------------------------------------------- + + +class TestGpuPostprocessFlag: + """gpu_postprocess flag controls whether aug + normalize appear in CPU pipeline.""" + + def test_gpu_postprocess_true_omits_aug_and_normalize_from_train(self): + """gpu_postprocess=True: train pipeline has no Normalize; fewer AlbumentationsWrappers (no aug_wrappers).""" + from rfdetr.datasets.coco import make_coco_transforms + from rfdetr.datasets.transforms import AlbumentationsWrapper, Normalize + + pipeline_gpu = make_coco_transforms("train", 560, gpu_postprocess=True) + pipeline_cpu = make_coco_transforms("train", 560, gpu_postprocess=False) + + steps_gpu = pipeline_gpu.transforms + steps_cpu = pipeline_cpu.transforms + + normalize_gpu = [s for s in steps_gpu if isinstance(s, Normalize)] + assert len(normalize_gpu) == 0, "gpu_postprocess=True must omit Normalize from train pipeline" + + # Resize wrappers (AlbumentationsWrapper) remain; aug wrappers are removed. + # Default AUG_CONFIG adds 1 aug wrapper, so gpu version must have fewer wrappers. + n_alb_gpu = sum(isinstance(s, AlbumentationsWrapper) for s in steps_gpu) + n_alb_cpu = sum(isinstance(s, AlbumentationsWrapper) for s in steps_cpu) + assert n_alb_gpu < n_alb_cpu, "gpu_postprocess=True must remove aug AlbumentationsWrappers from train pipeline" + + def test_gpu_postprocess_false_includes_aug_and_normalize_from_train(self): + """gpu_postprocess=False (default): train pipeline includes Normalize.""" + from rfdetr.datasets.coco import make_coco_transforms + from rfdetr.datasets.transforms import Normalize + + pipeline = make_coco_transforms("train", 560, gpu_postprocess=False) + steps = pipeline.transforms + + normalize_steps = [s for s in steps if isinstance(s, Normalize)] + assert len(normalize_steps) > 0, "gpu_postprocess=False must include Normalize in train pipeline" + + def test_val_path_unaffected_by_gpu_postprocess(self): + """Val pipeline is unchanged regardless of gpu_postprocess value.""" + from rfdetr.datasets.coco import make_coco_transforms + from rfdetr.datasets.transforms import Normalize + + pipeline_default = make_coco_transforms("val", 560, gpu_postprocess=False) + pipeline_gpu = make_coco_transforms("val", 560, gpu_postprocess=True) + + # Both should have Normalize (val is never stripped) + norm_default = [s for s in pipeline_default.transforms if isinstance(s, Normalize)] + norm_gpu = [s for s in pipeline_gpu.transforms if isinstance(s, Normalize)] + + assert len(norm_default) > 0, "Val pipeline (default) must include Normalize" + assert len(norm_gpu) > 0, "Val pipeline (gpu_postprocess=True) must include Normalize" + + # Same number of pipeline steps + assert len(pipeline_default.transforms) == len(pipeline_gpu.transforms), ( + "Val pipeline step count must be identical regardless of gpu_postprocess" + ) + + +# --------------------------------------------------------------------------- +# TestGaussianBlurMinKernel — validates that blur_limit < 3 is clamped so +# Kornia never receives an invalid kernel_size < 3. +# --------------------------------------------------------------------------- + + +class TestGaussianBlurMinKernel: + """_make_gaussian_blur enforces kernel_size >= 3 regardless of blur_limit.""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + @pytest.mark.parametrize( + "blur_limit", + [pytest.param(1, id="blur_limit_1"), pytest.param(2, id="blur_limit_2")], + ) + def test_small_blur_limit_produces_valid_kernel(self, blur_limit): + """blur_limit below 3 must be clamped so the resulting kernel_size >= 3.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + # Should not raise; previously blur_limit=1 produced kernel_size=(3,1) + pipeline = build_kornia_pipeline({"GaussianBlur": {"blur_limit": blur_limit, "p": 1.0}}, 560) + assert pipeline is not None + + import kornia.augmentation as K + + blur_augs = [c for c in pipeline.children() if isinstance(c, K.RandomGaussianBlur)] + assert len(blur_augs) == 1 + ks = blur_augs[0].flags["kernel_size"] + assert int(ks[0]) >= 3, f"kernel_size[0]={int(ks[0])} must be >= 3" + assert int(ks[1]) >= 3, f"kernel_size[1]={int(ks[1])} must be >= 3" + + def test_blur_limit_3_unchanged(self): + """blur_limit=3 (default) passes through without modification.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline({"GaussianBlur": {"blur_limit": 3, "p": 1.0}}, 560) + import kornia.augmentation as K + + blur_augs = [c for c in pipeline.children() if isinstance(c, K.RandomGaussianBlur)] + ks = blur_augs[0].flags["kernel_size"] + assert int(ks[0]) == 3 + assert int(ks[1]) == 3 + + +# --------------------------------------------------------------------------- +# TestKorniaPipelineForwardPass — validates that a built pipeline produces +# output of the correct shape and dtype on CPU tensors. +# --------------------------------------------------------------------------- + + +class TestKorniaPipelineForwardPass: + """build_kornia_pipeline output passes through without shape/dtype errors.""" + + @pytest.fixture(autouse=True) + def _require_kornia(self): + pytest.importorskip("kornia") + + def test_forward_pass_shape_and_dtype(self): + """Pipeline output images have same shape as input; boxes shape is [B, N, 4].""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline({"HorizontalFlip": {"p": 1.0}}, resolution=64) + + B, C, H, W = 2, 3, 64, 64 + img = torch.rand(B, C, H, W) + boxes = torch.tensor([[[0.0, 0.0, 32.0, 32.0]], [[10.0, 10.0, 50.0, 50.0]]], dtype=torch.float32) + + img_out, boxes_out = pipeline(img, boxes) + + assert img_out.shape == (B, C, H, W), f"Image shape changed: {img_out.shape}" + assert img_out.dtype == torch.float32 + assert boxes_out.shape == (B, 1, 4), f"Boxes shape wrong: {boxes_out.shape}" + + def test_forward_pass_empty_boxes(self): + """Pipeline handles a batch where N_max=0 (no boxes) without error.""" + from rfdetr.datasets.kornia_transforms import build_kornia_pipeline + + pipeline = build_kornia_pipeline({"HorizontalFlip": {"p": 1.0}}, resolution=32) + + B, C, H, W = 2, 3, 32, 32 + img = torch.rand(B, C, H, W) + # [B, 0, 4] — no boxes + boxes = torch.zeros(B, 0, 4, dtype=torch.float32) + + img_out, boxes_out = pipeline(img, boxes) + + assert img_out.shape == (B, C, H, W) + assert boxes_out.shape == (B, 0, 4) diff --git a/tests/training/conftest.py b/tests/training/conftest.py index c642bdf6..06788709 100644 --- a/tests/training/conftest.py +++ b/tests/training/conftest.py @@ -117,3 +117,22 @@ def _restore_rfdetr_module_trainer_property(): if "trainer" in RFDETRModelModule.__dict__: delattr(RFDETRModelModule, "trainer") + + +@pytest.fixture(autouse=True) +def _restore_rfdetr_datamodule_trainer_property(): + """Restore RFDETRDataModule.trainer to the LightningDataModule parent property after each test. + + Tests that mock the ``trainer`` property on ``RFDETRDataModule`` (e.g. for + ``on_after_batch_transfer`` tests) patch it at the class level. Without + cleanup this mutates the class for the remainder of the session. + + This fixture deletes any class-level override from ``RFDETRDataModule.__dict__`` + after every test, mirroring the ``_restore_rfdetr_module_trainer_property`` + pattern above. + """ + yield + from rfdetr.training.module_data import RFDETRDataModule + + if "trainer" in RFDETRDataModule.__dict__: + delattr(RFDETRDataModule, "trainer") diff --git a/tests/training/test_module_data.py b/tests/training/test_module_data.py index 545d2538..24b930d5 100644 --- a/tests/training/test_module_data.py +++ b/tests/training/test_module_data.py @@ -779,3 +779,355 @@ def test_preserves_nested_tensor_type(self, build_datamodule): result_samples, _ = dm.transfer_batch_to_device((samples, targets), torch.device("cpu"), dataloader_idx=0) assert isinstance(result_samples, NestedTensor) + + +# --------------------------------------------------------------------------- +# TestBackendResolution — validates augmentation_backend logic in setup("fit") +# --------------------------------------------------------------------------- + + +class TestBackendResolution: + """Backend resolution selects Kornia, CPU, or raises depending on environment. + + All tests run on CPU CI by mocking ``torch.cuda.is_available`` and the + ``kornia`` import as needed. + """ + + def _build_dm_with_backend(self, tmp_path, augmentation_backend="cpu"): + """Construct a DataModule with the given augmentation_backend.""" + mc = _base_model_config() + tc = _base_train_config(tmp_path, augmentation_backend=augmentation_backend) + from rfdetr.training.module_data import RFDETRDataModule + + return RFDETRDataModule(mc, tc) + + def _setup_with_mock_build(self, dm): + """Call setup('fit') with build_dataset mocked to avoid real I/O.""" + fake_train = _fake_dataset(100) + fake_val = _fake_dataset(20) + + def _build(image_set, args, resolution): + return fake_train if image_set == "train" else fake_val + + with patch("rfdetr.training.module_data.build_dataset", side_effect=_build): + dm.setup("fit") + return dm + + def test_auto_no_cuda_falls_back_to_cpu(self, tmp_path): + """auto + no CUDA: _kornia_pipeline stays None, no error.""" + dm = self._build_dm_with_backend(tmp_path, "auto") + with patch("torch.cuda.is_available", return_value=False): + dm = self._setup_with_mock_build(dm) + assert getattr(dm, "_kornia_pipeline", None) is None, ( + "auto backend with no CUDA must not build a Kornia pipeline" + ) + + def test_auto_no_kornia_falls_back_to_cpu(self, tmp_path): + """auto + CUDA available but kornia not installed: fallback to CPU.""" + dm = self._build_dm_with_backend(tmp_path, "auto") + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + + def _mock_import(name, *args, **kwargs): + if name == "kornia" or name.startswith("kornia."): + raise ImportError("No module named 'kornia'") + return original_import(name, *args, **kwargs) + + with ( + patch("torch.cuda.is_available", return_value=True), + patch("builtins.__import__", side_effect=_mock_import), + ): + dm = self._setup_with_mock_build(dm) + + assert getattr(dm, "_kornia_pipeline", None) is None, ( + "auto backend with kornia missing must fall back to CPU (pipeline=None)" + ) + + def test_gpu_no_cuda_raises_runtime_error(self, tmp_path): + """gpu + no CUDA: must raise RuntimeError.""" + dm = self._build_dm_with_backend(tmp_path, "gpu") + with ( + patch("torch.cuda.is_available", return_value=False), + pytest.raises(RuntimeError, match="CUDA"), + ): + self._setup_with_mock_build(dm) + + def test_gpu_no_kornia_raises_import_error(self, tmp_path): + """gpu + CUDA but no kornia: must raise ImportError with install hint.""" + dm = self._build_dm_with_backend(tmp_path, "gpu") + + original_import = __builtins__.__import__ if hasattr(__builtins__, "__import__") else __import__ + + def _mock_import(name, *args, **kwargs): + if name == "kornia" or name.startswith("kornia."): + raise ImportError("No module named 'kornia'") + return original_import(name, *args, **kwargs) + + with ( + patch("torch.cuda.is_available", return_value=True), + patch("builtins.__import__", side_effect=_mock_import), + pytest.raises(ImportError, match="rfdetr\\[kornia\\]"), + ): + self._setup_with_mock_build(dm) + + def test_cpu_backend_builds_no_pipeline(self, tmp_path): + """Default cpu backend: _kornia_pipeline stays None.""" + dm = self._build_dm_with_backend(tmp_path, "cpu") + dm = self._setup_with_mock_build(dm) + assert getattr(dm, "_kornia_pipeline", None) is None, "cpu backend must never build a Kornia pipeline" + + def test_gpu_path_uses_aug_config_fallback(self, tmp_path): + """When aug_config=None (default), GPU path passes AUG_CONFIG to build_kornia_pipeline.""" + import sys + from unittest.mock import MagicMock, patch + + from rfdetr.datasets.aug_config import AUG_CONFIG + + dm = self._build_dm_with_backend(tmp_path, "auto") + assert dm.train_config.aug_config is None, "precondition: aug_config must be None for this test" + + captured = {} + + def _fake_build_kornia(aug_cfg, resolution): + captured["aug_config"] = aug_cfg + return MagicMock() + + with ( + patch("torch.cuda.is_available", return_value=True), + patch("rfdetr.training.module_data.build_dataset", side_effect=lambda *a, **k: _fake_dataset(10)), + patch.dict(sys.modules, {"kornia": MagicMock(), "kornia.augmentation": MagicMock()}), + patch("rfdetr.datasets.kornia_transforms.build_kornia_pipeline", side_effect=_fake_build_kornia), + patch("rfdetr.datasets.kornia_transforms.build_normalize", return_value=MagicMock()), + ): + dm.setup("fit") + + assert captured.get("aug_config") is AUG_CONFIG, ( + "GPU path must fall back to AUG_CONFIG when train_config.aug_config is None" + ) + + def test_auto_no_cuda_does_not_strip_cpu_normalize(self, tmp_path): + """auto + no CUDA: gpu_postprocess must be False so CPU Normalize is retained.""" + dm = self._build_dm_with_backend(tmp_path, "auto") + captured_gpu_postprocess = {} + + def _spy_build(image_set, args, resolution): + captured_gpu_postprocess[image_set] = getattr(args, "augmentation_backend", "cpu") + return _fake_dataset(10) + + with ( + patch("torch.cuda.is_available", return_value=False), + patch("rfdetr.training.module_data.build_dataset", side_effect=_spy_build), + ): + dm.setup("fit") + + # When CUDA is unavailable, resolved backend must be 'cpu' so datasets are + # built with gpu_postprocess=False and CPU Normalize is not stripped. + assert captured_gpu_postprocess.get("train") == "cpu", ( + "auto + no CUDA must resolve to cpu before dataset build to preserve CPU Normalize" + ) + + def test_resolve_augmentation_backend_auto_no_cuda(self): + """_resolve_augmentation_backend returns 'cpu' for auto when CUDA is absent.""" + from rfdetr.training.module_data import _resolve_augmentation_backend + + with patch("torch.cuda.is_available", return_value=False): + assert _resolve_augmentation_backend("auto") == "cpu" + + def test_resolve_augmentation_backend_cpu_passthrough(self): + """_resolve_augmentation_backend passes 'cpu' through unchanged.""" + from rfdetr.training.module_data import _resolve_augmentation_backend + + assert _resolve_augmentation_backend("cpu") == "cpu" + + def test_resolve_augmentation_backend_gpu_passthrough(self): + """_resolve_augmentation_backend passes 'gpu' through unchanged.""" + from rfdetr.training.module_data import _resolve_augmentation_backend + + assert _resolve_augmentation_backend("gpu") == "gpu" + + +# --------------------------------------------------------------------------- +# TestOnAfterBatchTransfer — validates GPU-side augmentation hook +# --------------------------------------------------------------------------- + + +class TestOnAfterBatchTransfer: + """on_after_batch_transfer applies Kornia augmentation only during training. + + Uses CPU tensors with a mocked pipeline — no real GPU or Kornia needed. + """ + + def _build_dm(self, tmp_path, segmentation_head=False): + """Construct a DataModule for on_after_batch_transfer tests.""" + mc = _base_model_config(segmentation_head=segmentation_head) + tc = _base_train_config(tmp_path) + from rfdetr.training.module_data import RFDETRDataModule + + return RFDETRDataModule(mc, tc) + + def _attach_mock_trainer(self, dm, training=True): + """Attach a mock trainer with the given training state to the DataModule.""" + mock_trainer = MagicMock(training=training) + type(dm).trainer = property(lambda self: mock_trainer) + return dm + + def _make_kornia_batch(self, batch_size=2, h=16, w=16): + """Build a batch with xyxy boxes suitable for on_after_batch_transfer. + + Returns (NestedTensor, targets) where boxes are in absolute xyxy format + and pixel values are in [0, 1] (pre-normalization). + """ + tensors = torch.rand(batch_size, 3, h, w) # [0, 1] range + mask = torch.zeros(batch_size, h, w, dtype=torch.bool) + samples = NestedTensor(tensors, mask) + targets = [ + { + "boxes": torch.tensor([[2.0, 2.0, 10.0, 10.0]], dtype=torch.float32), + "labels": torch.tensor([1]), + "area": torch.tensor([64.0]), + "iscrowd": torch.tensor([0]), + "image_id": torch.tensor(i), + "orig_size": torch.tensor([h, w]), + } + for i in range(batch_size) + ] + return samples, targets + + def test_training_true_applies_augmentation(self, tmp_path): + """When training=True and _kornia_pipeline is set, augmentation is applied.""" + dm = self._build_dm(tmp_path) + dm = self._attach_mock_trainer(dm, training=True) + + samples, targets = self._make_kornia_batch() + img_aug = samples.tensors.clone() + # Mock pipeline returns (augmented_images, augmented_boxes) + boxes_padded = torch.tensor([[[2.0, 2.0, 10.0, 10.0]]] * 2) + mock_pipeline = MagicMock(return_value=(img_aug, boxes_padded)) + dm._kornia_pipeline = mock_pipeline + + # Mock normalize to be a passthrough + dm._kornia_normalize = MagicMock(side_effect=lambda x: x) + + result = dm.on_after_batch_transfer((samples, targets), dataloader_idx=0) + + mock_pipeline.assert_called_once() + dm._kornia_normalize.assert_called_once() + assert isinstance(result, tuple) + assert len(result) == 2 + + def test_training_false_skips_augmentation(self, tmp_path): + """When training=False, batch is returned unchanged.""" + dm = self._build_dm(tmp_path) + dm = self._attach_mock_trainer(dm, training=False) + + samples, targets = self._make_kornia_batch() + mock_pipeline = MagicMock() + dm._kornia_pipeline = mock_pipeline + dm._kornia_normalize = MagicMock() + + result = dm.on_after_batch_transfer((samples, targets), dataloader_idx=0) + + mock_pipeline.assert_not_called() + # Batch returned as-is + result_samples, result_targets = result + assert result_samples is samples + assert result_targets is targets + + def test_segmentation_model_skips_augmentation(self, tmp_path): + """When segmentation_head=True, pipeline is not called even during training.""" + dm = self._build_dm(tmp_path, segmentation_head=True) + dm = self._attach_mock_trainer(dm, training=True) + + samples, targets = self._make_kornia_batch() + mock_pipeline = MagicMock() + dm._kornia_pipeline = mock_pipeline + dm._kornia_normalize = MagicMock() + + dm.on_after_batch_transfer((samples, targets), dataloader_idx=0) + + mock_pipeline.assert_not_called() + + def test_returns_nested_tensor_in_batch(self, tmp_path): + """Output batch still has NestedTensor as first element after augmentation.""" + dm = self._build_dm(tmp_path) + dm = self._attach_mock_trainer(dm, training=True) + + samples, targets = self._make_kornia_batch() + img_aug = samples.tensors.clone() + boxes_padded = torch.tensor([[[2.0, 2.0, 10.0, 10.0]]] * 2) + dm._kornia_pipeline = MagicMock(return_value=(img_aug, boxes_padded)) + dm._kornia_normalize = MagicMock(side_effect=lambda x: x) + + result_samples, _ = dm.on_after_batch_transfer((samples, targets), dataloader_idx=0) + + assert isinstance(result_samples, NestedTensor), f"Expected NestedTensor, got {type(result_samples).__name__}" + + +# --------------------------------------------------------------------------- +# TestKorniaSetupDoneSentinel — validates the _kornia_setup_done guard +# --------------------------------------------------------------------------- + + +class TestKorniaSetupDoneSentinel: + """_kornia_setup_done prevents _setup_kornia_pipeline re-running on repeated setup('fit') calls.""" + + def _build_dm(self, tmp_path, augmentation_backend="auto"): + mc = _base_model_config() + tc = _base_train_config(tmp_path, augmentation_backend=augmentation_backend) + from rfdetr.training.module_data import RFDETRDataModule + + return RFDETRDataModule(mc, tc) + + def _setup_fit_with_mocks(self, dm): + """Call setup('fit') with build_dataset and cuda mocked (no CUDA → fallback).""" + fake_train = _fake_dataset(100) + fake_val = _fake_dataset(20) + + def _build(image_set, args, resolution): + return fake_train if image_set == "train" else fake_val + + with ( + patch("rfdetr.training.module_data.build_dataset", side_effect=_build), + patch("torch.cuda.is_available", return_value=False), + ): + dm.setup("fit") + return dm + + def test_sentinel_starts_false(self, tmp_path): + """_kornia_setup_done is False immediately after __init__.""" + dm = self._build_dm(tmp_path) + assert dm._kornia_setup_done is False + + def test_sentinel_set_after_fit(self, tmp_path): + """_kornia_setup_done becomes True after the first setup('fit').""" + dm = self._build_dm(tmp_path) + dm = self._setup_fit_with_mocks(dm) + assert dm._kornia_setup_done is True + + def test_setup_kornia_pipeline_not_called_twice(self, tmp_path): + """Calling setup('fit') twice only calls _setup_kornia_pipeline once.""" + dm = self._build_dm(tmp_path) + call_count = 0 + original_setup = dm._setup_kornia_pipeline + + def _counting_setup(): + nonlocal call_count + call_count += 1 + original_setup() + + dm._setup_kornia_pipeline = _counting_setup + + fake_train = _fake_dataset(100) + fake_val = _fake_dataset(20) + + def _build(image_set, args, resolution): + return fake_train if image_set == "train" else fake_val + + with ( + patch("rfdetr.training.module_data.build_dataset", side_effect=_build), + patch("torch.cuda.is_available", return_value=False), + ): + dm.setup("fit") + dm.setup("fit") + + assert call_count == 1, f"_setup_kornia_pipeline called {call_count} times; expected exactly 1"