diff --git a/pyproject.toml b/pyproject.toml index 89eb66abc..b9e6a3eb7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -187,6 +187,7 @@ filterwarnings = [ [tool.codespell] skip = "*.pth" +ignore-words-list = "dota" [tool.mypy] python_version = "3.10" @@ -222,6 +223,7 @@ overrides = [ "rfdetr.config", "rfdetr.datasets._develop", "rfdetr.datasets.coco", + "rfdetr.datasets.dota_detection", "rfdetr.datasets.save_grids", "rfdetr.datasets.synthetic", "rfdetr.datasets.transforms", diff --git a/src/rfdetr/_namespace.py b/src/rfdetr/_namespace.py index a149def99..2d95826ee 100644 --- a/src/rfdetr/_namespace.py +++ b/src/rfdetr/_namespace.py @@ -41,6 +41,7 @@ "num_queries", "num_select", "num_windows", + "oriented", "out_feature_indexes", "patch_size", "positional_encoding_size", diff --git a/src/rfdetr/config.py b/src/rfdetr/config.py index 8ba2e8ee9..87fb33702 100644 --- a/src/rfdetr/config.py +++ b/src/rfdetr/config.py @@ -80,6 +80,7 @@ class ModelConfig(BaseConfig): ia_bce_loss: bool = True cls_loss_coef: float = 1.0 segmentation_head: bool = False + oriented: bool = False mask_downsample_ratio: int = 4 backbone_lora: bool = False freeze_encoder: bool = False @@ -390,7 +391,7 @@ class TrainConfig(BaseModel): ia_bce_loss: bool = True cls_loss_coef: float = 1.0 num_select: int = 300 - dataset_file: Literal["coco", "o365", "roboflow", "yolo"] = "roboflow" + dataset_file: Literal["coco", "o365", "roboflow", "yolo", "dota"] = "roboflow" square_resize_div_64: bool = True dataset_dir: str output_dir: str = "output" diff --git a/src/rfdetr/datasets/__init__.py b/src/rfdetr/datasets/__init__.py index 2ece8d149..bf5e52b0a 100644 --- a/src/rfdetr/datasets/__init__.py +++ b/src/rfdetr/datasets/__init__.py @@ -21,6 +21,7 @@ from torch.utils.data import Dataset, Subset from rfdetr.datasets.coco import build_coco, build_roboflow_from_coco +from rfdetr.datasets.dota_detection import DotaDetection, build_dota from rfdetr.datasets.o365 import build_o365 from rfdetr.datasets.yolo import YoloDetection, build_roboflow_from_yolo @@ -33,6 +34,8 @@ def get_coco_api_from_dataset(dataset: Dataset[Any]) -> Optional[Any]: return dataset.coco if isinstance(dataset, YoloDetection): return dataset.coco + if isinstance(dataset, DotaDetection): + return None return None @@ -92,4 +95,6 @@ def build_dataset(image_set: str, args: Any, resolution: int) -> Dataset[Any]: return build_roboflow(image_set, args, resolution) if args.dataset_file == "yolo": return build_roboflow_from_yolo(image_set, args, resolution) + if args.dataset_file == "dota": + return build_dota(image_set, args, resolution) raise ValueError(f"dataset {args.dataset_file} not supported") diff --git a/src/rfdetr/datasets/dota_detection.py b/src/rfdetr/datasets/dota_detection.py new file mode 100644 index 000000000..f49b24cef --- /dev/null +++ b/src/rfdetr/datasets/dota_detection.py @@ -0,0 +1,277 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""DOTA v1.0 dataset loader for oriented object detection.""" + +from pathlib import Path +from typing import Any + +import torch +from PIL import Image +from torch.utils.data import Dataset +from torchvision.transforms.v2 import Compose, ToDtype, ToImage + +from rfdetr.datasets.transforms import AlbumentationsWrapper +from rfdetr.utilities.logger import get_logger +from rfdetr.utilities.rotated_box_ops import corners_to_cxcywha + +logger = get_logger() + +DOTA_V1_CLASSES = ( + "baseball-diamond", + "basketball-court", + "bridge", + "ground-track-field", + "harbor", + "helicopter", + "large-vehicle", + "plane", + "roundabout", + "ship", + "small-vehicle", + "soccer-ball-field", + "storage-tank", + "swimming-pool", + "tennis-court", +) + + +def parse_dota_annotation(ann_path: Path) -> list[dict[str, Any]]: + """Parse a DOTA annotation text file. + + Each line after the optional header has the format: + ``x1 y1 x2 y2 x3 y3 x4 y4 category difficulty`` + + Args: + ann_path: Path to the annotation ``.txt`` file. + + Returns: + List of annotation dicts with keys ``corners`` (8 floats) + ``category`` (str), and ``difficulty`` (int). + """ + annotations: list[dict[str, Any]] = [] + with open(ann_path) as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split() + if len(parts) < 9: + continue + try: + coords = [float(parts[i]) for i in range(8)] + except ValueError: + continue + category = parts[8] + try: + difficulty = int(parts[9]) if len(parts) > 9 else 0 + except ValueError: + difficulty = 0 + annotations.append( + { + "corners": coords, + "category": category, + "difficulty": difficulty, + } + ) + return annotations + + +def corners_list_to_tensor(corners: list[float]) -> torch.Tensor: + """Convert flat 8-element corner list to ``(4, 2)`` tensor. + + Args: + corners: Flat list ``[x1, y1, x2, y2, x3, y3, x4, y4]``. + + Returns: + Tensor of shape ``(4, 2)``. + """ + return torch.tensor(corners, dtype=torch.float32).reshape(4, 2) + + +class DotaDetection(Dataset): + """DOTA v1.0 dataset for oriented object detection. + + Expects the standard DOTA directory layout:: + + root/ + images/ + P0001.png + P0002.png + ... + labelTxt/ + P0001.txt + P0002.txt + ... + + Each annotation file contains one object per line with 4-corner polygon + coordinates, a category name, and a difficulty flag. + + Args: + root: Path to the split directory (e.g. ``dota/train``). + transforms: Transform pipeline applied to ``(image, target)`` pairs. + class_names: Ordered tuple of class names. Defaults to DOTA v1.0 classes. + include_difficult: If ``True``, include objects marked as difficult. + """ + + def __init__( + self, + root: str | Path, + transforms: Compose | None = None, + class_names: tuple[str, ...] = DOTA_V1_CLASSES, + include_difficult: bool = False, + ) -> None: + self.root = Path(root) + self._transforms = transforms + self.class_names = class_names + self.class_to_idx = {name: i for i, name in enumerate(class_names)} + self.include_difficult = include_difficult + + self.images_dir = self.root / "images" + self.labels_dir = self.root / "labelTxt" + + if not self.images_dir.exists(): + raise FileNotFoundError(f"Images directory not found: {self.images_dir}") + if not self.labels_dir.exists(): + raise FileNotFoundError(f"Labels directory not found: {self.labels_dir}") + + self.image_files = sorted( + p for p in self.images_dir.iterdir() if p.suffix.lower() in (".png", ".jpg", ".jpeg", ".bmp", ".tif") + ) + logger.info(f"DOTA dataset loaded: {len(self.image_files)} images, {len(class_names)} classes") + + def __len__(self) -> int: + return len(self.image_files) + + def __getitem__(self, idx: int) -> tuple[Any, dict[str, Any]]: + img_path = self.image_files[idx] + ann_path = self.labels_dir / f"{img_path.stem}.txt" + + image = Image.open(img_path).convert("RGB") + + annotations = parse_dota_annotation(ann_path) if ann_path.exists() else [] + + corners_list = [] + labels = [] + for ann in annotations: + if not self.include_difficult and ann["difficulty"] == 1: + continue + cat = ann["category"] + if cat not in self.class_to_idx: + continue + corners_list.append(corners_list_to_tensor(ann["corners"])) + labels.append(self.class_to_idx[cat]) + + if corners_list: + all_corners = torch.stack(corners_list) + boxes_obb = corners_to_cxcywha(all_corners) + else: + all_corners = torch.zeros((0, 4, 2), dtype=torch.float32) + boxes_obb = torch.zeros((0, 5), dtype=torch.float32) + + target: dict[str, Any] = { + "boxes_obb": boxes_obb, + "corners": all_corners, + "labels": torch.tensor(labels, dtype=torch.int64), + "image_id": torch.tensor([idx]), + } + + if self._transforms is not None: + image, target = self._transforms(image, target) + + return image, target + + +def make_dota_transforms( + image_set: str, + resolution: int, +) -> Compose: + """Build transform pipeline for DOTA dataset. + + Args: + image_set: Split identifier — ``"train"`` or ``"val"``. + resolution: Target square resolution in pixels. + + Returns: + Composed transform pipeline. + """ + to_image = ToImage() + to_float = ToDtype(torch.float32, scale=True) + normalize = DotaNormalize() + + if image_set == "train": + resize_wrappers = AlbumentationsWrapper.from_config( + [ + {"Resize": {"height": resolution, "width": resolution}}, + ] + ) + aug_wrappers = AlbumentationsWrapper.from_config( + [ + {"HorizontalFlip": {"p": 0.5}}, + {"VerticalFlip": {"p": 0.5}}, + {"RandomRotate90": {"p": 0.5}}, + ] + ) + return Compose([*resize_wrappers, *aug_wrappers, to_image, to_float, normalize]) + + resize_wrappers = AlbumentationsWrapper.from_config( + [ + {"Resize": {"height": resolution, "width": resolution}}, + ] + ) + return Compose([*resize_wrappers, to_image, to_float, normalize]) + + +class DotaNormalize: + """Normalize images and convert OBB corners to normalized cxcywha format. + + After geometric augmentations, recomputes ``boxes_obb`` from the + (potentially transformed) ``corners`` keypoints, then normalizes + spatial coordinates by image size. + """ + + def __init__( + self, + mean: tuple[float, ...] = (0.485, 0.456, 0.406), + std: tuple[float, ...] = (0.229, 0.224, 0.225), + ) -> None: + from torchvision.transforms import Normalize as _TVNormalize + + self._normalize = _TVNormalize(mean, std) + + def __call__( + self, image: torch.Tensor, target: dict[str, Any] | None = None + ) -> tuple[torch.Tensor, dict[str, Any] | None]: + image = self._normalize(image) + if target is None: + return image, None + target = target.copy() + h, w = image.shape[-2:] + + if "corners" in target and len(target["corners"]) > 0: + corners = target["corners"] + boxes_obb = corners_to_cxcywha(corners) + scale = torch.tensor([w, h, w, h, 1.0], dtype=boxes_obb.dtype) + boxes_obb = boxes_obb / scale + target["boxes_obb"] = boxes_obb + + return image, target + + +def build_dota(image_set: str, args: Any, resolution: int) -> DotaDetection: + """Build a DOTA dataset for the given split. + + Args: + image_set: Split identifier — ``"train"`` or ``"val"``. + args: Namespace with ``dataset_dir`` attribute. + resolution: Target resolution in pixels. + + Returns: + Configured DotaDetection dataset. + """ + root = Path(args.dataset_dir) / image_set + transforms = make_dota_transforms(image_set, resolution) + return DotaDetection(root=root, transforms=transforms) diff --git a/src/rfdetr/models/_types.py b/src/rfdetr/models/_types.py index 3b7db2510..070e9ab3a 100644 --- a/src/rfdetr/models/_types.py +++ b/src/rfdetr/models/_types.py @@ -63,6 +63,7 @@ class BuilderArgs(Protocol): ia_bce_loss: bool cls_loss_coef: float segmentation_head: bool + oriented: bool mask_downsample_ratio: int num_queries: int num_select: int diff --git a/src/rfdetr/models/criterion.py b/src/rfdetr/models/criterion.py index 1f39371c1..dffe61d92 100644 --- a/src/rfdetr/models/criterion.py +++ b/src/rfdetr/models/criterion.py @@ -22,6 +22,7 @@ from rfdetr.models.math import accuracy from rfdetr.utilities import box_ops from rfdetr.utilities.distributed import get_world_size, is_dist_avail_and_initialized +from rfdetr.utilities.rotated_box_ops import kld_loss, probiou def sigmoid_focal_loss(inputs, targets, num_boxes, alpha: float = 0.25, gamma: float = 2): @@ -164,6 +165,7 @@ def __init__( self.use_varifocal_loss = use_varifocal_loss self.use_position_supervised_loss = use_position_supervised_loss self.ia_bce_loss = ia_bce_loss + self.oriented = getattr(matcher, "oriented", False) self.mask_point_sample_ratio = mask_point_sample_ratio def loss_labels(self, outputs, targets, indices, num_boxes, log=True): @@ -180,14 +182,18 @@ def loss_labels(self, outputs, targets, indices, num_boxes, log=True): alpha = self.focal_alpha gamma = 2 src_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) + tgt_key = "boxes_obb" if self.oriented else "boxes" + target_boxes = torch.cat([t[tgt_key][i] for t, (_, i) in zip(targets, indices)], dim=0) - iou_targets = torch.diag( - box_ops.box_iou( - box_ops.box_cxcywh_to_xyxy(src_boxes.detach()), - box_ops.box_cxcywh_to_xyxy(target_boxes), - )[0] - ) + if self.oriented: + iou_targets = probiou(src_boxes.detach(), target_boxes) + else: + iou_targets = torch.diag( + box_ops.box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes.detach()), + box_ops.box_cxcywh_to_xyxy(target_boxes), + )[0] + ) pos_ious = iou_targets.clone().detach() prob = src_logits.sigmoid() # init positive weights and negative weights @@ -326,27 +332,34 @@ def loss_cardinality(self, outputs, targets, indices, num_boxes): return losses def loss_boxes(self, outputs, targets, indices, num_boxes): - """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss - targets dicts must contain the key "boxes" containing a tensor of dim [nb_target_boxes, 4] - The target boxes are expected in format (center_x, center_y, w, h), normalized by the image size. + """Compute the losses related to the bounding boxes, the L1 regression loss and the GIoU loss. + + For oriented boxes, uses KLD loss instead of GIoU, and L1 on spatial dims only. """ assert "pred_boxes" in outputs idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] - target_boxes = torch.cat([t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) - - loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + tgt_key = "boxes_obb" if self.oriented else "boxes" + target_boxes = torch.cat([t[tgt_key][i] for t, (_, i) in zip(targets, indices)], dim=0) losses = {} - losses["loss_bbox"] = loss_bbox.sum() / num_boxes - loss_giou = 1 - torch.diag( - box_ops.generalized_box_iou( - box_ops.box_cxcywh_to_xyxy(src_boxes), - box_ops.box_cxcywh_to_xyxy(target_boxes), + if self.oriented: + loss_bbox = F.l1_loss(src_boxes[..., :4], target_boxes[..., :4], reduction="none") + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + loss_kld = kld_loss(src_boxes, target_boxes) + losses["loss_giou"] = loss_kld.sum() / num_boxes + else: + loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") + losses["loss_bbox"] = loss_bbox.sum() / num_boxes + loss_giou = 1 - torch.diag( + box_ops.generalized_box_iou( + box_ops.box_cxcywh_to_xyxy(src_boxes), + box_ops.box_cxcywh_to_xyxy(target_boxes), + ) ) - ) - losses["loss_giou"] = loss_giou.sum() / num_boxes + losses["loss_giou"] = loss_giou.sum() / num_boxes + return losses def loss_masks(self, outputs, targets, indices, num_boxes): diff --git a/src/rfdetr/models/heads/detection.py b/src/rfdetr/models/heads/detection.py index 3df8c3c9d..4afe96b37 100644 --- a/src/rfdetr/models/heads/detection.py +++ b/src/rfdetr/models/heads/detection.py @@ -6,6 +6,9 @@ """Detection head: bounding-box regression + classification projections.""" +import math + +import torch import torch.nn as nn from rfdetr.models.math import MLP @@ -20,14 +23,18 @@ class DetectionHead(nn.Module): Args: hidden_dim: Feature dimension coming from the transformer decoder. num_classes: Number of object classes (excluding background). + oriented: If ``True``, add an angle prediction head for oriented + bounding boxes. """ - def __init__(self, hidden_dim: int, num_classes: int) -> None: + def __init__(self, hidden_dim: int, num_classes: int, oriented: bool = False) -> None: super().__init__() self.class_embed = nn.Linear(hidden_dim, num_classes) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.oriented = oriented + self.angle_embed = MLP(hidden_dim, hidden_dim, 1, 3) if oriented else None - def forward(self, hs): + def forward(self, hs: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """Project decoder hidden states to class logits and box coordinates. Args: @@ -36,9 +43,13 @@ def forward(self, hs): Returns: Tuple of ``(outputs_class, outputs_coord)`` where ``outputs_class`` has shape ``(B, N, num_classes)`` and - ``outputs_coord`` has shape ``(B, N, 4)`` in ``[cx, cy, w, h]`` - normalised to ``[0, 1]``. + ``outputs_coord`` has shape ``(B, N, 4)`` or ``(B, N, 5)`` + when oriented. Box format is ``[cx, cy, w, h]`` normalised + to ``[0, 1]``, with an optional angle in ``[0, pi)``. """ outputs_class = self.class_embed(hs) outputs_coord = self.bbox_embed(hs).sigmoid() + if self.angle_embed is not None: + angle = self.angle_embed(hs).sigmoid() * math.pi + outputs_coord = torch.cat([outputs_coord, angle], dim=-1) return outputs_class, outputs_coord diff --git a/src/rfdetr/models/lwdetr.py b/src/rfdetr/models/lwdetr.py index 3b14a6b48..f22a11494 100644 --- a/src/rfdetr/models/lwdetr.py +++ b/src/rfdetr/models/lwdetr.py @@ -101,6 +101,7 @@ def __init__( two_stage=False, lite_refpoint_refine=False, bbox_reparam=False, + oriented=False, ): """Initializes the model. Parameters: @@ -112,6 +113,7 @@ def __init__( aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. group_detr: Number of groups to speed detr training. Default is 1. lite_refpoint_refine: TODO + oriented: If True, add an angle prediction head for oriented bounding boxes. """ super().__init__() self.num_queries = num_queries @@ -119,6 +121,8 @@ def __init__( hidden_dim = transformer.d_model self.class_embed = nn.Linear(hidden_dim, num_classes) self.bbox_embed = MLP(hidden_dim, hidden_dim, 4, 3) + self.oriented = oriented + self.angle_embed = MLP(hidden_dim, hidden_dim, 1, 3) if oriented else None self.segmentation_head = segmentation_head query_dim = 4 @@ -148,6 +152,10 @@ def __init__( nn.init.constant_(self.bbox_embed.layers[-1].weight.data, 0) nn.init.constant_(self.bbox_embed.layers[-1].bias.data, 0) + if self.angle_embed is not None: + nn.init.constant_(self.angle_embed.layers[-1].weight.data, 0) + nn.init.constant_(self.angle_embed.layers[-1].bias.data, 0) + # two_stage self.two_stage = two_stage if self.two_stage: @@ -240,6 +248,10 @@ def forward(self, samples: NestedTensor, targets=None): else: outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid() + if self.angle_embed is not None: + angle = self.angle_embed(hs).sigmoid() * math.pi + outputs_coord = torch.cat([outputs_coord, angle], dim=-1) + outputs_class = self.class_embed(hs) if self.segmentation_head is not None: @@ -306,6 +318,9 @@ def forward_export(self, tensors): outputs_coord = torch.concat([outputs_coord_cxcy, outputs_coord_wh], dim=-1) else: outputs_coord = (self.bbox_embed(hs) + ref_unsigmoid).sigmoid() + if self.angle_embed is not None: + angle = self.angle_embed(hs).sigmoid() * math.pi + outputs_coord = torch.cat([outputs_coord, angle], dim=-1) outputs_class = self.class_embed(hs) if self.segmentation_head is not None: outputs_masks = self.segmentation_head( @@ -461,6 +476,7 @@ def build_model(args: "BuilderArgs"): two_stage=args.two_stage, lite_refpoint_refine=args.lite_refpoint_refine, bbox_reparam=args.bbox_reparam, + oriented=getattr(args, "oriented", False), ) return model @@ -515,7 +531,8 @@ def build_criterion_and_postprocessors(args: "BuilderArgs"): ia_bce_loss=args.ia_bce_loss, ) criterion.to(device) - postprocess = PostProcess(num_select=args.num_select) + oriented = getattr(args, "oriented", False) + postprocess = PostProcess(num_select=args.num_select, oriented=oriented) return criterion, postprocess diff --git a/src/rfdetr/models/matcher.py b/src/rfdetr/models/matcher.py index ca132b448..dee9520a4 100644 --- a/src/rfdetr/models/matcher.py +++ b/src/rfdetr/models/matcher.py @@ -29,6 +29,7 @@ from rfdetr.models.heads.segmentation import point_sample from rfdetr.utilities.box_ops import batch_dice_loss, batch_sigmoid_ce_loss, box_cxcywh_to_xyxy, generalized_box_iou from rfdetr.utilities.logger import get_logger +from rfdetr.utilities.rotated_box_ops import gwd_pairwise logger = get_logger() _SANITIZED_COST_MARGIN = 1.0 @@ -52,6 +53,7 @@ def __init__( mask_point_sample_ratio: int = 16, cost_mask_ce: float = 1, cost_mask_dice: float = 1, + oriented: bool = False, ): """Creates the matcher. @@ -65,6 +67,7 @@ def __init__( mask_point_sample_ratio: Downsampling ratio for mask point sampling. cost_mask_ce: Relative weight of the binary cross-entropy mask cost. cost_mask_dice: Relative weight of the Dice mask cost. + oriented: If ``True``, use GWD cost instead of GIoU for rotated boxes. """ super().__init__() self.cost_class = cost_class @@ -75,6 +78,7 @@ def __init__( self.mask_point_sample_ratio = mask_point_sample_ratio self.cost_mask_ce = cost_mask_ce self.cost_mask_dice = cost_mask_dice + self.oriented = oriented self._warned_non_finite_costs = False @staticmethod @@ -142,17 +146,20 @@ def forward(self, outputs, targets, group_detr=1): # We flatten to compute the cost matrices in a batch flat_pred_logits = outputs["pred_logits"].flatten(0, 1) out_prob = flat_pred_logits.sigmoid() # [batch_size * num_queries, num_classes] - out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4] + out_bbox = outputs["pred_boxes"].flatten(0, 1) # [batch_size * num_queries, 4 or 5] # Also concat the target labels and boxes tgt_ids = torch.cat([v["labels"] for v in targets]) - tgt_bbox = torch.cat([v["boxes"] for v in targets]) + tgt_bbox_key = "boxes_obb" if self.oriented else "boxes" + tgt_bbox = torch.cat([v[tgt_bbox_key] for v in targets]) masks_present = "masks" in targets[0] - # Compute the giou cost between boxes - giou = generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) - cost_giou = -giou + if self.oriented: + cost_giou = gwd_pairwise(out_bbox, tgt_bbox) + else: + giou = generalized_box_iou(box_cxcywh_to_xyxy(out_bbox), box_cxcywh_to_xyxy(tgt_bbox)) + cost_giou = -giou # Compute the classification cost. alpha = 0.25 @@ -165,8 +172,10 @@ def forward(self, outputs, targets, group_detr=1): pos_cost_class = alpha * ((1 - out_prob) ** gamma) * (-F.logsigmoid(flat_pred_logits)) cost_class = pos_cost_class[:, tgt_ids] - neg_cost_class[:, tgt_ids] - # Compute the L1 cost between boxes - cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) + # Compute the L1 cost between boxes (spatial dims only for oriented) + out_bbox_spatial = out_bbox[..., :4] if self.oriented else out_bbox + tgt_bbox_spatial = tgt_bbox[..., :4] if self.oriented else tgt_bbox + cost_bbox = torch.cdist(out_bbox_spatial, tgt_bbox_spatial, p=1) if masks_present: tgt_masks = torch.cat([v["masks"] for v in targets]) @@ -227,7 +236,7 @@ def forward(self, outputs, targets, group_detr=1): self._warned_non_finite_costs = True C = self._sanitize_cost_matrix(C) - sizes = [len(v["boxes"]) for v in targets] + sizes = [len(v[tgt_bbox_key]) for v in targets] indices = [] g_num_queries = num_queries // group_detr C_list = C.split(g_num_queries, dim=1) @@ -248,6 +257,7 @@ def forward(self, outputs, targets, group_detr=1): def build_matcher(args): + oriented = getattr(args, "oriented", False) if args.segmentation_head: return HungarianMatcher( cost_class=args.set_cost_class, @@ -257,6 +267,7 @@ def build_matcher(args): cost_mask_ce=args.mask_ce_loss_coef, cost_mask_dice=args.mask_dice_loss_coef, mask_point_sample_ratio=args.mask_point_sample_ratio, + oriented=oriented, ) else: return HungarianMatcher( @@ -264,4 +275,5 @@ def build_matcher(args): cost_bbox=args.set_cost_bbox, cost_giou=args.set_cost_giou, focal_alpha=args.focal_alpha, + oriented=oriented, ) diff --git a/src/rfdetr/models/postprocess.py b/src/rfdetr/models/postprocess.py index 07094b24f..7c6d5bcf5 100644 --- a/src/rfdetr/models/postprocess.py +++ b/src/rfdetr/models/postprocess.py @@ -15,14 +15,16 @@ from torch import nn from rfdetr.utilities import box_ops +from rfdetr.utilities.rotated_box_ops import box_cxcywha_to_corners class PostProcess(nn.Module): """This module converts the model's output into the format expected by the coco api""" - def __init__(self, num_select=300) -> None: + def __init__(self, num_select=300, oriented=False) -> None: super().__init__() self.num_select = num_select + self.oriented = oriented @torch.no_grad() def forward(self, outputs, target_sizes): @@ -44,6 +46,20 @@ def forward(self, outputs, target_sizes): scores = topk_values topk_boxes = topk_indexes // out_logits.shape[2] labels = topk_indexes % out_logits.shape[2] + + if self.oriented: + box_dim = out_bbox.shape[-1] + obb = torch.gather(out_bbox, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, box_dim)).clone() + img_h, img_w = target_sizes.unbind(1) + scale_fct = torch.stack([img_w, img_h, img_w, img_h], dim=1) + obb[..., :4] = obb[..., :4] * scale_fct[:, None, :] + corners = box_cxcywha_to_corners(obb) + results = [ + {"scores": sc, "labels": lb, "boxes_obb": ob, "corners": cn} + for sc, lb, ob, cn in zip(scores, labels, obb, corners) + ] + return results + boxes = box_ops.box_cxcywh_to_xyxy(out_bbox) boxes = torch.gather(boxes, 1, topk_boxes.unsqueeze(-1).repeat(1, 1, 4)) diff --git a/src/rfdetr/utilities/rotated_box_ops.py b/src/rfdetr/utilities/rotated_box_ops.py new file mode 100644 index 000000000..9e57333fb --- /dev/null +++ b/src/rfdetr/utilities/rotated_box_ops.py @@ -0,0 +1,259 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +"""Utilities for oriented (rotated) bounding box manipulation, IoU, and losses.""" + +import math + +import torch + + +def normalize_angle(angle: torch.Tensor) -> torch.Tensor: + """Normalize angles to [0, pi) with pi-periodicity. + + Args: + angle: Angles in radians, arbitrary range. + + Returns: + Angles clamped to [0, pi). + """ + return angle - torch.floor(angle / math.pi) * math.pi + + +def box_cxcywha_to_corners(boxes: torch.Tensor) -> torch.Tensor: + """Convert oriented boxes from center format to four corner points. + + Args: + boxes: Tensor of shape ``(..., 5)`` as ``[cx, cy, w, h, angle]`` + where angle is in radians. + + Returns: + Tensor of shape ``(..., 4, 2)`` with four corner coordinates. + """ + cx, cy, w, h, angle = boxes.unbind(-1) + cos_a = torch.cos(angle) + sin_a = torch.sin(angle) + + hw = w / 2 + hh = h / 2 + + dx_w = hw * cos_a + dy_w = hw * sin_a + dx_h = hh * -sin_a + dy_h = hh * cos_a + + c1 = torch.stack([cx - dx_w - dx_h, cy - dy_w - dy_h], dim=-1) + c2 = torch.stack([cx + dx_w - dx_h, cy + dy_w - dy_h], dim=-1) + c3 = torch.stack([cx + dx_w + dx_h, cy + dy_w + dy_h], dim=-1) + c4 = torch.stack([cx - dx_w + dx_h, cy - dy_w + dy_h], dim=-1) + + return torch.stack([c1, c2, c3, c4], dim=-2) + + +def corners_to_cxcywha(corners: torch.Tensor) -> torch.Tensor: + """Convert four corner points to oriented box center format. + + Uses the first edge (corner0 -> corner1) as the width direction to + derive the rotation angle. + + Args: + corners: Tensor of shape ``(..., 4, 2)`` with four corner points + ordered sequentially around the box. + + Returns: + Tensor of shape ``(..., 5)`` as ``[cx, cy, w, h, angle]``. + """ + cx = corners[..., :, 0].mean(dim=-1) + cy = corners[..., :, 1].mean(dim=-1) + + edge_w = corners[..., 1, :] - corners[..., 0, :] + edge_h = corners[..., 3, :] - corners[..., 0, :] + + w = torch.linalg.norm(edge_w, dim=-1) + h = torch.linalg.norm(edge_h, dim=-1) + + angle = torch.atan2(edge_w[..., 1], edge_w[..., 0]) + angle = normalize_angle(angle) + + return torch.stack([cx, cy, w, h, angle], dim=-1) + + +def _obb_to_gaussian( + boxes: torch.Tensor, +) -> tuple[torch.Tensor, torch.Tensor]: + """Convert oriented boxes to 2D Gaussian distributions. + + Each box ``[cx, cy, w, h, angle]`` maps to a Gaussian with mean + ``(cx, cy)`` and covariance ``R @ diag(w^2/4, h^2/4) @ R^T``. + + Args: + boxes: Tensor of shape ``(..., 5)``. + + Returns: + Tuple of ``(mu, sigma)`` where ``mu`` has shape ``(..., 2)`` and + ``sigma`` has shape ``(..., 2, 2)``. + """ + cx, cy, w, h, angle = boxes.unbind(-1) + mu = torch.stack([cx, cy], dim=-1) + + cos_a = torch.cos(angle) + sin_a = torch.sin(angle) + + w = w.clamp(min=1e-6) + h = h.clamp(min=1e-6) + var_w = (w * w) / 4 + var_h = (h * h) / 4 + + a = var_w * cos_a * cos_a + var_h * sin_a * sin_a + b = (var_w - var_h) * cos_a * sin_a + d = var_w * sin_a * sin_a + var_h * cos_a * cos_a + + sigma = torch.stack([a, b, b, d], dim=-1).reshape(*boxes.shape[:-1], 2, 2) + + return mu, sigma + + +def gwd_loss(pred: torch.Tensor, target: torch.Tensor, tau: float = 1.0) -> torch.Tensor: + """Gaussian Wasserstein Distance between paired oriented boxes. + + Uses the closed-form 2nd Wasserstein distance between two 2D Gaussians + derived from the box parameters. + + Args: + pred: Predicted boxes ``(..., 5)`` as ``[cx, cy, w, h, angle]``. + target: Target boxes ``(..., 5)``, same shape as pred. + tau: Temperature parameter for loss normalization. + + Returns: + Per-box GWD loss, same leading shape as inputs. + """ + mu_p, sigma_p = _obb_to_gaussian(pred) + mu_t, sigma_t = _obb_to_gaussian(target) + + diff_mu = mu_p - mu_t + term_center = (diff_mu * diff_mu).sum(dim=-1) + + trace_p = sigma_p[..., 0, 0] + sigma_p[..., 1, 1] + trace_t = sigma_t[..., 0, 0] + sigma_t[..., 1, 1] + + product = torch.bmm(sigma_p.reshape(-1, 2, 2), sigma_t.reshape(-1, 2, 2)).reshape(*sigma_p.shape) + trace_product = product[..., 0, 0] + product[..., 1, 1] + det_p = sigma_p[..., 0, 0] * sigma_p[..., 1, 1] - sigma_p[..., 0, 1] * sigma_p[..., 1, 0] + det_t = sigma_t[..., 0, 0] * sigma_t[..., 1, 1] - sigma_t[..., 0, 1] * sigma_t[..., 1, 0] + det_sqrt = (det_p.clamp(min=1e-8) * det_t.clamp(min=1e-8)).sqrt() + trace_sqrt = (trace_product + 2 * det_sqrt).clamp(min=1e-8).sqrt() + + w2 = (term_center + trace_p + trace_t - 2 * trace_sqrt).clamp(min=0) + + return 1 - 1 / (tau + torch.log1p(w2)) + + +def kld_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """KL Divergence loss between oriented boxes modeled as 2D Gaussians. + + Scale-invariant and aspect-ratio adaptive: elongated objects receive + stronger angular gradients. + + Args: + pred: Predicted boxes ``(..., 5)`` as ``[cx, cy, w, h, angle]``. + target: Target boxes ``(..., 5)``, same shape as pred. + + Returns: + Per-box KLD loss, same leading shape as inputs. + """ + mu_p, sigma_p = _obb_to_gaussian(pred) + mu_t, sigma_t = _obb_to_gaussian(target) + + det_p = (sigma_p[..., 0, 0] * sigma_p[..., 1, 1] - sigma_p[..., 0, 1] ** 2).clamp(min=1e-8) + det_t = (sigma_t[..., 0, 0] * sigma_t[..., 1, 1] - sigma_t[..., 0, 1] ** 2).clamp(min=1e-8) + + inv_t00 = sigma_t[..., 1, 1] / det_t + inv_t01 = -sigma_t[..., 0, 1] / det_t + inv_t11 = sigma_t[..., 0, 0] / det_t + + trace_term = inv_t00 * sigma_p[..., 0, 0] + 2 * inv_t01 * sigma_p[..., 0, 1] + inv_t11 * sigma_p[..., 1, 1] + + diff = mu_p - mu_t + mahal_term = inv_t00 * diff[..., 0] ** 2 + 2 * inv_t01 * diff[..., 0] * diff[..., 1] + inv_t11 * diff[..., 1] ** 2 + + log_det_term = torch.log(det_t) - torch.log(det_p) + + kld = 0.5 * (trace_term + mahal_term + log_det_term - 2) + + return torch.log1p(kld) + + +def probiou(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: + """Probabilistic IoU via Bhattacharyya coefficient between Gaussian-encoded boxes. + + Returns a similarity score in ``[0, 1]`` where 1 means identical boxes. + + Args: + pred: Predicted boxes ``(..., 5)`` as ``[cx, cy, w, h, angle]``. + target: Target boxes ``(..., 5)``, same shape as pred. + + Returns: + Per-box ProbIoU similarity, same leading shape as inputs. + """ + mu_p, sigma_p = _obb_to_gaussian(pred) + mu_t, sigma_t = _obb_to_gaussian(target) + + sigma_avg = (sigma_p + sigma_t) / 2 + det_avg = (sigma_avg[..., 0, 0] * sigma_avg[..., 1, 1] - sigma_avg[..., 0, 1] ** 2).clamp(min=1e-8) + det_p = (sigma_p[..., 0, 0] * sigma_p[..., 1, 1] - sigma_p[..., 0, 1] ** 2).clamp(min=1e-8) + det_t = (sigma_t[..., 0, 0] * sigma_t[..., 1, 1] - sigma_t[..., 0, 1] ** 2).clamp(min=1e-8) + + inv_avg00 = sigma_avg[..., 1, 1] / det_avg + inv_avg01 = -sigma_avg[..., 0, 1] / det_avg + inv_avg11 = sigma_avg[..., 0, 0] / det_avg + + diff = mu_p - mu_t + mahal = inv_avg00 * diff[..., 0] ** 2 + 2 * inv_avg01 * diff[..., 0] * diff[..., 1] + inv_avg11 * diff[..., 1] ** 2 + + log_coeff = 0.5 * (torch.log(det_avg) - 0.5 * (torch.log(det_p) + torch.log(det_t))) + bd = 0.125 * mahal + log_coeff + + hd_squared = (1 - torch.exp(-bd.clamp(max=50))).clamp(min=0) + return 1 - hd_squared + + +def gwd_pairwise(boxes1: torch.Tensor, boxes2: torch.Tensor, tau: float = 1.0) -> torch.Tensor: + """Pairwise GWD cost matrix for Hungarian matching. + + Args: + boxes1: Predicted boxes of shape ``(N, 5)``. + boxes2: Target boxes of shape ``(M, 5)``. + tau: Temperature parameter. + + Returns: + Cost matrix of shape ``(N, M)``. + """ + mu_p, sigma_p = _obb_to_gaussian(boxes1) + mu_t, sigma_t = _obb_to_gaussian(boxes2) + + diff_mu = mu_p[:, None, :] - mu_t[None, :, :] + term_center = (diff_mu * diff_mu).sum(dim=-1) + + trace_p = sigma_p[..., 0, 0] + sigma_p[..., 1, 1] + trace_t = sigma_t[..., 0, 0] + sigma_t[..., 1, 1] + + n, m = boxes1.shape[0], boxes2.shape[0] + + sp_exp = sigma_p[:, None, :, :].expand(n, m, 2, 2).reshape(n * m, 2, 2) + st_exp = sigma_t[None, :, :, :].expand(n, m, 2, 2).reshape(n * m, 2, 2) + + product = torch.bmm(sp_exp, st_exp).reshape(n, m, 2, 2) + trace_product = product[..., 0, 0] + product[..., 1, 1] + + det_p = (sigma_p[..., 0, 0] * sigma_p[..., 1, 1] - sigma_p[..., 0, 1] ** 2).clamp(min=1e-8) + det_t = (sigma_t[..., 0, 0] * sigma_t[..., 1, 1] - sigma_t[..., 0, 1] ** 2).clamp(min=1e-8) + + det_sqrt = (det_p[:, None] * det_t[None, :]).sqrt() + trace_sqrt = (trace_product + 2 * det_sqrt).clamp(min=1e-8).sqrt() + + w2 = (term_center + trace_p[:, None] + trace_t[None, :] - 2 * trace_sqrt).clamp(min=0) + + return 1 - 1 / (tau + torch.log1p(w2)) diff --git a/tests/datasets/test_dota_detection.py b/tests/datasets/test_dota_detection.py new file mode 100644 index 000000000..c01037565 --- /dev/null +++ b/tests/datasets/test_dota_detection.py @@ -0,0 +1,198 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import math +from pathlib import Path + +import pytest +import torch +from PIL import Image + +from rfdetr.datasets.dota_detection import ( + DOTA_V1_CLASSES, + DotaDetection, + DotaNormalize, + build_dota, + corners_list_to_tensor, + make_dota_transforms, + parse_dota_annotation, +) + + +@pytest.fixture() +def dota_root(tmp_path: Path) -> Path: + """Create a minimal DOTA directory with one image and annotation.""" + images_dir = tmp_path / "images" + labels_dir = tmp_path / "labelTxt" + images_dir.mkdir() + labels_dir.mkdir() + + img = Image.new("RGB", (100, 100), color="red") + img.save(images_dir / "P0001.png") + + ann_text = "10 10 50 10 50 40 10 40 plane 0\n60 60 90 60 90 90 60 90 ship 0\n20 20 30 20 30 30 20 30 plane 1\n" + (labels_dir / "P0001.txt").write_text(ann_text) + + return tmp_path + + +class TestParseDotaAnnotation: + def test_parses_valid_lines(self, dota_root: Path) -> None: + ann_path = dota_root / "labelTxt" / "P0001.txt" + annotations = parse_dota_annotation(ann_path) + assert len(annotations) == 3 + + def test_annotation_fields(self, dota_root: Path) -> None: + ann_path = dota_root / "labelTxt" / "P0001.txt" + ann = parse_dota_annotation(ann_path)[0] + assert ann["category"] == "plane" + assert ann["difficulty"] == 0 + assert len(ann["corners"]) == 8 + + def test_skips_short_lines(self, tmp_path: Path) -> None: + ann_path = tmp_path / "test.txt" + ann_path.write_text("10 20 30\n10 10 50 10 50 40 10 40 plane 0\n") + annotations = parse_dota_annotation(ann_path) + assert len(annotations) == 1 + + def test_empty_file(self, tmp_path: Path) -> None: + ann_path = tmp_path / "empty.txt" + ann_path.write_text("") + assert parse_dota_annotation(ann_path) == [] + + def test_difficulty_defaults_to_zero(self, tmp_path: Path) -> None: + ann_path = tmp_path / "no_diff.txt" + ann_path.write_text("10 10 50 10 50 40 10 40 plane\n") + ann = parse_dota_annotation(ann_path)[0] + assert ann["difficulty"] == 0 + + +class TestCornersListToTensor: + def test_shape(self) -> None: + result = corners_list_to_tensor([0, 0, 10, 0, 10, 5, 0, 5]) + assert result.shape == (4, 2) + + def test_values(self) -> None: + result = corners_list_to_tensor([1, 2, 3, 4, 5, 6, 7, 8]) + expected = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]], dtype=torch.float32) + assert torch.equal(result, expected) + + +class TestDotaDetection: + def test_len(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + assert len(dataset) == 1 + + def test_getitem_returns_image_and_target(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + image, target = dataset[0] + assert isinstance(image, Image.Image) + assert "boxes_obb" in target + assert "labels" in target + assert "corners" in target + + def test_filters_difficult_by_default(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert target["labels"].shape[0] == 2 + + def test_includes_difficult_when_flag_set(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root, include_difficult=True) + _, target = dataset[0] + assert target["labels"].shape[0] == 3 + + def test_boxes_obb_shape(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert target["boxes_obb"].shape == (2, 5) + + def test_corners_shape(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert target["corners"].shape == (2, 4, 2) + + def test_labels_are_valid_indices(self, dota_root: Path) -> None: + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert (target["labels"] >= 0).all() + assert (target["labels"] < len(DOTA_V1_CLASSES)).all() + + def test_skips_unknown_categories(self, dota_root: Path) -> None: + (dota_root / "labelTxt" / "P0001.txt").write_text("10 10 50 10 50 40 10 40 unknown_category 0\n") + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert target["labels"].shape[0] == 0 + + def test_missing_images_dir_raises(self, tmp_path: Path) -> None: + (tmp_path / "labelTxt").mkdir() + with pytest.raises(FileNotFoundError): + DotaDetection(root=tmp_path) + + def test_missing_annotation_file_returns_empty(self, dota_root: Path) -> None: + (dota_root / "labelTxt" / "P0001.txt").unlink() + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + assert target["labels"].shape[0] == 0 + + def test_axis_aligned_box_angle_near_zero(self, dota_root: Path) -> None: + (dota_root / "labelTxt" / "P0001.txt").write_text("0 0 10 0 10 5 0 5 plane 0\n") + dataset = DotaDetection(root=dota_root) + _, target = dataset[0] + angle = target["boxes_obb"][0, 4].item() + assert abs(angle) < 0.01 or abs(angle - math.pi) < 0.01 + + +class TestDotaNormalize: + def test_normalizes_boxes(self) -> None: + normalize = DotaNormalize() + image = torch.rand(3, 100, 200) + corners = torch.tensor([[[10, 10], [50, 10], [50, 40], [10, 40]]], dtype=torch.float32) + target = {"corners": corners, "boxes_obb": torch.zeros(1, 5)} + + image_out, target_out = normalize(image, target) + obb = target_out["boxes_obb"] + assert obb[0, 0].item() < 1.0 + assert obb[0, 1].item() < 1.0 + + def test_none_target_passthrough(self) -> None: + normalize = DotaNormalize() + image = torch.rand(3, 100, 100) + image_out, target_out = normalize(image, None) + assert target_out is None + + def test_empty_corners(self) -> None: + normalize = DotaNormalize() + image = torch.rand(3, 100, 100) + target = {"corners": torch.zeros(0, 4, 2), "boxes_obb": torch.zeros(0, 5)} + _, target_out = normalize(image, target) + assert target_out["boxes_obb"].shape == (0, 5) + + +class TestMakeDotaTransforms: + def test_train_returns_compose(self) -> None: + transforms = make_dota_transforms("train", 512) + assert transforms is not None + + def test_val_returns_compose(self) -> None: + transforms = make_dota_transforms("val", 512) + assert transforms is not None + + +class TestBuildDota: + def test_builds_dataset(self, dota_root: Path) -> None: + import types + + args = types.SimpleNamespace(dataset_dir=str(dota_root.parent)) + root_with_split = dota_root.parent / "train" + root_with_split.mkdir(exist_ok=True) + (root_with_split / "images").mkdir(exist_ok=True) + (root_with_split / "labelTxt").mkdir(exist_ok=True) + img = Image.new("RGB", (50, 50), color="blue") + img.save(root_with_split / "images" / "test.png") + (root_with_split / "labelTxt" / "test.txt").write_text("") + args.dataset_dir = str(dota_root.parent) + dataset = build_dota("train", args, 256) + assert isinstance(dataset, DotaDetection) diff --git a/tests/models/test_obb_export.py b/tests/models/test_obb_export.py new file mode 100644 index 000000000..5c6d9ff74 --- /dev/null +++ b/tests/models/test_obb_export.py @@ -0,0 +1,64 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from unittest.mock import MagicMock + +import torch + +from rfdetr.models.lwdetr import LWDETR + + +def _make_exportable_oriented_lwdetr() -> LWDETR: + """Build a minimal oriented LWDETR that can run forward_export.""" + hidden_dim = 8 + num_queries = 4 + num_classes = 3 + + backbone = MagicMock() + src = torch.randn(1, hidden_dim, 4, 4) + pos = torch.randn(1, hidden_dim, 4, 4) + backbone.return_value = ([src], None, [pos]) + + transformer = MagicMock() + transformer.d_model = hidden_dim + hs = torch.randn(1, 1, num_queries, hidden_dim) + ref = torch.randn(1, 1, num_queries, 4) + transformer.return_value = (hs, ref, None, None) + transformer.decoder = MagicMock() + transformer.decoder.bbox_embed = None + + model = LWDETR( + backbone=backbone, + transformer=transformer, + segmentation_head=None, + num_classes=num_classes, + num_queries=num_queries, + group_detr=1, + oriented=True, + bbox_reparam=False, + ) + return model + + +class TestOrientedExportForward: + def test_forward_export_output_has_angle(self) -> None: + model = _make_exportable_oriented_lwdetr() + model.eval() + model._export = True + tensors = torch.randn(1, 3, 32, 32) + dets, labels = model.forward_export(tensors) + assert dets.shape[-1] == 5 + assert labels.shape[-1] == 3 + + def test_forward_export_angle_range(self) -> None: + model = _make_exportable_oriented_lwdetr() + model.eval() + model._export = True + tensors = torch.randn(1, 3, 32, 32) + dets, _ = model.forward_export(tensors) + angles = dets[..., 4] + assert (angles >= 0).all() + assert (angles < 3.2).all() diff --git a/tests/models/test_obb_head.py b/tests/models/test_obb_head.py new file mode 100644 index 000000000..bddbf5466 --- /dev/null +++ b/tests/models/test_obb_head.py @@ -0,0 +1,99 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import math +from unittest.mock import MagicMock + +import torch + +from rfdetr.models.heads.detection import DetectionHead +from rfdetr.models.lwdetr import LWDETR + + +def _make_oriented_lwdetr(num_classes: int = 91) -> LWDETR: + """Construct a minimal oriented LWDETR for testing.""" + hidden_dim = 8 + backbone = MagicMock() + transformer = MagicMock() + transformer.d_model = hidden_dim + transformer.decoder = MagicMock() + transformer.decoder.bbox_embed = None + return LWDETR( + backbone=backbone, + transformer=transformer, + segmentation_head=None, + num_classes=num_classes, + num_queries=4, + group_detr=1, + oriented=True, + ) + + +class TestDetectionHeadOriented: + def test_standard_head_output_shape(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10) + hs = torch.randn(2, 5, 16) + cls_out, coord_out = head(hs) + assert cls_out.shape == (2, 5, 10) + assert coord_out.shape == (2, 5, 4) + + def test_oriented_head_output_shape(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10, oriented=True) + hs = torch.randn(2, 5, 16) + cls_out, coord_out = head(hs) + assert cls_out.shape == (2, 5, 10) + assert coord_out.shape == (2, 5, 5) + + def test_oriented_angle_range(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10, oriented=True) + hs = torch.randn(2, 5, 16) + _, coord_out = head(hs) + angles = coord_out[..., 4] + assert (angles >= 0).all() + assert (angles < math.pi + 0.01).all() + + def test_oriented_has_angle_embed(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10, oriented=True) + assert head.angle_embed is not None + + def test_standard_has_no_angle_embed(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10) + assert head.angle_embed is None + + def test_oriented_gradients_flow(self) -> None: + head = DetectionHead(hidden_dim=16, num_classes=10, oriented=True) + hs = torch.randn(2, 5, 16, requires_grad=True) + _, coord_out = head(hs) + loss = coord_out.sum() + loss.backward() + assert hs.grad is not None + assert torch.isfinite(hs.grad).all() + + +class TestLWDETROriented: + def test_oriented_model_has_angle_embed(self) -> None: + model = _make_oriented_lwdetr() + assert model.angle_embed is not None + assert model.oriented is True + + def test_non_oriented_model_has_no_angle_embed(self) -> None: + hidden_dim = 8 + backbone = MagicMock() + transformer = MagicMock() + transformer.d_model = hidden_dim + transformer.decoder = MagicMock() + transformer.decoder.bbox_embed = None + model = LWDETR( + backbone=backbone, + transformer=transformer, + segmentation_head=None, + num_classes=91, + num_queries=4, + group_detr=1, + oriented=False, + ) + assert model.angle_embed is None + assert model.oriented is False diff --git a/tests/models/test_obb_matcher_criterion.py b/tests/models/test_obb_matcher_criterion.py new file mode 100644 index 000000000..1d32ba833 --- /dev/null +++ b/tests/models/test_obb_matcher_criterion.py @@ -0,0 +1,104 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import math + +import torch + +from rfdetr.models.criterion import SetCriterion +from rfdetr.models.matcher import HungarianMatcher + + +def _make_oriented_matcher() -> HungarianMatcher: + return HungarianMatcher(cost_class=1, cost_bbox=5, cost_giou=2, oriented=True) + + +def _make_oriented_criterion() -> SetCriterion: + matcher = _make_oriented_matcher() + weight_dict = {"loss_ce": 1.0, "loss_bbox": 5.0, "loss_giou": 2.0} + return SetCriterion( + num_classes=16, + matcher=matcher, + weight_dict=weight_dict, + focal_alpha=0.25, + losses=["labels", "boxes"], + ia_bce_loss=False, + ) + + +class TestOrientedMatcher: + def test_returns_valid_indices(self) -> None: + matcher = _make_oriented_matcher() + outputs = { + "pred_logits": torch.randn(1, 10, 16), + "pred_boxes": torch.rand(1, 10, 5), + } + targets = [ + { + "labels": torch.tensor([0, 3]), + "boxes_obb": torch.tensor( + [ + [0.5, 0.5, 0.2, 0.1, 0.3], + [0.3, 0.7, 0.15, 0.1, 1.0], + ] + ), + } + ] + indices = matcher(outputs, targets) + assert len(indices) == 1 + src_idx, tgt_idx = indices[0] + assert len(src_idx) == 2 + assert len(tgt_idx) == 2 + + def test_oriented_flag_stored(self) -> None: + matcher = _make_oriented_matcher() + assert matcher.oriented is True + + def test_non_oriented_default(self) -> None: + matcher = HungarianMatcher() + assert matcher.oriented is False + + +class TestOrientedCriterion: + def test_loss_boxes_returns_kld(self) -> None: + criterion = _make_oriented_criterion() + outputs = { + "pred_logits": torch.randn(1, 10, 16), + "pred_boxes": torch.rand(1, 10, 5) * 0.5 + 0.1, + } + outputs["pred_boxes"][..., 4] = outputs["pred_boxes"][..., 4] * math.pi + targets = [ + { + "labels": torch.tensor([0, 3]), + "boxes_obb": torch.tensor( + [ + [0.5, 0.5, 0.2, 0.1, 0.3], + [0.3, 0.7, 0.15, 0.1, 1.0], + ] + ), + } + ] + indices = [(torch.tensor([0, 1]), torch.tensor([0, 1]))] + losses = criterion.loss_boxes(outputs, targets, indices, num_boxes=2) + assert "loss_bbox" in losses + assert "loss_giou" in losses + assert losses["loss_bbox"].item() >= 0 + assert losses["loss_giou"].item() >= 0 + + def test_oriented_flag_propagated(self) -> None: + criterion = _make_oriented_criterion() + assert criterion.oriented is True + + def test_non_oriented_criterion(self) -> None: + matcher = HungarianMatcher() + criterion = SetCriterion( + num_classes=91, + matcher=matcher, + weight_dict={"loss_ce": 1.0, "loss_bbox": 5.0, "loss_giou": 2.0}, + focal_alpha=0.25, + losses=["labels", "boxes"], + ) + assert criterion.oriented is False diff --git a/tests/models/test_obb_postprocess.py b/tests/models/test_obb_postprocess.py new file mode 100644 index 000000000..5e04eedab --- /dev/null +++ b/tests/models/test_obb_postprocess.py @@ -0,0 +1,69 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import torch + +from rfdetr.models.postprocess import PostProcess + + +class TestPostProcessOriented: + def test_oriented_output_keys(self) -> None: + pp = PostProcess(num_select=5, oriented=True) + outputs = { + "pred_logits": torch.randn(1, 10, 3), + "pred_boxes": torch.rand(1, 10, 5), + } + target_sizes = torch.tensor([[480, 640]]) + results = pp(outputs, target_sizes) + assert len(results) == 1 + assert "boxes_obb" in results[0] + assert "corners" in results[0] + assert "scores" in results[0] + assert "labels" in results[0] + + def test_oriented_obb_shape(self) -> None: + pp = PostProcess(num_select=5, oriented=True) + outputs = { + "pred_logits": torch.randn(1, 10, 3), + "pred_boxes": torch.rand(1, 10, 5), + } + target_sizes = torch.tensor([[480, 640]]) + results = pp(outputs, target_sizes) + assert results[0]["boxes_obb"].shape == (5, 5) + assert results[0]["corners"].shape == (5, 4, 2) + + def test_oriented_scales_spatial_dims(self) -> None: + pp = PostProcess(num_select=5, oriented=True) + outputs = { + "pred_logits": torch.randn(1, 10, 3), + "pred_boxes": torch.full((1, 10, 5), 0.5), + } + target_sizes = torch.tensor([[100, 200]]) + results = pp(outputs, target_sizes) + obb = results[0]["boxes_obb"] + assert torch.allclose(obb[0, 0], torch.tensor(100.0), atol=1.0) + assert torch.allclose(obb[0, 1], torch.tensor(50.0), atol=1.0) + + def test_standard_postprocess_unchanged(self) -> None: + pp = PostProcess(num_select=5, oriented=False) + outputs = { + "pred_logits": torch.randn(1, 10, 3), + "pred_boxes": torch.rand(1, 10, 4), + } + target_sizes = torch.tensor([[480, 640]]) + results = pp(outputs, target_sizes) + assert "boxes" in results[0] + assert "boxes_obb" not in results[0] + + def test_batch_support(self) -> None: + pp = PostProcess(num_select=5, oriented=True) + outputs = { + "pred_logits": torch.randn(3, 10, 3), + "pred_boxes": torch.rand(3, 10, 5), + } + target_sizes = torch.tensor([[480, 640], [320, 320], [600, 800]]) + results = pp(outputs, target_sizes) + assert len(results) == 3 diff --git a/tests/training/test_obb_integration.py b/tests/training/test_obb_integration.py new file mode 100644 index 000000000..531eaad09 --- /dev/null +++ b/tests/training/test_obb_integration.py @@ -0,0 +1,36 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +from rfdetr._namespace import _namespace_from_configs +from rfdetr.config import RFDETRBaseConfig, TrainConfig + + +class TestOrientedNamespace: + def test_oriented_false_by_default(self) -> None: + mc = RFDETRBaseConfig(pretrain_weights=None) + tc = TrainConfig(dataset_dir="/tmp/fake") + ns = _namespace_from_configs(mc, tc) + assert ns.oriented is False + + def test_oriented_forwarded_to_namespace(self) -> None: + mc = RFDETRBaseConfig(pretrain_weights=None, oriented=True) + tc = TrainConfig(dataset_dir="/tmp/fake") + ns = _namespace_from_configs(mc, tc) + assert ns.oriented is True + + +class TestOrientedConfig: + def test_oriented_default_false(self) -> None: + mc = RFDETRBaseConfig(pretrain_weights=None) + assert mc.oriented is False + + def test_oriented_can_be_set(self) -> None: + mc = RFDETRBaseConfig(pretrain_weights=None, oriented=True) + assert mc.oriented is True + + def test_dota_dataset_file_accepted(self) -> None: + tc = TrainConfig(dataset_dir="/tmp/fake", dataset_file="dota") + assert tc.dataset_file == "dota" diff --git a/tests/utilities/test_rotated_box_ops.py b/tests/utilities/test_rotated_box_ops.py new file mode 100644 index 000000000..1bd8e8487 --- /dev/null +++ b/tests/utilities/test_rotated_box_ops.py @@ -0,0 +1,264 @@ +# ------------------------------------------------------------------------ +# RF-DETR +# Copyright (c) 2025 Roboflow. All Rights Reserved. +# Licensed under the Apache License, Version 2.0 [see LICENSE for details] +# ------------------------------------------------------------------------ + +import math + +import torch + +from rfdetr.utilities.rotated_box_ops import ( + box_cxcywha_to_corners, + corners_to_cxcywha, + gwd_loss, + gwd_pairwise, + kld_loss, + normalize_angle, + probiou, +) + + +class TestNormalizeAngle: + def test_already_in_range(self) -> None: + angles = torch.tensor([0.0, 0.5, 1.0, math.pi - 0.01]) + result = normalize_angle(angles) + assert torch.allclose(result, angles, atol=1e-6) + + def test_negative_angles(self) -> None: + result = normalize_angle(torch.tensor([-0.5])) + assert 0 <= result.item() < math.pi + + def test_angles_beyond_pi(self) -> None: + result = normalize_angle(torch.tensor([math.pi + 0.5])) + expected = torch.tensor([0.5]) + assert torch.allclose(result, expected, atol=1e-6) + + def test_pi_periodicity(self) -> None: + angle = torch.tensor([0.3]) + shifted = torch.tensor([0.3 + math.pi]) + assert torch.allclose(normalize_angle(angle), normalize_angle(shifted), atol=1e-6) + + def test_two_pi(self) -> None: + result = normalize_angle(torch.tensor([2 * math.pi])) + assert torch.allclose(result, torch.tensor([0.0]), atol=1e-6) + + +class TestBoxCxcywhaToCorners: + def test_axis_aligned_box(self) -> None: + box = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.0]]) + corners = box_cxcywha_to_corners(box) + assert corners.shape == (1, 4, 2) + expected = torch.tensor([[[7.0, 18.0], [13.0, 18.0], [13.0, 22.0], [7.0, 22.0]]]) + assert torch.allclose(corners, expected, atol=1e-5) + + def test_90_degree_rotation(self) -> None: + box = torch.tensor([[0.0, 0.0, 4.0, 2.0, math.pi / 2]]) + corners = box_cxcywha_to_corners(box) + assert corners.shape == (1, 4, 2) + xs = corners[0, :, 0] + ys = corners[0, :, 1] + assert torch.allclose(xs.min(), torch.tensor(-1.0), atol=1e-5) + assert torch.allclose(xs.max(), torch.tensor(1.0), atol=1e-5) + assert torch.allclose(ys.min(), torch.tensor(-2.0), atol=1e-5) + assert torch.allclose(ys.max(), torch.tensor(2.0), atol=1e-5) + + def test_batch_shape(self) -> None: + boxes = torch.rand(5, 3, 5) + corners = box_cxcywha_to_corners(boxes) + assert corners.shape == (5, 3, 4, 2) + + def test_center_is_mean_of_corners(self) -> None: + box = torch.tensor([[5.0, 10.0, 8.0, 3.0, 0.7]]) + corners = box_cxcywha_to_corners(box) + center = corners[0].mean(dim=0) + assert torch.allclose(center, torch.tensor([5.0, 10.0]), atol=1e-5) + + +class TestCornersToBoxCxcywha: + def test_roundtrip(self) -> None: + original = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + corners = box_cxcywha_to_corners(original) + recovered = corners_to_cxcywha(corners) + assert torch.allclose(recovered, original, atol=1e-4) + + def test_roundtrip_batch(self) -> None: + original = torch.tensor( + [ + [5.0, 5.0, 10.0, 4.0, 0.0], + [15.0, 25.0, 8.0, 6.0, 1.0], + [0.0, 0.0, 3.0, 7.0, 2.5], + ] + ) + corners = box_cxcywha_to_corners(original) + recovered = corners_to_cxcywha(corners) + assert torch.allclose(recovered, original, atol=1e-4) + + def test_roundtrip_axis_aligned(self) -> None: + original = torch.tensor([[0.0, 0.0, 4.0, 2.0, 0.0]]) + corners = box_cxcywha_to_corners(original) + recovered = corners_to_cxcywha(corners) + assert torch.allclose(recovered, original, atol=1e-4) + + +class TestGwdLoss: + def test_identical_boxes_near_zero(self) -> None: + boxes = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + loss = gwd_loss(boxes, boxes) + assert loss.shape == (1,) + assert loss.item() < 0.01 + + def test_different_boxes_positive(self) -> None: + pred = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + target = torch.tensor([[15.0, 25.0, 8.0, 6.0, 1.0]]) + loss = gwd_loss(pred, target) + assert loss.item() > 0 + + def test_angle_boundary_symmetry(self) -> None: + box_a = torch.tensor([[10.0, 10.0, 6.0, 4.0, 0.01]]) + box_b = torch.tensor([[10.0, 10.0, 6.0, 4.0, math.pi - 0.01]]) + loss = gwd_loss(box_a, box_b) + assert loss.item() < 0.05 + + def test_batch(self) -> None: + pred = torch.rand(10, 5) * 10 + pred[..., 4] = pred[..., 4] % math.pi + target = torch.rand(10, 5) * 10 + target[..., 4] = target[..., 4] % math.pi + loss = gwd_loss(pred, target) + assert loss.shape == (10,) + + def test_gradients_flow(self) -> None: + pred = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]], requires_grad=True) + target = torch.tensor([[15.0, 25.0, 8.0, 6.0, 1.0]]) + loss = gwd_loss(pred, target).sum() + loss.backward() + assert pred.grad is not None + assert torch.isfinite(pred.grad).all() + + +class TestKldLoss: + def test_identical_boxes_zero(self) -> None: + boxes = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + loss = kld_loss(boxes, boxes) + assert loss.shape == (1,) + assert torch.allclose(loss, torch.tensor([0.0]), atol=1e-5) + + def test_different_boxes_positive(self) -> None: + pred = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + target = torch.tensor([[15.0, 25.0, 8.0, 6.0, 1.0]]) + loss = kld_loss(pred, target) + assert loss.item() > 0 + + def test_angle_sensitivity_scales_with_aspect_ratio(self) -> None: + thin_box = torch.tensor([[10.0, 10.0, 20.0, 2.0, 0.0]]) + thin_box_rotated = torch.tensor([[10.0, 10.0, 20.0, 2.0, 0.3]]) + square_box = torch.tensor([[10.0, 10.0, 5.0, 5.0, 0.0]]) + square_box_rotated = torch.tensor([[10.0, 10.0, 5.0, 5.0, 0.3]]) + + loss_thin = kld_loss(thin_box, thin_box_rotated) + loss_square = kld_loss(square_box, square_box_rotated) + + assert loss_thin.item() > loss_square.item() + + def test_gradients_flow(self) -> None: + pred = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]], requires_grad=True) + target = torch.tensor([[15.0, 25.0, 8.0, 6.0, 1.0]]) + loss = kld_loss(pred, target).sum() + loss.backward() + assert pred.grad is not None + assert torch.isfinite(pred.grad).all() + + +class TestProbiou: + def test_identical_boxes_one(self) -> None: + boxes = torch.tensor([[10.0, 20.0, 6.0, 4.0, 0.5]]) + score = probiou(boxes, boxes) + assert score.shape == (1,) + assert torch.allclose(score, torch.tensor([1.0]), atol=1e-4) + + def test_far_apart_boxes_near_zero(self) -> None: + pred = torch.tensor([[0.0, 0.0, 2.0, 2.0, 0.0]]) + target = torch.tensor([[1000.0, 1000.0, 2.0, 2.0, 0.0]]) + score = probiou(pred, target) + assert score.item() < 0.01 + + def test_range_zero_to_one(self) -> None: + pred = torch.rand(20, 5) * 10 + 1 + pred[..., 4] = pred[..., 4] % math.pi + target = torch.rand(20, 5) * 10 + 1 + target[..., 4] = target[..., 4] % math.pi + scores = probiou(pred, target) + assert (scores >= -0.01).all() + assert (scores <= 1.01).all() + + def test_batch(self) -> None: + pred = torch.rand(8, 5) * 10 + 1 + target = torch.rand(8, 5) * 10 + 1 + scores = probiou(pred, target) + assert scores.shape == (8,) + + +class TestGwdPairwise: + def test_output_shape(self) -> None: + boxes1 = torch.rand(5, 5) * 10 + 1 + boxes2 = torch.rand(3, 5) * 10 + 1 + boxes1[..., 4] = boxes1[..., 4] % math.pi + boxes2[..., 4] = boxes2[..., 4] % math.pi + cost = gwd_pairwise(boxes1, boxes2) + assert cost.shape == (5, 3) + + def test_diagonal_matches_paired(self) -> None: + boxes = torch.tensor( + [ + [10.0, 20.0, 6.0, 4.0, 0.5], + [5.0, 5.0, 3.0, 7.0, 1.2], + ] + ) + cost_matrix = gwd_pairwise(boxes, boxes) + paired = gwd_loss(boxes, boxes) + assert torch.allclose(torch.diag(cost_matrix), paired, atol=1e-5) + + def test_self_cost_near_zero_on_diagonal(self) -> None: + boxes = torch.tensor( + [ + [10.0, 20.0, 6.0, 4.0, 0.5], + [5.0, 5.0, 3.0, 7.0, 1.2], + ] + ) + cost = gwd_pairwise(boxes, boxes) + assert torch.allclose(torch.diag(cost), torch.zeros(2), atol=0.01) + + +class TestEdgeCases: + def test_zero_size_box_gwd_no_crash(self) -> None: + pred = torch.tensor([[5.0, 5.0, 0.0, 0.0, 0.5]]) + target = torch.tensor([[5.0, 5.0, 3.0, 2.0, 0.5]]) + loss = gwd_loss(pred, target) + assert torch.isfinite(loss).all() + + def test_zero_size_box_kld_no_crash(self) -> None: + pred = torch.tensor([[5.0, 5.0, 0.0, 0.0, 0.5]]) + target = torch.tensor([[5.0, 5.0, 3.0, 2.0, 0.5]]) + loss = kld_loss(pred, target) + assert torch.isfinite(loss).all() + + def test_zero_size_box_probiou_no_crash(self) -> None: + pred = torch.tensor([[5.0, 5.0, 0.0, 0.0, 0.5]]) + target = torch.tensor([[5.0, 5.0, 3.0, 2.0, 0.5]]) + score = probiou(pred, target) + assert torch.isfinite(score).all() + + def test_very_large_boxes(self) -> None: + pred = torch.tensor([[500.0, 500.0, 1000.0, 800.0, 0.5]]) + target = torch.tensor([[500.0, 500.0, 1000.0, 800.0, 0.5]]) + assert gwd_loss(pred, target).item() < 0.01 + assert kld_loss(pred, target).item() < 0.01 + assert probiou(pred, target).item() > 0.99 + + def test_single_element_tensors(self) -> None: + pred = torch.tensor([5.0, 5.0, 3.0, 2.0, 0.5]).unsqueeze(0) + target = torch.tensor([5.0, 5.0, 3.0, 2.0, 0.5]).unsqueeze(0) + assert gwd_loss(pred, target).shape == (1,) + assert kld_loss(pred, target).shape == (1,) + assert probiou(pred, target).shape == (1,)