Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `augmentation_backend` field on `TrainConfig` (`"cpu"` / `"auto"` / `"gpu"`): opt-in GPU-side augmentation via [Kornia](https://kornia.readthedocs.io) applied in `RFDETRDataModule.on_after_batch_transfer` after the batch is resident on the GPU. CPU path is unchanged and remains the default. Install with `pip install 'rfdetr[kornia]'`. Phase 1 supports detection only; segmentation mask support is planned for Phase 2.
- `RFDETR.predict(shape=...)` — optional `(height, width)` tuple overrides the default inference resolution; useful for matching the resolution used when exporting the model. Both dimensions must be positive integers divisible by 14. (closes #682)
- `BuilderArgs` — a `@runtime_checkable` `typing.Protocol` documenting the minimum attribute set consumed by `build_model()`, `build_backbone()`, `build_transformer()`, and `build_criterion_and_postprocessors()`. Enables static type-checker support for custom builder integrations. Exported from `rfdetr.models`.
- `build_model_from_config(model_config, train_config=None, defaults=MODEL_DEFAULTS)` — config-native alternative to `build_model(build_namespace(mc, tc))`; accepts Pydantic config objects directly and constructs the internal namespace automatically. Exported from `rfdetr.models`.
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,9 @@ trt = [
"tensorrt>=8.6.1",
"polygraphy",
]
kornia = [
"kornia>=0.7,<1", # GPU-side augmentation via on_after_batch_transfer
]
loggers = [
"tensorboard>=2.13.0",
"protobuf>=3.20.0,<4.0.0", # Pins protobuf below 4.x to avoid TensorBoard descriptor crash with protobuf>=4 (see #844)
Expand Down
1 change: 1 addition & 0 deletions src/rfdetr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -410,6 +410,7 @@ class TrainConfig(BaseModel):
eval_interval: int = 1
log_per_class_metrics: bool = True
aug_config: Optional[Dict[str, Any]] = None
augmentation_backend: Literal["cpu", "auto", "gpu"] = "cpu"

@model_validator(mode="after")
def _warn_deprecated_train_config_fields(self) -> "TrainConfig":
Expand Down
21 changes: 21 additions & 0 deletions src/rfdetr/datasets/aug_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,27 @@
"YourCustomTransform", # Add here
}
```
## Kornia GPU Backend
When ``augmentation_backend="auto"`` or ``"gpu"`` is set in ``TrainConfig``, augmentations
run on the GPU via Kornia instead of Albumentations.
**Supported transforms** (all presets):
| Preset key | Kornia equivalent | Notes |
|---|---|---|
| ``HorizontalFlip`` | ``K.RandomHorizontalFlip`` | Direct |
| ``VerticalFlip`` | ``K.RandomVerticalFlip`` | Direct |
| ``Rotate`` | ``K.RandomRotation`` | ``limit`` may be scalar or tuple |
| ``Affine`` | ``K.RandomAffine`` | ``translate_percent`` treated as fraction |
| ``ColorJitter`` | ``K.ColorJiggle`` | Same multiplicative semantics |
| ``RandomBrightnessContrast`` | ``K.ColorJiggle`` | ``brightness_limit`` / ``contrast_limit`` direct |
| ``GaussianBlur`` | ``K.RandomGaussianBlur`` | ``blur_limit`` rounded up to odd; ``sigma=(0.1, 2.0)`` |
| ``GaussNoise`` | ``K.RandomGaussianNoise`` | Upper bound of ``std_range`` used as fixed std |
**Phase 1 limitation**: Segmentation models (``segmentation_head=True``) skip GPU augmentation;
CPU Albumentations are used instead. Mask support is planned for Phase 2.
"""

# ---------------------------------------------------------------------------
Expand Down
28 changes: 24 additions & 4 deletions src/rfdetr/datasets/coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,7 @@ def make_coco_transforms(
patch_size: int = 16,
num_windows: int = 4,
aug_config: Optional[Dict[str, Dict[str, Any]]] = None,
gpu_postprocess: bool = False,
) -> Compose:
"""Build the standard COCO transform pipeline for a given dataset split.

Expand Down Expand Up @@ -398,8 +399,14 @@ def make_coco_transforms(
resize_wrappers = AlbumentationsWrapper.from_config(
_build_train_resize_config(scales, square=False, max_size=1333)
)
aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config)
return Compose([*resize_wrappers, *aug_wrappers, to_image, to_float, normalize])
pipeline = [*resize_wrappers]
if not gpu_postprocess:
aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config)
pipeline += [*aug_wrappers]
pipeline += [to_image, to_float]
if not gpu_postprocess:
pipeline += [normalize]
return Compose(pipeline)

if image_set in ("val", "test"):
resize_wrappers = AlbumentationsWrapper.from_config(
Expand All @@ -425,6 +432,7 @@ def make_coco_transforms_square_div_64(
patch_size: int = 16,
num_windows: int = 4,
aug_config: Optional[Dict[str, Dict[str, Any]]] = None,
gpu_postprocess: bool = False,
) -> Compose:
"""
Create COCO transforms with square resizing where the output size is divisible by 64.
Expand Down Expand Up @@ -474,8 +482,14 @@ def make_coco_transforms_square_div_64(
if image_set == "train":
resolved_aug_config = aug_config if aug_config is not None else AUG_CONFIG
resize_wrappers = AlbumentationsWrapper.from_config(_build_train_resize_config(scales, square=True))
aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config)
return Compose([*resize_wrappers, *aug_wrappers, to_image, to_float, normalize])
pipeline = [*resize_wrappers]
if not gpu_postprocess:
aug_wrappers = AlbumentationsWrapper.from_config(resolved_aug_config)
pipeline += [*aug_wrappers]
pipeline += [to_image, to_float]
if not gpu_postprocess:
pipeline += [normalize]
return Compose(pipeline)

if image_set in ("val", "test", "val_speed"):
resize_wrappers = AlbumentationsWrapper.from_config([{"Resize": {"height": resolution, "width": resolution}}])
Expand All @@ -502,6 +516,7 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection:
square_resize_div_64 = getattr(args, "square_resize_div_64", False)
include_masks = getattr(args, "segmentation_head", False)
aug_config = getattr(args, "aug_config", None)
gpu_postprocess = getattr(args, "augmentation_backend", "cpu") != "cpu"

if square_resize_div_64:
logger.info(f"Building COCO {image_set} dataset with square resize at resolution {resolution}")
Expand All @@ -517,6 +532,7 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection:
patch_size=args.patch_size,
num_windows=args.num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
)
Expand All @@ -534,6 +550,7 @@ def build_coco(image_set: str, args: Any, resolution: int) -> CocoDetection:
patch_size=args.patch_size,
num_windows=args.num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
)
Expand Down Expand Up @@ -566,6 +583,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco
patch_size = getattr(args, "patch_size", 16)
num_windows = getattr(args, "num_windows", 4)
aug_config = getattr(args, "aug_config", None)
gpu_postprocess = getattr(args, "augmentation_backend", "cpu") != "cpu"

if square_resize_div_64:
logger.info(f"Building Roboflow {image_set} dataset with square resize at resolution {resolution}")
Expand All @@ -581,6 +599,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco
patch_size=patch_size,
num_windows=num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
remap_category_ids=True,
Expand All @@ -599,6 +618,7 @@ def build_roboflow_from_coco(image_set: str, args: Any, resolution: int) -> Coco
patch_size=patch_size,
num_windows=num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
remap_category_ids=True,
Expand Down
3 changes: 3 additions & 0 deletions src/rfdetr/datasets/yolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo
patch_size = getattr(args, "patch_size", None)
num_windows = getattr(args, "num_windows", None)
aug_config = getattr(args, "aug_config", None)
gpu_postprocess = getattr(args, "augmentation_backend", "cpu") != "cpu"

if square_resize_div_64:
dataset = YoloDetection(
Expand All @@ -861,6 +862,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo
patch_size=patch_size,
num_windows=num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
)
Expand All @@ -878,6 +880,7 @@ def build_roboflow_from_yolo(image_set: str, args: Any, resolution: int) -> Yolo
patch_size=patch_size,
num_windows=num_windows,
aug_config=aug_config,
gpu_postprocess=gpu_postprocess,
),
include_masks=include_masks,
)
Expand Down
80 changes: 79 additions & 1 deletion src/rfdetr/training/module_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

"""LightningDataModule for RF-DETR dataset construction and loaders."""

from typing import List, Optional, Tuple
from typing import Any, List, Optional, Tuple

import torch
import torch.utils.data
Expand Down Expand Up @@ -41,6 +41,10 @@ def __init__(self, model_config: ModelConfig, train_config: TrainConfig) -> None
self._dataset_val: Optional[torch.utils.data.Dataset] = None
self._dataset_test: Optional[torch.utils.data.Dataset] = None

# GPU augmentation pipeline (Kornia); built lazily in setup("fit").
self._kornia_pipeline: Optional[Any] = None
self._kornia_normalize: Optional[Any] = None

num_workers = self.train_config.num_workers
self._pin_memory: bool = (
torch.cuda.is_available() if self.train_config.pin_memory is None else bool(self.train_config.pin_memory)
Expand Down Expand Up @@ -79,6 +83,9 @@ def setup(self, stage: str) -> None:
self._dataset_train = build_dataset("train", ns, resolution)
if self._dataset_val is None:
self._dataset_val = build_dataset("val", ns, resolution)
# Build Kornia GPU augmentation pipeline (once).
if self._kornia_pipeline is None:
self._setup_kornia_pipeline()
elif stage == "validate":
if self._dataset_val is None:
self._dataset_val = build_dataset("val", ns, resolution)
Expand Down Expand Up @@ -194,6 +201,77 @@ def predict_dataloader(self) -> DataLoader:
prefetch_factor=self._prefetch_factor,
)

def _setup_kornia_pipeline(self) -> None:
"""Resolve augmentation backend and build the Kornia pipeline if applicable.

Called once during ``setup("fit")``. When ``augmentation_backend``
is ``"cpu"`` this is a no-op. For ``"auto"`` the method falls back
silently when CUDA or Kornia are unavailable. For ``"gpu"`` missing
requirements raise hard errors.
"""
backend = self.train_config.augmentation_backend
if backend == "cpu":
return

if backend == "auto":
if not torch.cuda.is_available():
logger.info("augmentation_backend='auto': no CUDA, using CPU augmentation")
return
try:
import kornia.augmentation
except ImportError:
logger.warning("augmentation_backend='auto': kornia not installed, using CPU augmentation")
return
elif backend == "gpu":
if not torch.cuda.is_available():
raise RuntimeError("augmentation_backend='gpu' requires a CUDA device")
try:
import kornia.augmentation # noqa: F401
except ImportError as e:
raise ImportError("GPU augmentation requires kornia. Install with: pip install 'rfdetr[kornia]'") from e

from rfdetr.datasets.kornia_transforms import build_kornia_pipeline, build_normalize

self._kornia_pipeline = build_kornia_pipeline(
self.train_config.aug_config or {},
self.model_config.resolution,
)
self._kornia_normalize = build_normalize()
logger.info("Kornia GPU augmentation pipeline built (backend=%s)", backend)

def on_after_batch_transfer(self, batch: Tuple, dataloader_idx: int) -> Tuple:
"""Apply Kornia GPU augmentation after the batch is transferred to device.

When ``_kornia_pipeline`` is set and the trainer is in training mode,
augmentation and normalization are applied on the GPU. Validation
and test batches pass through unchanged.

Segmentation models skip GPU augmentation in phase 1 with a warning.

Args:
batch: Tuple of ``(NestedTensor, list[dict])`` already on device.
dataloader_idx: Index of the current dataloader.

Returns:
The (possibly augmented) batch.
"""
if self.trainer.training and self._kornia_pipeline is not None:
if self.model_config.segmentation_head:
logger.warning_once("Kornia GPU augmentation skipped for segmentation models (phase 2)")
return batch

from rfdetr.datasets.kornia_transforms import collate_boxes, unpack_boxes
from rfdetr.utilities.tensors import NestedTensor

samples, targets = batch
img = samples.tensors # [B, C, H, W]
boxes_padded, valid = collate_boxes(targets, img.device)
img_aug, boxes_aug = self._kornia_pipeline(img, boxes_padded)
img_aug = self._kornia_normalize(img_aug)
targets = unpack_boxes(boxes_aug, valid, targets, *img_aug.shape[-2:])
batch = (NestedTensor(img_aug, samples.mask), targets)
return batch

# ------------------------------------------------------------------
# Properties
# ------------------------------------------------------------------
Expand Down
19 changes: 19 additions & 0 deletions tests/training/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,3 +117,22 @@ def _restore_rfdetr_module_trainer_property():

if "trainer" in RFDETRModelModule.__dict__:
delattr(RFDETRModelModule, "trainer")


@pytest.fixture(autouse=True)
def _restore_rfdetr_datamodule_trainer_property():
"""Restore RFDETRDataModule.trainer to the LightningModule parent property after each test.

Tests that mock the ``trainer`` property on ``RFDETRDataModule`` (e.g. for
``on_after_batch_transfer`` tests) patch it at the class level. Without
cleanup this mutates the class for the remainder of the session.

This fixture deletes any class-level override from ``RFDETRDataModule.__dict__``
after every test, mirroring the ``_restore_rfdetr_module_trainer_property``
pattern above.
"""
yield
from rfdetr.training.module_data import RFDETRDataModule

if "trainer" in RFDETRDataModule.__dict__:
delattr(RFDETRDataModule, "trainer")
Loading
Loading